Skip to content
Snippets Groups Projects
Commit de9301ad authored by Martyna Wiącek's avatar Martyna Wiącek
Browse files

Predict and save deprel matrix

parent 1f8b632e
No related merge requests found
Pipeline #9482 passed with stage
in 3 minutes and 23 seconds
import collections import collections
import dataclasses import dataclasses
import json import json
import os
import string
from dataclasses import dataclass, field from dataclasses import dataclass, field
from json import JSONEncoder
import random
import numpy
from typing import Optional, List, Dict, Any, Union, Tuple from typing import Optional, List, Dict, Any, Union, Tuple
import conllu import conllu
...@@ -32,14 +38,28 @@ class Sentence: ...@@ -32,14 +38,28 @@ class Sentence:
relation_label_distribution: List[float] = field(default_factory=list, repr=False) relation_label_distribution: List[float] = field(default_factory=list, repr=False)
metadata: Dict[str, Any] = field(default_factory=collections.OrderedDict) metadata: Dict[str, Any] = field(default_factory=collections.OrderedDict)
def to_json(self): def to_json(self, save_relation_distribution_path=None):
class NumpyArrayEncoder(JSONEncoder):
def default(self, obj):
if isinstance(obj, numpy.ndarray):
return obj.tolist()
return JSONEncoder.default(self, obj)
hash = ''.join(random.sample(string.ascii_letters + string.digits, 32))
numpy.savez(
os.path.join(save_relation_distribution_path, hash + '.npz'),
relation_distribution=self.relation_distribution,
relation_distribution_trimmed=self.relation_distribution[1:, 1:],
relation_label_distribution=self.relation_label_distribution
)
# numpy.savez(hash + '_trimmed.npz', self.relation_distribution[1:, 1:])
return json.dumps({ return json.dumps({
"tokens": [dataclasses.asdict(t) for t in self.tokens], "tokens": [(t.token, t.lemma, t.upostag, t.xpostag, t. feats, t.head, t.deprel) for t in self.tokens],
"sentence_embedding": self.sentence_embedding, # "sentence_embedding": self.sentence_embedding,
"relation_distribution": self.relation_distribution, "relation_distribution_hash": hash,
"relation_label_distribution": self.relation_label_distribution, "path_file": str(os.path.join(save_relation_distribution_path, hash + '.npz'))
"metadata": self.metadata, }, cls=NumpyArrayEncoder)
})
def __len__(self): def __len__(self):
return len(self.tokens) return len(self.tokens)
......
...@@ -88,6 +88,8 @@ flags.DEFINE_boolean(name="silent", default=True, ...@@ -88,6 +88,8 @@ flags.DEFINE_boolean(name="silent", default=True,
flags.DEFINE_enum(name="predictor_name", default="combo-spacy", flags.DEFINE_enum(name="predictor_name", default="combo-spacy",
enum_values=["combo", "combo-spacy"], enum_values=["combo", "combo-spacy"],
help="Use predictor with whitespace or spacy tokenizer.") help="Use predictor with whitespace or spacy tokenizer.")
flags.DEFINE_string(name="save_relation_distribution_path", default=None,
help="Save relation distribution to file.")
def run(_): def run(_):
...@@ -153,6 +155,8 @@ def run(_): ...@@ -153,6 +155,8 @@ def run(_):
use_dataset_reader = False use_dataset_reader = False
predictor.without_sentence_embedding = True predictor.without_sentence_embedding = True
predictor.line_to_conllu = False predictor.line_to_conllu = False
predictor.line_to_conllu = False
predictor.save_relation_distribution_path = FLAGS.save_relation_distribution_path
if FLAGS.silent: if FLAGS.silent:
logging.getLogger("allennlp.common.params").disabled = True logging.getLogger("allennlp.common.params").disabled = True
manager = allen_predict._PredictManager( manager = allen_predict._PredictManager(
......
...@@ -134,7 +134,9 @@ class COMBO(predictor.Predictor): ...@@ -134,7 +134,9 @@ class COMBO(predictor.Predictor):
if self.line_to_conllu: if self.line_to_conllu:
return sentence2conllu(outputs, keep_semrel=self.dataset_reader.use_sem).serialize() return sentence2conllu(outputs, keep_semrel=self.dataset_reader.use_sem).serialize()
else: else:
return outputs.to_json() if not os.path.exists(self.save_relation_distribution_path):
os.makedirs(self.save_relation_distribution_path)
return outputs.to_json(self.save_relation_distribution_path)
@staticmethod @staticmethod
def _to_input_json(sentence: str): def _to_input_json(sentence: str):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment