Skip to content
Snippets Groups Projects
Commit e31253c7 authored by piotrmp's avatar piotrmp
Browse files

Subword splitting integration in segmenter.

parent 9fcf1532
No related branches found
No related tags found
1 merge request!2Multiword generation
...@@ -30,6 +30,7 @@ class Lambo(): ...@@ -30,6 +30,7 @@ class Lambo():
:param provided_name: either a full model name (``LAMBO_no_pretraining-UD_Polish-PDB``), or language name (``Polish``) or ISO 639-1 code (``pl``) :param provided_name: either a full model name (``LAMBO_no_pretraining-UD_Polish-PDB``), or language name (``Polish``) or ISO 639-1 code (``pl``)
:return: LAMBO segmenter based on the expected model :return: LAMBO segmenter based on the expected model
""" """
# TODO handle splitter here
if '-' in provided_name: if '-' in provided_name:
# It's s regular name # It's s regular name
model_name = provided_name model_name = provided_name
...@@ -77,9 +78,13 @@ class Lambo(): ...@@ -77,9 +78,13 @@ class Lambo():
""" """
model = torch.load(model_path / (model_name + '.pth'), map_location=torch.device('cpu')) model = torch.load(model_path / (model_name + '.pth'), map_location=torch.device('cpu'))
dict = Lambo.read_dict(model_path / (model_name + '.dict')) dict = Lambo.read_dict(model_path / (model_name + '.dict'))
return cls(model, dict) splitter=None
if (model_path / (model_name + '_subwords.pth')).exists():
from lambo.subwords.splitter import LamboSplitter
splitter = LamboSplitter.from_path(model_path, model_name)
return cls(model, dict, splitter)
def __init__(self, model, dict): def __init__(self, model, dict, splitter=None):
""" """
Create a new LAMBO segmenter from a given model and dictionary. Create a new LAMBO segmenter from a given model and dictionary.
...@@ -88,6 +93,7 @@ class Lambo(): ...@@ -88,6 +93,7 @@ class Lambo():
""" """
self.model = model self.model = model
self.dict = dict self.dict = dict
self.splitter = splitter
@staticmethod @staticmethod
def read_dict(dict_path): def read_dict(dict_path):
...@@ -150,7 +156,7 @@ class Lambo(): ...@@ -150,7 +156,7 @@ class Lambo():
Perform the segmentation of the text. This involves: Perform the segmentation of the text. This involves:
* splitting the document into turns using turn markers from ``turn_regexp.txt`` * splitting the document into turns using turn markers from ``turn_regexp.txt``
* splitting the turns into sentences and tokens according to the model's predictions * splitting the turns into sentences and tokens according to the model's predictions (including splitting into subwords)
* modifying the output to account for special tokens (emojis and pauses) * modifying the output to account for special tokens (emojis and pauses)
:param text: input text :param text: input text
...@@ -223,6 +229,12 @@ class Lambo(): ...@@ -223,6 +229,12 @@ class Lambo():
if token_end: if token_end:
# End of token # End of token
token = Token(turn_offset + token_begin, turn_offset + i + 1, text[token_begin:(i + 1)], mwtoken_end) token = Token(turn_offset + token_begin, turn_offset + i + 1, text[token_begin:(i + 1)], mwtoken_end)
if mwtoken_end:
subwords=self.splitter.split(token.text)
if len(subwords)==1:
token.is_multi_word=False
else:
token.subwords=subwords
sentence.add_token(token) sentence.add_token(token)
token_begin = -1 token_begin = -1
if sentence_end: if sentence_end:
......
...@@ -20,7 +20,7 @@ def print_document_to_screen(document): ...@@ -20,7 +20,7 @@ def print_document_to_screen(document):
formatted = '' formatted = ''
for token in sentence.tokens: for token in sentence.tokens:
if token.is_multi_word: if token.is_multi_word:
formatted += '(' + token.text+ '=' + '-'.join(token.words) + ')' formatted += '(' + token.text+ '=' + '-'.join(token.subwords) + ')'
else: else:
formatted += '(' + token.text + ')' formatted += '(' + token.text + ')'
print('TOKENS: ' + formatted) print('TOKENS: ' + formatted)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment