Skip to content
Snippets Groups Projects
Commit 88413d8f authored by piotrmp's avatar piotrmp
Browse files

Added larger models and GPU support for training.

parent 7fffa595
No related branches found
No related tags found
1 merge request!1Migration to UD 2.11
""" """
Script for training LAMBO models using UD data Script for training LAMBO models using UD data
""" """
import sys
import importlib_resources as resources import importlib_resources as resources
from pathlib import Path from pathlib import Path
import torch
from lambo.learning.train import train_new_and_save from lambo.learning.train import train_new_and_save
if __name__=='__main__': if __name__=='__main__':
treebanks = Path.home() / 'PATH-TO/ud-treebanks-v2.9/' treebanks = Path(sys.argv[1]) #Path.home() / 'PATH-TO/ud-treebanks-v2.9/'
outpath = Path.home() / 'PATH-TO/models/' outpath = Path(sys.argv[2]) #Path.home() / 'PATH-TO/models/'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Read available languages # Read available languages
languages_file_str = resources.read_text('lambo.resources', 'languages.txt', encoding='utf-8', errors='strict') languages_file_str = resources.read_text('lambo.resources', 'languages.txt', encoding='utf-8', errors='strict')
languages = [line.split(' ')[0] for line in languages_file_str.split('\n')] languages = [line.split(' ')[0] for line in languages_file_str.split('\n')]
...@@ -19,4 +25,4 @@ if __name__=='__main__': ...@@ -19,4 +25,4 @@ if __name__=='__main__':
continue continue
print(str(i) + '/' + str(len(languages)) + '========== ' + language + ' ==========') print(str(i) + '/' + str(len(languages)) + '========== ' + language + ' ==========')
inpath = treebanks / language inpath = treebanks / language
train_new_and_save('LAMBO-BILSTM', inpath, outpath) train_new_and_save('LAMBO-BILSTM', inpath, outpath, device)
""" """
Script for training LAMBO models using UD data from pretrained Script for training LAMBO models using UD data from pretrained
""" """
import sys
from pathlib import Path from pathlib import Path
import importlib_resources as resources import importlib_resources as resources
import torch
from lambo.learning.train import train_new_and_save, train_pretrained_and_save from lambo.learning.train import train_new_and_save, train_pretrained_and_save
if __name__=='__main__': if __name__=='__main__':
treebanks = Path.home() / 'PATH-TO/ud-treebanks-v2.9/' treebanks = Path(sys.argv[1]) #Path.home() / 'PATH-TO/ud-treebanks-v2.9/'
outpath = Path.home() / 'PATH-TO/models/full/' outpath = Path(sys.argv[2]) #Path.home() / 'PATH-TO/models/full/'
pretrained_path = Path.home() / 'PATH-TO/models/pretrained/' pretrained_path = Path(sys.argv[3]) #Path.home() / 'PATH-TO/models/pretrained/'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
languages_file_str = resources.read_text('lambo.resources', 'languages.txt', encoding='utf-8', errors='strict') languages_file_str = resources.read_text('lambo.resources', 'languages.txt', encoding='utf-8', errors='strict')
lines = [line.strip() for line in languages_file_str.split('\n') if not line[0] == '#'] lines = [line.strip() for line in languages_file_str.split('\n') if not line[0] == '#']
for i, line in enumerate(lines): for i, line in enumerate(lines):
if i % 5 != int(sys.argv[4]):
continue
parts = line.split() parts = line.split()
model = parts[0] model = parts[0]
language = parts[1] language = parts[1]
...@@ -25,6 +30,6 @@ if __name__=='__main__': ...@@ -25,6 +30,6 @@ if __name__=='__main__':
print(str(i) + '/' + str(len(lines)) + '========== ' + model + ' ==========') print(str(i) + '/' + str(len(lines)) + '========== ' + model + ' ==========')
inpath = treebanks / model inpath = treebanks / model
if language != '?': if language != '?':
train_pretrained_and_save(language, inpath, outpath, pretrained_path) train_pretrained_and_save(language, inpath, outpath, pretrained_path, device)
else: else:
train_new_and_save('LAMBO-BILSTM', inpath, outpath) train_new_and_save('LAMBO-BILSTM', inpath, outpath, device)
...@@ -199,7 +199,7 @@ def tune(model, train_dataloader, test_dataloader, epochs, device='cpu'): ...@@ -199,7 +199,7 @@ def tune(model, train_dataloader, test_dataloader, epochs, device='cpu'):
optimizer = Adam(model.parameters(), lr=learning_rate) optimizer = Adam(model.parameters(), lr=learning_rate)
print("Training") print("Training")
test_loop(test_dataloader, model) test_loop(test_dataloader, model, device)
for t in range(epochs): for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------") print(f"Epoch {t + 1}\n-------------------------------")
train_loop(train_dataloader, model, optimizer, device) train_loop(train_dataloader, model, optimizer, device)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment