Skip to content
Snippets Groups Projects
Commit 2512fd9f authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Fix mismatched transformer tokenizer and embedder.

parent 44772209
No related merge requests found
from .samplers import TokenCountBatchSampler from .samplers import TokenCountBatchSampler
from .token_indexers import TokenCharactersIndexer from .token_indexers import *
from .api import * from .api import *
from .token_characters_indexer import TokenCharactersIndexer from .token_characters_indexer import TokenCharactersIndexer
from .token_features_indexer import TokenFeatsIndexer from .token_features_indexer import TokenFeatsIndexer
from .pretrained_transformer_mismatched_indexer import PretrainedTransformerMismatchedIndexer
from typing import List
from allennlp import data
from allennlp.data import token_indexers
from overrides import overrides
@token_indexers.TokenIndexer.register("pretrained_transformer_mismatched_tmp_fix")
class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransformerMismatchedIndexer):
"""TODO remove when fixed in AllenNLP (fc47bf6ae5c0df6d473103d459b75fa7edbdd979)"""
@overrides
def tokens_to_indices(self, tokens: List[data.Token], vocabulary: data.Vocabulary) -> data.IndexedTokenList:
self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)
wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize([t.text for t in tokens])
# For tokens that don't correspond to any word pieces, we put (-1, -1) into the offsets.
# That results in the embedding for the token to be all zeros.
offsets = [x if x is not None else (-1, -1) for x in offsets]
output: data.IndexedTokenList = {
"token_ids": [t.text_id for t in wordpieces],
"mask": [True] * len(tokens), # for original tokens (i.e. word-level)
"type_ids": [t.type_id for t in wordpieces],
"offsets": offsets,
"wordpiece_mask": [True] * len(wordpieces), # for wordpieces (i.e. subword-level)
}
return self._matched_indexer._postprocess_output(output)
...@@ -127,6 +127,7 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm ...@@ -127,6 +127,7 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
self.projection_layer = None self.projection_layer = None
self.output_dim = super().get_output_dim() self.output_dim = super().get_output_dim()
@overrides
def forward( def forward(
self, self,
token_ids: torch.LongTensor, token_ids: torch.LongTensor,
...@@ -135,12 +136,32 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm ...@@ -135,12 +136,32 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
wordpiece_mask: torch.BoolTensor, wordpiece_mask: torch.BoolTensor,
type_ids: Optional[torch.LongTensor] = None, type_ids: Optional[torch.LongTensor] = None,
segment_concat_mask: Optional[torch.BoolTensor] = None, segment_concat_mask: Optional[torch.BoolTensor] = None,
) -> torch.Tensor: ) -> torch.Tensor: # type: ignore
x = super().forward(token_ids=token_ids, mask=mask, offsets=offsets, wordpiece_mask=wordpiece_mask, """TODO remove (and call super) when fixed in AllenNLP (fc47bf6ae5c0df6d473103d459b75fa7edbdd979)"""
type_ids=type_ids, segment_concat_mask=segment_concat_mask) # Shape: [batch_size, num_wordpieces, embedding_size].
embeddings = self._matched_embedder(
token_ids, wordpiece_mask, type_ids=type_ids, segment_concat_mask=segment_concat_mask
)
# span_embeddings: (batch_size, num_orig_tokens, max_span_length, embedding_size)
# span_mask: (batch_size, num_orig_tokens, max_span_length)
span_embeddings, span_mask = util.batched_span_select(embeddings.contiguous(), offsets)
span_mask = span_mask.unsqueeze(-1)
span_embeddings *= span_mask # zero out paddings
span_embeddings_sum = span_embeddings.sum(2)
span_embeddings_len = span_mask.sum(2)
# Shape: (batch_size, num_orig_tokens, embedding_size)
orig_embeddings = span_embeddings_sum / span_embeddings_len
# All the places where the span length is zero, write in zeros.
orig_embeddings[(span_embeddings_len == 0).expand(orig_embeddings.shape)] = 0
# TODO end remove
if self.projection_layer: if self.projection_layer:
x = self.projection_layer(x) orig_embeddings = self.projection_layer(orig_embeddings)
return x return orig_embeddings
@overrides @overrides
def get_output_dim(self): def get_output_dim(self):
......
...@@ -67,6 +67,6 @@ def _requests_retry_session( ...@@ -67,6 +67,6 @@ def _requests_retry_session(
status_forcelist=status_forcelist, status_forcelist=status_forcelist,
) )
adapter = adapters.HTTPAdapter(max_retries=retry) adapter = adapters.HTTPAdapter(max_retries=retry)
session.mount('http://', adapter) session.mount("http://", adapter)
session.mount('https://', adapter) session.mount("https://", adapter)
return session return session
...@@ -111,7 +111,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't ...@@ -111,7 +111,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
use_sem: if in_features("semrel") then true else false, use_sem: if in_features("semrel") then true else false,
token_indexers: { token_indexers: {
token: if use_transformer then { token: if use_transformer then {
type: "pretrained_transformer_mismatched", type: "pretrained_transformer_mismatched_tmp_fix",
model_name: pretrained_transformer_name, model_name: pretrained_transformer_name,
} else { } else {
# SingleIdTokenIndexer, token as single int # SingleIdTokenIndexer, token as single int
......
...@@ -14,6 +14,7 @@ REQUIREMENTS = [ ...@@ -14,6 +14,7 @@ REQUIREMENTS = [
'torchvision==0.6.0', 'torchvision==0.6.0',
'tqdm==4.43.0', 'tqdm==4.43.0',
'transformers==2.9.1', 'transformers==2.9.1',
'urllib3==1.24.2',
] ]
setup( setup(
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment