Skip to content
Snippets Groups Projects
Commit 4c556fd9 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Refactor tensor indexing and prediction mapping.

parent bdd9d5bb
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
This commit is part of merge request !8. Comments created here will be created in the context of that merge request.
import collections import collections
import dataclasses
import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any, Union, Tuple from typing import Optional, List, Dict, Any, Union, Tuple
import conllu import conllu
from dataclasses_json import dataclass_json
from overrides import overrides from overrides import overrides
@dataclass_json
@dataclass @dataclass
class Token: class Token:
id: Optional[Union[int, Tuple]] = None id: Optional[Union[int, Tuple]] = None
...@@ -23,13 +23,19 @@ class Token: ...@@ -23,13 +23,19 @@ class Token:
semrel: Optional[str] = None semrel: Optional[str] = None
@dataclass_json
@dataclass @dataclass
class Sentence: class Sentence:
tokens: List[Token] = field(default_factory=list) tokens: List[Token] = field(default_factory=list)
sentence_embedding: List[float] = field(default_factory=list) sentence_embedding: List[float] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=collections.OrderedDict) metadata: Dict[str, Any] = field(default_factory=collections.OrderedDict)
def to_json(self):
return json.dumps({
"tokens": [dataclasses.asdict(t) for t in self.tokens],
"sentence_embedding": self.sentence_embedding,
"metadata": self.metadata,
})
class _TokenList(conllu.TokenList): class _TokenList(conllu.TokenList):
...@@ -41,7 +47,7 @@ class _TokenList(conllu.TokenList): ...@@ -41,7 +47,7 @@ class _TokenList(conllu.TokenList):
def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.TokenList: def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.TokenList:
tokens = [] tokens = []
for token in sentence.tokens: for token in sentence.tokens:
token_dict = collections.OrderedDict(token.to_dict()) token_dict = collections.OrderedDict(dataclasses.asdict(token))
# Remove semrel to have default conllu format. # Remove semrel to have default conllu format.
if not keep_semrel: if not keep_semrel:
del token_dict["semrel"] del token_dict["semrel"]
...@@ -52,7 +58,7 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.Toke ...@@ -52,7 +58,7 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.Toke
t["id"] = tuple(t["id"]) t["id"] = tuple(t["id"])
if t["deps"]: if t["deps"]:
for dep in t["deps"]: for dep in t["deps"]:
if type(dep[1]) == list: if len(dep) > 1 and type(dep[1]) == list:
dep[1] = tuple(dep[1]) dep[1] = tuple(dep[1])
return _TokenList(tokens=tokens, return _TokenList(tokens=tokens,
metadata=sentence.metadata) metadata=sentence.metadata)
...@@ -68,9 +74,18 @@ def tokens2conllu(tokens: List[str]) -> conllu.TokenList: ...@@ -68,9 +74,18 @@ def tokens2conllu(tokens: List[str]) -> conllu.TokenList:
def conllu2sentence(conllu_sentence: conllu.TokenList, def conllu2sentence(conllu_sentence: conllu.TokenList,
sentence_embedding: List[float]) -> Sentence: sentence_embedding=None) -> Sentence:
if sentence_embedding is None:
sentence_embedding = []
tokens = []
for token in conllu_sentence.tokens:
tokens.append(
Token(
**token
)
)
return Sentence( return Sentence(
tokens=[Token.from_dict(t) for t in conllu_sentence.tokens], tokens=tokens,
sentence_embedding=sentence_embedding, sentence_embedding=sentence_embedding,
metadata=conllu_sentence.metadata metadata=conllu_sentence.metadata
) )
...@@ -27,11 +27,11 @@ class Linear(nn.Linear, common.FromParams): ...@@ -27,11 +27,11 @@ class Linear(nn.Linear, common.FromParams):
def __init__(self, def __init__(self,
in_features: int, in_features: int,
out_features: int, out_features: int,
activation: Optional[allen_nn.Activation] = lambda x: x, activation: Optional[allen_nn.Activation] = None,
dropout_rate: Optional[float] = 0.0): dropout_rate: Optional[float] = 0.0):
super().__init__(in_features, out_features) super().__init__(in_features, out_features)
self.activation = activation self.activation = activation if activation else self.identity
self.dropout = nn.Dropout(p=dropout_rate) if dropout_rate else lambda x: x self.dropout = nn.Dropout(p=dropout_rate) if dropout_rate else self.identity
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
x = super().forward(x) x = super().forward(x)
...@@ -41,6 +41,10 @@ class Linear(nn.Linear, common.FromParams): ...@@ -41,6 +41,10 @@ class Linear(nn.Linear, common.FromParams):
def get_output_dim(self) -> int: def get_output_dim(self) -> int:
return self.out_features return self.out_features
@staticmethod
def identity(x):
return x
@Predictor.register("feedforward_predictor") @Predictor.register("feedforward_predictor")
@Predictor.register("feedforward_predictor_from_vocab", constructor="from_vocab") @Predictor.register("feedforward_predictor_from_vocab", constructor="from_vocab")
......
...@@ -196,10 +196,10 @@ class FeatsTokenEmbedder(token_embedders.Embedding): ...@@ -196,10 +196,10 @@ class FeatsTokenEmbedder(token_embedders.Embedding):
def forward(self, tokens: torch.Tensor) -> torch.Tensor: def forward(self, tokens: torch.Tensor) -> torch.Tensor:
# (batch_size, sentence_length, features_vocab_length) # (batch_size, sentence_length, features_vocab_length)
mask = (tokens > 0).float() mask = tokens.gt(0)
# (batch_size, sentence_length, features_vocab_length, embedding_dim) # (batch_size, sentence_length, features_vocab_length, embedding_dim)
x = super().forward(tokens) x = super().forward(tokens)
# (batch_size, sentence_length, embedding_dim) # (batch_size, sentence_length, embedding_dim)
return x.sum(dim=-2) / ( return x.sum(dim=-2) / (
(mask.sum(dim=-1) + util.tiny_value_of_dtype(mask.dtype)).unsqueeze(dim=-1) (mask.sum(dim=-1) + util.tiny_value_of_dtype(torch.float)).unsqueeze(dim=-1)
) )
...@@ -60,9 +60,11 @@ class ComboModel(allen_models.Model): ...@@ -60,9 +60,11 @@ class ComboModel(allen_models.Model):
enhanced_deprels: torch.Tensor = None) -> Dict[str, torch.Tensor]: enhanced_deprels: torch.Tensor = None) -> Dict[str, torch.Tensor]:
# Prepare masks # Prepare masks
char_mask: torch.BoolTensor = sentence["char"]["token_characters"] > 0 char_mask = sentence["char"]["token_characters"].gt(0)
word_mask = util.get_text_field_mask(sentence) word_mask = util.get_text_field_mask(sentence)
device = word_mask.device
# If enabled weight samples loss by log(sentence_length) # If enabled weight samples loss by log(sentence_length)
sample_weights = word_mask.sum(-1).float().log() if self.use_sample_weight else None sample_weights = word_mask.sum(-1).float().log() if self.use_sample_weight else None
...@@ -73,45 +75,45 @@ class ComboModel(allen_models.Model): ...@@ -73,45 +75,45 @@ class ComboModel(allen_models.Model):
# Concatenate the head sentinel (ROOT) onto the sentence representation. # Concatenate the head sentinel (ROOT) onto the sentence representation.
head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim)
encoder_emb = torch.cat([head_sentinel, encoder_emb], 1) encoder_emb_with_root = torch.cat([head_sentinel, encoder_emb], 1)
word_mask = torch.cat([word_mask.new_ones((batch_size, 1)), word_mask], 1) word_mask_with_root = torch.cat([torch.ones((batch_size, 1), device=device), word_mask], 1)
upos_output = self._optional(self.upos_tagger, upos_output = self._optional(self.upos_tagger,
encoder_emb[:, 1:], encoder_emb,
mask=word_mask[:, 1:], mask=word_mask,
labels=upostag, labels=upostag,
sample_weights=sample_weights) sample_weights=sample_weights)
xpos_output = self._optional(self.xpos_tagger, xpos_output = self._optional(self.xpos_tagger,
encoder_emb[:, 1:], encoder_emb,
mask=word_mask[:, 1:], mask=word_mask,
labels=xpostag, labels=xpostag,
sample_weights=sample_weights) sample_weights=sample_weights)
semrel_output = self._optional(self.semantic_relation, semrel_output = self._optional(self.semantic_relation,
encoder_emb[:, 1:], encoder_emb,
mask=word_mask[:, 1:], mask=word_mask,
labels=semrel, labels=semrel,
sample_weights=sample_weights) sample_weights=sample_weights)
morpho_output = self._optional(self.morphological_feat, morpho_output = self._optional(self.morphological_feat,
encoder_emb[:, 1:], encoder_emb,
mask=word_mask[:, 1:], mask=word_mask,
labels=feats, labels=feats,
sample_weights=sample_weights) sample_weights=sample_weights)
lemma_output = self._optional(self.lemmatizer, lemma_output = self._optional(self.lemmatizer,
(encoder_emb[:, 1:], sentence.get("char").get("token_characters") (encoder_emb, sentence.get("char").get("token_characters")
if sentence.get("char") else None), if sentence.get("char") else None),
mask=word_mask[:, 1:], mask=word_mask,
labels=lemma.get("char").get("token_characters") if lemma else None, labels=lemma.get("char").get("token_characters") if lemma else None,
sample_weights=sample_weights) sample_weights=sample_weights)
parser_output = self._optional(self.dependency_relation, parser_output = self._optional(self.dependency_relation,
encoder_emb, encoder_emb_with_root,
returns_tuple=True, returns_tuple=True,
mask=word_mask, mask=word_mask_with_root,
labels=(deprel, head), labels=(deprel, head),
sample_weights=sample_weights) sample_weights=sample_weights)
enhanced_parser_output = self._optional(self.enhanced_dependency_relation, enhanced_parser_output = self._optional(self.enhanced_dependency_relation,
encoder_emb, encoder_emb_with_root,
returns_tuple=True, returns_tuple=True,
mask=word_mask, mask=word_mask_with_root,
labels=(enhanced_deprels, head, enhanced_heads), labels=(enhanced_deprels, head, enhanced_heads),
sample_weights=sample_weights) sample_weights=sample_weights)
relations_pred, head_pred = parser_output["prediction"] relations_pred, head_pred = parser_output["prediction"]
...@@ -126,7 +128,7 @@ class ComboModel(allen_models.Model): ...@@ -126,7 +128,7 @@ class ComboModel(allen_models.Model):
"deprel": relations_pred, "deprel": relations_pred,
"enhanced_head": enhanced_head_pred, "enhanced_head": enhanced_head_pred,
"enhanced_deprel": enhanced_relations_pred, "enhanced_deprel": enhanced_relations_pred,
"sentence_embedding": torch.max(encoder_emb[:, 1:], dim=1)[0], "sentence_embedding": torch.max(encoder_emb, dim=1)[0],
} }
if "rel_probability" in enhanced_parser_output: if "rel_probability" in enhanced_parser_output:
...@@ -153,7 +155,7 @@ class ComboModel(allen_models.Model): ...@@ -153,7 +155,7 @@ class ComboModel(allen_models.Model):
"enhanced_head": enhanced_heads, "enhanced_head": enhanced_heads,
"enhanced_deprel": enhanced_deprels, "enhanced_deprel": enhanced_deprels,
} }
self.scores(output, labels, word_mask[:, 1:]) self.scores(output, labels, word_mask)
relations_loss, head_loss = parser_output["loss"] relations_loss, head_loss = parser_output["loss"]
enhanced_relations_loss, enhanced_head_loss = enhanced_parser_output["loss"] enhanced_relations_loss, enhanced_head_loss = enhanced_parser_output["loss"]
losses = { losses = {
......
"""Dependency parsing models.""" """Dependency parsing models."""
import math
from typing import Tuple, Dict, Optional, Union, List from typing import Tuple, Dict, Optional, Union, List
import numpy as np import numpy as np
......
...@@ -25,7 +25,7 @@ class COMBO(predictor.Predictor): ...@@ -25,7 +25,7 @@ class COMBO(predictor.Predictor):
model: models.Model, model: models.Model,
dataset_reader: allen_data.DatasetReader, dataset_reader: allen_data.DatasetReader,
tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer(), tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer(),
batch_size: int = 500, batch_size: int = 32,
line_to_conllu: bool = False) -> None: line_to_conllu: bool = False) -> None:
super().__init__(model, dataset_reader) super().__init__(model, dataset_reader)
self.batch_size = batch_size self.batch_size = batch_size
...@@ -52,7 +52,7 @@ class COMBO(predictor.Predictor): ...@@ -52,7 +52,7 @@ class COMBO(predictor.Predictor):
def predict(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]): def predict(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]):
if isinstance(sentence, str): if isinstance(sentence, str):
return data.Sentence.from_dict(self.predict_json({"sentence": sentence})) return self.predict_json({"sentence": sentence})
elif isinstance(sentence, list): elif isinstance(sentence, list):
if len(sentence) == 0: if len(sentence) == 0:
return [] return []
...@@ -219,7 +219,7 @@ class COMBO(predictor.Predictor): ...@@ -219,7 +219,7 @@ class COMBO(predictor.Predictor):
@classmethod @classmethod
def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(),
batch_size: int = 500, batch_size: int = 32,
cuda_device: int = -1): cuda_device: int = -1):
util.import_module_and_submodules("combo.commands") util.import_module_and_submodules("combo.commands")
util.import_module_and_submodules("combo.models") util.import_module_and_submodules("combo.models")
......
...@@ -6,7 +6,6 @@ REQUIREMENTS = [ ...@@ -6,7 +6,6 @@ REQUIREMENTS = [
'allennlp==1.2.1', 'allennlp==1.2.1',
'conllu==2.3.2', 'conllu==2.3.2',
'dataclasses;python_version<"3.7"', 'dataclasses;python_version<"3.7"',
'dataclasses-json==0.5.2',
'joblib==0.14.1', 'joblib==0.14.1',
'jsonnet==0.15.0', 'jsonnet==0.15.0',
'requests==2.23.0', 'requests==2.23.0',
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment