From f3dcbc3931617f829c9682b908b9bbd0006410c9 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Thu, 22 Apr 2021 11:55:48 +0200
Subject: [PATCH] Configure validation metric for early stopping.

---
 combo/config.multitask.template.jsonnet | 4 +++-
 combo/config.shared.libsonnet           | 4 ++--
 2 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/combo/config.multitask.template.jsonnet b/combo/config.multitask.template.jsonnet
index 3553bc1..a977819 100644
--- a/combo/config.multitask.template.jsonnet
+++ b/combo/config.multitask.template.jsonnet
@@ -397,7 +397,9 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
             },
         },
     }),
-    trainer: shared_config.Trainer(cuda_device, num_epochs, learning_rate, use_tensorboard),
+    trainer: shared_config.Trainer(cuda_device, num_epochs, learning_rate, use_tensorboard,
+        ["+conllu_EM", "+iob_f1-measure-overall"]
+    ),
     random_seed: 8787,
     pytorch_seed: 8787,
     numpy_seed: 8787,
diff --git a/combo/config.shared.libsonnet b/combo/config.shared.libsonnet
index bb262f6..2804c8c 100644
--- a/combo/config.shared.libsonnet
+++ b/combo/config.shared.libsonnet
@@ -1,5 +1,5 @@
 {
-    local trainer(cuda_device, num_epochs, learning_rate, use_tensorboard) =
+    local trainer(cuda_device, num_epochs, learning_rate, use_tensorboard, validation_metric="+EM") =
         std.prune({
             checkpointer: {
                 type: "finishing_only_checkpointer",
@@ -23,7 +23,7 @@
             learning_rate_scheduler: {
                 type: "combo_scheduler",
             },
-            validation_metric: "+EM",
+            validation_metric: validation_metric,
         }),
 
     local lemma(hidden_size, dropout) = {
-- 
GitLab