diff --git a/combo/data/api.py b/combo/data/api.py index 39f449a46a10ac9720943fe659faa60f412cafcc..5a8b7b7c1488f3ac10a413031be9a50c1e8a582b 100644 --- a/combo/data/api.py +++ b/combo/data/api.py @@ -28,12 +28,16 @@ class Token: class Sentence: tokens: List[Token] = field(default_factory=list) sentence_embedding: List[float] = field(default_factory=list, repr=False) + relation_distribution: List[float] = field(default_factory=list, repr=False) + relation_label_distribution: List[float] = field(default_factory=list, repr=False) metadata: Dict[str, Any] = field(default_factory=collections.OrderedDict) def to_json(self): return json.dumps({ "tokens": [dataclasses.asdict(t) for t in self.tokens], "sentence_embedding": self.sentence_embedding, + "relation_distribution": self.relation_distribution, + "relation_label_distribution": self.relation_label_distribution, "metadata": self.metadata, }) @@ -79,7 +83,12 @@ def tokens2conllu(tokens: List[str]) -> conllu.TokenList: def conllu2sentence(conllu_sentence: conllu.TokenList, - sentence_embedding=None, embeddings=None) -> Sentence: + sentence_embedding=None, embeddings=None, relation_distribution=None, + relation_label_distribution=None) -> Sentence: + if relation_distribution is None: + relation_distribution = [] + if relation_label_distribution is None: + relation_label_distribution = [] if embeddings is None: embeddings = {} if sentence_embedding is None: @@ -94,5 +103,7 @@ def conllu2sentence(conllu_sentence: conllu.TokenList, return Sentence( tokens=tokens, sentence_embedding=sentence_embedding, + relation_distribution=relation_distribution, + relation_label_distribution=relation_label_distribution, metadata=conllu_sentence.metadata ) diff --git a/combo/models/model.py b/combo/models/model.py index c648453db5ab951ed68573cf79094f1bf21286d2..838979e5ab4e51f5f27590663caaede695dc28cf 100644 --- a/combo/models/model.py +++ b/combo/models/model.py @@ -134,6 +134,8 @@ class ComboModel(allen_models.Model): "semrel_token_embedding": semrel_output["embedding"], "feats_token_embedding": morpho_output["embedding"], "deprel_token_embedding": parser_output["embedding"], + "deprel_tree_distribution": parser_output["deprel_tree_distribution"], + "deprel_label_distribution": parser_output["deprel_label_distribution"] } if "rel_probability" in enhanced_parser_output: diff --git a/combo/models/parser.py b/combo/models/parser.py index b16f0adcff066c39558cb8709122780d69ee8702..bb0fc91834778323d594a879c2918ab651cc50b4 100644 --- a/combo/models/parser.py +++ b/combo/models/parser.py @@ -41,7 +41,7 @@ class HeadPredictionModel(base.Predictor): # Adding non existing in mask ROOT to lengths lengths = mask.data.sum(dim=1).long().cpu().numpy() + 1 for idx, length in enumerate(lengths): - probs = x[idx, :].softmax(dim=-1).cpu().numpy() + probs = x[idx, :].softmax(dim=-1).cpu().numpy() # tu jest macierz której szukamy # We do not want any word to be parent of the root node (ROOT, 0). # Also setting it to -1 instead of 0 fixes edge case where softmax made all @@ -154,6 +154,9 @@ class DependencyRelationModel(base.Predictor): relation_prediction = self.relation_prediction_layer(dep_rel_pred) output = head_output output["embedding"] = dep_rel_pred + #import pdb;pdb.set_trace() + output["deprel_label_distribution"] = F.softmax(relation_prediction[:, 1:, 1:], dim=-1) + output["deprel_tree_distribution"] = head_pred_soft if self.training: output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"]) diff --git a/combo/predict.py b/combo/predict.py index 9d0b4a62af383f825f746d9fec8494d1c02882bb..52d009454f32a3aeb3007bf8d2ee8ba251e1fcc2 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -93,8 +93,8 @@ class COMBO(predictor.Predictor): sentences = [] predictions = super().predict_batch_instance(instances) for prediction, instance in zip(predictions, instances): - tree, sentence_embedding, embeddings = self._predictions_as_tree(prediction, instance) - sentence = conllu2sentence(tree, sentence_embedding, embeddings) + tree, sentence_embedding, embeddings, relation_distribution, relation_label_distribution = self._predictions_as_tree(prediction, instance) + sentence = conllu2sentence(tree, sentence_embedding, embeddings, relation_distribution, relation_label_distribution) sentences.append(sentence) return sentences @@ -107,8 +107,8 @@ class COMBO(predictor.Predictor): @overrides def predict_instance(self, instance: allen_data.Instance, serialize: bool = True) -> data.Sentence: predictions = super().predict_instance(instance) - tree, sentence_embedding, embeddings = self._predictions_as_tree(predictions, instance, ) - return conllu2sentence(tree, sentence_embedding, embeddings) + tree, sentence_embedding, embeddings, relation_distribution, relation_label_distribution = self._predictions_as_tree(predictions, instance, ) + return conllu2sentence(tree, sentence_embedding, embeddings, relation_distribution, relation_label_distribution) @overrides def predict_json(self, inputs: common.JsonDict) -> data.Sentence: @@ -153,9 +153,15 @@ class COMBO(predictor.Predictor): field_names = instance.fields["metadata"]["field_names"] tree_tokens = [t for t in tree if isinstance(t["id"], int)] embeddings = {t["id"]: {} for t in tree} + deprel_tree_distribution = None + deprel_label_distribution = None for field_name in field_names: if field_name not in predictions: continue + if field_name == "deprel": + sentence_length = len(tree_tokens) + deprel_tree_distribution = np.matrix(predictions["deprel_tree_distribution"])[:sentence_length+1,:sentence_length+1] + deprel_label_distribution = np.matrix(predictions["deprel_label_distribution"])[:sentence_length,:] field_predictions = predictions[field_name] for idx, token in enumerate(tree_tokens): if field_name in {"xpostag", "upostag", "semrel", "deprel"}: @@ -229,7 +235,8 @@ class COMBO(predictor.Predictor): empty_tokens = graph.restore_collapse_edges(tree_tokens) tree.tokens.extend(empty_tokens) - return tree, predictions["sentence_embedding"], embeddings + return tree, predictions["sentence_embedding"], embeddings, \ + deprel_tree_distribution, deprel_label_distribution @classmethod def with_spacy_tokenizer(cls, model: models.Model,