diff --git a/src/lambo/examples/run_pretraining.py b/src/lambo/examples/run_pretraining.py
index a1a14c51989a7be0cbb15dce6420a4397c7b26fc..45531bcc32b48e4b5ce261834522b38271b9d878 100644
--- a/src/lambo/examples/run_pretraining.py
+++ b/src/lambo/examples/run_pretraining.py
@@ -6,6 +6,7 @@ import importlib_resources as resources
 from pathlib import Path
 
 import torch
+import sys
 
 from lambo.learning.dictionary import create_dictionary
 from lambo.learning.model_pretraining import LamboPretrainingNetwork
@@ -15,11 +16,13 @@ from lambo.learning.train import pretrain
 from lambo.utils.oscar import read_jsonl_to_documents, download_archive1_from_oscar
 
 if __name__=='__main__':
-    outpath = Path.home() / 'PATH-TO/models/pretrained/'
-    tmppath = Path.home() / 'PATH-TO/tmp/tmp.jsonl.gz'
+    outpath = sys.argv[1] #Path.home() / 'PATH-TO/models/pretrained/'
+    tmppath = sys.argv[2] #Path.home() / 'PATH-TO/tmp/tmp.jsonl.gz'
     # These need to be filled ine before running. OSCAR is avaialable on request.
-    OSCAR_LOGIN = ''
-    OSCAR_PASSWORD = ''
+    OSCAR_LOGIN = sys.argv[3]
+    OSCAR_PASSWORD = sys.argv[4]
+
+    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
     
     languages_file_str = resources.read_text('lambo.resources', 'languages.txt', encoding='utf-8', errors='strict')
     languages = [line.split(' ')[1] for line in languages_file_str.split('\n') if
@@ -50,7 +53,7 @@ if __name__=='__main__':
             Xchars, Xutfs, Xmasks, Yvecs = encode_pretraining([document_train], dict, CONTEXT_LEN)
             _, train_dataloader, test_dataloader = prepare_dataloaders_pretraining([document_train],
                                                                                    [document_test], CONTEXT_LEN, 32, dict)
-            pretrain(model, train_dataloader, test_dataloader, 1)
+            pretrain(model, train_dataloader, test_dataloader, 1, device)
         torch.save(model, outpath / ('oscar_' + language + '.pth'))
         with open(outpath / ('oscar_' + language + '.dict'), "w") as file1:
             file1.writelines([x + '\t' + str(dict[x]) + '\n' for x in dict])
diff --git a/src/lambo/learning/train.py b/src/lambo/learning/train.py
index f33c88bfb1c39504963e00b97d96e5339e1fdc05..a2cfd2c77418a517a20635501ffe6906e027063d 100644
--- a/src/lambo/learning/train.py
+++ b/src/lambo/learning/train.py
@@ -9,17 +9,19 @@ from lambo.learning.preprocessing_dict import utf_category_dictionary, prepare_d
 from lambo.utils.ud_reader import read_treebank
 
 
-def train_loop(dataloader, model, optimizer):
+def train_loop(dataloader, model, optimizer, device='cpu'):
     """
     Training loop.
     
     :param dataloader: dataloader with training data
     :param model: model to be optimised
     :param optimizer: optimiser used
+    :param device: the device to use for computation
     :return: no value returned
     """
     size = len(dataloader.dataset)
     for batch, XY in enumerate(dataloader):
+        XY = [xy.to(device) for xy in XY]
         Xs = XY[:-1]
         Y = XY[-1]
         pred = model(*Xs)
@@ -32,12 +34,13 @@ def train_loop(dataloader, model, optimizer):
             print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
 
 
-def test_loop(dataloader, model):
+def test_loop(dataloader, model, device='cpu'):
     """
     Test loop.
     
     :param dataloader: dataloader with test data
     :param model: model to be tested
+    :param device: the device to use for computation
     :return: no value returned
     """
     num_batches = len(dataloader)
@@ -46,6 +49,7 @@ def test_loop(dataloader, model):
     size = [0, 0, 0, 0]
     with torch.no_grad():
         for XY in dataloader:
+            XY = [xy.to(device) for xy in XY]
             Xs = XY[:-1]
             Y = XY[-1]
             pred = model(*Xs)
@@ -63,12 +67,13 @@ def test_loop(dataloader, model):
         f"Test Error: \n Accuracy chars: {(100 * (correct[0] / size[0])):>5f}%, tokens: {(100 * (correct[1] / size[1])):>5f}%, mwtokens: {(100 * (correct[2] / size[2])):>5f}%, sentences: {(100 * (correct[3] / size[3])):>5f}%, Avg loss: {test_loss:>8f} \n")
 
 
-def test_loop_pretraining(dataloader, model):
+def test_loop_pretraining(dataloader, model, device='cpu'):
     """
     Test loop for pretraining.
 
     :param dataloader: dataloader with test data
     :param model: model to be tested
+    :param device: the device to use for computation
     :return: no value returned
     """
     num_batches = len(dataloader)
@@ -77,6 +82,7 @@ def test_loop_pretraining(dataloader, model):
     size = [0, 0]
     with torch.no_grad():
         for XY in dataloader:
+            XY = [xy.to(device) for xy in XY]
             Xs = XY[:-1]
             Y = XY[-1]
             pred = model(*Xs)
@@ -93,7 +99,7 @@ def test_loop_pretraining(dataloader, model):
         f"Test Error: \n Accuracy nontrivial: {(100 * (correct[0] / size[0])):>5f}%, trivial: {(100 * (correct[1] / size[1])):>5f}%, Avg loss: {test_loss:>8f} \n")
 
 
-def train_new_and_save(model_name, treebank_path, save_path, epochs=10):
+def train_new_and_save(model_name, treebank_path, save_path, epochs=10, device='cpu'):
     """
     Train a new LAMBO model and save it in filesystem.
     
@@ -101,6 +107,7 @@ def train_new_and_save(model_name, treebank_path, save_path, epochs=10):
     :param treebank_path: path to the treebank training data
     :param save_path: path to save the generated model
     :param epochs: number of epochs to run for (default: 10)
+    :param device: the device to use for computation
     :return: no value returned
     """
     if model_name not in ['LAMBO-BILSTM']:
@@ -120,7 +127,7 @@ def train_new_and_save(model_name, treebank_path, save_path, epochs=10):
                                                                            BATCH_SIZE)
     model = LamboNetwork(MAX_LEN, dict, len(utf_category_dictionary))
     
-    tune(model, train_dataloader, test_dataloader, epochs)
+    tune(model, train_dataloader, test_dataloader, epochs, device)
     
     print("Saving")
     torch.save(model, save_path / (treebank_path.name + '.pth'))
@@ -128,7 +135,7 @@ def train_new_and_save(model_name, treebank_path, save_path, epochs=10):
         file1.writelines([x + '\t' + str(dict[x]) + '\n' for x in dict])
 
 
-def train_pretrained_and_save(language, treebank_path, save_path, pretrained_path, epochs=10):
+def train_pretrained_and_save(language, treebank_path, save_path, pretrained_path, epochs=10, device='cpu'):
     """
     Train a new LAMBO model, staring from pretrained, and save it in filesystem.
 
@@ -137,6 +144,7 @@ def train_pretrained_and_save(language, treebank_path, save_path, pretrained_pat
     :param save_path: path to save the generated model
     :param pretrained_path: path to the pretraining models
     :param epochs: number of epochs to run for (default: 10)
+    :param device: the device to use for computation
     :return: no value returned
     """
     print("Loading pretrained model")
@@ -166,7 +174,7 @@ def train_pretrained_and_save(language, treebank_path, save_path, pretrained_pat
                                                                            MAX_LEN,
                                                                            BATCH_SIZE, dict=dict)
     
-    tune(model, train_dataloader, test_dataloader, epochs)
+    tune(model, train_dataloader, test_dataloader, epochs, device)
     
     print("Saving")
     torch.save(model, save_path / (treebank_path.name + '.pth'))
@@ -174,7 +182,7 @@ def train_pretrained_and_save(language, treebank_path, save_path, pretrained_pat
         file1.writelines([x + '\t' + str(dict[x]) + '\n' for x in dict])
 
 
-def tune(model, train_dataloader, test_dataloader, epochs):
+def tune(model, train_dataloader, test_dataloader, epochs, device='cpu'):
     """
     Tune an existing LAMBO model with the provided data
     
@@ -182,9 +190,11 @@ def tune(model, train_dataloader, test_dataloader, epochs):
     :param train_dataloader: dataloader for training data
     :param test_dataloader: dataloader for test data
     :param epochs: number of epochs to run for
+    :param device: the device to use for computation
     :return: no value returned
     """
     print("Preparing training")
+    model.to(device)
     learning_rate = 1e-3
     optimizer = Adam(model.parameters(), lr=learning_rate)
     
@@ -192,11 +202,11 @@ def tune(model, train_dataloader, test_dataloader, epochs):
     test_loop(test_dataloader, model)
     for t in range(epochs):
         print(f"Epoch {t + 1}\n-------------------------------")
-        train_loop(train_dataloader, model, optimizer)
-        test_loop(test_dataloader, model)
+        train_loop(train_dataloader, model, optimizer, device)
+        test_loop(test_dataloader, model, device)
 
 
-def pretrain(model, train_dataloader, test_dataloader, epochs):
+def pretrain(model, train_dataloader, test_dataloader, epochs, device='cpu'):
     """
     Tune an existing LAMBO pretraining model with the provided data
     
@@ -204,9 +214,11 @@ def pretrain(model, train_dataloader, test_dataloader, epochs):
     :param train_dataloader: dataloader for training data
     :param test_dataloader: dataloader for test data
     :param epochs: number of epochs to run for
+    :param device: the device to use for computation
     :return: no value returned
     """
     print("Preparing pretraining")
+    model.to(device)
     learning_rate = 1e-3
     optimizer = Adam(model.parameters(), lr=learning_rate)
     
@@ -214,5 +226,5 @@ def pretrain(model, train_dataloader, test_dataloader, epochs):
     test_loop_pretraining(test_dataloader, model)
     for t in range(epochs):
         print(f"Epoch {t + 1}\n-------------------------------")
-        train_loop(train_dataloader, model, optimizer)
-        test_loop_pretraining(test_dataloader, model)
+        train_loop(train_dataloader, model, optimizer, device)
+        test_loop_pretraining(test_dataloader, model, device)