Skip to content
Snippets Groups Projects
Commit dc482287 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Fix handling sentence embeddings.

parent 29cc2f3a
Branches
Tags
No related merge requests found
......@@ -79,7 +79,7 @@ class SemanticMultitaskPredictor(predictor.Predictor):
for prediction, instance in zip(predictions, instances):
tree = self._predictions_as_tree(prediction, instance)
if serialize:
tree = tree.serialize()
tree = self._serialize(tree)
tree_json = util.sanitize(tree)
trees.append(collections.OrderedDict([
("tree", tree_json),
......@@ -89,11 +89,13 @@ class SemanticMultitaskPredictor(predictor.Predictor):
@overrides
def predict_instance(self, instance: allen_data.Instance, serialize: bool = True) -> common.JsonDict:
tree = self.predict_instance_as_tree(instance)
sentence_embedding = tree.metadata.get("sentence_embedding", [])
if serialize:
tree = tree.serialize()
tree = self._serialize(tree)
tree_json = util.sanitize(tree)
result = collections.OrderedDict([
("tree", tree_json),
("sentence_embedding", sentence_embedding)
])
return result
......@@ -118,6 +120,7 @@ class SemanticMultitaskPredictor(predictor.Predictor):
result = collections.OrderedDict([
("tree", tree_json),
])
result["sentence_embedding"] = tree.metadata.get("sentence_embedding", [])
return result
def predict_instance_as_tree(self, instance: allen_data.Instance) -> conllu.TokenList:
......@@ -217,9 +220,15 @@ class SemanticMultitaskPredictor(predictor.Predictor):
raise NotImplementedError(f"Unknown field name {field_name}!")
if self._dataset_reader and "sent" in self._dataset_reader._targets:
tree.metadata["sentence_embedding"] = str(predictions["sentence_embedding"])
tree.metadata["sentence_embedding"] = predictions["sentence_embedding"]
return tree
@staticmethod
def _serialize(tree: conllu.TokenList):
if "sentence_embedding" in tree.metadata:
tree.metadata["sentence_embedding"] = str(tree.metadata["sentence_embedding"])
return tree.serialize()
@classmethod
def with_spacy_tokenizer(cls, model: models.Model,
dataset_reader: allen_data.DatasetReader):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment