diff --git a/combo/data/__init__.py b/combo/data/__init__.py index f5973ab11e1a74eeca4f5239ba537538b1e200f1..91426e3950386ecf787e18321e8c7765bc33a187 100644 --- a/combo/data/__init__.py +++ b/combo/data/__init__.py @@ -1,3 +1,3 @@ from .samplers import TokenCountBatchSampler -from .token_indexers import TokenCharactersIndexer +from .token_indexers import * from .api import * diff --git a/combo/data/token_indexers/__init__.py b/combo/data/token_indexers/__init__.py index 1b918b3ad66692a761b564c08d0c270745a263cb..14f4f1ccd3b3e7efba45d75b7cba5a522f3582e6 100644 --- a/combo/data/token_indexers/__init__.py +++ b/combo/data/token_indexers/__init__.py @@ -1,2 +1,3 @@ from .token_characters_indexer import TokenCharactersIndexer from .token_features_indexer import TokenFeatsIndexer +from .pretrained_transformer_mismatched_indexer import PretrainedTransformerMismatchedIndexer diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a662623d98bc488b4d4868839b08b1b758e6f7 --- /dev/null +++ b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py @@ -0,0 +1,30 @@ +from typing import List + +from allennlp import data +from allennlp.data import token_indexers +from overrides import overrides + + +@token_indexers.TokenIndexer.register("pretrained_transformer_mismatched_tmp_fix") +class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransformerMismatchedIndexer): + """TODO remove when fixed in AllenNLP (fc47bf6ae5c0df6d473103d459b75fa7edbdd979)""" + + @overrides + def tokens_to_indices(self, tokens: List[data.Token], vocabulary: data.Vocabulary) -> data.IndexedTokenList: + self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary) + + wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize([t.text for t in tokens]) + + # For tokens that don't correspond to any word pieces, we put (-1, -1) into the offsets. + # That results in the embedding for the token to be all zeros. + offsets = [x if x is not None else (-1, -1) for x in offsets] + + output: data.IndexedTokenList = { + "token_ids": [t.text_id for t in wordpieces], + "mask": [True] * len(tokens), # for original tokens (i.e. word-level) + "type_ids": [t.type_id for t in wordpieces], + "offsets": offsets, + "wordpiece_mask": [True] * len(wordpieces), # for wordpieces (i.e. subword-level) + } + + return self._matched_indexer._postprocess_output(output) diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py index 38cbf8fea3d591e04c794a676744b409a63c645e..009ca72f6004fc6d2bc9bf4eeae0d73975d6eb3a 100644 --- a/combo/models/embeddings.py +++ b/combo/models/embeddings.py @@ -127,6 +127,7 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm self.projection_layer = None self.output_dim = super().get_output_dim() + @overrides def forward( self, token_ids: torch.LongTensor, @@ -135,12 +136,32 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm wordpiece_mask: torch.BoolTensor, type_ids: Optional[torch.LongTensor] = None, segment_concat_mask: Optional[torch.BoolTensor] = None, - ) -> torch.Tensor: - x = super().forward(token_ids=token_ids, mask=mask, offsets=offsets, wordpiece_mask=wordpiece_mask, - type_ids=type_ids, segment_concat_mask=segment_concat_mask) + ) -> torch.Tensor: # type: ignore + """TODO remove (and call super) when fixed in AllenNLP (fc47bf6ae5c0df6d473103d459b75fa7edbdd979)""" + # Shape: [batch_size, num_wordpieces, embedding_size]. + embeddings = self._matched_embedder( + token_ids, wordpiece_mask, type_ids=type_ids, segment_concat_mask=segment_concat_mask + ) + + # span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size) + # span_mask: (batch_size, num_orig_tokens, max_span_length) + span_embeddings, span_mask = util.batched_span_select(embeddings.contiguous(), offsets) + span_mask = span_mask.unsqueeze(-1) + span_embeddings *= span_mask # zero out paddings + + span_embeddings_sum = span_embeddings.sum(2) + span_embeddings_len = span_mask.sum(2) + # Shape: (batch_size, num_orig_tokens, embedding_size) + orig_embeddings = span_embeddings_sum / span_embeddings_len + + # All the places where the span length is zero, write in zeros. + orig_embeddings[(span_embeddings_len == 0).expand(orig_embeddings.shape)] = 0 + + # TODO end remove + if self.projection_layer: - x = self.projection_layer(x) - return x + orig_embeddings = self.projection_layer(orig_embeddings) + return orig_embeddings @overrides def get_output_dim(self): diff --git a/combo/utils/download.py b/combo/utils/download.py index 59abf09306c4ac77346df11b732422d6eab1e8a6..b2c6b2e32579dc91e2adb50298e2b71bf3936536 100644 --- a/combo/utils/download.py +++ b/combo/utils/download.py @@ -67,6 +67,6 @@ def _requests_retry_session( status_forcelist=status_forcelist, ) adapter = adapters.HTTPAdapter(max_retries=retry) - session.mount('http://', adapter) - session.mount('https://', adapter) + session.mount("http://", adapter) + session.mount("https://", adapter) return session diff --git a/config.template.jsonnet b/config.template.jsonnet index f12c3bd1ba29b690e0f4328137d6da796706591c..154c20b63dc84127a0de99d18929c3f132e9e318 100644 --- a/config.template.jsonnet +++ b/config.template.jsonnet @@ -111,7 +111,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't use_sem: if in_features("semrel") then true else false, token_indexers: { token: if use_transformer then { - type: "pretrained_transformer_mismatched", + type: "pretrained_transformer_mismatched_tmp_fix", model_name: pretrained_transformer_name, } else { # SingleIdTokenIndexer, token as single int diff --git a/setup.py b/setup.py index 57df6075816a2071fe1e85e618a001b6527e9b0d..f74cd325827c723f2de3be2c352725804428d052 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,7 @@ REQUIREMENTS = [ 'torchvision==0.6.0', 'tqdm==4.43.0', 'transformers==2.9.1', + 'urllib3==1.24.2', ] setup(