From 88413d8f156e6ee961342ddfcf40b6dbed752927 Mon Sep 17 00:00:00 2001 From: piotrmp <piotr.m.przybyla@gmail.com> Date: Thu, 24 Nov 2022 08:52:56 +0100 Subject: [PATCH] Added larger models and GPU support for training. --- src/lambo/examples/run_training.py | 12 +++++++++--- src/lambo/examples/run_training_pretrained.py | 17 +++++++++++------ src/lambo/learning/train.py | 2 +- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/lambo/examples/run_training.py b/src/lambo/examples/run_training.py index 4376970..7bac54d 100644 --- a/src/lambo/examples/run_training.py +++ b/src/lambo/examples/run_training.py @@ -1,14 +1,20 @@ """ Script for training LAMBO models using UD data """ +import sys + import importlib_resources as resources from pathlib import Path +import torch from lambo.learning.train import train_new_and_save if __name__=='__main__': - treebanks = Path.home() / 'PATH-TO/ud-treebanks-v2.9/' - outpath = Path.home() / 'PATH-TO/models/' + treebanks = Path(sys.argv[1]) #Path.home() / 'PATH-TO/ud-treebanks-v2.9/' + outpath = Path(sys.argv[2]) #Path.home() / 'PATH-TO/models/' + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # Read available languages languages_file_str = resources.read_text('lambo.resources', 'languages.txt', encoding='utf-8', errors='strict') languages = [line.split(' ')[0] for line in languages_file_str.split('\n')] @@ -19,4 +25,4 @@ if __name__=='__main__': continue print(str(i) + '/' + str(len(languages)) + '========== ' + language + ' ==========') inpath = treebanks / language - train_new_and_save('LAMBO-BILSTM', inpath, outpath) + train_new_and_save('LAMBO-BILSTM', inpath, outpath, device) diff --git a/src/lambo/examples/run_training_pretrained.py b/src/lambo/examples/run_training_pretrained.py index f2dc8f2..33c3ea3 100644 --- a/src/lambo/examples/run_training_pretrained.py +++ b/src/lambo/examples/run_training_pretrained.py @@ -1,22 +1,27 @@ """ Script for training LAMBO models using UD data from pretrained """ - +import sys from pathlib import Path import importlib_resources as resources +import torch from lambo.learning.train import train_new_and_save, train_pretrained_and_save if __name__=='__main__': - treebanks = Path.home() / 'PATH-TO/ud-treebanks-v2.9/' - outpath = Path.home() / 'PATH-TO/models/full/' - pretrained_path = Path.home() / 'PATH-TO/models/pretrained/' + treebanks = Path(sys.argv[1]) #Path.home() / 'PATH-TO/ud-treebanks-v2.9/' + outpath = Path(sys.argv[2]) #Path.home() / 'PATH-TO/models/full/' + pretrained_path = Path(sys.argv[3]) #Path.home() / 'PATH-TO/models/pretrained/' + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') languages_file_str = resources.read_text('lambo.resources', 'languages.txt', encoding='utf-8', errors='strict') lines = [line.strip() for line in languages_file_str.split('\n') if not line[0] == '#'] for i, line in enumerate(lines): + if i % 5 != int(sys.argv[4]): + continue parts = line.split() model = parts[0] language = parts[1] @@ -25,6 +30,6 @@ if __name__=='__main__': print(str(i) + '/' + str(len(lines)) + '========== ' + model + ' ==========') inpath = treebanks / model if language != '?': - train_pretrained_and_save(language, inpath, outpath, pretrained_path) + train_pretrained_and_save(language, inpath, outpath, pretrained_path, device) else: - train_new_and_save('LAMBO-BILSTM', inpath, outpath) + train_new_and_save('LAMBO-BILSTM', inpath, outpath, device) diff --git a/src/lambo/learning/train.py b/src/lambo/learning/train.py index dfab088..203eafb 100644 --- a/src/lambo/learning/train.py +++ b/src/lambo/learning/train.py @@ -199,7 +199,7 @@ def tune(model, train_dataloader, test_dataloader, epochs, device='cpu'): optimizer = Adam(model.parameters(), lr=learning_rate) print("Training") - test_loop(test_dataloader, model) + test_loop(test_dataloader, model, device) for t in range(epochs): print(f"Epoch {t + 1}\n-------------------------------") train_loop(train_dataloader, model, optimizer, device) -- GitLab