Skip to content
Snippets Groups Projects
Commit 971f6109 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Add batch prediction for jsons and data instances.

parent e3507eba
Branches
Tags
No related merge requests found
import collections
import errno
import logging
import os
import time
from typing import List
from typing import List, Union
import conllu
import requests
import tqdm
from allennlp import data as allen_data, common, models
from allennlp.common import util
from allennlp.data import tokenizers
......@@ -29,19 +26,26 @@ class SemanticMultitaskPredictor(predictor.Predictor):
dataset_reader: allen_data.DatasetReader,
tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer()) -> None:
super().__init__(model, dataset_reader)
self.batch_size = 1000
self.vocab = model.vocab
self._dataset_reader.generate_labels = False
self._tokenizer = tokenizer
@overrides
def _json_to_instance(self, json_dict: common.JsonDict) -> allen_data.Instance:
tokens = self._tokenizer.tokenize(json_dict["sentence"])
tree = self._sentence_to_tree([t.text for t in tokens])
sentence = json_dict["sentence"]
if isinstance(sentence, str):
tokens = [t.text for t in self._tokenizer.tokenize(json_dict["sentence"])]
elif isinstance(sentence, list):
tokens = sentence
else:
raise ValueError("Input must be either string or list of strings.")
tree = self._sentence_to_tree(tokens)
return self._dataset_reader.text_to_instance(tree)
@overrides
def load_line(self, line: str) -> common.JsonDict:
return {"sentence": line.replace("\n", " ").strip()}
return self._to_input_json(line.replace("\n", "").strip())
@overrides
def dump_line(self, outputs: common.JsonDict) -> str:
......@@ -52,35 +56,61 @@ class SemanticMultitaskPredictor(predictor.Predictor):
else:
return str(outputs["tree"]) + "\n"
def predict(self, sentence: Union[str, List[str]]):
if isinstance(sentence, str):
return data.Sentence.from_json(self.predict_json({"sentence": sentence}))
elif isinstance(sentence, list):
sentences = []
for sentences_batch in util.lazy_groups_of(sentence, self.batch_size):
trees = self.predict_batch_json([self._to_input_json(s) for s in sentences_batch])
sentences.extend([data.Sentence.from_json(t) for t in trees])
return sentences
else:
raise ValueError("Input must be either string or list of strings.")
def __call__(self, sentence: Union[str, List[str]]):
return self.predict(sentence)
@overrides
def predict_batch_instance(self, instances: List[allen_data.Instance]) -> List[common.JsonDict]:
trees = []
predictions = super().predict_batch_instance(instances)
for prediction, instance in zip(predictions, instances):
tree_json = util.sanitize(self._predictions_as_tree(prediction, instance).serialize())
trees.append(collections.OrderedDict([
("tree", tree_json),
]))
return trees
@overrides
def predict_instance(self, instance: allen_data.Instance) -> common.JsonDict:
start_time = time.time()
tree = self.predict_instance_as_tree(instance)
tree_json = util.sanitize(tree.serialize())
result = collections.OrderedDict([
("tree", tree_json),
])
end_time = time.time()
logger.info(f"Took {(end_time - start_time) * 1000.0} ms")
return result
def predict(self, sentence: str):
return data.Sentence.from_json(self.predict_json({"sentence": sentence}))
def __call__(self, sentence: str):
return self.predict(sentence)
@overrides
def predict_batch_json(self, inputs: List[common.JsonDict]) -> List[common.JsonDict]:
trees = []
instances = self._batch_json_to_instances(inputs)
predictions = super().predict_batch_json(inputs)
for prediction, instance in zip(predictions, instances):
tree_json = util.sanitize(self._predictions_as_tree(prediction, instance))
trees.append(collections.OrderedDict([
("tree", tree_json),
]))
return trees
@overrides
def predict_json(self, inputs: common.JsonDict) -> common.JsonDict:
start_time = time.time()
instance = self._json_to_instance(inputs)
tree = self.predict_instance_as_tree(instance)
tree_json = util.sanitize(tree)
result = collections.OrderedDict([
("tree", tree_json),
])
end_time = time.time()
logger.info(f"Took {(end_time - start_time) * 1000.0} ms")
return result
def predict_instance_as_tree(self, instance: allen_data.Instance) -> conllu.TokenList:
......@@ -97,6 +127,10 @@ class SemanticMultitaskPredictor(predictor.Predictor):
metadata=collections.OrderedDict()
)
@staticmethod
def _to_input_json(sentence: str):
return {"sentence": sentence}
def _predictions_as_tree(self, predictions, instance):
tree = instance.fields["metadata"]["input"]
field_names = instance.fields["metadata"]["field_names"]
......@@ -165,12 +199,14 @@ class SemanticMultitaskPredictor(predictor.Predictor):
model_path = path
else:
try:
logger.debug("Downloading model.")
model_path = download.download_file(path)
except Exception as e:
logger.error(e)
raise e
model = models.Model.from_archive(model_path)
archive = models.load_archive(model_path)
model = archive.model
dataset_reader = allen_data.DatasetReader.from_params(
models.load_archive(model_path).config["dataset_reader"])
archive.config["dataset_reader"])
return cls(model, dataset_reader, tokenizer)
# sent_id = test-s1
# text = Easy sentence.
1 Verylongwordwhichmustbetruncatedbythesystemto30 easy ADJ adj AdpType=Prep|Adp 1 amod _ _
1 Verylongwordwhichmustbetruncatedbythesystemto30 easy ADJ adj AdpType=Prep|Adp 2 amod _ _
2 Sentence verylonglemmawhichmustbetruncatedbythesystemto30 NOUN nom Number=Sing 0 root _ _
3 . . PUNCT . _ 1 punct _ _
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment