From 02df983498ad9ff444bba5ebf529596fcddfe174 Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Mon, 4 Mar 2024 23:46:27 +1100
Subject: [PATCH] Fix reading UD treebanks with "_" values

---
 .../universal_dependencies_dataset_reader.py         |  7 ++++---
 combo/main.py                                        | 12 ++++++------
 pyproject.toml                                       |  2 +-
 3 files changed, 11 insertions(+), 10 deletions(-)

diff --git a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py
index 22960a2..8ddfa78 100644
--- a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py
+++ b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py
@@ -130,7 +130,6 @@ class UniversalDependenciesDatasetReader(DatasetReader, ABC):
                 for annotation in conllu.parse_incr(f, fields=self.fields, field_parsers=self.field_parsers):
                     yield self.text_to_instance([Token.from_conllu_token(t) for t in annotation if isinstance(t.get("id"), int)])
 
-
     def text_to_instance(self, tree: List[Token]) -> Instance:
         fields_: Dict[str, Field] = {}
 
@@ -193,8 +192,10 @@ class UniversalDependenciesDatasetReader(DatasetReader, ABC):
                         # TODO: co robic gdy nie ma xpostag?
                         if all([t is None for t in target_values]):
                             continue
-                        fields_[target_name] = SequenceLabelField(target_values, text_field,
-                                                                  label_namespace=target_name + "_labels")
+                        else:
+                            target_values = [t or '_' for t in target_values]
+                            fields_[target_name] = SequenceLabelField(target_values, text_field,
+                                                                      label_namespace=target_name + "_labels")
 
         # Restore feats fields to string representation
         # parser.serialize_field doesn't handle key without value
diff --git a/combo/main.py b/combo/main.py
index b813074..1fefe6c 100755
--- a/combo/main.py
+++ b/combo/main.py
@@ -87,7 +87,7 @@ flags.DEFINE_list(name="datasets_for_vocabulary", default=["train"],
                   help="")
 flags.DEFINE_enum(name="split_level", default="sentence", enum_values=["none", "turn", "sentence"],
                   help="Don\'t segment, or segment into sentences on sentence break or on turn break.")
-flags.DEFINE_boolean(name="split_multiwords", default=False,
+flags.DEFINE_boolean(name="split_multiwords", default=True,
                      help="Split subwords (e.g. don\'t = do, n\'t) into separate tokens.")
 flags.DEFINE_boolean(name="transformer_encoder", default=False, help="Use transformer encoder.")
 
@@ -162,7 +162,7 @@ def get_defaults(dataset_reader: Optional[DatasetReader],
         # Dataset reader is required to read training data and/or for training (and validation) data loader
         dataset_reader = default_ud_dataset_reader(FLAGS.pretrained_transformer_name,
                                                    tokenizer=LamboTokenizer(FLAGS.tokenizer_language,
-                                                    default_split_level="TURNS" if FLAGS.turns else "SENTENCES",
+                                                    default_split_level=FLAGS.split_level.upper(),
                                                     default_split_multiwords=FLAGS.split_multiwords)
                                                    )
 
@@ -459,8 +459,8 @@ def run(_):
                 with open(FLAGS.output_file, "w") as file:
                     for tree in tqdm(test_trees):
                         prediction = predictor.predict_instance(tree)
-                        file.writelines(api.serialize_token_list(api.sentence2conllu(prediction,
-                                                                                     keep_semrel=dataset_reader.use_sem)))
+                        file.writelines(api.sentence2conllu(prediction,
+                                                            keep_semrel=dataset_reader.use_sem).serialize())
                         predictions.append(prediction)
 
             else:
@@ -472,8 +472,8 @@ def run(_):
                 predictions = predictor.predict(input_sentences)
                 with open(FLAGS.output_file, "w") as file:
                     for prediction in tqdm(predictions):
-                        file.writelines(api.serialize_token_list(api.sentence2conllu(prediction,
-                                                                                     keep_semrel=dataset_reader.use_sem)))
+                        file.writelines(api.sentence2conllu(prediction,
+                                                            keep_semrel=dataset_reader.use_sem).serialize())
 
             if FLAGS.save_matrices:
                 logger.info("Saving matrices", prefix=prefix)
diff --git a/pyproject.toml b/pyproject.toml
index da51df8..c17c0df 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,7 +3,7 @@ requires = ["setuptools"]
 
 [project]
 name = "combo"
-version = "3.2.0"
+version = "3.2.1"
 authors = [
     {name = "Maja Jablonska", email = "maja.jablonska@ipipan.waw.pl"}
 ]
-- 
GitLab