From 2e2525dec74ba9679f053240d618b9c0ef46e910 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Thu, 20 May 2021 14:02:09 +0200
Subject: [PATCH] Enable weighted average of LM embeddings.

---
 combo/config.graph.template.jsonnet | 3 ++-
 combo/models/embeddings.py          | 2 ++
 2 files changed, 4 insertions(+), 1 deletion(-)

diff --git a/combo/config.graph.template.jsonnet b/combo/config.graph.template.jsonnet
index c72a057..708cfa3 100644
--- a/combo/config.graph.template.jsonnet
+++ b/combo/config.graph.template.jsonnet
@@ -49,7 +49,7 @@ local lemma_char_dim = 64;
 # Character embedding dim, int
 local char_dim = 64;
 # Word embedding projection dim, int
-local projected_embedding_dim = 100;
+local projected_embedding_dim = 768;
 # Loss weights, dict[str, int]
 local loss_weights = {
     xpostag: 0.05,
@@ -202,6 +202,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
                 },
                 token: if use_transformer then {
                     type: "transformers_word_embeddings",
+                    last_layer_only: false,
                     model_name: pretrained_transformer_name,
                     projection_dim: projected_embedding_dim,
                     tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert")
diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py
index d8e9d7a..d8c3d71 100644
--- a/combo/models/embeddings.py
+++ b/combo/models/embeddings.py
@@ -111,10 +111,12 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
                  projection_activation: Optional[allen_nn.Activation] = lambda x: x,
                  projection_dropout_rate: Optional[float] = 0.0,
                  freeze_transformer: bool = True,
+                 last_layer_only: bool = True,
                  tokenizer_kwargs: Optional[Dict[str, Any]] = None,
                  transformer_kwargs: Optional[Dict[str, Any]] = None):
         super().__init__(model_name,
                          train_parameters=not freeze_transformer,
+                         last_layer_only=last_layer_only,
                          tokenizer_kwargs=tokenizer_kwargs,
                          transformer_kwargs=transformer_kwargs)
         if projection_dim:
-- 
GitLab