diff --git a/src/lambo/segmenter/lambo.py b/src/lambo/segmenter/lambo.py index 32d786c066c636ec1d9d38f0dbcb6c3d89e3e484..bd949b800838ea1b6fef07f2d4012a51f4fb83d6 100644 --- a/src/lambo/segmenter/lambo.py +++ b/src/lambo/segmenter/lambo.py @@ -30,6 +30,7 @@ class Lambo(): :param provided_name: either a full model name (``LAMBO_no_pretraining-UD_Polish-PDB``), or language name (``Polish``) or ISO 639-1 code (``pl``) :return: LAMBO segmenter based on the expected model """ + # TODO handle splitter here if '-' in provided_name: # It's s regular name model_name = provided_name @@ -77,9 +78,13 @@ class Lambo(): """ model = torch.load(model_path / (model_name + '.pth'), map_location=torch.device('cpu')) dict = Lambo.read_dict(model_path / (model_name + '.dict')) - return cls(model, dict) + splitter=None + if (model_path / (model_name + '_subwords.pth')).exists(): + from lambo.subwords.splitter import LamboSplitter + splitter = LamboSplitter.from_path(model_path, model_name) + return cls(model, dict, splitter) - def __init__(self, model, dict): + def __init__(self, model, dict, splitter=None): """ Create a new LAMBO segmenter from a given model and dictionary. @@ -88,6 +93,7 @@ class Lambo(): """ self.model = model self.dict = dict + self.splitter = splitter @staticmethod def read_dict(dict_path): @@ -150,7 +156,7 @@ class Lambo(): Perform the segmentation of the text. This involves: * splitting the document into turns using turn markers from ``turn_regexp.txt`` - * splitting the turns into sentences and tokens according to the model's predictions + * splitting the turns into sentences and tokens according to the model's predictions (including splitting into subwords) * modifying the output to account for special tokens (emojis and pauses) :param text: input text @@ -223,6 +229,12 @@ class Lambo(): if token_end: # End of token token = Token(turn_offset + token_begin, turn_offset + i + 1, text[token_begin:(i + 1)], mwtoken_end) + if mwtoken_end: + subwords=self.splitter.split(token.text) + if len(subwords)==1: + token.is_multi_word=False + else: + token.subwords=subwords sentence.add_token(token) token_begin = -1 if sentence_end: diff --git a/src/lambo/utils/printer.py b/src/lambo/utils/printer.py index bd731bae035c5833070440aca249c64501fb34f5..a55ecd980e3f95a7fe53fbbfcb3d75a3a746f856 100644 --- a/src/lambo/utils/printer.py +++ b/src/lambo/utils/printer.py @@ -20,7 +20,7 @@ def print_document_to_screen(document): formatted = '' for token in sentence.tokens: if token.is_multi_word: - formatted += '(' + token.text+ '=' + '-'.join(token.words) + ')' + formatted += '(' + token.text+ '=' + '-'.join(token.subwords) + ')' else: formatted += '(' + token.text + ')' print('TOKENS: ' + formatted)