Skip to content
Snippets Groups Projects
Commit 88413d8f authored by piotrmp's avatar piotrmp
Browse files

Added larger models and GPU support for training.

parent 7fffa595
No related branches found
No related tags found
1 merge request!1Migration to UD 2.11
"""
Script for training LAMBO models using UD data
"""
import sys
import importlib_resources as resources
from pathlib import Path
import torch
from lambo.learning.train import train_new_and_save
if __name__=='__main__':
treebanks = Path.home() / 'PATH-TO/ud-treebanks-v2.9/'
outpath = Path.home() / 'PATH-TO/models/'
treebanks = Path(sys.argv[1]) #Path.home() / 'PATH-TO/ud-treebanks-v2.9/'
outpath = Path(sys.argv[2]) #Path.home() / 'PATH-TO/models/'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Read available languages
languages_file_str = resources.read_text('lambo.resources', 'languages.txt', encoding='utf-8', errors='strict')
languages = [line.split(' ')[0] for line in languages_file_str.split('\n')]
......@@ -19,4 +25,4 @@ if __name__=='__main__':
continue
print(str(i) + '/' + str(len(languages)) + '========== ' + language + ' ==========')
inpath = treebanks / language
train_new_and_save('LAMBO-BILSTM', inpath, outpath)
train_new_and_save('LAMBO-BILSTM', inpath, outpath, device)
"""
Script for training LAMBO models using UD data from pretrained
"""
import sys
from pathlib import Path
import importlib_resources as resources
import torch
from lambo.learning.train import train_new_and_save, train_pretrained_and_save
if __name__=='__main__':
treebanks = Path.home() / 'PATH-TO/ud-treebanks-v2.9/'
outpath = Path.home() / 'PATH-TO/models/full/'
pretrained_path = Path.home() / 'PATH-TO/models/pretrained/'
treebanks = Path(sys.argv[1]) #Path.home() / 'PATH-TO/ud-treebanks-v2.9/'
outpath = Path(sys.argv[2]) #Path.home() / 'PATH-TO/models/full/'
pretrained_path = Path(sys.argv[3]) #Path.home() / 'PATH-TO/models/pretrained/'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
languages_file_str = resources.read_text('lambo.resources', 'languages.txt', encoding='utf-8', errors='strict')
lines = [line.strip() for line in languages_file_str.split('\n') if not line[0] == '#']
for i, line in enumerate(lines):
if i % 5 != int(sys.argv[4]):
continue
parts = line.split()
model = parts[0]
language = parts[1]
......@@ -25,6 +30,6 @@ if __name__=='__main__':
print(str(i) + '/' + str(len(lines)) + '========== ' + model + ' ==========')
inpath = treebanks / model
if language != '?':
train_pretrained_and_save(language, inpath, outpath, pretrained_path)
train_pretrained_and_save(language, inpath, outpath, pretrained_path, device)
else:
train_new_and_save('LAMBO-BILSTM', inpath, outpath)
train_new_and_save('LAMBO-BILSTM', inpath, outpath, device)
......@@ -199,7 +199,7 @@ def tune(model, train_dataloader, test_dataloader, epochs, device='cpu'):
optimizer = Adam(model.parameters(), lr=learning_rate)
print("Training")
test_loop(test_dataloader, model)
test_loop(test_dataloader, model, device)
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train_loop(train_dataloader, model, optimizer, device)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment