diff --git a/src/lambo/examples/run_training.py b/src/lambo/examples/run_training.py index 437697088ee16c13d73084f16547208394f8f7c8..7bac54dee19906cbad72e25b79acb3b6b710edcb 100644 --- a/src/lambo/examples/run_training.py +++ b/src/lambo/examples/run_training.py @@ -1,14 +1,20 @@ """ Script for training LAMBO models using UD data """ +import sys + import importlib_resources as resources from pathlib import Path +import torch from lambo.learning.train import train_new_and_save if __name__=='__main__': - treebanks = Path.home() / 'PATH-TO/ud-treebanks-v2.9/' - outpath = Path.home() / 'PATH-TO/models/' + treebanks = Path(sys.argv[1]) #Path.home() / 'PATH-TO/ud-treebanks-v2.9/' + outpath = Path(sys.argv[2]) #Path.home() / 'PATH-TO/models/' + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # Read available languages languages_file_str = resources.read_text('lambo.resources', 'languages.txt', encoding='utf-8', errors='strict') languages = [line.split(' ')[0] for line in languages_file_str.split('\n')] @@ -19,4 +25,4 @@ if __name__=='__main__': continue print(str(i) + '/' + str(len(languages)) + '========== ' + language + ' ==========') inpath = treebanks / language - train_new_and_save('LAMBO-BILSTM', inpath, outpath) + train_new_and_save('LAMBO-BILSTM', inpath, outpath, device) diff --git a/src/lambo/examples/run_training_pretrained.py b/src/lambo/examples/run_training_pretrained.py index f2dc8f2d5723cf4eaafc971fdee1c83c5559ade6..33c3ea3078a965a7d03ac99b40e5e5c5df1a775c 100644 --- a/src/lambo/examples/run_training_pretrained.py +++ b/src/lambo/examples/run_training_pretrained.py @@ -1,22 +1,27 @@ """ Script for training LAMBO models using UD data from pretrained """ - +import sys from pathlib import Path import importlib_resources as resources +import torch from lambo.learning.train import train_new_and_save, train_pretrained_and_save if __name__=='__main__': - treebanks = Path.home() / 'PATH-TO/ud-treebanks-v2.9/' - outpath = Path.home() / 'PATH-TO/models/full/' - pretrained_path = Path.home() / 'PATH-TO/models/pretrained/' + treebanks = Path(sys.argv[1]) #Path.home() / 'PATH-TO/ud-treebanks-v2.9/' + outpath = Path(sys.argv[2]) #Path.home() / 'PATH-TO/models/full/' + pretrained_path = Path(sys.argv[3]) #Path.home() / 'PATH-TO/models/pretrained/' + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') languages_file_str = resources.read_text('lambo.resources', 'languages.txt', encoding='utf-8', errors='strict') lines = [line.strip() for line in languages_file_str.split('\n') if not line[0] == '#'] for i, line in enumerate(lines): + if i % 5 != int(sys.argv[4]): + continue parts = line.split() model = parts[0] language = parts[1] @@ -25,6 +30,6 @@ if __name__=='__main__': print(str(i) + '/' + str(len(lines)) + '========== ' + model + ' ==========') inpath = treebanks / model if language != '?': - train_pretrained_and_save(language, inpath, outpath, pretrained_path) + train_pretrained_and_save(language, inpath, outpath, pretrained_path, device) else: - train_new_and_save('LAMBO-BILSTM', inpath, outpath) + train_new_and_save('LAMBO-BILSTM', inpath, outpath, device) diff --git a/src/lambo/learning/train.py b/src/lambo/learning/train.py index dfab088181bf50077974cd13057d7dcd4db5c8fa..203eafb8dfe8b4bdf6c1cc868241133b6d977b30 100644 --- a/src/lambo/learning/train.py +++ b/src/lambo/learning/train.py @@ -199,7 +199,7 @@ def tune(model, train_dataloader, test_dataloader, epochs, device='cpu'): optimizer = Adam(model.parameters(), lr=learning_rate) print("Training") - test_loop(test_dataloader, model) + test_loop(test_dataloader, model, device) for t in range(epochs): print(f"Epoch {t + 1}\n-------------------------------") train_loop(train_dataloader, model, optimizer, device)