From fd0f2ea1f11da7e58a83889114a7fa2c9476f845 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Thu, 18 Jun 2020 12:01:16 +0200
Subject: [PATCH] Update allennlp to 1.0.0.

---
 combo/data/dataset.py                         | 12 ++++++--
 combo/data/token_indexers/__init__.py         |  1 -
 ...etrained_transformer_mismatched_indexer.py | 30 -------------------
 combo/models/embeddings.py                    | 27 ++---------------
 combo/models/parser.py                        | 16 ++++++----
 combo/predict.py                              |  5 ++--
 combo/training/trainer.py                     |  2 +-
 config.template.jsonnet                       |  2 +-
 setup.py                                      |  2 +-
 9 files changed, 29 insertions(+), 68 deletions(-)
 delete mode 100644 combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py

diff --git a/combo/data/dataset.py b/combo/data/dataset.py
index d5bfc9c..b5f5c30 100644
--- a/combo/data/dataset.py
+++ b/combo/data/dataset.py
@@ -230,6 +230,14 @@ def get_slices_if_not_provided(vocab: allen_data.Vocabulary):
         return vocab.slices
 
 
-@dataclass
+@dataclass(init=False, repr=False)
 class _Token(allen_data.Token):
-    feats_: Optional[str] = None
+    __slots__ = allen_data.Token.__slots__ + ['feats_']
+
+    feats_: Optional[str]
+
+    def __init__(self, text: str = None, idx: int = None, idx_end: int = None, lemma_: str = None, pos_: str = None,
+                 tag_: str = None, dep_: str = None, ent_type_: str = None, text_id: int = None, type_id: int = None,
+                 feats_: str = None) -> None:
+        super().__init__(text, idx, idx_end, lemma_, pos_, tag_, dep_, ent_type_, text_id, type_id)
+        self.feats_ = feats_
diff --git a/combo/data/token_indexers/__init__.py b/combo/data/token_indexers/__init__.py
index 14f4f1c..1b918b3 100644
--- a/combo/data/token_indexers/__init__.py
+++ b/combo/data/token_indexers/__init__.py
@@ -1,3 +1,2 @@
 from .token_characters_indexer import TokenCharactersIndexer
 from .token_features_indexer import TokenFeatsIndexer
-from .pretrained_transformer_mismatched_indexer import PretrainedTransformerMismatchedIndexer
diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
deleted file mode 100644
index e7a6626..0000000
--- a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py
+++ /dev/null
@@ -1,30 +0,0 @@
-from typing import List
-
-from allennlp import data
-from allennlp.data import token_indexers
-from overrides import overrides
-
-
-@token_indexers.TokenIndexer.register("pretrained_transformer_mismatched_tmp_fix")
-class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransformerMismatchedIndexer):
-    """TODO remove when fixed in AllenNLP (fc47bf6ae5c0df6d473103d459b75fa7edbdd979)"""
-
-    @overrides
-    def tokens_to_indices(self, tokens: List[data.Token], vocabulary: data.Vocabulary) -> data.IndexedTokenList:
-        self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)
-
-        wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize([t.text for t in tokens])
-
-        # For tokens that don't correspond to any word pieces, we put (-1, -1) into the offsets.
-        # That results in the embedding for the token to be all zeros.
-        offsets = [x if x is not None else (-1, -1) for x in offsets]
-
-        output: data.IndexedTokenList = {
-            "token_ids": [t.text_id for t in wordpieces],
-            "mask": [True] * len(tokens),  # for original tokens (i.e. word-level)
-            "type_ids": [t.type_id for t in wordpieces],
-            "offsets": offsets,
-            "wordpiece_mask": [True] * len(wordpieces),  # for wordpieces (i.e. subword-level)
-        }
-
-        return self._matched_indexer._postprocess_output(output)
diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py
index 009ca72..edb37a1 100644
--- a/combo/models/embeddings.py
+++ b/combo/models/embeddings.py
@@ -137,31 +137,10 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
             type_ids: Optional[torch.LongTensor] = None,
             segment_concat_mask: Optional[torch.BoolTensor] = None,
     ) -> torch.Tensor:  # type: ignore
-        """TODO remove (and call super) when fixed in AllenNLP (fc47bf6ae5c0df6d473103d459b75fa7edbdd979)"""
-        # Shape: [batch_size, num_wordpieces, embedding_size].
-        embeddings = self._matched_embedder(
-            token_ids, wordpiece_mask, type_ids=type_ids, segment_concat_mask=segment_concat_mask
-        )
-
-        # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size)
-        # span_mask: (batch_size, num_orig_tokens, max_span_length)
-        span_embeddings, span_mask = util.batched_span_select(embeddings.contiguous(), offsets)
-        span_mask = span_mask.unsqueeze(-1)
-        span_embeddings *= span_mask  # zero out paddings
-
-        span_embeddings_sum = span_embeddings.sum(2)
-        span_embeddings_len = span_mask.sum(2)
-        # Shape: (batch_size, num_orig_tokens, embedding_size)
-        orig_embeddings = span_embeddings_sum / span_embeddings_len
-
-        # All the places where the span length is zero, write in zeros.
-        orig_embeddings[(span_embeddings_len == 0).expand(orig_embeddings.shape)] = 0
-
-        # TODO end remove
-
+        x = super().forward(token_ids, mask, offsets, wordpiece_mask, type_ids, segment_concat_mask)
         if self.projection_layer:
-            orig_embeddings = self.projection_layer(orig_embeddings)
-        return orig_embeddings
+            x = self.projection_layer(x)
+        return x
 
     @overrides
     def get_output_dim(self):
diff --git a/combo/models/parser.py b/combo/models/parser.py
index 7033089..486b248 100644
--- a/combo/models/parser.py
+++ b/combo/models/parser.py
@@ -67,11 +67,15 @@ class HeadPredictionModel(base.Predictor):
     def _cycle_loss(self, pred: torch.Tensor):
         BATCH_SIZE, _, _ = pred.size()
         loss = pred.new_zeros(BATCH_SIZE)
-        # 1: as using non __ROOT__ tokens
-        x = pred[:, 1:, 1:]
-        for _ in range(self.cycle_loss_n):
-            loss += self._batch_trace(x) / BATCH_SIZE
-            x = x.bmm(pred[:, 1:, 1:])
+        # Index from 1: as using non __ROOT__ tokens
+        pred = pred.softmax(-1)[:, 1:, 1:]
+        x = pred
+        for i in range(self.cycle_loss_n):
+            loss += self._batch_trace(x)
+
+            # Don't multiple on last iteration
+            if i < self.cycle_loss_n - 1:
+                x = x.bmm(pred)
 
         return loss
 
@@ -83,7 +87,7 @@ class HeadPredictionModel(base.Predictor):
         identity = x.new_tensor(torch.eye(N))
         identity = identity.reshape((1, N, N))
         batch_identity = identity.repeat(BATCH_SIZE, 1, 1)
-        return (x * batch_identity).sum()
+        return (x * batch_identity).sum((-1, -2))
 
     def _loss(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor,
               sample_weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
diff --git a/combo/predict.py b/combo/predict.py
index 1f48e22..35dbbee 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -29,6 +29,7 @@ class SemanticMultitaskPredictor(predictor.Predictor):
         self.batch_size = batch_size
         self.vocab = model.vocab
         self._dataset_reader.generate_labels = False
+        self._dataset_reader.lazy = True
         self._tokenizer = tokenizer
 
     def __call__(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]):
@@ -71,8 +72,8 @@ class SemanticMultitaskPredictor(predictor.Predictor):
             raise ValueError("Input must be either string or list of strings.")
 
     @overrides
-    def predict_batch_instance(self, instances: List[allen_data.Instance], serialize: bool = True) -> List[
-        common.JsonDict]:
+    def predict_batch_instance(self, instances: List[allen_data.Instance], serialize: bool = True
+                               ) -> List[common.JsonDict]:
         trees = []
         predictions = super().predict_batch_instance(instances)
         for prediction, instance in zip(predictions, instances):
diff --git a/combo/training/trainer.py b/combo/training/trainer.py
index f7fedad..2096d48 100644
--- a/combo/training/trainer.py
+++ b/combo/training/trainer.py
@@ -128,7 +128,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
                             num_batches,
                             reset=True,
                             world_size=self._world_size,
-                            cuda_device=[self.cuda_device],
+                            cuda_device=self.cuda_device,
                         )
 
                         # Check validation metric for early stopping
diff --git a/config.template.jsonnet b/config.template.jsonnet
index 1f47b38..57f02ae 100644
--- a/config.template.jsonnet
+++ b/config.template.jsonnet
@@ -112,7 +112,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
         use_sem: if in_targets("semrel") then true else false,
         token_indexers: {
             token: if use_transformer then {
-                type: "pretrained_transformer_mismatched_tmp_fix",
+                type: "pretrained_transformer_mismatched",
                 model_name: pretrained_transformer_name,
             } else {
                 # SingleIdTokenIndexer, token as single int
diff --git a/setup.py b/setup.py
index f74cd32..ee68a63 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@ from setuptools import find_packages, setup
 
 REQUIREMENTS = [
     'absl-py==0.9.0',
-    'allennlp==1.0.0rc5',
+    'allennlp==1.0.0',
     'conllu==2.3.2',
     'joblib==0.14.1',
     'jsonnet==0.15.0',
-- 
GitLab