Skip to content
Snippets Groups Projects
Commit 4115a0b1 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Add dependency relation distribution field to the model

parent 74ab8843
No related branches found
No related tags found
1 merge request!46Merge COMBO 3.0 into master
...@@ -155,6 +155,8 @@ class ComboModel(Model, FromParameters): ...@@ -155,6 +155,8 @@ class ComboModel(Model, FromParameters):
"semrel_token_embedding": semrel_output["embedding"], "semrel_token_embedding": semrel_output["embedding"],
"feats_token_embedding": morpho_output["embedding"], "feats_token_embedding": morpho_output["embedding"],
"deprel_token_embedding": parser_output["embedding"], "deprel_token_embedding": parser_output["embedding"],
"deprel_tree_distribution": parser_output["deprel_tree_distribution"],
"deprel_label_distribution": parser_output["deprel_label_distribution"]
} }
if "rel_probability" in enhanced_parser_output: if "rel_probability" in enhanced_parser_output:
......
...@@ -3,6 +3,9 @@ import dataclasses ...@@ -3,6 +3,9 @@ import dataclasses
import json import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any, Union, Tuple from typing import Optional, List, Dict, Any, Union, Tuple
from conllu.models import Metadata
from combo.data.tokenizers import Token from combo.data.tokenizers import Token
import conllu import conllu
...@@ -12,13 +15,17 @@ from overrides import overrides ...@@ -12,13 +15,17 @@ from overrides import overrides
class Sentence: class Sentence:
tokens: List[Token] = field(default_factory=list) tokens: List[Token] = field(default_factory=list)
sentence_embedding: List[float] = field(default_factory=list, repr=False) sentence_embedding: List[float] = field(default_factory=list, repr=False)
relation_distribution: List[float] = field(default_factory=list, repr=False)
relation_label_distribution: List[float] = field(default_factory=list, repr=False)
metadata: Dict[str, Any] = field(default_factory=collections.OrderedDict) metadata: Dict[str, Any] = field(default_factory=collections.OrderedDict)
def to_json(self): def to_json(self):
return json.dumps({ return json.dumps({
"tokens": [dataclasses.asdict(t) for t in self.tokens], "tokens": [dataclasses.asdict(t) for t in self.tokens],
"sentence_embedding": self.sentence_embedding, "sentence_embedding": self.sentence_embedding,
"metadata": self.metadata, "relation_distribution": self.relation_distribution,
"relation_label_distribution": self.relation_label_distribution,
"metadata": self.metadata
}) })
def __len__(self): def __len__(self):
...@@ -46,7 +53,7 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.mode ...@@ -46,7 +53,7 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.mode
if len(dep) > 1 and type(dep[1]) == list: if len(dep) > 1 and type(dep[1]) == list:
dep[1] = tuple(dep[1]) dep[1] = tuple(dep[1])
return _TokenList(tokens=tokens, return _TokenList(tokens=tokens,
metadata=sentence.metadata) metadata=sentence.metadata if sentence.metadata is None else Metadata())
def tokens2conllu(tokens: List[str]) -> conllu.models.TokenList: def tokens2conllu(tokens: List[str]) -> conllu.models.TokenList:
...@@ -60,13 +67,23 @@ def tokens2conllu(tokens: List[str]) -> conllu.models.TokenList: ...@@ -60,13 +67,23 @@ def tokens2conllu(tokens: List[str]) -> conllu.models.TokenList:
def conllu2sentence(tokens: List[Token], def conllu2sentence(tokens: List[Token],
sentence_embedding=None, embeddings=None, sentence_embedding=None, embeddings=None,
metadata=None) -> Sentence: metadata=None,
if embeddings is None: relation_distribution=None,
embeddings = {} relation_label_distribution=None) -> Sentence:
embeddings = embeddings or {}
if relation_distribution is None:
relation_distribution = []
if relation_label_distribution is None:
relation_label_distribution = []
if sentence_embedding is None: if sentence_embedding is None:
sentence_embedding = [] sentence_embedding = []
if embeddings:
for token in tokens:
token.embeddings = embeddings[token["idx"]]
return Sentence( return Sentence(
tokens=tokens, tokens=tokens,
sentence_embedding=sentence_embedding, sentence_embedding=sentence_embedding,
relation_distribution=relation_distribution,
relation_label_distribution=relation_label_distribution,
metadata=metadata metadata=metadata
) )
...@@ -17,7 +17,7 @@ from combo.data.fields.metadata_field import MetadataField ...@@ -17,7 +17,7 @@ from combo.data.fields.metadata_field import MetadataField
from combo.data.fields.sequence_label_field import SequenceLabelField from combo.data.fields.sequence_label_field import SequenceLabelField
from combo.data.fields.text_field import TextField from combo.data.fields.text_field import TextField
from combo.data.token_indexers import TokenIndexer from combo.data.token_indexers import TokenIndexer
from combo.models import parser from combo.modules import parser
from combo.utils import checks, pad_sequence_to_length from combo.utils import checks, pad_sequence_to_length
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -174,6 +174,8 @@ class DependencyRelationModel(Predictor): ...@@ -174,6 +174,8 @@ class DependencyRelationModel(Predictor):
relation_prediction = self.relation_prediction_layer(dep_rel_pred) relation_prediction = self.relation_prediction_layer(dep_rel_pred)
output = head_output output = head_output
output["embedding"] = dep_rel_pred output["embedding"] = dep_rel_pred
output["deprel_label_distribution"] = F.softmax(relation_prediction[:, 1:, 1:], dim=-1)
output["deprel_tree_distribution"] = head_pred_soft
if self.training: if self.training:
output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"]) output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])
......
...@@ -105,8 +105,12 @@ class COMBO(PredictorModule): ...@@ -105,8 +105,12 @@ class COMBO(PredictorModule):
sentences = [] sentences = []
predictions = super().predict_batch_instance(instances) predictions = super().predict_batch_instance(instances)
for prediction, instance in zip(predictions, instances): for prediction, instance in zip(predictions, instances):
tree, sentence_embedding, embeddings = self._predictions_as_tree(prediction, instance) (tree, sentence_embedding, embeddings,
sentence = conllu2sentence(tree, sentence_embedding, embeddings) relation_distribution, relation_label_distribution) = self._predictions_as_tree(prediction, instance)
sentence = conllu2sentence(
tree, sentence_embedding, embeddings,
relation_distribution, relation_label_distribution
)
sentences.append(sentence) sentences.append(sentence)
return sentences return sentences
...@@ -119,8 +123,9 @@ class COMBO(PredictorModule): ...@@ -119,8 +123,9 @@ class COMBO(PredictorModule):
@overrides @overrides
def predict_instance(self, instance: Instance, serialize: bool = True) -> data.Sentence: def predict_instance(self, instance: Instance, serialize: bool = True) -> data.Sentence:
predictions = super().predict_instance(instance) predictions = super().predict_instance(instance)
tree, sentence_embedding, embeddings = self._predictions_as_tree(predictions, instance) (tree, sentence_embedding, embeddings,
return conllu2sentence(tree, sentence_embedding, embeddings) relation_distribution, relation_label_distribution) = self._predictions_as_tree(predictions, instance)
return conllu2sentence(tree, sentence_embedding, embeddings, relation_distribution, relation_label_distribution)
@overrides @overrides
def predict_json(self, inputs: JsonDict) -> data.Sentence: def predict_json(self, inputs: JsonDict) -> data.Sentence:
...@@ -166,9 +171,16 @@ class COMBO(PredictorModule): ...@@ -166,9 +171,16 @@ class COMBO(PredictorModule):
field_names = instance.fields["metadata"]["field_names"] field_names = instance.fields["metadata"]["field_names"]
tree_tokens = [t for t in tree if isinstance(t["idx"], int)] tree_tokens = [t for t in tree if isinstance(t["idx"], int)]
embeddings = {t["idx"]: {} for t in tree} embeddings = {t["idx"]: {} for t in tree}
deprel_tree_distribution = None
deprel_label_distribution = None
for field_name in field_names: for field_name in field_names:
if field_name not in predictions: if field_name not in predictions:
continue continue
if field_name == "deprel":
sentence_length = len(tree_tokens)
deprel_tree_distribution = np.matrix(predictions["deprel_tree_distribution"])[:sentence_length + 1,
:sentence_length + 1]
deprel_label_distribution = np.matrix(predictions["deprel_label_distribution"])[:sentence_length, :]
field_predictions = predictions[field_name] field_predictions = predictions[field_name]
for idx, token in enumerate(tree_tokens): for idx, token in enumerate(tree_tokens):
if field_name in {"xpostag", "upostag", "semrel", "deprel"}: if field_name in {"xpostag", "upostag", "semrel", "deprel"}:
...@@ -242,7 +254,9 @@ class COMBO(PredictorModule): ...@@ -242,7 +254,9 @@ class COMBO(PredictorModule):
empty_tokens = graph.restore_collapse_edges(tree_tokens) empty_tokens = graph.restore_collapse_edges(tree_tokens)
tree.tokens.extend(empty_tokens) tree.tokens.extend(empty_tokens)
return tree, predictions["sentence_embedding"], embeddings return tree, predictions["sentence_embedding"], embeddings, \
deprel_tree_distribution, deprel_label_distribution
@classmethod @classmethod
def with_spacy_tokenizer(cls, model: Model, def with_spacy_tokenizer(cls, model: Model,
...@@ -250,7 +264,9 @@ class COMBO(PredictorModule): ...@@ -250,7 +264,9 @@ class COMBO(PredictorModule):
return cls(model, dataset_reader, tokenizers.SpacyTokenizer()) return cls(model, dataset_reader, tokenizers.SpacyTokenizer())
@classmethod @classmethod
def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), def from_pretrained(cls,
path: str,
tokenizer=tokenizers.SpacyTokenizer(),
batch_size: int = 1024, batch_size: int = 1024,
cuda_device: int = -1): cuda_device: int = -1):
if os.path.exists(path): if os.path.exists(path):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment