diff --git a/README.md b/README.md
index 615e1aaaec75d4618738c0d2a2f475e800809ac8..2673c09ca91174a3a2a3e5d1dec2b4d7703cc2cf 100644
--- a/README.md
+++ b/README.md
@@ -3,8 +3,8 @@
 LAMBO (Layered Approach to Multi-level BOundary identification) is a segmentation tool that is able to divide text on several levels:
 1. Dividing the original text into *turns* according to the provided list of separators. Turns can correspond to seperate utterences in a dialogue, paragraphs in a continuous text, etc.
 2. Splitting each turn into *sentences*.
-3. Finding *tokens* in sentences. Most tokens correspond to words, but multi-word tokens are also detected.
-LAMBO also supports special tokens that should be kept separate regardless of context, such as emojis and pause markers.
+3. Finding *tokens* in sentences. Most tokens correspond to words. LAMBO also supports special tokens that should be kept separate regardless of context, such as emojis and pause markers.
+4. Splitting tokens that are detected to be *multi-word* into *sub-words* (for selected languages).
 
 LAMBO is a machine learning model, which means it was trained to recognise boundaries of tokens and sentences from real-world text. It is implemented as a [PyTorch](https://pytorch.org/) deep neural network, including embeddings and recurrent layers operating at the character level. At the same time, LAMBO contains rule-based elements to allow a user to easily adjust it to one's needs, e.g. by adding custom special tokens or turn division markers.
 
@@ -14,6 +14,12 @@ LAMBO currently includes models trained on 98 corpora in 53 languages. The full
 - simple LAMBO, trained on the UD corpus
 - pretrained LAMBO, same as above, but starting from weights pre-trained on unsupervised masked character prediction using multilingual corpora from [OSCAR](https://oscar-corpus.com/).
 
+For 49 of the corpora, a subword splitting model is available. Note that different types of multi-word tokens exist in different languages:
+- those that are a concatenation of their subwords, as in English: *don't* = *do* + *n't*
+- those that differ from their subwords, as in Spanish: *al* = *a* + *el*
+
+The availability and type of subword splitting model depends on the training data (i.e., UD treebank).
+
 ## Installation
 
 Installation of LAMBO is easy.
@@ -50,9 +56,13 @@ Alternatively, you can select a specific model by defining LAMBO variant (`LAMBO
 lambo = Lambo.get('LAMBO-UD_Polish-PDB')
 ```
 
+There are two optional arguments to the `get()` function:
+- You can opt out of using subword splitter by providing `with_splitter=False`.
+- You can point to a specific pyTorch device by providing `device` parameter, for example `device=torch.device('cuda')` to enable GPU acceleration.
+
 Once the model is ready, you can perform segmentation of a given text:
 ```
-text = "A simple sentence might be enough... But some of us just ❤️ emojis. They should be tokens even when (yy) containing many characters, such as 👍🏿."
+text = "Simple sentences can't be enough... Some of us just ❤️ emojis. They should be tokens even when (yy) containing many characters, such as 👍🏿."
 document = lambo.segment(text)
 ```
 
@@ -67,12 +77,26 @@ for turn in document.turns:
         formatted = ''
         for token in sentence.tokens:
             if token.is_multi_word:
-                formatted += '((' + token.text + '))'
+               formatted += '(' + token.text+ '=' + '-'.join(token.subwords) + ')'
             else:
-                formatted += '(' + token.text + ')'
+               formatted += '(' + token.text + ')'
         print('TOKENS: ' + formatted)
 ```
-Chech out if the special tokens, i.e. emojis and pause (`(yy)`) were properly recognised.
+This should produce the following output:
+```
+======= TURN =======
+TEXT: Simple sentences can't be enough... Some of us just ❤️ emojis. They should be tokens even when (yy) ...
+======= SENTENCE =======
+TEXT: "Simple sentences can't be enough... "
+TOKENS: (Simple)(sentences)(can't=ca-n't)(be)(enough)(...)
+======= SENTENCE =======
+TEXT: "Some of us just ❤️ emojis. "
+TOKENS: (Some)(of)(us)(just)(❤️)(emojis)(.)
+======= SENTENCE =======
+TEXT: "They should be tokens even when (yy) containing many characters, such as 👍🏿."
+TOKENS: (They)(should)(be)(tokens)(even)(when)((yy))(containing)(many)(characters)(,)(such)(as)(👍🏿)(.)
+```
+Note how *can't* was split and the special tokens, i.e. emojis and pause (`(yy)`) were properly recognised.
 
 ## Using LAMBO with COMBO
 
@@ -117,6 +141,7 @@ You don't have to rely on the models trained so far in COMBO. You can use the in
 - `run_training.py` -- train simple LAMBO models. This script was used with [UD treebanks](https://universaldependencies.org/#language-) to generate `LAMBO_no_pretraining` models.
 - `run_pretraining.py` -- pretrain unsupervised LAMBO models. This script was used with [OSCAR](https://oscar-corpus.com/).
 - `run_training_pretrained.py` -- train LAMBO models on UD training data, starting from pretrained models. This script was used to generate `LAMBO` models.
+- `run_training_splitting.py` -- train LAMBO subword splitting models on UD training data. 
 - `run_tuning.py` -- tune existing LAMBO model to fit new data.
 - `run_evaluation.py` -- evaluate existing models using UD gold standard.
 
@@ -134,7 +159,7 @@ If you use LAMBO in your research, please cite it as software:
   author = {{Przyby{\l}a, Piotr}},
   title = {LAMBO: Layered Approach to Multi-level BOundary identification},
   url = {https://gitlab.clarin-pl.eu/syntactic-tools/lambo},
-  version = {1.0.0},
+  version = {2.0.0},
   year = {2022},
 }
 ```
diff --git a/pyproject.toml b/pyproject.toml
index bec18e9855e36514194852d5b9c20b12855b0a8c..f51575c345724adad021d889b42f4e7a9cafeafc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
 
 [project]
 name = "lambo"
-version = "1.0.0"
+version = "2.0.0"
 authors = [
   { name="Piotr Przybyła", email="piotr.przybyla@ipipan.waw.pl" },
 ]
diff --git a/src/lambo/data/token.py b/src/lambo/data/token.py
index d9304b5e275ead2b4d1653a54284f7e8751d8ec6..79d1ef281b9fa3bec285eb8fed170fc416251232 100644
--- a/src/lambo/data/token.py
+++ b/src/lambo/data/token.py
@@ -18,3 +18,12 @@ class Token:
         self.end = end
         self.text = text
         self.is_multi_word = is_multi_word
+        self.subwords = []
+        
+    def addSubword(self, subword):
+        """
+        Add a subword to the token.
+        
+        :param subword: the text of the subword to add
+        """
+        self.subwords.append(subword)
diff --git a/src/lambo/examples/__init__.py b/src/lambo/examples/__init__.py
index 9d2ced9dbbfdaacf3098215b96a3a6fcda4530b3..c80e8a3c6c7da7f6c08ef36ce771709cc97ab774 100644
--- a/src/lambo/examples/__init__.py
+++ b/src/lambo/examples/__init__.py
@@ -6,6 +6,7 @@ This package contains examples of how to use LAMBO in different scenarios:
 * ``run_training.py`` -- train LAMBO models from UD training data
 * ``run_pretraining.py`` -- train pretraining LAMBO models from OSCAR corpora
 * ``run_training_pretrained.py`` -- train LAMBO models on UD training data, starting from pretrained models
+* ``run_training_splitting.py`` -- train LAMBO models for splitting multi-word tokens
 * ``run_tuning.py`` -- tune existing LAMBO model to fit new data
 * ``run_evaluation.py`` -- evaluate existing models using UD gold standard
 """
\ No newline at end of file
diff --git a/src/lambo/examples/run_training_splitting.py b/src/lambo/examples/run_training_splitting.py
new file mode 100644
index 0000000000000000000000000000000000000000..038c17d3477d736c9f6756d3f104454cd037f90c
--- /dev/null
+++ b/src/lambo/examples/run_training_splitting.py
@@ -0,0 +1,33 @@
+"""
+Script for training LAMBO subword splitting models using UD data from pretrained
+"""
+import time, sys
+from pathlib import Path
+
+import importlib_resources as resources
+import torch
+
+from lambo.segmenter.lambo import Lambo
+from lambo.subwords.train import train_subwords_and_save
+
+if __name__ == '__main__':
+    treebanks = Path(sys.argv[1]) #Path.home() / 'PATH-TO/ud-treebanks-v2.11/'
+    outpath = Path(sys.argv[2]) #Path.home() / 'PATH-TO/models/full-subwords/'
+    segmenting_path = Path(sys.argv[3]) #Path.home() / 'PATH-TO/models/full/'
+    
+    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+    
+    languages_file_str = resources.read_text('lambo.resources', 'languages.txt', encoding='utf-8', errors='strict')
+    lines = [line.strip() for line in languages_file_str.split('\n') if not line[0] == '#']
+    
+    start = time.time()
+    for i, line in enumerate(lines):
+        parts = line.split()
+        model = parts[0]
+        if (outpath / (model + '_subwords.pth')).exists():
+            continue
+        print(str(i) + '/' + str(len(lines)) + '========== ' + model + ' ==========')
+        inpath = treebanks / model
+        segmenter = Lambo.from_path(segmenting_path, model)
+        train_subwords_and_save('LAMBO-BILSTM', treebanks / model, outpath, segmenter, epochs=20, device=device)
+    print(str(time.time()-start)+' s.')
diff --git a/src/lambo/examples/run_usage.py b/src/lambo/examples/run_usage.py
index 807d09a55b69a469d16bd213ba4385d168645795..905f032163d0a2d0a705d53937715693db1a6556 100644
--- a/src/lambo/examples/run_usage.py
+++ b/src/lambo/examples/run_usage.py
@@ -9,22 +9,10 @@ if __name__ == '__main__':
     lambo = Lambo.get('Polish')
     
     # Provide text, including pauses (``(yy)``), emojis and turn markers (``<turn>``).
-    text = "Ciemny i jasny (yy) pies biegają 🏴w płytkiej w🅾️dzie... obok 🏴󠁧󠁢󠁷󠁬󠁳󠁿kamienistej😂 plaży.\n\n 😆 To jest następne zdanie <turn>to następna tura."
+    text = "Ciemny i jasny (yy) pies biegają 🏴w płytkiej w🅾️dzie... obok 🏴󠁧󠁢󠁷󠁬󠁳󠁿kamienistej😂 plaży.\n\n 😆 To jest następne zdanie <turn>to byłaby następna tura."
     
     # Perform segmentation
     document = lambo.segment(text)
     
     # Display the results
-    for turn in document.turns:
-        print('======= TURN =======')
-        print('TEXT: ' + turn.text[:100] + '...')
-        for sentence in turn.sentences:
-            print('======= SENTENCE =======')
-            print('TEXT: "' + sentence.text + '"')
-            formatted = ''
-            for token in sentence.tokens:
-                if token.is_multi_word:
-                    formatted += '((' + token.text + '))'
-                else:
-                    formatted += '(' + token.text + ')'
-            print('TOKENS: ' + formatted)
+    document.print()
diff --git a/src/lambo/learning/model.py b/src/lambo/learning/model.py
index 9e69067e5d6f234ea9766b76488284c2faa411a6..de54b29a65ccf9a60cf66f9976c8e6b2a22c4e4e 100644
--- a/src/lambo/learning/model.py
+++ b/src/lambo/learning/model.py
@@ -25,7 +25,7 @@ class LamboNetwork(Module):
         :param max_len: maximum length of an input sequence,
         :param dict: character dictionary
         :param utf_categories_num: number of UTF categories
-        :param pretrained: either ``None`` (for new models) or an instance of ``LamboPretrainingModel`` (if using pretraining data)
+        :param pretrained: either ``None`` (for new models) or an instance of ``LamboPretrainingNetwork`` (if using pretraining data)
         """
         super(LamboNetwork, self).__init__()
         self.max_len = max_len
diff --git a/src/lambo/learning/train.py b/src/lambo/learning/train.py
index a0efdfcda1515a67c67d6e19b5922fd816694c43..acde36fb3acc9d5070c8bcfcdc0a38285f4a517c 100644
--- a/src/lambo/learning/train.py
+++ b/src/lambo/learning/train.py
@@ -6,10 +6,11 @@ from torch.optim import Adam
 
 from lambo.learning.model import LamboNetwork
 from lambo.learning.preprocessing_dict import utf_category_dictionary, prepare_dataloaders_withdict
+from lambo.segmenter.lambo import Lambo
 from lambo.utils.ud_reader import read_treebank
 
 
-def train_loop(dataloader, model, optimizer, device='cpu'):
+def train_loop(dataloader, model, optimizer, device=torch.device('cpu')):
     """
     Training loop.
     
@@ -34,7 +35,7 @@ def train_loop(dataloader, model, optimizer, device='cpu'):
             print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
 
 
-def test_loop(dataloader, model, device='cpu'):
+def test_loop(dataloader, model, device=torch.device('cpu')):
     """
     Test loop.
     
@@ -54,20 +55,33 @@ def test_loop(dataloader, model, device='cpu'):
             Y = XY[-1]
             pred = model(*Xs)
             test_loss += model.compute_loss(pred, Y, Xs).item()
-            for i in range(pred.shape[2]):
-                A = pred[:, :, i, :].argmax(2)
-                B = Y[:, :, i]
-                nontrivial = torch.nonzero(A + B, as_tuple=True)
+            if len(pred.shape)==4:
+                # Predicting character types (segmentation)
+                for i in range(pred.shape[2]):
+                    A = pred[:, :, i, :].argmax(2)
+                    B = Y[:, :, i]
+                    nontrivial = torch.nonzero(A + B, as_tuple=True)
+                    equals = (A == B)[nontrivial].type(torch.float)
+                    correct[i] += equals.sum().item()
+                    size[i] += torch.numel(equals)
+            elif len(pred.shape)==3:
+                # Predictiong characters (subword prediction)
+                A = pred.argmax(2)
+                B = Y
+                nontrivial = torch.nonzero(Y, as_tuple=True)
                 equals = (A == B)[nontrivial].type(torch.float)
-                correct[i] += equals.sum().item()
-                size[i] += torch.numel(equals)
+                #equals = (A==B).type(torch.float)
+                correct[0] += equals.sum().item()
+                size[0] += torch.numel(equals)
+                pass
+                
     test_loss /= num_batches
     size = [s if s > 0 else 1 for s in size]
     print(
         f"Test Error: \n Accuracy chars: {(100 * (correct[0] / size[0])):>5f}%, tokens: {(100 * (correct[1] / size[1])):>5f}%, mwtokens: {(100 * (correct[2] / size[2])):>5f}%, sentences: {(100 * (correct[3] / size[3])):>5f}%, Avg loss: {test_loss:>8f} \n")
 
 
-def test_loop_pretraining(dataloader, model, device='cpu'):
+def test_loop_pretraining(dataloader, model, device=torch.device('cpu')):
     """
     Test loop for pretraining.
 
@@ -99,7 +113,7 @@ def test_loop_pretraining(dataloader, model, device='cpu'):
         f"Test Error: \n Accuracy nontrivial: {(100 * (correct[0] / size[0])):>5f}%, trivial: {(100 * (correct[1] / size[1])):>5f}%, Avg loss: {test_loss:>8f} \n")
 
 
-def train_new_and_save(model_name, treebank_path, save_path, epochs=10, device='cpu'):
+def train_new_and_save(model_name, treebank_path, save_path, epochs=10, device=torch.device('cpu')):
     """
     Train a new LAMBO model and save it in filesystem.
     
@@ -135,7 +149,7 @@ def train_new_and_save(model_name, treebank_path, save_path, epochs=10, device='
         file1.writelines([x + '\t' + str(dict[x]) + '\n' for x in dict])
 
 
-def train_pretrained_and_save(language, treebank_path, save_path, pretrained_path, epochs=10, device='cpu'):
+def train_pretrained_and_save(language, treebank_path, save_path, pretrained_path, epochs=10, device=torch.device('cpu')):
     """
     Train a new LAMBO model, staring from pretrained, and save it in filesystem.
 
@@ -154,15 +168,7 @@ def train_pretrained_and_save(language, treebank_path, save_path, pretrained_pat
         print("Pretrained model not found, falling back to training from scratch.")
         return train_new_and_save('LAMBO-BILSTM', treebank_path, save_path, epochs, device)
     pretrained_model = torch.load(file_path, map_location=torch.device('cpu'))
-    dict = {}
-    for line in open(pretrained_path / (pretrained_name + '.dict')):
-        if line.strip() == '':
-            continue
-        parts = line.split('\t')
-        if len(parts) == 3 and parts[0] == '' and parts[1] == '':
-            # TAB character
-            parts = ['\t', parts[2]]
-        dict[parts[0]] = int(parts[1])
+    dict = Lambo.read_dict(pretrained_path / (pretrained_name + '.dict'))
     
     print("Reading data.")
     train_doc, dev_doc, test_doc = read_treebank(treebank_path, True)
@@ -186,7 +192,7 @@ def train_pretrained_and_save(language, treebank_path, save_path, pretrained_pat
         file1.writelines([x + '\t' + str(dict[x]) + '\n' for x in dict])
 
 
-def tune(model, train_dataloader, test_dataloader, epochs, device='cpu'):
+def tune(model, train_dataloader, test_dataloader, epochs, device=torch.device('cpu')):
     """
     Tune an existing LAMBO model with the provided data
     
@@ -210,7 +216,7 @@ def tune(model, train_dataloader, test_dataloader, epochs, device='cpu'):
         test_loop(test_dataloader, model, device)
 
 
-def pretrain(model, train_dataloader, test_dataloader, epochs, device='cpu'):
+def pretrain(model, train_dataloader, test_dataloader, epochs, device=torch.device('cpu')):
     """
     Tune an existing LAMBO pretraining model with the provided data
     
diff --git a/src/lambo/resources/languages.txt b/src/lambo/resources/languages.txt
index 9e213ff758500adbb9cc64f1850feed0dc739dc6..9fbf954031526e7508b3edf029f2eacaac64791e 100644
--- a/src/lambo/resources/languages.txt
+++ b/src/lambo/resources/languages.txt
@@ -3,7 +3,7 @@ UD_Afrikaans-AfriBooms af Afrikaans
 UD_Ancient_Greek-PROIEL ? Ancient_Greek *
 UD_Ancient_Greek-Perseus ? Ancient_Greek
 UD_Ancient_Hebrew-PTNK ? Ancient_Hebrew
-UD_Arabic-NYUAD ar Arabic
+#UD_Arabic-NYUAD ar Arabic
 UD_Arabic-PADT ar Arabic *
 UD_Armenian-ArmTDP hy Armenian *
 UD_Armenian-BSUT hy Armenian
diff --git a/src/lambo/segmenter/lambo.py b/src/lambo/segmenter/lambo.py
index a5e46d80fb95182f972d877d1ed4e903aacc79cd..e481fb973f0920b74508fd34d27014c8f55126af 100644
--- a/src/lambo/segmenter/lambo.py
+++ b/src/lambo/segmenter/lambo.py
@@ -23,11 +23,13 @@ class Lambo():
     """
     
     @classmethod
-    def get(cls, provided_name):
+    def get(cls, provided_name, with_splitter=True, device=torch.device('cpu')):
         """
         Obtain a LAMBO segmenter based on the name of the model.
         
         :param provided_name: either a full model name (``LAMBO_no_pretraining-UD_Polish-PDB``), or language name (``Polish``) or ISO 639-1 code (``pl``)
+        :param with_splitter: should a subword splitter be loaded as well
+        :param device: pytorch device to use for inference
         :return: LAMBO segmenter based on the expected model
         """
         if '-' in provided_name:
@@ -36,10 +38,14 @@ class Lambo():
         else:
             # It's an alias -- language name or code
             model_name = Lambo.getDefaultModel(provided_name)
-        dict_path, model_path = download_model(model_name)
+        dict_path, model_path, splitter_path = download_model(model_name)
         dict = Lambo.read_dict(dict_path)
         model = torch.load(model_path, map_location=torch.device('cpu'))
-        return cls(model, dict)
+        splitter = None
+        if with_splitter and splitter_path:
+            from lambo.subwords.splitter import LamboSplitter
+            splitter = LamboSplitter.from_path(splitter_path.parent, model_name)
+        return cls(model, dict, splitter, device)
     
     @staticmethod
     def getDefaultModel(provided_name):
@@ -67,27 +73,36 @@ class Lambo():
         return model_name
     
     @classmethod
-    def from_path(cls, model_path, model_name):
+    def from_path(cls, model_path, model_name, with_splitter=True, device=torch.device('cpu')):
         """
         Obtain a LAMBO segmenter by reading a model from a given path.
         
         :param model_path: directory including the model files
         :param model_name: model name
+        :param device: pytorch device to use for inference
         :return:
         """
         model = torch.load(model_path / (model_name + '.pth'), map_location=torch.device('cpu'))
         dict = Lambo.read_dict(model_path / (model_name + '.dict'))
-        return cls(model, dict)
+        splitter = None
+        if with_splitter and (model_path / (model_name + '_subwords.pth')).exists():
+            from lambo.subwords.splitter import LamboSplitter
+            splitter = LamboSplitter.from_path(model_path, model_name)
+        return cls(model, dict, splitter, device)
     
-    def __init__(self, model, dict):
+    def __init__(self, model, dict, splitter=None, device=torch.device('cpu')):
         """
         Create a new LAMBO segmenter from a given model and dictionary.
         
         :param model: prediction Pytorch model
         :param dict: dictionary
+        :param device: pytorch device to use for inference
         """
         self.model = model
         self.dict = dict
+        self.splitter = splitter
+        self.device = device
+        self.model.to(self.device)
     
     @staticmethod
     def read_dict(dict_path):
@@ -98,14 +113,24 @@ class Lambo():
         :return: character dictionary
         """
         dict = {}
-        for line in open(dict_path):
-            if line.strip() == '':
+        prevEmpty = False
+        # to properly handle non-standard newline characters
+        chunks = dict_path.read_bytes().decode('utf-8').split('\n')
+        for chunk in chunks:
+            if chunk == '':
+                prevEmpty = True
                 continue
-            parts = line.split('\t')
+            parts = chunk.split('\t')
             if len(parts) == 3 and parts[0] == '' and parts[1] == '':
                 # TAB character
                 parts = ['\t', parts[2]]
+            # to properly handle newline characters in the dictionary
+            if parts[0] == '' and prevEmpty:
+                parts[0] = '\n'
+            if parts[0] in dict:
+                print("WARNING: duplicated key in dictionary")
             dict[parts[0]] = int(parts[1])
+            prevEmpty = False
         return dict
     
     @staticmethod
@@ -142,7 +167,7 @@ class Lambo():
         Perform the segmentation of the text. This involves:
         
         * splitting the document into turns using turn markers from ``turn_regexp.txt``
-        * splitting the turns into sentences and tokens according to the model's predictions
+        * splitting the turns into sentences and tokens according to the model's predictions (including splitting into subwords)
         * modifying the output to account for special tokens (emojis and pauses)
         
         :param text: input text
@@ -175,7 +200,9 @@ class Lambo():
         
         # compute neural network output
         with torch.no_grad():
+            X = [x.to(self.device) for x in X]
             Y = self.model(*X)
+        Y = Y.to('cpu')
         
         # perform postprocessing
         decisions = self.model.postprocessing(Y, text)
@@ -215,6 +242,14 @@ class Lambo():
             if token_end:
                 # End of token
                 token = Token(turn_offset + token_begin, turn_offset + i + 1, text[token_begin:(i + 1)], mwtoken_end)
+                if mwtoken_end and self.splitter:
+                    # If token looks multi-word and splitter is avilable, use it
+                    subwords = self.splitter.split(token.text)
+                    if len(subwords) == 1:
+                        # If not split in the end, ignore
+                        token.is_multi_word = False
+                    else:
+                        token.subwords = subwords
                 sentence.add_token(token)
                 token_begin = -1
             if sentence_end:
diff --git a/src/lambo/subwords/__init__.py b/src/lambo/subwords/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1f3e72ae0349d0e1fac242354307126d90509fb
--- /dev/null
+++ b/src/lambo/subwords/__init__.py
@@ -0,0 +1,3 @@
+"""
+The package includes code for splitting multi-word tokens
+"""
\ No newline at end of file
diff --git a/src/lambo/subwords/model.py b/src/lambo/subwords/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6ec06e30aa6312973c9d01ebd69b0a0ad8fa48b
--- /dev/null
+++ b/src/lambo/subwords/model.py
@@ -0,0 +1,86 @@
+import torch
+from torch.nn import Embedding, LSTM, Linear, LogSoftmax, NLLLoss, Module
+
+
+class LamboSubwordNetwork(Module):
+    """
+    LAMBO subword splitting neural network model. The network has four layers:
+
+    * embedding layers for characters, representing each as a 64-long vector,
+    * bidirectional LSTM layer, taking a character embedding as input and outputting 2*64-long state vector,
+    * dense linear layer, converting LSTM state vectors to 64-dimensional embedding space
+    * inverted embedding layer to convert back to characters using the same matrix as for embedding
+    * softmax layer, computing probability of any character
+    """
+    
+    def __init__(self, max_len, dict, pretrained=None):
+        """
+        Create a LAMBO subword neural network.
+
+        :param max_len: maximum length of an input sequence (i.e. characters in a token),
+        :param dict: character dictionary
+        :param pretrained: either ``None`` (for virgin models) or an instance of ``LamboNetwork`` (if using pretraining data)
+        """
+        super(LamboSubwordNetwork, self).__init__()
+        self.max_len = max_len
+        self.dict = dict
+        if pretrained is not None:
+            # Copy the weights of the embedding of pretraining model
+            self.embedding_layer = Embedding.from_pretrained(pretrained.embedding_layer.weight, freeze=False,
+                                                             padding_idx=None)
+        else:
+            self.embedding_layer = Embedding(len(dict), 64, dict['<PAD>'])
+        self.lstm_layer = LSTM(input_size=self.embedding_layer.embedding_dim, hidden_size=64, batch_first=True,
+                               bidirectional=True)
+        self.linear_layer = Linear(self.lstm_layer.hidden_size * 2, self.embedding_layer.embedding_dim)
+        self.softmax_layer = LogSoftmax(2)
+        self.loss_fn = NLLLoss()
+    
+    def forward(self, x_char):
+        """
+        Computation of the network output (B = batch size, L = maximum sequence length, V = number of words in the dictionary)
+
+        :param x_char: a tensor of BxL character indices,
+        :return: a tensor of BxLxV class scores
+        """
+        embedded = self.embedding_layer(x_char)
+        hidden = self.lstm_layer(embedded)[0]
+        reduced = self.linear_layer(hidden)
+        
+        # Computing inverted embedding as a cosine similarity score of the transformed representation and original embeddings
+        scores = self.inverted_embedding(reduced, self.embedding_layer)
+        
+        probabilities = self.softmax_layer(scores)
+        return probabilities
+    
+    @staticmethod
+    def inverted_embedding(input, embedding_layer):
+        """
+        Inverted embeddings matrix. Finds the best items (i.e. characters or words) in the dictionary of the
+        original embedding layer (B = batch size, L = maximum sequence length, E = embedding size, V = number of words
+         in the dictionary) for the input in the embedding space.
+        
+        :param input: a tensor in the hidden space of shape BxLxE
+        :param embedding_layer: an embedding layer with VxE parameter matrix
+        :return: dot product similarity tensor of shape BxLxV
+        """
+        # Normalise both matrices
+        input_normalised = torch.nn.functional.normalize(input, dim=2)
+        weights_normalised = torch.nn.functional.normalize(embedding_layer.weight.data, dim=1)
+        # Dot product of normalised vectors equals cosine similarity
+        scores = torch.matmul(input_normalised, torch.transpose(weights_normalised, 0, 1))
+        return scores
+    
+    def compute_loss(self, pred, true, Xs):
+        """
+        Comput cross-entropy loss.
+
+        :param pred: tensor with predicted character probabilities
+        :param true: tensor witrh true classes
+        :param Xs: (not used)
+        :return: loss value
+        """
+        pred = torch.reshape(pred, (-1, len(self.dict)))
+        true = torch.reshape(true, (-1,))
+        output = self.loss_fn(pred, true)
+        return output
diff --git a/src/lambo/subwords/preprocessing.py b/src/lambo/subwords/preprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b880e08558241837bedd63fc8d1cfc4bff36474
--- /dev/null
+++ b/src/lambo/subwords/preprocessing.py
@@ -0,0 +1,103 @@
+"""
+Funtions preprocessing data for subword splitting
+"""
+import random
+
+import torch
+from torch.utils.data import TensorDataset, DataLoader
+
+
+def encode_test(text, dictionary, maximum_length):
+    """
+    Encode text for test purposes (no Y available).
+    
+    :param text: string of text to be encoded
+    :param dictionary: the dictionary
+    :param maximum_length: maximum length of a token (the rest will be trimmed)
+    :return: a 1-tuple with a pytorch tensor of the given maximimum length
+    """
+    Xchar = [dictionary[char] if char in dictionary else dictionary['<UNK>'] for char in text]
+    Xchar += [dictionary['<PAD>']] * (maximum_length - (len(Xchar) % maximum_length))
+    Xchar = Xchar[:maximum_length]
+    return torch.Tensor([Xchar]).to(torch.int64),
+
+
+def encode_subwords(documents, dictionary, maximum_length):
+    """
+    Encode subwords as neural network inputs and outputs
+
+    :param documents: list of documents
+    :param dictionary: character dictionary
+    :param maximum_length: maximum length of network input and output (the rest will be trimmed)
+    :return: a pair of network input/output tensors: character encodings, true catagories (split words)
+    """
+    tokenCount = 0
+    multiwordCount = 0
+    # Count true multi-word tokens
+    for document in documents:
+        for turn in document.turns:
+            for sentence in turn.sentences:
+                for token in sentence.tokens:
+                    tokenCount += 1
+                    if token.is_multi_word:
+                        multiwordCount += 1
+    thrs = multiwordCount / tokenCount
+    Xchars = []
+    Ychars = []
+    random.seed(1)
+    for document in documents:
+        for turn in document.turns:
+            for sentence in turn.sentences:
+                for token in sentence.tokens:
+                    r = random.random()
+                    if token.is_multi_word or r < thrs:
+                        # Token is added to training if truly multi-word or randomly selected according to threshold
+                        Xchar = [dictionary[char] if char in dictionary else dictionary['<UNK>'] for char in token.text]
+                        Xchar += [dictionary['<PAD>']] * (maximum_length - (len(Xchar) % maximum_length))
+                        Xchar = Xchar[:maximum_length]
+                        subwords = token.subwords
+                        if len(subwords) == 0:
+                            subwords = [token.text]
+                        Ychar = []
+                        for subword in subwords:
+                            if len(Ychar) != 0:
+                                Ychar += [dictionary['<PAD>']]
+                            Ychar += [dictionary[char] if char in dictionary else dictionary['<UNK>'] for char in
+                                      subword]
+                        Ychar += [dictionary['<PAD>']] * (maximum_length - (len(Ychar) % maximum_length))
+                        Ychar = Ychar[:maximum_length]
+                        Xchars += [Xchar]
+                        Ychars += [Ychar]
+    
+    return Xchars, Ychars
+
+
+def prepare_subwords_dataloaders(train_docs, test_docs, max_len, batch_size, dict):
+    """
+    Prapare Pytorch dataloaders for the documents.
+
+    :param train_docs: list of training documents
+    :param test_docs: list of test documents
+    :param max_len: maximum length of network input
+    :param batch_size: batch size
+    :param dict: character dictionary (or None, if to be created)
+    :return: a triple with character dictionary, train dataloader and test dataloader
+    """
+    train_X_char, train_Y = encode_subwords(train_docs, dict, max_len)
+    if len(train_X_char) < 256:
+        # Not enough data for training
+        return None, None
+    train_X_char = torch.Tensor(train_X_char).to(torch.int64)
+    train_Y = torch.Tensor(train_Y).to(torch.int64)
+    train_dataset = TensorDataset(train_X_char, train_Y)
+    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
+    
+    test_X_char, test_Y = encode_subwords(test_docs, dict, max_len)
+    if len(test_X_char) < 64:
+        # Not enough data for testing
+        return None, None
+    test_X_char = torch.Tensor(test_X_char).to(torch.int64)
+    test_Y = torch.Tensor(test_Y).to(torch.int64)
+    test_dataset = TensorDataset(test_X_char, test_Y)
+    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
+    return train_dataloader, test_dataloader
diff --git a/src/lambo/subwords/splitter.py b/src/lambo/subwords/splitter.py
new file mode 100644
index 0000000000000000000000000000000000000000..b06a66e18811d54fd077c0a61b7974525db8222c
--- /dev/null
+++ b/src/lambo/subwords/splitter.py
@@ -0,0 +1,65 @@
+import torch
+
+from lambo.segmenter.lambo import Lambo
+from lambo.subwords.preprocessing import encode_test
+
+
+class LamboSplitter():
+    """
+    Class for splitting tokens into sub-words (wrapper for neural network)
+    """
+    
+    @classmethod
+    def from_path(cls, model_path, model_name):
+        """
+        Obtain a LAMBO subword splitter by reading a model from a given path.
+
+        :param model_path: directory including the model files
+        :param model_name: model name
+        :return: instance of LamboSplitter
+        """
+        model = torch.load(model_path / (model_name + '_subwords.pth'), map_location=torch.device('cpu'))
+        dict = Lambo.read_dict(model_path / (model_name + '.dict'))
+        return cls(model, dict)
+    
+    def __init__(self, model, dict):
+        """
+        Create a new LAMBO subword splitter from a given model and dictionary.
+
+        :param model: prediction Pytorch model
+        :param dict: dictionary
+        """
+        self.model = model
+        self.dict = dict
+        self.inv_dict = {dict[key]: key for key in dict}
+    
+    def split(self, token_text):
+        """
+        Split a given token text
+        
+        :param token_text: string with a token to split
+        :return: list of subwords
+        """
+        # Too long for the maximum length
+        if len(token_text) >= self.model.max_len:
+            return [token_text]
+        Xs = encode_test(token_text, self.dict, self.model.max_len)
+        with torch.no_grad():
+            Y = self.model(*Xs)
+        codes = Y.argmax(2).numpy()[0]
+        decisions = [self.inv_dict[code] for code in codes]
+        # Recover the subwords from the network output
+        result = ['']
+        for char in decisions:
+            if len(char) == 1:
+                result[-1] += char
+            elif char == '<PAD>':
+                if result[-1] == '':
+                    break
+                result.append('')
+            else:
+                return [token_text]
+        result = [subword for subword in result if subword != '']
+        if len(result) == 0:
+            return [token_text]
+        return result
diff --git a/src/lambo/subwords/train.py b/src/lambo/subwords/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..da6c492fcadefc9f9a12c44bfb44c41ea802b54c
--- /dev/null
+++ b/src/lambo/subwords/train.py
@@ -0,0 +1,45 @@
+import torch
+
+from lambo.learning.train import tune
+from lambo.subwords.model import LamboSubwordNetwork
+from lambo.subwords.preprocessing import prepare_subwords_dataloaders
+from lambo.utils.ud_reader import read_treebank
+
+
+def train_subwords_and_save(model_name, treebank_path, save_path, lambo_segmenter, epochs=20, device=torch.device('cpu')):
+    """
+    Train a new LAMBO subwords model and save it in filesystem.
+
+    :param model_name: type of model trained, currently only ``LAMBO-BILSTM`` is supported
+    :param treebank_path: path to the treebank training data
+    :param save_path: path to save the generated model
+    :param lambo_segmenter: LAMBO segmenter to base on
+    :param epochs: number of epochs to run for (default: 20)
+    :param device: the device to use for computation
+    :return: no value returned
+    """
+    if model_name not in ['LAMBO-BILSTM']:
+        print(" Unrecognised model name: " + model_name)
+        return
+    
+    print("Reading data.")
+    train_doc, dev_doc, test_doc = read_treebank(treebank_path, True)
+    
+    print("Preparing data")
+    BATCH_SIZE = 32
+    print("Initiating the model.")
+    
+    MAX_LEN = 32
+    train_dataloader, test_dataloader = prepare_subwords_dataloaders([train_doc, dev_doc], [test_doc],
+                                                                     MAX_LEN,
+                                                                     BATCH_SIZE, lambo_segmenter.dict)
+    if train_dataloader is None:
+        print("Not enough data to train, moving on.")
+        return
+    
+    model = LamboSubwordNetwork(MAX_LEN, lambo_segmenter.dict, lambo_segmenter.model)
+    
+    tune(model, train_dataloader, test_dataloader, epochs, device)
+    
+    print("Saving")
+    torch.save(model, save_path / (treebank_path.name + '_subwords.pth'))
diff --git a/src/lambo/utils/download.py b/src/lambo/utils/download.py
index 0aef6dfb767ec19e3711e32cfcab7031a604a626..a54e10e2a519aa3af34f0361a0a79052dcf1509a 100644
--- a/src/lambo/utils/download.py
+++ b/src/lambo/utils/download.py
@@ -22,7 +22,11 @@ TYPE_TO_PATH = {
 # The adress of the remote repository
 _URL = "http://home.ipipan.waw.pl/p.przybyla/lambo/{type}/{treebank}.{extension}"
 
+# The adress of the remote repository for optional variants
+_URL_VAR = "http://home.ipipan.waw.pl/p.przybyla/lambo/{type}/{treebank}_{variant}.{extension}"
+
 _HOME_DIR = os.getenv("HOME", os.curdir)
+
 # Models are stored in ~/.lambo/
 _CACHE_DIR = os.getenv("LAMBO_DIR", os.path.join(_HOME_DIR, ".lambo"))
 
@@ -39,6 +43,7 @@ def download_model(model_name, force=False):
     type = model_name.split("-", 1)[0]
     treebank = model_name.split("-", 1)[1]
     locations = []
+    # First download pass for the main model
     for extension in ['dict', 'pth']:
         url = _URL.format(type=TYPE_TO_PATH[type], treebank=treebank, extension=extension)
         local_filename = model_name + '.' + extension
@@ -60,8 +65,36 @@ def download_model(model_name, force=False):
                                 f.write(chunk)
                                 pbar.update(len(chunk))
         except exceptions.RetryError:
-            raise ConnectionError(f"Couldn't find or download model {model_name}.tar.gz. "
+            raise ConnectionError(f"Couldn't find or download model. "
                                   "Check if model name is correct or try again later!")
+    # Second download pass for the variants
+    for variant in ['subwords']:
+        extension = 'pth'
+        url = _URL_VAR.format(type=TYPE_TO_PATH[type], treebank=treebank, variant=variant, extension=extension)
+        local_filename = model_name + '_' + variant + '.' + extension
+        location = os.path.join(_CACHE_DIR, local_filename)
+        if os.path.exists(location) and not force:
+            logger.debug("Using cached data.")
+            locations.append(Path(location))
+            continue
+        chunk_size = 1024
+        logger.info(url)
+        try:
+            with _requests_retry_session(retries=2).get(url, stream=True) as r:
+                pbar = tqdm.tqdm(unit="B", total=int(r.headers.get("content-length")),
+                                 unit_divisor=chunk_size, unit_scale=True)
+                with open(location, "wb") as f:
+                    with pbar:
+                        for chunk in r.iter_content(chunk_size):
+                            if chunk:
+                                f.write(chunk)
+                                pbar.update(len(chunk))
+        except exceptions.RetryError:
+            # This is normal if splitter was not trained
+            print("Couldn't find or download model variant (" + variant + ") -- might be unavailable.")
+            locations.append(None)
+            continue
+        locations.append(Path(location))
     return locations
 
 
diff --git a/src/lambo/utils/generate_languages_txt.py b/src/lambo/utils/generate_languages_txt.py
index d3982e4ca1af327c65c264a5138f19ca8a403840..77c3dc5c082f97201fd6396db426c78be08e1cb8 100644
--- a/src/lambo/utils/generate_languages_txt.py
+++ b/src/lambo/utils/generate_languages_txt.py
@@ -1,3 +1,9 @@
+"""
+Rough procedure to generate languages.txt from a UD folder. Includes all languages that have a test and dev portions.
+Uses the previous version of the file to translate language names to ISO codes. Selects the largest corpus as preferred
+ for the language. May require manual adjustment to exclude abnormal treebanks (e.g. UD_Arabic-NYUAD) or add missing
+ ISO codes for new languages.
+"""
 from pathlib import Path
 
 old_languages_txt = ''
diff --git a/src/lambo/utils/printer.py b/src/lambo/utils/printer.py
index 25ccf393e927380f694a37c73179bfeaa6cbf5a7..08001e896bffef8bacbf52ad2bceabeae4f180f7 100644
--- a/src/lambo/utils/printer.py
+++ b/src/lambo/utils/printer.py
@@ -20,7 +20,7 @@ def print_document_to_screen(document):
             formatted = ''
             for token in sentence.tokens:
                 if token.is_multi_word:
-                    formatted += '((' + token.text + '))'
+                    formatted += '(' + token.text+ '=' + '-'.join(token.subwords) + ')'
                 else:
                     formatted += '(' + token.text + ')'
             print('TOKENS: ' + formatted)
@@ -48,13 +48,13 @@ def print_document_to_conll(document, path):
                     token_text = token_text_with_whitespace_for_conllu(token, document, turn, sentence).strip()
                     if token_text == '':
                         continue
-                    if token.is_multi_word:
+                    if token.is_multi_word and len(token.subwords) > 1:
                         file1.write(str(token_id))
-                        file1.write('-' + str(token_id + 1))
+                        file1.write('-' + str(token_id + len(token.subwords) - 1))
                         file1.write('\t' + token_text + '\t_\t_\t_\t_\t_\t_\t_\t_\n')
-                        token_id += 2
-                        file1.write(str(token_id - 2) + '\t_\t_\t_\t_\t_\t' + str(token_id - 3) + '\t_\t_\t_\n')
-                        file1.write(str(token_id - 1) + '\t_\t_\t_\t_\t_\t' + str(token_id - 2) + '\t_\t_\t_\n')
+                        for word in token.subwords:
+                            file1.write(str(token_id) + '\t' + word + '\t_\t_\t_\t_\t' + str(token_id - 1) + '\t_\t_\t_\n')
+                            token_id += 1
                     else:
                         file1.write(str(token_id))
                         file1.write('\t' + token_text + '\t_\t_\t_\t_\t' + str(token_id - 1) + '\t_\t_\t_\n')
diff --git a/src/lambo/utils/ud_reader.py b/src/lambo/utils/ud_reader.py
index 1fecc8ed116f44b5b17aa5cb907f5af032d68d00..99fdbada4f20302500590a8d19389900696d6bba 100644
--- a/src/lambo/utils/ud_reader.py
+++ b/src/lambo/utils/ud_reader.py
@@ -125,9 +125,10 @@ def read_document(file_path, random_separators):
     turn_text = ""
     sentence = Sentence()
     sentence_text = ""
-    banned_range = [0, 0]
+    word_range = [0, 0]
     current_offset = 0
     separator = ''
+    lastToken = None
     for line in file_path.read_text().split('\n'):
         if line.startswith('#'):
             # Comment, ignore
@@ -149,7 +150,7 @@ def read_document(file_path, random_separators):
             turn.add_sentence(sentence)
             sentence = Sentence()
             sentence_text = ""
-            banned_range = [0, 0]
+            word_range = [0, 0]
         else:
             parts = line.split('\t')
             is_copy = any(x.startswith('CopyOf=') for x in parts[-1].split('|')) or ('.' in parts[0])
@@ -159,9 +160,9 @@ def read_document(file_path, random_separators):
             form = parts[1]
             space_after_no = ('SpaceAfter=No' in parts[-1].split('|'))
             if len(numbers) == 1:
-                if banned_range[0] <= numbers[0] <= banned_range[1]:
+                if word_range[0] <= numbers[0] <= word_range[1]:
                     # Individual word within multi-word token
-                    pass
+                    lastToken.addSubword(form)
                 else:
                     # Individual word not covered
                     token = Token(current_offset, current_offset + len(form), form, False)
@@ -186,7 +187,8 @@ def read_document(file_path, random_separators):
                 sentence_text += separator
                 current_offset += len(separator)
                 sentence.add_token(token)
-                banned_range = numbers
+                word_range = numbers
+                lastToken = token
     turn.set_text(turn_text)
     document = Document()
     document.set_text(turn_text)