diff --git a/combo/data/api.py b/combo/data/api.py index 0f407cdc7f3b3a71e88c96f12a0b45d175ff2c1d..d49a531a27d78dd50357fc10581ded59f6d45d5b 100644 --- a/combo/data/api.py +++ b/combo/data/api.py @@ -55,6 +55,54 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.mode return _TokenList(tokens=tokens, metadata=sentence.metadata if sentence.metadata is None else Metadata()) +def serialize_field(field: Any) -> str: + if field is None: + return '_' + + if isinstance(field, dict): + if field == {}: + return '_' + + fields = [] + for key, value in field.items(): + if value is None: + value = "_" + if value == "": + fields.append(key) + continue + + fields.append('='.join((key, value))) + + return '|'.join(fields) + + if isinstance(field, tuple): + return "".join([serialize_field(item) for item in field]) + + if isinstance(field, list): + if len(field[0]) != 2: + raise ValueError("Can't serialize '{}', invalid format".format(field)) + return "|".join([serialize_field(value) + ":" + str(key) for key, value in field]) + + return "{}".format(field) + +def serialize_token_list(tokenlist: conllu.models.TokenList) -> str: + KEYS_ORDER = ['idx', 'text', 'lemma', 'upostag', 'xpostag', + 'entity_type', 'feats', 'head', 'deprel', 'deps', 'misc'] + lines = [] + + if tokenlist.metadata: + for key, value in tokenlist.metadata.items(): + if value: + line = f"# {key} = {value}" + else: + line = f"# {key}" + lines.append(line) + + for token_data in tokenlist: + line = '\t'.join(serialize_field(token_data[k]) for k in KEYS_ORDER) + lines.append(line) + + return '\n'.join(lines) + "\n\n" def tokens2conllu(tokens: List[str]) -> conllu.models.TokenList: return _TokenList( diff --git a/combo/main.py b/combo/main.py index a76db04fe90caa40020ddc616282a10a2c2038f6..9eb0a03469f7b9d344cd2797cca8fc8874b259e9 100755 --- a/combo/main.py +++ b/combo/main.py @@ -27,7 +27,6 @@ from combo.modules.model import Model from combo.utils import ConfigurationError from combo.utils.matrices import extract_combo_matrices -import codecs logging.setLoggerClass(ComboLogger) logger = logging.getLogger(__name__) @@ -103,8 +102,6 @@ flags.DEFINE_string(name="input_file", default=None, help="File to predict path") flags.DEFINE_boolean(name="conllu_format", default=True, help="Prediction based on conllu format (instead of raw text).") -flags.DEFINE_boolean(name="silent", default=True, - help="Silent prediction to file (without printing to console).") flags.DEFINE_boolean(name="finetuning", default=False, help="Finetuning mode for training.") flags.DEFINE_string(name="tokenizer_language", default="English", @@ -410,8 +407,8 @@ def run(_): with open(FLAGS.output_file, "w") as file: for tree in tqdm(test_trees): prediction = predictor.predict_instance(tree) - file.writelines(api.sentence2conllu(prediction, - keep_semrel=dataset_reader.use_sem).serialize()) + file.writelines(api.serialize_token_list(api.sentence2conllu(prediction, + keep_semrel=dataset_reader.use_sem))) predictions.append(prediction) else: @@ -421,8 +418,8 @@ def run(_): predictions = predictor.predict(input_sentences) with open(FLAGS.output_file, "w") as file: for prediction in tqdm(predictions): - file.writelines(api.sentence2conllu(prediction, - keep_semrel=dataset_reader.use_sem).serialize()) + file.writelines(api.serialize_token_list(api.sentence2conllu(prediction, + keep_semrel=dataset_reader.use_sem))) if FLAGS.save_matrices: logger.info("Saving matrices", prefix=prefix)