diff --git a/combo/combo_model.py b/combo/combo_model.py index d2312319a930ab43b7db1c8cc3a510fdb9a1b775..2747ad2dd904e40f218dcb2bb316d5cfbbecaf04 100644 --- a/combo/combo_model.py +++ b/combo/combo_model.py @@ -155,6 +155,8 @@ class ComboModel(Model, FromParameters): "semrel_token_embedding": semrel_output["embedding"], "feats_token_embedding": morpho_output["embedding"], "deprel_token_embedding": parser_output["embedding"], + "deprel_tree_distribution": parser_output["deprel_tree_distribution"], + "deprel_label_distribution": parser_output["deprel_label_distribution"] } if "rel_probability" in enhanced_parser_output: diff --git a/combo/data/api.py b/combo/data/api.py index ae4ef35a0ee4e8958b8663b0721bb064adb170d7..b04192f04b2db84981ea34cfcfa1297769078f26 100644 --- a/combo/data/api.py +++ b/combo/data/api.py @@ -3,6 +3,9 @@ import dataclasses import json from dataclasses import dataclass, field from typing import Optional, List, Dict, Any, Union, Tuple + +from conllu.models import Metadata + from combo.data.tokenizers import Token import conllu @@ -12,13 +15,17 @@ from overrides import overrides class Sentence: tokens: List[Token] = field(default_factory=list) sentence_embedding: List[float] = field(default_factory=list, repr=False) + relation_distribution: List[float] = field(default_factory=list, repr=False) + relation_label_distribution: List[float] = field(default_factory=list, repr=False) metadata: Dict[str, Any] = field(default_factory=collections.OrderedDict) def to_json(self): return json.dumps({ "tokens": [dataclasses.asdict(t) for t in self.tokens], "sentence_embedding": self.sentence_embedding, - "metadata": self.metadata, + "relation_distribution": self.relation_distribution, + "relation_label_distribution": self.relation_label_distribution, + "metadata": self.metadata }) def __len__(self): @@ -46,7 +53,7 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.mode if len(dep) > 1 and type(dep[1]) == list: dep[1] = tuple(dep[1]) return _TokenList(tokens=tokens, - metadata=sentence.metadata) + metadata=sentence.metadata if sentence.metadata is None else Metadata()) def tokens2conllu(tokens: List[str]) -> conllu.models.TokenList: @@ -60,13 +67,23 @@ def tokens2conllu(tokens: List[str]) -> conllu.models.TokenList: def conllu2sentence(tokens: List[Token], sentence_embedding=None, embeddings=None, - metadata=None) -> Sentence: - if embeddings is None: - embeddings = {} + metadata=None, + relation_distribution=None, + relation_label_distribution=None) -> Sentence: + embeddings = embeddings or {} + if relation_distribution is None: + relation_distribution = [] + if relation_label_distribution is None: + relation_label_distribution = [] if sentence_embedding is None: sentence_embedding = [] + if embeddings: + for token in tokens: + token.embeddings = embeddings[token["idx"]] return Sentence( tokens=tokens, sentence_embedding=sentence_embedding, + relation_distribution=relation_distribution, + relation_label_distribution=relation_label_distribution, metadata=metadata ) diff --git a/combo/data/dataset.py b/combo/data/dataset.py index 3993a55c84b9d9486826f756144ef846bedde481..46d8aba42b6a0ca90d9fac80e74da60d18c0c8f9 100644 --- a/combo/data/dataset.py +++ b/combo/data/dataset.py @@ -17,7 +17,7 @@ from combo.data.fields.metadata_field import MetadataField from combo.data.fields.sequence_label_field import SequenceLabelField from combo.data.fields.text_field import TextField from combo.data.token_indexers import TokenIndexer -from combo.models import parser +from combo.modules import parser from combo.utils import checks, pad_sequence_to_length logger = logging.getLogger(__name__) diff --git a/combo/modules/parser.py b/combo/modules/parser.py index ef885f28daf469d86735899ceaf4435c2a047deb..e034566f0fab7913ee348d0bcb519966385f681f 100644 --- a/combo/modules/parser.py +++ b/combo/modules/parser.py @@ -174,6 +174,8 @@ class DependencyRelationModel(Predictor): relation_prediction = self.relation_prediction_layer(dep_rel_pred) output = head_output output["embedding"] = dep_rel_pred + output["deprel_label_distribution"] = F.softmax(relation_prediction[:, 1:, 1:], dim=-1) + output["deprel_tree_distribution"] = head_pred_soft if self.training: output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"]) diff --git a/combo/predict.py b/combo/predict.py index c4218f58e5f2cb3bd88433fa934d0420f9027310..e30889d33eaa59956d08ea73a888b351e00f8103 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -105,8 +105,12 @@ class COMBO(PredictorModule): sentences = [] predictions = super().predict_batch_instance(instances) for prediction, instance in zip(predictions, instances): - tree, sentence_embedding, embeddings = self._predictions_as_tree(prediction, instance) - sentence = conllu2sentence(tree, sentence_embedding, embeddings) + (tree, sentence_embedding, embeddings, + relation_distribution, relation_label_distribution) = self._predictions_as_tree(prediction, instance) + sentence = conllu2sentence( + tree, sentence_embedding, embeddings, + relation_distribution, relation_label_distribution + ) sentences.append(sentence) return sentences @@ -119,8 +123,9 @@ class COMBO(PredictorModule): @overrides def predict_instance(self, instance: Instance, serialize: bool = True) -> data.Sentence: predictions = super().predict_instance(instance) - tree, sentence_embedding, embeddings = self._predictions_as_tree(predictions, instance) - return conllu2sentence(tree, sentence_embedding, embeddings) + (tree, sentence_embedding, embeddings, + relation_distribution, relation_label_distribution) = self._predictions_as_tree(predictions, instance) + return conllu2sentence(tree, sentence_embedding, embeddings, relation_distribution, relation_label_distribution) @overrides def predict_json(self, inputs: JsonDict) -> data.Sentence: @@ -166,9 +171,16 @@ class COMBO(PredictorModule): field_names = instance.fields["metadata"]["field_names"] tree_tokens = [t for t in tree if isinstance(t["idx"], int)] embeddings = {t["idx"]: {} for t in tree} + deprel_tree_distribution = None + deprel_label_distribution = None for field_name in field_names: if field_name not in predictions: continue + if field_name == "deprel": + sentence_length = len(tree_tokens) + deprel_tree_distribution = np.matrix(predictions["deprel_tree_distribution"])[:sentence_length + 1, + :sentence_length + 1] + deprel_label_distribution = np.matrix(predictions["deprel_label_distribution"])[:sentence_length, :] field_predictions = predictions[field_name] for idx, token in enumerate(tree_tokens): if field_name in {"xpostag", "upostag", "semrel", "deprel"}: @@ -242,7 +254,9 @@ class COMBO(PredictorModule): empty_tokens = graph.restore_collapse_edges(tree_tokens) tree.tokens.extend(empty_tokens) - return tree, predictions["sentence_embedding"], embeddings + return tree, predictions["sentence_embedding"], embeddings, \ + deprel_tree_distribution, deprel_label_distribution + @classmethod def with_spacy_tokenizer(cls, model: Model, @@ -250,7 +264,9 @@ class COMBO(PredictorModule): return cls(model, dataset_reader, tokenizers.SpacyTokenizer()) @classmethod - def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), + def from_pretrained(cls, + path: str, + tokenizer=tokenizers.SpacyTokenizer(), batch_size: int = 1024, cuda_device: int = -1): if os.path.exists(path):