From c5031d73779c261389ece396280b2482c9e65b30 Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Wed, 22 Nov 2023 01:24:25 +1100
Subject: [PATCH] Switched id and text in token list

---
 combo/data/api.py | 48 +++++++++++++++++++++++++++++++++++++++++++++++
 combo/main.py     | 11 ++++-------
 2 files changed, 52 insertions(+), 7 deletions(-)

diff --git a/combo/data/api.py b/combo/data/api.py
index 0f407cd..d49a531 100644
--- a/combo/data/api.py
+++ b/combo/data/api.py
@@ -55,6 +55,54 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.mode
     return _TokenList(tokens=tokens,
                       metadata=sentence.metadata if sentence.metadata is None else Metadata())
 
+def serialize_field(field: Any) -> str:
+    if field is None:
+        return '_'
+
+    if isinstance(field, dict):
+        if field == {}:
+            return '_'
+
+        fields = []
+        for key, value in field.items():
+            if value is None:
+                value = "_"
+            if value == "":
+                fields.append(key)
+                continue
+
+            fields.append('='.join((key, value)))
+
+        return '|'.join(fields)
+
+    if isinstance(field, tuple):
+        return "".join([serialize_field(item) for item in field])
+
+    if isinstance(field, list):
+        if len(field[0]) != 2:
+            raise ValueError("Can't serialize '{}', invalid format".format(field))
+        return "|".join([serialize_field(value) + ":" + str(key) for key, value in field])
+
+    return "{}".format(field)
+
+def serialize_token_list(tokenlist: conllu.models.TokenList) -> str:
+    KEYS_ORDER = ['idx', 'text', 'lemma', 'upostag', 'xpostag',
+                  'entity_type', 'feats', 'head', 'deprel', 'deps', 'misc']
+    lines = []
+
+    if tokenlist.metadata:
+        for key, value in tokenlist.metadata.items():
+            if value:
+                line = f"# {key} = {value}"
+            else:
+                line = f"# {key}"
+            lines.append(line)
+
+    for token_data in tokenlist:
+        line = '\t'.join(serialize_field(token_data[k]) for k in KEYS_ORDER)
+        lines.append(line)
+
+    return '\n'.join(lines) + "\n\n"
 
 def tokens2conllu(tokens: List[str]) -> conllu.models.TokenList:
     return _TokenList(
diff --git a/combo/main.py b/combo/main.py
index a76db04..9eb0a03 100755
--- a/combo/main.py
+++ b/combo/main.py
@@ -27,7 +27,6 @@ from combo.modules.model import Model
 from combo.utils import ConfigurationError
 from combo.utils.matrices import extract_combo_matrices
 
-import codecs
 
 logging.setLoggerClass(ComboLogger)
 logger = logging.getLogger(__name__)
@@ -103,8 +102,6 @@ flags.DEFINE_string(name="input_file", default=None,
                     help="File to predict path")
 flags.DEFINE_boolean(name="conllu_format", default=True,
                      help="Prediction based on conllu format (instead of raw text).")
-flags.DEFINE_boolean(name="silent", default=True,
-                     help="Silent prediction to file (without printing to console).")
 flags.DEFINE_boolean(name="finetuning", default=False,
                      help="Finetuning mode for training.")
 flags.DEFINE_string(name="tokenizer_language", default="English",
@@ -410,8 +407,8 @@ def run(_):
                 with open(FLAGS.output_file, "w") as file:
                     for tree in tqdm(test_trees):
                         prediction = predictor.predict_instance(tree)
-                        file.writelines(api.sentence2conllu(prediction,
-                                                            keep_semrel=dataset_reader.use_sem).serialize())
+                        file.writelines(api.serialize_token_list(api.sentence2conllu(prediction,
+                                                                                     keep_semrel=dataset_reader.use_sem)))
                         predictions.append(prediction)
 
             else:
@@ -421,8 +418,8 @@ def run(_):
                 predictions = predictor.predict(input_sentences)
                 with open(FLAGS.output_file, "w") as file:
                     for prediction in tqdm(predictions):
-                        file.writelines(api.sentence2conllu(prediction,
-                                                            keep_semrel=dataset_reader.use_sem).serialize())
+                        file.writelines(api.serialize_token_list(api.sentence2conllu(prediction,
+                                                                                     keep_semrel=dataset_reader.use_sem)))
 
             if FLAGS.save_matrices:
                 logger.info("Saving matrices", prefix=prefix)
-- 
GitLab