From fed57c5488d8abd06480a785c7b96fa6064a7cad Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Wed, 22 Nov 2023 01:12:56 +1100 Subject: [PATCH] Correct LAMBO segmentation --- combo/data/tokenizers/lambo_tokenizer.py | 1 - combo/data/tokenizers/tokenizer.py | 8 ++++++++ combo/main.py | 17 +++++++++++------ combo/predict.py | 5 +++-- 4 files changed, 22 insertions(+), 9 deletions(-) diff --git a/combo/data/tokenizers/lambo_tokenizer.py b/combo/data/tokenizers/lambo_tokenizer.py index f37b076..c5f4451 100644 --- a/combo/data/tokenizers/lambo_tokenizer.py +++ b/combo/data/tokenizers/lambo_tokenizer.py @@ -49,4 +49,3 @@ class LamboTokenizer(Tokenizer): sentences.append([t.text for t in sentence.tokens]) return sentences - diff --git a/combo/data/tokenizers/tokenizer.py b/combo/data/tokenizers/tokenizer.py index 7b9269c..163036e 100644 --- a/combo/data/tokenizers/tokenizer.py +++ b/combo/data/tokenizers/tokenizer.py @@ -72,3 +72,11 @@ class Tokenizer(FromParameters): Returns the number of special tokens added for a pair of sequences. """ return 0 + + def segment(self, text: str) -> List[List[str]]: + """ + Full segmentation - segment into sentences + :param text: + :return: + """ + return [[]] diff --git a/combo/main.py b/combo/main.py index dd4bdd2..a76db04 100755 --- a/combo/main.py +++ b/combo/main.py @@ -27,6 +27,8 @@ from combo.modules.model import Model from combo.utils import ConfigurationError from combo.utils.matrices import extract_combo_matrices +import codecs + logging.setLoggerClass(ComboLogger) logger = logging.getLogger(__name__) _FEATURES = ["token", "char", "upostag", "xpostag", "lemma", "feats"] @@ -383,9 +385,13 @@ def run(_): if FLAGS.input_file == '-': print("Interactive mode.") sentence = input("Sentence: ") - prediction = predictor(sentence) + prediction = [p.tokens for p in predictor(sentence)] + # Flatten the prediction + flattened_prediction = [] + for p in prediction: + flattened_prediction.extend(p) print("{:15} {:15} {:10} {:10} {:10}".format('TOKEN', 'LEMMA', 'UPOS', 'HEAD', 'DEPREL')) - for token in prediction.tokens: + for token in flattened_prediction: print("{:15} {:15} {:10} {:10} {:10}".format(token.text, token.lemma, token.upostag, token.head, token.deprel)) elif FLAGS.output_file: @@ -410,14 +416,13 @@ def run(_): else: tokenizer = LamboTokenizer(tokenizer_language) - with open(FLAGS.input_file, "r") as file: + with open(FLAGS.input_file, "r", encoding='utf-8') as file: input_sentences = tokenizer.segment(file.read()) + predictions = predictor.predict(input_sentences) with open(FLAGS.output_file, "w") as file: - for sentence in tqdm(input_sentences): - prediction = predictor.predict(' '.join(sentence)) + for prediction in tqdm(predictions): file.writelines(api.sentence2conllu(prediction, keep_semrel=dataset_reader.use_sem).serialize()) - predictions.append(prediction) if FLAGS.save_matrices: logger.info("Saving matrices", prefix=prefix) diff --git a/combo/predict.py b/combo/predict.py index 4450d2f..8363e50 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -74,8 +74,9 @@ class COMBO(PredictorModule): def predict(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]): if isinstance(sentence, str): - return self.predict_json({"sentence": sentence}) - elif isinstance(sentence, list): + sentence = self.dataset_reader.tokenizer.segment(sentence) + + if isinstance(sentence, list): if len(sentence) == 0: return [] example = sentence[0] -- GitLab