From 9fcf153282a63130bebd940724154a6c7bb2b088 Mon Sep 17 00:00:00 2001
From: piotrmp <piotr.m.przybyla@gmail.com>
Date: Tue, 13 Dec 2022 15:01:32 +0100
Subject: [PATCH] Subword splitting implementation.

---
 src/lambo/examples/run_training_splitting.py | 34 ++++++++++++++++++++
 1 file changed, 34 insertions(+)
 create mode 100644 src/lambo/examples/run_training_splitting.py

diff --git a/src/lambo/examples/run_training_splitting.py b/src/lambo/examples/run_training_splitting.py
new file mode 100644
index 0000000..1e60319
--- /dev/null
+++ b/src/lambo/examples/run_training_splitting.py
@@ -0,0 +1,34 @@
+"""
+Script for training LAMBO subword splitting models using UD data from pretrained
+"""
+import sys
+from pathlib import Path
+
+import importlib_resources as resources
+import torch
+
+from lambo.learning.train import train_new_and_save, train_pretrained_and_save
+from lambo.segmenter.lambo import Lambo
+from lambo.subwords.train import train_subwords_and_save
+
+if __name__=='__main__':
+    treebanks = Path.home() / 'data/lambo/ud-treebanks-v2.11/'
+    outpath = Path.home() / 'data/lambo/models/subword/'
+    segmenting_path = Path.home() / 'data/lambo/models/full211-s/'
+
+    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+    
+    languages_file_str = resources.read_text('lambo.resources', 'languages.txt', encoding='utf-8', errors='strict')
+    lines = [line.strip() for line in languages_file_str.split('\n') if not line[0] == '#']
+    
+    for i, line in enumerate(lines):
+        #if len(sys.argv)>4 and i % 5 != int(sys.argv[4]):
+        #    continue
+        parts = line.split()
+        model = parts[0]
+        if (outpath / (model + '_subwords.pth')).exists():
+            continue
+        print(str(i) + '/' + str(len(lines)) + '========== ' + model + ' ==========')
+        inpath = treebanks / model
+        segmenter = Lambo.from_path(segmenting_path, model)
+        train_subwords_and_save('LAMBO-BILSTM', treebanks/ model, outpath, segmenter, epochs=20)
-- 
GitLab