diff --git a/combo/data/dataset.py b/combo/data/dataset.py index d5bfc9c609b427a17c022c961f824c13f37deda7..b5f5c30d550c533fa8426ed9f7c298550ae0a3e5 100644 --- a/combo/data/dataset.py +++ b/combo/data/dataset.py @@ -230,6 +230,14 @@ def get_slices_if_not_provided(vocab: allen_data.Vocabulary): return vocab.slices -@dataclass +@dataclass(init=False, repr=False) class _Token(allen_data.Token): - feats_: Optional[str] = None + __slots__ = allen_data.Token.__slots__ + ['feats_'] + + feats_: Optional[str] + + def __init__(self, text: str = None, idx: int = None, idx_end: int = None, lemma_: str = None, pos_: str = None, + tag_: str = None, dep_: str = None, ent_type_: str = None, text_id: int = None, type_id: int = None, + feats_: str = None) -> None: + super().__init__(text, idx, idx_end, lemma_, pos_, tag_, dep_, ent_type_, text_id, type_id) + self.feats_ = feats_ diff --git a/combo/data/token_indexers/__init__.py b/combo/data/token_indexers/__init__.py index 14f4f1ccd3b3e7efba45d75b7cba5a522f3582e6..1b918b3ad66692a761b564c08d0c270745a263cb 100644 --- a/combo/data/token_indexers/__init__.py +++ b/combo/data/token_indexers/__init__.py @@ -1,3 +1,2 @@ from .token_characters_indexer import TokenCharactersIndexer from .token_features_indexer import TokenFeatsIndexer -from .pretrained_transformer_mismatched_indexer import PretrainedTransformerMismatchedIndexer diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py deleted file mode 100644 index e7a662623d98bc488b4d4868839b08b1b758e6f7..0000000000000000000000000000000000000000 --- a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import List - -from allennlp import data -from allennlp.data import token_indexers -from overrides import overrides - - -@token_indexers.TokenIndexer.register("pretrained_transformer_mismatched_tmp_fix") -class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransformerMismatchedIndexer): - """TODO remove when fixed in AllenNLP (fc47bf6ae5c0df6d473103d459b75fa7edbdd979)""" - - @overrides - def tokens_to_indices(self, tokens: List[data.Token], vocabulary: data.Vocabulary) -> data.IndexedTokenList: - self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary) - - wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize([t.text for t in tokens]) - - # For tokens that don't correspond to any word pieces, we put (-1, -1) into the offsets. - # That results in the embedding for the token to be all zeros. - offsets = [x if x is not None else (-1, -1) for x in offsets] - - output: data.IndexedTokenList = { - "token_ids": [t.text_id for t in wordpieces], - "mask": [True] * len(tokens), # for original tokens (i.e. word-level) - "type_ids": [t.type_id for t in wordpieces], - "offsets": offsets, - "wordpiece_mask": [True] * len(wordpieces), # for wordpieces (i.e. subword-level) - } - - return self._matched_indexer._postprocess_output(output) diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py index 009ca72f6004fc6d2bc9bf4eeae0d73975d6eb3a..edb37a174adb2a2f5fd8cd3edcc4c21c6bc4fa75 100644 --- a/combo/models/embeddings.py +++ b/combo/models/embeddings.py @@ -137,31 +137,10 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm type_ids: Optional[torch.LongTensor] = None, segment_concat_mask: Optional[torch.BoolTensor] = None, ) -> torch.Tensor: # type: ignore - """TODO remove (and call super) when fixed in AllenNLP (fc47bf6ae5c0df6d473103d459b75fa7edbdd979)""" - # Shape: [batch_size, num_wordpieces, embedding_size]. - embeddings = self._matched_embedder( - token_ids, wordpiece_mask, type_ids=type_ids, segment_concat_mask=segment_concat_mask - ) - - # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size) - # span_mask: (batch_size, num_orig_tokens, max_span_length) - span_embeddings, span_mask = util.batched_span_select(embeddings.contiguous(), offsets) - span_mask = span_mask.unsqueeze(-1) - span_embeddings *= span_mask # zero out paddings - - span_embeddings_sum = span_embeddings.sum(2) - span_embeddings_len = span_mask.sum(2) - # Shape: (batch_size, num_orig_tokens, embedding_size) - orig_embeddings = span_embeddings_sum / span_embeddings_len - - # All the places where the span length is zero, write in zeros. - orig_embeddings[(span_embeddings_len == 0).expand(orig_embeddings.shape)] = 0 - - # TODO end remove - + x = super().forward(token_ids, mask, offsets, wordpiece_mask, type_ids, segment_concat_mask) if self.projection_layer: - orig_embeddings = self.projection_layer(orig_embeddings) - return orig_embeddings + x = self.projection_layer(x) + return x @overrides def get_output_dim(self): diff --git a/combo/models/parser.py b/combo/models/parser.py index 70330890f9050b04197343e6eef2cce9da5282f2..486b2481b96bf17bb19fd8557916f21dcb6c4584 100644 --- a/combo/models/parser.py +++ b/combo/models/parser.py @@ -67,11 +67,15 @@ class HeadPredictionModel(base.Predictor): def _cycle_loss(self, pred: torch.Tensor): BATCH_SIZE, _, _ = pred.size() loss = pred.new_zeros(BATCH_SIZE) - # 1: as using non __ROOT__ tokens - x = pred[:, 1:, 1:] - for _ in range(self.cycle_loss_n): - loss += self._batch_trace(x) / BATCH_SIZE - x = x.bmm(pred[:, 1:, 1:]) + # Index from 1: as using non __ROOT__ tokens + pred = pred.softmax(-1)[:, 1:, 1:] + x = pred + for i in range(self.cycle_loss_n): + loss += self._batch_trace(x) + + # Don't multiple on last iteration + if i < self.cycle_loss_n - 1: + x = x.bmm(pred) return loss @@ -83,7 +87,7 @@ class HeadPredictionModel(base.Predictor): identity = x.new_tensor(torch.eye(N)) identity = identity.reshape((1, N, N)) batch_identity = identity.repeat(BATCH_SIZE, 1, 1) - return (x * batch_identity).sum() + return (x * batch_identity).sum((-1, -2)) def _loss(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor, sample_weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/combo/predict.py b/combo/predict.py index 1f48e22df3ced5f4cbed70e169ce21fb7814c6e2..35dbbee3af5720b110c37cee4c4e12f72243c8c4 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -29,6 +29,7 @@ class SemanticMultitaskPredictor(predictor.Predictor): self.batch_size = batch_size self.vocab = model.vocab self._dataset_reader.generate_labels = False + self._dataset_reader.lazy = True self._tokenizer = tokenizer def __call__(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]): @@ -71,8 +72,8 @@ class SemanticMultitaskPredictor(predictor.Predictor): raise ValueError("Input must be either string or list of strings.") @overrides - def predict_batch_instance(self, instances: List[allen_data.Instance], serialize: bool = True) -> List[ - common.JsonDict]: + def predict_batch_instance(self, instances: List[allen_data.Instance], serialize: bool = True + ) -> List[common.JsonDict]: trees = [] predictions = super().predict_batch_instance(instances) for prediction, instance in zip(predictions, instances): diff --git a/combo/training/trainer.py b/combo/training/trainer.py index f7fedad3f8f89f026c9d05f9e2991ba6cb0ed5e0..2096d489a2e4963c1e05bc361a120747760b0143 100644 --- a/combo/training/trainer.py +++ b/combo/training/trainer.py @@ -128,7 +128,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): num_batches, reset=True, world_size=self._world_size, - cuda_device=[self.cuda_device], + cuda_device=self.cuda_device, ) # Check validation metric for early stopping diff --git a/config.template.jsonnet b/config.template.jsonnet index 1f47b38fc4fec65c3d4cafcbb83b77496ca99dca..57f02ae3fcaadbad5402f40add0b6a2b5d3a874c 100644 --- a/config.template.jsonnet +++ b/config.template.jsonnet @@ -112,7 +112,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't use_sem: if in_targets("semrel") then true else false, token_indexers: { token: if use_transformer then { - type: "pretrained_transformer_mismatched_tmp_fix", + type: "pretrained_transformer_mismatched", model_name: pretrained_transformer_name, } else { # SingleIdTokenIndexer, token as single int diff --git a/setup.py b/setup.py index f74cd325827c723f2de3be2c352725804428d052..ee68a63bd843b576b334b68af5a6b20c4f4b09c1 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import find_packages, setup REQUIREMENTS = [ 'absl-py==0.9.0', - 'allennlp==1.0.0rc5', + 'allennlp==1.0.0', 'conllu==2.3.2', 'joblib==0.14.1', 'jsonnet==0.15.0',