Skip to content
Snippets Groups Projects
Commit e3507eba authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Fix error in semrel training, mapping feats as features and finetuning paths handling.

parent b189603d
Branches
Tags
No related merge requests found
......@@ -52,7 +52,11 @@ class TokenFeatsIndexer(data.TokenIndexer):
if feat in ["_", "__ROOT__"]:
pass
else:
# Handle case where feature is binary (doesn't have associated value)
if value:
features.append(feat + "=" + value)
else:
features.append(feat)
return features
@overrides
......
......@@ -56,10 +56,10 @@ flags.DEFINE_boolean(name="tensorboard", default=False,
help="When provided model will log tensorboard metrics.")
# Finetune after training flags
flags.DEFINE_string(name="finetuning_training_data_path", default="",
help="Training data path")
flags.DEFINE_string(name="finetuning_validation_data_path", default="",
help="Validation data path")
flags.DEFINE_list(name="finetuning_training_data_path", default="",
help="Training data path(s)")
flags.DEFINE_list(name="finetuning_validation_data_path", default="",
help="Validation data path(s)")
flags.DEFINE_string(name="config_path", default="config.template.jsonnet",
help="Config file path.")
......@@ -101,7 +101,8 @@ def run(_):
logger.info(f"Training model stored in: {serialization_dir}")
if FLAGS.finetuning_training_data_path:
checks.file_exists(FLAGS.finetuning_training_data_path)
for f in FLAGS.finetuning_training_data_path:
checks.file_exists(f)
# Loading will be performed from stored model.tar.gz
del model
......@@ -171,9 +172,9 @@ def _get_ext_vars(finetuning: bool = False) -> Dict:
return {}
return {
"training_data_path": (
":".join(FLAGS.training_data_path) if not finetuning else FLAGS.finetuning_training_data_path),
",".join(FLAGS.training_data_path if not finetuning else FLAGS.finetuning_training_data_path)),
"validation_data_path": (
":".join(FLAGS.validation_data_path) if not finetuning else FLAGS.finetuning_validation_data_path),
",".join(FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path)),
"pretrained_tokens": FLAGS.pretrained_tokens,
"pretrained_transformer_name": FLAGS.pretrained_transformer_name,
"features": " ".join(FLAGS.features),
......
......@@ -42,7 +42,11 @@ class HeadPredictionModel(base.Predictor):
lengths = mask.data.sum(dim=1).long().cpu().numpy() + 1
for idx, length in enumerate(lengths):
probs = x[idx, :].softmax(dim=-1).cpu().numpy()
probs[:, 0] = 0
# We do not want any word to be parent of the root node (ROOT, 0).
# Also setting it to -1 instead of 0 fixes edge case where softmax made all
# but ROOT prediction to EXACTLY 0.0 and it might cause in many ROOT -> word edges)
probs[:, 0] = -1
heads, _ = chu_liu_edmonds.decode_mst(probs.T, length=length, has_labels=False)
heads[0] = 0
pred.append(heads)
......
......@@ -58,6 +58,7 @@ local loss_weights = {
feats: 0.2,
deprel: 0.8,
head: 0.2,
semrel: 0.05,
};
# Encoder hidden size, int
local hidden_size = 512;
......
# sent_id = test-s1
# text = Easy sentence.
1 Verylongwordwhichmustbetruncatedbythesystemto30 easy ADJ adj AdpType=Prep 1 amod _ _
1 Verylongwordwhichmustbetruncatedbythesystemto30 easy ADJ adj AdpType=Prep|Adp 1 amod _ _
2 Sentence verylonglemmawhichmustbetruncatedbythesystemto30 NOUN nom Number=Sing 0 root _ _
3 . . PUNCT . _ 1 punct _ _
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment