diff --git a/combo/data/token_indexers/token_features_indexer.py b/combo/data/token_indexers/token_features_indexer.py index d027d62632027bae0115c5882c49319a63c4eeca..7c591243ec19b0f27c3344cd748d17e3b9aa50f6 100644 --- a/combo/data/token_indexers/token_features_indexer.py +++ b/combo/data/token_indexers/token_features_indexer.py @@ -52,7 +52,11 @@ class TokenFeatsIndexer(data.TokenIndexer): if feat in ["_", "__ROOT__"]: pass else: - features.append(feat + "=" + value) + # Handle case where feature is binary (doesn't have associated value) + if value: + features.append(feat + "=" + value) + else: + features.append(feat) return features @overrides diff --git a/combo/main.py b/combo/main.py index 366fa93fe7ac94ee74f076987d159ba63b8af0d2..e2e29d437ad31e24c69fe6c95c60d29a8491e085 100644 --- a/combo/main.py +++ b/combo/main.py @@ -56,10 +56,10 @@ flags.DEFINE_boolean(name="tensorboard", default=False, help="When provided model will log tensorboard metrics.") # Finetune after training flags -flags.DEFINE_string(name="finetuning_training_data_path", default="", - help="Training data path") -flags.DEFINE_string(name="finetuning_validation_data_path", default="", - help="Validation data path") +flags.DEFINE_list(name="finetuning_training_data_path", default="", + help="Training data path(s)") +flags.DEFINE_list(name="finetuning_validation_data_path", default="", + help="Validation data path(s)") flags.DEFINE_string(name="config_path", default="config.template.jsonnet", help="Config file path.") @@ -101,7 +101,8 @@ def run(_): logger.info(f"Training model stored in: {serialization_dir}") if FLAGS.finetuning_training_data_path: - checks.file_exists(FLAGS.finetuning_training_data_path) + for f in FLAGS.finetuning_training_data_path: + checks.file_exists(f) # Loading will be performed from stored model.tar.gz del model @@ -171,9 +172,9 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: return {} return { "training_data_path": ( - ":".join(FLAGS.training_data_path) if not finetuning else FLAGS.finetuning_training_data_path), + ",".join(FLAGS.training_data_path if not finetuning else FLAGS.finetuning_training_data_path)), "validation_data_path": ( - ":".join(FLAGS.validation_data_path) if not finetuning else FLAGS.finetuning_validation_data_path), + ",".join(FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path)), "pretrained_tokens": FLAGS.pretrained_tokens, "pretrained_transformer_name": FLAGS.pretrained_transformer_name, "features": " ".join(FLAGS.features), diff --git a/combo/models/parser.py b/combo/models/parser.py index de71781790d72fcfec6f1c2bd1c4e607d676302b..70330890f9050b04197343e6eef2cce9da5282f2 100644 --- a/combo/models/parser.py +++ b/combo/models/parser.py @@ -42,7 +42,11 @@ class HeadPredictionModel(base.Predictor): lengths = mask.data.sum(dim=1).long().cpu().numpy() + 1 for idx, length in enumerate(lengths): probs = x[idx, :].softmax(dim=-1).cpu().numpy() - probs[:, 0] = 0 + + # We do not want any word to be parent of the root node (ROOT, 0). + # Also setting it to -1 instead of 0 fixes edge case where softmax made all + # but ROOT prediction to EXACTLY 0.0 and it might cause in many ROOT -> word edges) + probs[:, 0] = -1 heads, _ = chu_liu_edmonds.decode_mst(probs.T, length=length, has_labels=False) heads[0] = 0 pred.append(heads) diff --git a/config.template.jsonnet b/config.template.jsonnet index 3bb0dad0c56bf1db897f89fc1f67c3a0ff24f8a8..1f47b38fc4fec65c3d4cafcbb83b77496ca99dca 100644 --- a/config.template.jsonnet +++ b/config.template.jsonnet @@ -58,6 +58,7 @@ local loss_weights = { feats: 0.2, deprel: 0.8, head: 0.2, + semrel: 0.05, }; # Encoder hidden size, int local hidden_size = 512; diff --git a/tests/fixtures/example.conllu b/tests/fixtures/example.conllu index 0cf30f04351db8904513b7363a6e8414b27cd347..b58f0f33ebe554d7febcc8019544f49a59cd58cc 100644 --- a/tests/fixtures/example.conllu +++ b/tests/fixtures/example.conllu @@ -1,5 +1,5 @@ # sent_id = test-s1 # text = Easy sentence. -1 Verylongwordwhichmustbetruncatedbythesystemto30 easy ADJ adj AdpType=Prep 1 amod _ _ +1 Verylongwordwhichmustbetruncatedbythesystemto30 easy ADJ adj AdpType=Prep|Adp 1 amod _ _ 2 Sentence verylonglemmawhichmustbetruncatedbythesystemto30 NOUN nom Number=Sing 0 root _ _ 3 . . PUNCT . _ 1 punct _ _