From c5031d73779c261389ece396280b2482c9e65b30 Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Wed, 22 Nov 2023 01:24:25 +1100 Subject: [PATCH] Switched id and text in token list --- combo/data/api.py | 48 +++++++++++++++++++++++++++++++++++++++++++++++ combo/main.py | 11 ++++------- 2 files changed, 52 insertions(+), 7 deletions(-) diff --git a/combo/data/api.py b/combo/data/api.py index 0f407cd..d49a531 100644 --- a/combo/data/api.py +++ b/combo/data/api.py @@ -55,6 +55,54 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.mode return _TokenList(tokens=tokens, metadata=sentence.metadata if sentence.metadata is None else Metadata()) +def serialize_field(field: Any) -> str: + if field is None: + return '_' + + if isinstance(field, dict): + if field == {}: + return '_' + + fields = [] + for key, value in field.items(): + if value is None: + value = "_" + if value == "": + fields.append(key) + continue + + fields.append('='.join((key, value))) + + return '|'.join(fields) + + if isinstance(field, tuple): + return "".join([serialize_field(item) for item in field]) + + if isinstance(field, list): + if len(field[0]) != 2: + raise ValueError("Can't serialize '{}', invalid format".format(field)) + return "|".join([serialize_field(value) + ":" + str(key) for key, value in field]) + + return "{}".format(field) + +def serialize_token_list(tokenlist: conllu.models.TokenList) -> str: + KEYS_ORDER = ['idx', 'text', 'lemma', 'upostag', 'xpostag', + 'entity_type', 'feats', 'head', 'deprel', 'deps', 'misc'] + lines = [] + + if tokenlist.metadata: + for key, value in tokenlist.metadata.items(): + if value: + line = f"# {key} = {value}" + else: + line = f"# {key}" + lines.append(line) + + for token_data in tokenlist: + line = '\t'.join(serialize_field(token_data[k]) for k in KEYS_ORDER) + lines.append(line) + + return '\n'.join(lines) + "\n\n" def tokens2conllu(tokens: List[str]) -> conllu.models.TokenList: return _TokenList( diff --git a/combo/main.py b/combo/main.py index a76db04..9eb0a03 100755 --- a/combo/main.py +++ b/combo/main.py @@ -27,7 +27,6 @@ from combo.modules.model import Model from combo.utils import ConfigurationError from combo.utils.matrices import extract_combo_matrices -import codecs logging.setLoggerClass(ComboLogger) logger = logging.getLogger(__name__) @@ -103,8 +102,6 @@ flags.DEFINE_string(name="input_file", default=None, help="File to predict path") flags.DEFINE_boolean(name="conllu_format", default=True, help="Prediction based on conllu format (instead of raw text).") -flags.DEFINE_boolean(name="silent", default=True, - help="Silent prediction to file (without printing to console).") flags.DEFINE_boolean(name="finetuning", default=False, help="Finetuning mode for training.") flags.DEFINE_string(name="tokenizer_language", default="English", @@ -410,8 +407,8 @@ def run(_): with open(FLAGS.output_file, "w") as file: for tree in tqdm(test_trees): prediction = predictor.predict_instance(tree) - file.writelines(api.sentence2conllu(prediction, - keep_semrel=dataset_reader.use_sem).serialize()) + file.writelines(api.serialize_token_list(api.sentence2conllu(prediction, + keep_semrel=dataset_reader.use_sem))) predictions.append(prediction) else: @@ -421,8 +418,8 @@ def run(_): predictions = predictor.predict(input_sentences) with open(FLAGS.output_file, "w") as file: for prediction in tqdm(predictions): - file.writelines(api.sentence2conllu(prediction, - keep_semrel=dataset_reader.use_sem).serialize()) + file.writelines(api.serialize_token_list(api.sentence2conllu(prediction, + keep_semrel=dataset_reader.use_sem))) if FLAGS.save_matrices: logger.info("Saving matrices", prefix=prefix) -- GitLab