From 142eead6f6549f06ae8aaf7f09efdf50417220e0 Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Mon, 27 Nov 2023 22:17:13 +1100 Subject: [PATCH] Add an option to ignore turns in text segmentation --- combo/data/tokenizers/lambo_tokenizer.py | 13 ++++++++++--- combo/main.py | 8 +++++--- tests/data/tokenizers/test_lambo_tokenizer.py | 8 ++++++++ 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/combo/data/tokenizers/lambo_tokenizer.py b/combo/data/tokenizers/lambo_tokenizer.py index 8d4e4e6..c11b605 100644 --- a/combo/data/tokenizers/lambo_tokenizer.py +++ b/combo/data/tokenizers/lambo_tokenizer.py @@ -34,10 +34,11 @@ class LamboTokenizer(Tokenizer): return tokens - def segment(self, text: str) -> List[List[str]]: + def segment(self, text: str, turns: bool = False) -> List[List[str]]: """ - Full segmentation - segment into sentences + Full segmentation - segment into sentences. :param text: + :param turns: segment into sentences by splitting on sentences or on turns. Default: sentences. :return: """ @@ -46,13 +47,19 @@ class LamboTokenizer(Tokenizer): sentence_tokens = [] for turn in document.turns: - for sentence in turn.sentences: + if turns: sentence_tokens = [] + for sentence in turn.sentences: + if not turns: + sentence_tokens = [] for token in sentence.tokens: if len(token.subwords) > 0: sentence_tokens.extend([s for s in token.subwords]) else: sentence_tokens.append(token.text) + if not turns: + sentences.append(sentence_tokens) + if turns: sentences.append(sentence_tokens) return sentences diff --git a/combo/main.py b/combo/main.py index 1e28106..5fe926f 100755 --- a/combo/main.py +++ b/combo/main.py @@ -78,10 +78,10 @@ flags.DEFINE_string(name="tensorboard_name", default="combo", help="Name of the model in TensorBoard logs.") flags.DEFINE_string(name="config_path", default="", help="Config file path.") -flags.DEFINE_boolean(name="save_matrices", default=True, - help="Save relation distribution matrices.") flags.DEFINE_list(name="datasets_for_vocabulary", default=["train"], help="") +flags.DEFINE_boolean(name="turns", default=False, + help="Segment into sentences on sentence break or on turn break.") # Finetune after training flags flags.DEFINE_string(name="finetuning_training_data_path", default="", @@ -107,6 +107,8 @@ flags.DEFINE_boolean(name="finetuning", default=False, help="Finetuning mode for training.") flags.DEFINE_string(name="tokenizer_language", default="English", help="Tokenizer language.") +flags.DEFINE_boolean(name="save_matrices", default=True, + help="Save relation distribution matrices.") flags.DEFINE_enum(name="predictor_name", default="combo-lambo", enum_values=["combo", "combo-spacy", "combo-lambo"], help="Use predictor with whitespace, spacy or lambo (recommended) tokenizer.") @@ -434,7 +436,7 @@ def run(_): else: tokenizer = LamboTokenizer(tokenizer_language) with open(FLAGS.input_file, "r", encoding='utf-8') as file: - input_sentences = tokenizer.segment(file.read()) + input_sentences = tokenizer.segment(file.read(), turns=FLAGS.turns) predictions = predictor.predict(input_sentences) with open(FLAGS.output_file, "w") as file: for prediction in tqdm(predictions): diff --git a/tests/data/tokenizers/test_lambo_tokenizer.py b/tests/data/tokenizers/test_lambo_tokenizer.py index 7c49cc5..2945a82 100644 --- a/tests/data/tokenizers/test_lambo_tokenizer.py +++ b/tests/data/tokenizers/test_lambo_tokenizer.py @@ -13,6 +13,14 @@ class LamboTokenizerTest(unittest.TestCase): self.assertListEqual([t.text for t in tokens], ['Hello', 'cats', '.', 'I', 'love', 'you']) + def test_segment_text(self): + tokens = self.lambo_tokenizer.segment('Hello cats. I love you.\n\nHi.') + self.assertListEqual(tokens, [['Hello', 'cats', '.'], ['I', 'love', 'you', '.'], ['Hi', '.']]) + + def test_segment_text_with_turns(self): + tokens = self.lambo_tokenizer.segment('Hello cats. I love you.\n\nHi.', turns=True) + self.assertListEqual(tokens, [['Hello', 'cats', '.', 'I', 'love', 'you', '.'], ['Hi', '.']]) + def test_tokenize_sentence_with_multiword(self): tokens = self.lambo_tokenizer.tokenize('I don\'t like apples.') self.assertListEqual([t.text for t in tokens], -- GitLab