From 4c556fd975e90f8d01798b1ffef022c2f9ed2c4d Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Sun, 3 Jan 2021 11:31:12 +0100 Subject: [PATCH] Refactor tensor indexing and prediction mapping. --- combo/data/api.py | 29 ++++++++++++++++++++------- combo/models/base.py | 10 +++++++--- combo/models/embeddings.py | 4 ++-- combo/models/model.py | 40 ++++++++++++++++++++------------------ combo/models/parser.py | 1 - combo/predict.py | 6 +++--- setup.py | 1 - 7 files changed, 55 insertions(+), 36 deletions(-) diff --git a/combo/data/api.py b/combo/data/api.py index ca8f75a..7d44917 100644 --- a/combo/data/api.py +++ b/combo/data/api.py @@ -1,13 +1,13 @@ import collections +import dataclasses +import json from dataclasses import dataclass, field from typing import Optional, List, Dict, Any, Union, Tuple import conllu -from dataclasses_json import dataclass_json from overrides import overrides -@dataclass_json @dataclass class Token: id: Optional[Union[int, Tuple]] = None @@ -23,13 +23,19 @@ class Token: semrel: Optional[str] = None -@dataclass_json @dataclass class Sentence: tokens: List[Token] = field(default_factory=list) sentence_embedding: List[float] = field(default_factory=list) metadata: Dict[str, Any] = field(default_factory=collections.OrderedDict) + def to_json(self): + return json.dumps({ + "tokens": [dataclasses.asdict(t) for t in self.tokens], + "sentence_embedding": self.sentence_embedding, + "metadata": self.metadata, + }) + class _TokenList(conllu.TokenList): @@ -41,7 +47,7 @@ class _TokenList(conllu.TokenList): def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.TokenList: tokens = [] for token in sentence.tokens: - token_dict = collections.OrderedDict(token.to_dict()) + token_dict = collections.OrderedDict(dataclasses.asdict(token)) # Remove semrel to have default conllu format. if not keep_semrel: del token_dict["semrel"] @@ -52,7 +58,7 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.Toke t["id"] = tuple(t["id"]) if t["deps"]: for dep in t["deps"]: - if type(dep[1]) == list: + if len(dep) > 1 and type(dep[1]) == list: dep[1] = tuple(dep[1]) return _TokenList(tokens=tokens, metadata=sentence.metadata) @@ -68,9 +74,18 @@ def tokens2conllu(tokens: List[str]) -> conllu.TokenList: def conllu2sentence(conllu_sentence: conllu.TokenList, - sentence_embedding: List[float]) -> Sentence: + sentence_embedding=None) -> Sentence: + if sentence_embedding is None: + sentence_embedding = [] + tokens = [] + for token in conllu_sentence.tokens: + tokens.append( + Token( + **token + ) + ) return Sentence( - tokens=[Token.from_dict(t) for t in conllu_sentence.tokens], + tokens=tokens, sentence_embedding=sentence_embedding, metadata=conllu_sentence.metadata ) diff --git a/combo/models/base.py b/combo/models/base.py index 10e9d37..a5cb5fe 100644 --- a/combo/models/base.py +++ b/combo/models/base.py @@ -27,11 +27,11 @@ class Linear(nn.Linear, common.FromParams): def __init__(self, in_features: int, out_features: int, - activation: Optional[allen_nn.Activation] = lambda x: x, + activation: Optional[allen_nn.Activation] = None, dropout_rate: Optional[float] = 0.0): super().__init__(in_features, out_features) - self.activation = activation - self.dropout = nn.Dropout(p=dropout_rate) if dropout_rate else lambda x: x + self.activation = activation if activation else self.identity + self.dropout = nn.Dropout(p=dropout_rate) if dropout_rate else self.identity def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: x = super().forward(x) @@ -41,6 +41,10 @@ class Linear(nn.Linear, common.FromParams): def get_output_dim(self) -> int: return self.out_features + @staticmethod + def identity(x): + return x + @Predictor.register("feedforward_predictor") @Predictor.register("feedforward_predictor_from_vocab", constructor="from_vocab") diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py index 5cad959..6ad2559 100644 --- a/combo/models/embeddings.py +++ b/combo/models/embeddings.py @@ -196,10 +196,10 @@ class FeatsTokenEmbedder(token_embedders.Embedding): def forward(self, tokens: torch.Tensor) -> torch.Tensor: # (batch_size, sentence_length, features_vocab_length) - mask = (tokens > 0).float() + mask = tokens.gt(0) # (batch_size, sentence_length, features_vocab_length, embedding_dim) x = super().forward(tokens) # (batch_size, sentence_length, embedding_dim) return x.sum(dim=-2) / ( - (mask.sum(dim=-1) + util.tiny_value_of_dtype(mask.dtype)).unsqueeze(dim=-1) + (mask.sum(dim=-1) + util.tiny_value_of_dtype(torch.float)).unsqueeze(dim=-1) ) diff --git a/combo/models/model.py b/combo/models/model.py index 9d3f817..9866bcb 100644 --- a/combo/models/model.py +++ b/combo/models/model.py @@ -60,9 +60,11 @@ class ComboModel(allen_models.Model): enhanced_deprels: torch.Tensor = None) -> Dict[str, torch.Tensor]: # Prepare masks - char_mask: torch.BoolTensor = sentence["char"]["token_characters"] > 0 + char_mask = sentence["char"]["token_characters"].gt(0) word_mask = util.get_text_field_mask(sentence) + device = word_mask.device + # If enabled weight samples loss by log(sentence_length) sample_weights = word_mask.sum(-1).float().log() if self.use_sample_weight else None @@ -73,45 +75,45 @@ class ComboModel(allen_models.Model): # Concatenate the head sentinel (ROOT) onto the sentence representation. head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) - encoder_emb = torch.cat([head_sentinel, encoder_emb], 1) - word_mask = torch.cat([word_mask.new_ones((batch_size, 1)), word_mask], 1) + encoder_emb_with_root = torch.cat([head_sentinel, encoder_emb], 1) + word_mask_with_root = torch.cat([torch.ones((batch_size, 1), device=device), word_mask], 1) upos_output = self._optional(self.upos_tagger, - encoder_emb[:, 1:], - mask=word_mask[:, 1:], + encoder_emb, + mask=word_mask, labels=upostag, sample_weights=sample_weights) xpos_output = self._optional(self.xpos_tagger, - encoder_emb[:, 1:], - mask=word_mask[:, 1:], + encoder_emb, + mask=word_mask, labels=xpostag, sample_weights=sample_weights) semrel_output = self._optional(self.semantic_relation, - encoder_emb[:, 1:], - mask=word_mask[:, 1:], + encoder_emb, + mask=word_mask, labels=semrel, sample_weights=sample_weights) morpho_output = self._optional(self.morphological_feat, - encoder_emb[:, 1:], - mask=word_mask[:, 1:], + encoder_emb, + mask=word_mask, labels=feats, sample_weights=sample_weights) lemma_output = self._optional(self.lemmatizer, - (encoder_emb[:, 1:], sentence.get("char").get("token_characters") + (encoder_emb, sentence.get("char").get("token_characters") if sentence.get("char") else None), - mask=word_mask[:, 1:], + mask=word_mask, labels=lemma.get("char").get("token_characters") if lemma else None, sample_weights=sample_weights) parser_output = self._optional(self.dependency_relation, - encoder_emb, + encoder_emb_with_root, returns_tuple=True, - mask=word_mask, + mask=word_mask_with_root, labels=(deprel, head), sample_weights=sample_weights) enhanced_parser_output = self._optional(self.enhanced_dependency_relation, - encoder_emb, + encoder_emb_with_root, returns_tuple=True, - mask=word_mask, + mask=word_mask_with_root, labels=(enhanced_deprels, head, enhanced_heads), sample_weights=sample_weights) relations_pred, head_pred = parser_output["prediction"] @@ -126,7 +128,7 @@ class ComboModel(allen_models.Model): "deprel": relations_pred, "enhanced_head": enhanced_head_pred, "enhanced_deprel": enhanced_relations_pred, - "sentence_embedding": torch.max(encoder_emb[:, 1:], dim=1)[0], + "sentence_embedding": torch.max(encoder_emb, dim=1)[0], } if "rel_probability" in enhanced_parser_output: @@ -153,7 +155,7 @@ class ComboModel(allen_models.Model): "enhanced_head": enhanced_heads, "enhanced_deprel": enhanced_deprels, } - self.scores(output, labels, word_mask[:, 1:]) + self.scores(output, labels, word_mask) relations_loss, head_loss = parser_output["loss"] enhanced_relations_loss, enhanced_head_loss = enhanced_parser_output["loss"] losses = { diff --git a/combo/models/parser.py b/combo/models/parser.py index 4b5b126..dfb53ab 100644 --- a/combo/models/parser.py +++ b/combo/models/parser.py @@ -1,5 +1,4 @@ """Dependency parsing models.""" -import math from typing import Tuple, Dict, Optional, Union, List import numpy as np diff --git a/combo/predict.py b/combo/predict.py index 8d3e2f9..21941d9 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -25,7 +25,7 @@ class COMBO(predictor.Predictor): model: models.Model, dataset_reader: allen_data.DatasetReader, tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer(), - batch_size: int = 500, + batch_size: int = 32, line_to_conllu: bool = False) -> None: super().__init__(model, dataset_reader) self.batch_size = batch_size @@ -52,7 +52,7 @@ class COMBO(predictor.Predictor): def predict(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]): if isinstance(sentence, str): - return data.Sentence.from_dict(self.predict_json({"sentence": sentence})) + return self.predict_json({"sentence": sentence}) elif isinstance(sentence, list): if len(sentence) == 0: return [] @@ -219,7 +219,7 @@ class COMBO(predictor.Predictor): @classmethod def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), - batch_size: int = 500, + batch_size: int = 32, cuda_device: int = -1): util.import_module_and_submodules("combo.commands") util.import_module_and_submodules("combo.models") diff --git a/setup.py b/setup.py index 5c833d5..fdaa2be 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,6 @@ REQUIREMENTS = [ 'allennlp==1.2.1', 'conllu==2.3.2', 'dataclasses;python_version<"3.7"', - 'dataclasses-json==0.5.2', 'joblib==0.14.1', 'jsonnet==0.15.0', 'requests==2.23.0', -- GitLab