Skip to content
Snippets Groups Projects
Commit 4c556fd9 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Refactor tensor indexing and prediction mapping.

parent bdd9d5bb
Branches
Tags
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
import collections
import dataclasses
import json
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any, Union, Tuple
import conllu
from dataclasses_json import dataclass_json
from overrides import overrides
@dataclass_json
@dataclass
class Token:
id: Optional[Union[int, Tuple]] = None
......@@ -23,13 +23,19 @@ class Token:
semrel: Optional[str] = None
@dataclass_json
@dataclass
class Sentence:
tokens: List[Token] = field(default_factory=list)
sentence_embedding: List[float] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=collections.OrderedDict)
def to_json(self):
return json.dumps({
"tokens": [dataclasses.asdict(t) for t in self.tokens],
"sentence_embedding": self.sentence_embedding,
"metadata": self.metadata,
})
class _TokenList(conllu.TokenList):
......@@ -41,7 +47,7 @@ class _TokenList(conllu.TokenList):
def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.TokenList:
tokens = []
for token in sentence.tokens:
token_dict = collections.OrderedDict(token.to_dict())
token_dict = collections.OrderedDict(dataclasses.asdict(token))
# Remove semrel to have default conllu format.
if not keep_semrel:
del token_dict["semrel"]
......@@ -52,7 +58,7 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.Toke
t["id"] = tuple(t["id"])
if t["deps"]:
for dep in t["deps"]:
if type(dep[1]) == list:
if len(dep) > 1 and type(dep[1]) == list:
dep[1] = tuple(dep[1])
return _TokenList(tokens=tokens,
metadata=sentence.metadata)
......@@ -68,9 +74,18 @@ def tokens2conllu(tokens: List[str]) -> conllu.TokenList:
def conllu2sentence(conllu_sentence: conllu.TokenList,
sentence_embedding: List[float]) -> Sentence:
sentence_embedding=None) -> Sentence:
if sentence_embedding is None:
sentence_embedding = []
tokens = []
for token in conllu_sentence.tokens:
tokens.append(
Token(
**token
)
)
return Sentence(
tokens=[Token.from_dict(t) for t in conllu_sentence.tokens],
tokens=tokens,
sentence_embedding=sentence_embedding,
metadata=conllu_sentence.metadata
)
......@@ -27,11 +27,11 @@ class Linear(nn.Linear, common.FromParams):
def __init__(self,
in_features: int,
out_features: int,
activation: Optional[allen_nn.Activation] = lambda x: x,
activation: Optional[allen_nn.Activation] = None,
dropout_rate: Optional[float] = 0.0):
super().__init__(in_features, out_features)
self.activation = activation
self.dropout = nn.Dropout(p=dropout_rate) if dropout_rate else lambda x: x
self.activation = activation if activation else self.identity
self.dropout = nn.Dropout(p=dropout_rate) if dropout_rate else self.identity
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
x = super().forward(x)
......@@ -41,6 +41,10 @@ class Linear(nn.Linear, common.FromParams):
def get_output_dim(self) -> int:
return self.out_features
@staticmethod
def identity(x):
return x
@Predictor.register("feedforward_predictor")
@Predictor.register("feedforward_predictor_from_vocab", constructor="from_vocab")
......
......@@ -196,10 +196,10 @@ class FeatsTokenEmbedder(token_embedders.Embedding):
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
# (batch_size, sentence_length, features_vocab_length)
mask = (tokens > 0).float()
mask = tokens.gt(0)
# (batch_size, sentence_length, features_vocab_length, embedding_dim)
x = super().forward(tokens)
# (batch_size, sentence_length, embedding_dim)
return x.sum(dim=-2) / (
(mask.sum(dim=-1) + util.tiny_value_of_dtype(mask.dtype)).unsqueeze(dim=-1)
(mask.sum(dim=-1) + util.tiny_value_of_dtype(torch.float)).unsqueeze(dim=-1)
)
......@@ -60,9 +60,11 @@ class ComboModel(allen_models.Model):
enhanced_deprels: torch.Tensor = None) -> Dict[str, torch.Tensor]:
# Prepare masks
char_mask: torch.BoolTensor = sentence["char"]["token_characters"] > 0
char_mask = sentence["char"]["token_characters"].gt(0)
word_mask = util.get_text_field_mask(sentence)
device = word_mask.device
# If enabled weight samples loss by log(sentence_length)
sample_weights = word_mask.sum(-1).float().log() if self.use_sample_weight else None
......@@ -73,45 +75,45 @@ class ComboModel(allen_models.Model):
# Concatenate the head sentinel (ROOT) onto the sentence representation.
head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
encoder_emb = torch.cat([head_sentinel, encoder_emb], 1)
word_mask = torch.cat([word_mask.new_ones((batch_size, 1)), word_mask], 1)
encoder_emb_with_root = torch.cat([head_sentinel, encoder_emb], 1)
word_mask_with_root = torch.cat([torch.ones((batch_size, 1), device=device), word_mask], 1)
upos_output = self._optional(self.upos_tagger,
encoder_emb[:, 1:],
mask=word_mask[:, 1:],
encoder_emb,
mask=word_mask,
labels=upostag,
sample_weights=sample_weights)
xpos_output = self._optional(self.xpos_tagger,
encoder_emb[:, 1:],
mask=word_mask[:, 1:],
encoder_emb,
mask=word_mask,
labels=xpostag,
sample_weights=sample_weights)
semrel_output = self._optional(self.semantic_relation,
encoder_emb[:, 1:],
mask=word_mask[:, 1:],
encoder_emb,
mask=word_mask,
labels=semrel,
sample_weights=sample_weights)
morpho_output = self._optional(self.morphological_feat,
encoder_emb[:, 1:],
mask=word_mask[:, 1:],
encoder_emb,
mask=word_mask,
labels=feats,
sample_weights=sample_weights)
lemma_output = self._optional(self.lemmatizer,
(encoder_emb[:, 1:], sentence.get("char").get("token_characters")
(encoder_emb, sentence.get("char").get("token_characters")
if sentence.get("char") else None),
mask=word_mask[:, 1:],
mask=word_mask,
labels=lemma.get("char").get("token_characters") if lemma else None,
sample_weights=sample_weights)
parser_output = self._optional(self.dependency_relation,
encoder_emb,
encoder_emb_with_root,
returns_tuple=True,
mask=word_mask,
mask=word_mask_with_root,
labels=(deprel, head),
sample_weights=sample_weights)
enhanced_parser_output = self._optional(self.enhanced_dependency_relation,
encoder_emb,
encoder_emb_with_root,
returns_tuple=True,
mask=word_mask,
mask=word_mask_with_root,
labels=(enhanced_deprels, head, enhanced_heads),
sample_weights=sample_weights)
relations_pred, head_pred = parser_output["prediction"]
......@@ -126,7 +128,7 @@ class ComboModel(allen_models.Model):
"deprel": relations_pred,
"enhanced_head": enhanced_head_pred,
"enhanced_deprel": enhanced_relations_pred,
"sentence_embedding": torch.max(encoder_emb[:, 1:], dim=1)[0],
"sentence_embedding": torch.max(encoder_emb, dim=1)[0],
}
if "rel_probability" in enhanced_parser_output:
......@@ -153,7 +155,7 @@ class ComboModel(allen_models.Model):
"enhanced_head": enhanced_heads,
"enhanced_deprel": enhanced_deprels,
}
self.scores(output, labels, word_mask[:, 1:])
self.scores(output, labels, word_mask)
relations_loss, head_loss = parser_output["loss"]
enhanced_relations_loss, enhanced_head_loss = enhanced_parser_output["loss"]
losses = {
......
"""Dependency parsing models."""
import math
from typing import Tuple, Dict, Optional, Union, List
import numpy as np
......
......@@ -25,7 +25,7 @@ class COMBO(predictor.Predictor):
model: models.Model,
dataset_reader: allen_data.DatasetReader,
tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer(),
batch_size: int = 500,
batch_size: int = 32,
line_to_conllu: bool = False) -> None:
super().__init__(model, dataset_reader)
self.batch_size = batch_size
......@@ -52,7 +52,7 @@ class COMBO(predictor.Predictor):
def predict(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]):
if isinstance(sentence, str):
return data.Sentence.from_dict(self.predict_json({"sentence": sentence}))
return self.predict_json({"sentence": sentence})
elif isinstance(sentence, list):
if len(sentence) == 0:
return []
......@@ -219,7 +219,7 @@ class COMBO(predictor.Predictor):
@classmethod
def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(),
batch_size: int = 500,
batch_size: int = 32,
cuda_device: int = -1):
util.import_module_and_submodules("combo.commands")
util.import_module_and_submodules("combo.models")
......
......@@ -6,7 +6,6 @@ REQUIREMENTS = [
'allennlp==1.2.1',
'conllu==2.3.2',
'dataclasses;python_version<"3.7"',
'dataclasses-json==0.5.2',
'joblib==0.14.1',
'jsonnet==0.15.0',
'requests==2.23.0',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment