From 02df983498ad9ff444bba5ebf529596fcddfe174 Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Mon, 4 Mar 2024 23:46:27 +1100 Subject: [PATCH] Fix reading UD treebanks with "_" values --- .../universal_dependencies_dataset_reader.py | 7 ++++--- combo/main.py | 12 ++++++------ pyproject.toml | 2 +- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py index 22960a2..8ddfa78 100644 --- a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py +++ b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py @@ -130,7 +130,6 @@ class UniversalDependenciesDatasetReader(DatasetReader, ABC): for annotation in conllu.parse_incr(f, fields=self.fields, field_parsers=self.field_parsers): yield self.text_to_instance([Token.from_conllu_token(t) for t in annotation if isinstance(t.get("id"), int)]) - def text_to_instance(self, tree: List[Token]) -> Instance: fields_: Dict[str, Field] = {} @@ -193,8 +192,10 @@ class UniversalDependenciesDatasetReader(DatasetReader, ABC): # TODO: co robic gdy nie ma xpostag? if all([t is None for t in target_values]): continue - fields_[target_name] = SequenceLabelField(target_values, text_field, - label_namespace=target_name + "_labels") + else: + target_values = [t or '_' for t in target_values] + fields_[target_name] = SequenceLabelField(target_values, text_field, + label_namespace=target_name + "_labels") # Restore feats fields to string representation # parser.serialize_field doesn't handle key without value diff --git a/combo/main.py b/combo/main.py index b813074..1fefe6c 100755 --- a/combo/main.py +++ b/combo/main.py @@ -87,7 +87,7 @@ flags.DEFINE_list(name="datasets_for_vocabulary", default=["train"], help="") flags.DEFINE_enum(name="split_level", default="sentence", enum_values=["none", "turn", "sentence"], help="Don\'t segment, or segment into sentences on sentence break or on turn break.") -flags.DEFINE_boolean(name="split_multiwords", default=False, +flags.DEFINE_boolean(name="split_multiwords", default=True, help="Split subwords (e.g. don\'t = do, n\'t) into separate tokens.") flags.DEFINE_boolean(name="transformer_encoder", default=False, help="Use transformer encoder.") @@ -162,7 +162,7 @@ def get_defaults(dataset_reader: Optional[DatasetReader], # Dataset reader is required to read training data and/or for training (and validation) data loader dataset_reader = default_ud_dataset_reader(FLAGS.pretrained_transformer_name, tokenizer=LamboTokenizer(FLAGS.tokenizer_language, - default_split_level="TURNS" if FLAGS.turns else "SENTENCES", + default_split_level=FLAGS.split_level.upper(), default_split_multiwords=FLAGS.split_multiwords) ) @@ -459,8 +459,8 @@ def run(_): with open(FLAGS.output_file, "w") as file: for tree in tqdm(test_trees): prediction = predictor.predict_instance(tree) - file.writelines(api.serialize_token_list(api.sentence2conllu(prediction, - keep_semrel=dataset_reader.use_sem))) + file.writelines(api.sentence2conllu(prediction, + keep_semrel=dataset_reader.use_sem).serialize()) predictions.append(prediction) else: @@ -472,8 +472,8 @@ def run(_): predictions = predictor.predict(input_sentences) with open(FLAGS.output_file, "w") as file: for prediction in tqdm(predictions): - file.writelines(api.serialize_token_list(api.sentence2conllu(prediction, - keep_semrel=dataset_reader.use_sem))) + file.writelines(api.sentence2conllu(prediction, + keep_semrel=dataset_reader.use_sem).serialize()) if FLAGS.save_matrices: logger.info("Saving matrices", prefix=prefix) diff --git a/pyproject.toml b/pyproject.toml index da51df8..c17c0df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools"] [project] name = "combo" -version = "3.2.0" +version = "3.2.1" authors = [ {name = "Maja Jablonska", email = "maja.jablonska@ipipan.waw.pl"} ] -- GitLab