From 88413d8f156e6ee961342ddfcf40b6dbed752927 Mon Sep 17 00:00:00 2001
From: piotrmp <piotr.m.przybyla@gmail.com>
Date: Thu, 24 Nov 2022 08:52:56 +0100
Subject: [PATCH] Added larger models and GPU support for training.

---
 src/lambo/examples/run_training.py            | 12 +++++++++---
 src/lambo/examples/run_training_pretrained.py | 17 +++++++++++------
 src/lambo/learning/train.py                   |  2 +-
 3 files changed, 21 insertions(+), 10 deletions(-)

diff --git a/src/lambo/examples/run_training.py b/src/lambo/examples/run_training.py
index 4376970..7bac54d 100644
--- a/src/lambo/examples/run_training.py
+++ b/src/lambo/examples/run_training.py
@@ -1,14 +1,20 @@
 """
 Script for training LAMBO models using UD data
 """
+import sys
+
 import importlib_resources as resources
 from pathlib import Path
 
+import torch
 from lambo.learning.train import train_new_and_save
 
 if __name__=='__main__':
-    treebanks = Path.home() / 'PATH-TO/ud-treebanks-v2.9/'
-    outpath = Path.home() / 'PATH-TO/models/'
+    treebanks = Path(sys.argv[1]) #Path.home() / 'PATH-TO/ud-treebanks-v2.9/'
+    outpath = Path(sys.argv[2]) #Path.home() / 'PATH-TO/models/'
+
+    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+    
     # Read available languages
     languages_file_str = resources.read_text('lambo.resources', 'languages.txt', encoding='utf-8', errors='strict')
     languages = [line.split(' ')[0] for line in languages_file_str.split('\n')]
@@ -19,4 +25,4 @@ if __name__=='__main__':
             continue
         print(str(i) + '/' + str(len(languages)) + '========== ' + language + ' ==========')
         inpath = treebanks / language
-        train_new_and_save('LAMBO-BILSTM', inpath, outpath)
+        train_new_and_save('LAMBO-BILSTM', inpath, outpath, device)
diff --git a/src/lambo/examples/run_training_pretrained.py b/src/lambo/examples/run_training_pretrained.py
index f2dc8f2..33c3ea3 100644
--- a/src/lambo/examples/run_training_pretrained.py
+++ b/src/lambo/examples/run_training_pretrained.py
@@ -1,22 +1,27 @@
 """
 Script for training LAMBO models using UD data from pretrained
 """
-
+import sys
 from pathlib import Path
 
 import importlib_resources as resources
+import torch
 
 from lambo.learning.train import train_new_and_save, train_pretrained_and_save
 
 if __name__=='__main__':
-    treebanks = Path.home() / 'PATH-TO/ud-treebanks-v2.9/'
-    outpath = Path.home() / 'PATH-TO/models/full/'
-    pretrained_path = Path.home() / 'PATH-TO/models/pretrained/'
+    treebanks = Path(sys.argv[1]) #Path.home() / 'PATH-TO/ud-treebanks-v2.9/'
+    outpath = Path(sys.argv[2]) #Path.home() / 'PATH-TO/models/full/'
+    pretrained_path = Path(sys.argv[3]) #Path.home() / 'PATH-TO/models/pretrained/'
+
+    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
     
     languages_file_str = resources.read_text('lambo.resources', 'languages.txt', encoding='utf-8', errors='strict')
     lines = [line.strip() for line in languages_file_str.split('\n') if not line[0] == '#']
     
     for i, line in enumerate(lines):
+        if i % 5 != int(sys.argv[4]):
+            continue
         parts = line.split()
         model = parts[0]
         language = parts[1]
@@ -25,6 +30,6 @@ if __name__=='__main__':
         print(str(i) + '/' + str(len(lines)) + '========== ' + model + ' ==========')
         inpath = treebanks / model
         if language != '?':
-            train_pretrained_and_save(language, inpath, outpath, pretrained_path)
+            train_pretrained_and_save(language, inpath, outpath, pretrained_path, device)
         else:
-            train_new_and_save('LAMBO-BILSTM', inpath, outpath)
+            train_new_and_save('LAMBO-BILSTM', inpath, outpath, device)
diff --git a/src/lambo/learning/train.py b/src/lambo/learning/train.py
index dfab088..203eafb 100644
--- a/src/lambo/learning/train.py
+++ b/src/lambo/learning/train.py
@@ -199,7 +199,7 @@ def tune(model, train_dataloader, test_dataloader, epochs, device='cpu'):
     optimizer = Adam(model.parameters(), lr=learning_rate)
     
     print("Training")
-    test_loop(test_dataloader, model)
+    test_loop(test_dataloader, model, device)
     for t in range(epochs):
         print(f"Epoch {t + 1}\n-------------------------------")
         train_loop(train_dataloader, model, optimizer, device)
-- 
GitLab