From 1f8b632e78ab340fa1a88185940c537bfc229706 Mon Sep 17 00:00:00 2001 From: pszenny <pszenny@e-science.pl> Date: Tue, 22 Nov 2022 10:03:30 +0100 Subject: [PATCH] Adding sentence attributes: -relation_distribiution - probas of arc between dependent and head -relation_label_distribiution- probas for each label in sentence --- combo/data/api.py | 13 ++++++++++++- combo/models/model.py | 2 ++ combo/models/parser.py | 5 ++++- combo/predict.py | 17 ++++++++++++----- 4 files changed, 30 insertions(+), 7 deletions(-) diff --git a/combo/data/api.py b/combo/data/api.py index 308e9e4..adb1b27 100644 --- a/combo/data/api.py +++ b/combo/data/api.py @@ -28,12 +28,16 @@ class Token: class Sentence: tokens: List[Token] = field(default_factory=list) sentence_embedding: List[float] = field(default_factory=list, repr=False) + relation_distribution: List[float] = field(default_factory=list, repr=False) + relation_label_distribution: List[float] = field(default_factory=list, repr=False) metadata: Dict[str, Any] = field(default_factory=collections.OrderedDict) def to_json(self): return json.dumps({ "tokens": [dataclasses.asdict(t) for t in self.tokens], "sentence_embedding": self.sentence_embedding, + "relation_distribution": self.relation_distribution, + "relation_label_distribution": self.relation_label_distribution, "metadata": self.metadata, }) @@ -79,7 +83,12 @@ def tokens2conllu(tokens: List[str]) -> conllu.TokenList: def conllu2sentence(conllu_sentence: conllu.TokenList, - sentence_embedding=None, embeddings=None) -> Sentence: + sentence_embedding=None, embeddings=None, relation_distribution=None, + relation_label_distribution=None) -> Sentence: + if relation_distribution is None: + relation_distribution = [] + if relation_label_distribution is None: + relation_label_distribution = [] if embeddings is None: embeddings = {} if sentence_embedding is None: @@ -94,5 +103,7 @@ def conllu2sentence(conllu_sentence: conllu.TokenList, return Sentence( tokens=tokens, sentence_embedding=sentence_embedding, + relation_distribution=relation_distribution, + relation_label_distribution=relation_label_distribution, metadata=conllu_sentence.metadata ) diff --git a/combo/models/model.py b/combo/models/model.py index c648453..838979e 100644 --- a/combo/models/model.py +++ b/combo/models/model.py @@ -134,6 +134,8 @@ class ComboModel(allen_models.Model): "semrel_token_embedding": semrel_output["embedding"], "feats_token_embedding": morpho_output["embedding"], "deprel_token_embedding": parser_output["embedding"], + "deprel_tree_distribution": parser_output["deprel_tree_distribution"], + "deprel_label_distribution": parser_output["deprel_label_distribution"] } if "rel_probability" in enhanced_parser_output: diff --git a/combo/models/parser.py b/combo/models/parser.py index b16f0ad..bb0fc91 100644 --- a/combo/models/parser.py +++ b/combo/models/parser.py @@ -41,7 +41,7 @@ class HeadPredictionModel(base.Predictor): # Adding non existing in mask ROOT to lengths lengths = mask.data.sum(dim=1).long().cpu().numpy() + 1 for idx, length in enumerate(lengths): - probs = x[idx, :].softmax(dim=-1).cpu().numpy() + probs = x[idx, :].softmax(dim=-1).cpu().numpy() # tu jest macierz której szukamy # We do not want any word to be parent of the root node (ROOT, 0). # Also setting it to -1 instead of 0 fixes edge case where softmax made all @@ -154,6 +154,9 @@ class DependencyRelationModel(base.Predictor): relation_prediction = self.relation_prediction_layer(dep_rel_pred) output = head_output output["embedding"] = dep_rel_pred + #import pdb;pdb.set_trace() + output["deprel_label_distribution"] = F.softmax(relation_prediction[:, 1:, 1:], dim=-1) + output["deprel_tree_distribution"] = head_pred_soft if self.training: output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"]) diff --git a/combo/predict.py b/combo/predict.py index 83b030f..cdcd694 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -88,8 +88,8 @@ class COMBO(predictor.Predictor): sentences = [] predictions = super().predict_batch_instance(instances) for prediction, instance in zip(predictions, instances): - tree, sentence_embedding, embeddings = self._predictions_as_tree(prediction, instance) - sentence = conllu2sentence(tree, sentence_embedding, embeddings) + tree, sentence_embedding, embeddings, relation_distribution, relation_label_distribution = self._predictions_as_tree(prediction, instance) + sentence = conllu2sentence(tree, sentence_embedding, embeddings, relation_distribution, relation_label_distribution) sentences.append(sentence) return sentences @@ -102,8 +102,8 @@ class COMBO(predictor.Predictor): @overrides def predict_instance(self, instance: allen_data.Instance, serialize: bool = True) -> data.Sentence: predictions = super().predict_instance(instance) - tree, sentence_embedding, embeddings = self._predictions_as_tree(predictions, instance, ) - return conllu2sentence(tree, sentence_embedding, embeddings) + tree, sentence_embedding, embeddings, relation_distribution, relation_label_distribution = self._predictions_as_tree(predictions, instance, ) + return conllu2sentence(tree, sentence_embedding, embeddings, relation_distribution, relation_label_distribution) @overrides def predict_json(self, inputs: common.JsonDict) -> data.Sentence: @@ -148,9 +148,15 @@ class COMBO(predictor.Predictor): field_names = instance.fields["metadata"]["field_names"] tree_tokens = [t for t in tree if isinstance(t["id"], int)] embeddings = {t["id"]: {} for t in tree} + deprel_tree_distribution = None + deprel_label_distribution = None for field_name in field_names: if field_name not in predictions: continue + if field_name == "deprel": + sentence_length = len(tree_tokens) + deprel_tree_distribution = np.matrix(predictions["deprel_tree_distribution"])[:sentence_length+1,:sentence_length+1] + deprel_label_distribution = np.matrix(predictions["deprel_label_distribution"])[:sentence_length,:] field_predictions = predictions[field_name] for idx, token in enumerate(tree_tokens): if field_name in {"xpostag", "upostag", "semrel", "deprel"}: @@ -224,7 +230,8 @@ class COMBO(predictor.Predictor): empty_tokens = graph.restore_collapse_edges(tree_tokens) tree.tokens.extend(empty_tokens) - return tree, predictions["sentence_embedding"], embeddings + return tree, predictions["sentence_embedding"], embeddings, \ + deprel_tree_distribution, deprel_label_distribution @classmethod def with_spacy_tokenizer(cls, model: models.Model, -- GitLab