Skip to content
Snippets Groups Projects
Commit 142eead6 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Add an option to ignore turns in text segmentation

parent 1e1002f5
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
...@@ -34,10 +34,11 @@ class LamboTokenizer(Tokenizer): ...@@ -34,10 +34,11 @@ class LamboTokenizer(Tokenizer):
return tokens return tokens
def segment(self, text: str) -> List[List[str]]: def segment(self, text: str, turns: bool = False) -> List[List[str]]:
""" """
Full segmentation - segment into sentences Full segmentation - segment into sentences.
:param text: :param text:
:param turns: segment into sentences by splitting on sentences or on turns. Default: sentences.
:return: :return:
""" """
...@@ -46,13 +47,19 @@ class LamboTokenizer(Tokenizer): ...@@ -46,13 +47,19 @@ class LamboTokenizer(Tokenizer):
sentence_tokens = [] sentence_tokens = []
for turn in document.turns: for turn in document.turns:
for sentence in turn.sentences: if turns:
sentence_tokens = [] sentence_tokens = []
for sentence in turn.sentences:
if not turns:
sentence_tokens = []
for token in sentence.tokens: for token in sentence.tokens:
if len(token.subwords) > 0: if len(token.subwords) > 0:
sentence_tokens.extend([s for s in token.subwords]) sentence_tokens.extend([s for s in token.subwords])
else: else:
sentence_tokens.append(token.text) sentence_tokens.append(token.text)
if not turns:
sentences.append(sentence_tokens)
if turns:
sentences.append(sentence_tokens) sentences.append(sentence_tokens)
return sentences return sentences
...@@ -78,10 +78,10 @@ flags.DEFINE_string(name="tensorboard_name", default="combo", ...@@ -78,10 +78,10 @@ flags.DEFINE_string(name="tensorboard_name", default="combo",
help="Name of the model in TensorBoard logs.") help="Name of the model in TensorBoard logs.")
flags.DEFINE_string(name="config_path", default="", flags.DEFINE_string(name="config_path", default="",
help="Config file path.") help="Config file path.")
flags.DEFINE_boolean(name="save_matrices", default=True,
help="Save relation distribution matrices.")
flags.DEFINE_list(name="datasets_for_vocabulary", default=["train"], flags.DEFINE_list(name="datasets_for_vocabulary", default=["train"],
help="") help="")
flags.DEFINE_boolean(name="turns", default=False,
help="Segment into sentences on sentence break or on turn break.")
# Finetune after training flags # Finetune after training flags
flags.DEFINE_string(name="finetuning_training_data_path", default="", flags.DEFINE_string(name="finetuning_training_data_path", default="",
...@@ -107,6 +107,8 @@ flags.DEFINE_boolean(name="finetuning", default=False, ...@@ -107,6 +107,8 @@ flags.DEFINE_boolean(name="finetuning", default=False,
help="Finetuning mode for training.") help="Finetuning mode for training.")
flags.DEFINE_string(name="tokenizer_language", default="English", flags.DEFINE_string(name="tokenizer_language", default="English",
help="Tokenizer language.") help="Tokenizer language.")
flags.DEFINE_boolean(name="save_matrices", default=True,
help="Save relation distribution matrices.")
flags.DEFINE_enum(name="predictor_name", default="combo-lambo", flags.DEFINE_enum(name="predictor_name", default="combo-lambo",
enum_values=["combo", "combo-spacy", "combo-lambo"], enum_values=["combo", "combo-spacy", "combo-lambo"],
help="Use predictor with whitespace, spacy or lambo (recommended) tokenizer.") help="Use predictor with whitespace, spacy or lambo (recommended) tokenizer.")
...@@ -434,7 +436,7 @@ def run(_): ...@@ -434,7 +436,7 @@ def run(_):
else: else:
tokenizer = LamboTokenizer(tokenizer_language) tokenizer = LamboTokenizer(tokenizer_language)
with open(FLAGS.input_file, "r", encoding='utf-8') as file: with open(FLAGS.input_file, "r", encoding='utf-8') as file:
input_sentences = tokenizer.segment(file.read()) input_sentences = tokenizer.segment(file.read(), turns=FLAGS.turns)
predictions = predictor.predict(input_sentences) predictions = predictor.predict(input_sentences)
with open(FLAGS.output_file, "w") as file: with open(FLAGS.output_file, "w") as file:
for prediction in tqdm(predictions): for prediction in tqdm(predictions):
......
...@@ -13,6 +13,14 @@ class LamboTokenizerTest(unittest.TestCase): ...@@ -13,6 +13,14 @@ class LamboTokenizerTest(unittest.TestCase):
self.assertListEqual([t.text for t in tokens], self.assertListEqual([t.text for t in tokens],
['Hello', 'cats', '.', 'I', 'love', 'you']) ['Hello', 'cats', '.', 'I', 'love', 'you'])
def test_segment_text(self):
tokens = self.lambo_tokenizer.segment('Hello cats. I love you.\n\nHi.')
self.assertListEqual(tokens, [['Hello', 'cats', '.'], ['I', 'love', 'you', '.'], ['Hi', '.']])
def test_segment_text_with_turns(self):
tokens = self.lambo_tokenizer.segment('Hello cats. I love you.\n\nHi.', turns=True)
self.assertListEqual(tokens, [['Hello', 'cats', '.', 'I', 'love', 'you', '.'], ['Hi', '.']])
def test_tokenize_sentence_with_multiword(self): def test_tokenize_sentence_with_multiword(self):
tokens = self.lambo_tokenizer.tokenize('I don\'t like apples.') tokens = self.lambo_tokenizer.tokenize('I don\'t like apples.')
self.assertListEqual([t.text for t in tokens], self.assertListEqual([t.text for t in tokens],
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment