Skip to content
Snippets Groups Projects
Commit 971f6109 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Add batch prediction for jsons and data instances.

parent e3507eba
No related merge requests found
...@@ -94,7 +94,7 @@ class SemanticMultitaskModel(allen_models.Model): ...@@ -94,7 +94,7 @@ class SemanticMultitaskModel(allen_models.Model):
sample_weights=sample_weights) sample_weights=sample_weights)
lemma_output = self._optional(self.lemmatizer, lemma_output = self._optional(self.lemmatizer,
(encoder_emb[:, 1:], sentence.get("char").get("token_characters") (encoder_emb[:, 1:], sentence.get("char").get("token_characters")
if sentence.get("char") else None), if sentence.get("char") else None),
mask=word_mask[:, 1:], mask=word_mask[:, 1:],
labels=lemma.get("char").get("token_characters") if lemma else None, labels=lemma.get("char").get("token_characters") if lemma else None,
sample_weights=sample_weights) sample_weights=sample_weights)
......
import collections import collections
import errno
import logging import logging
import os import os
import time import time
from typing import List from typing import List, Union
import conllu import conllu
import requests
import tqdm
from allennlp import data as allen_data, common, models from allennlp import data as allen_data, common, models
from allennlp.common import util from allennlp.common import util
from allennlp.data import tokenizers from allennlp.data import tokenizers
...@@ -29,19 +26,26 @@ class SemanticMultitaskPredictor(predictor.Predictor): ...@@ -29,19 +26,26 @@ class SemanticMultitaskPredictor(predictor.Predictor):
dataset_reader: allen_data.DatasetReader, dataset_reader: allen_data.DatasetReader,
tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer()) -> None: tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer()) -> None:
super().__init__(model, dataset_reader) super().__init__(model, dataset_reader)
self.batch_size = 1000
self.vocab = model.vocab self.vocab = model.vocab
self._dataset_reader.generate_labels = False self._dataset_reader.generate_labels = False
self._tokenizer = tokenizer self._tokenizer = tokenizer
@overrides @overrides
def _json_to_instance(self, json_dict: common.JsonDict) -> allen_data.Instance: def _json_to_instance(self, json_dict: common.JsonDict) -> allen_data.Instance:
tokens = self._tokenizer.tokenize(json_dict["sentence"]) sentence = json_dict["sentence"]
tree = self._sentence_to_tree([t.text for t in tokens]) if isinstance(sentence, str):
tokens = [t.text for t in self._tokenizer.tokenize(json_dict["sentence"])]
elif isinstance(sentence, list):
tokens = sentence
else:
raise ValueError("Input must be either string or list of strings.")
tree = self._sentence_to_tree(tokens)
return self._dataset_reader.text_to_instance(tree) return self._dataset_reader.text_to_instance(tree)
@overrides @overrides
def load_line(self, line: str) -> common.JsonDict: def load_line(self, line: str) -> common.JsonDict:
return {"sentence": line.replace("\n", " ").strip()} return self._to_input_json(line.replace("\n", "").strip())
@overrides @overrides
def dump_line(self, outputs: common.JsonDict) -> str: def dump_line(self, outputs: common.JsonDict) -> str:
...@@ -52,35 +56,61 @@ class SemanticMultitaskPredictor(predictor.Predictor): ...@@ -52,35 +56,61 @@ class SemanticMultitaskPredictor(predictor.Predictor):
else: else:
return str(outputs["tree"]) + "\n" return str(outputs["tree"]) + "\n"
def predict(self, sentence: Union[str, List[str]]):
if isinstance(sentence, str):
return data.Sentence.from_json(self.predict_json({"sentence": sentence}))
elif isinstance(sentence, list):
sentences = []
for sentences_batch in util.lazy_groups_of(sentence, self.batch_size):
trees = self.predict_batch_json([self._to_input_json(s) for s in sentences_batch])
sentences.extend([data.Sentence.from_json(t) for t in trees])
return sentences
else:
raise ValueError("Input must be either string or list of strings.")
def __call__(self, sentence: Union[str, List[str]]):
return self.predict(sentence)
@overrides
def predict_batch_instance(self, instances: List[allen_data.Instance]) -> List[common.JsonDict]:
trees = []
predictions = super().predict_batch_instance(instances)
for prediction, instance in zip(predictions, instances):
tree_json = util.sanitize(self._predictions_as_tree(prediction, instance).serialize())
trees.append(collections.OrderedDict([
("tree", tree_json),
]))
return trees
@overrides @overrides
def predict_instance(self, instance: allen_data.Instance) -> common.JsonDict: def predict_instance(self, instance: allen_data.Instance) -> common.JsonDict:
start_time = time.time()
tree = self.predict_instance_as_tree(instance) tree = self.predict_instance_as_tree(instance)
tree_json = util.sanitize(tree.serialize()) tree_json = util.sanitize(tree.serialize())
result = collections.OrderedDict([ result = collections.OrderedDict([
("tree", tree_json), ("tree", tree_json),
]) ])
end_time = time.time()
logger.info(f"Took {(end_time - start_time) * 1000.0} ms")
return result return result
def predict(self, sentence: str): @overrides
return data.Sentence.from_json(self.predict_json({"sentence": sentence})) def predict_batch_json(self, inputs: List[common.JsonDict]) -> List[common.JsonDict]:
trees = []
def __call__(self, sentence: str): instances = self._batch_json_to_instances(inputs)
return self.predict(sentence) predictions = super().predict_batch_json(inputs)
for prediction, instance in zip(predictions, instances):
tree_json = util.sanitize(self._predictions_as_tree(prediction, instance))
trees.append(collections.OrderedDict([
("tree", tree_json),
]))
return trees
@overrides @overrides
def predict_json(self, inputs: common.JsonDict) -> common.JsonDict: def predict_json(self, inputs: common.JsonDict) -> common.JsonDict:
start_time = time.time()
instance = self._json_to_instance(inputs) instance = self._json_to_instance(inputs)
tree = self.predict_instance_as_tree(instance) tree = self.predict_instance_as_tree(instance)
tree_json = util.sanitize(tree) tree_json = util.sanitize(tree)
result = collections.OrderedDict([ result = collections.OrderedDict([
("tree", tree_json), ("tree", tree_json),
]) ])
end_time = time.time()
logger.info(f"Took {(end_time - start_time) * 1000.0} ms")
return result return result
def predict_instance_as_tree(self, instance: allen_data.Instance) -> conllu.TokenList: def predict_instance_as_tree(self, instance: allen_data.Instance) -> conllu.TokenList:
...@@ -97,6 +127,10 @@ class SemanticMultitaskPredictor(predictor.Predictor): ...@@ -97,6 +127,10 @@ class SemanticMultitaskPredictor(predictor.Predictor):
metadata=collections.OrderedDict() metadata=collections.OrderedDict()
) )
@staticmethod
def _to_input_json(sentence: str):
return {"sentence": sentence}
def _predictions_as_tree(self, predictions, instance): def _predictions_as_tree(self, predictions, instance):
tree = instance.fields["metadata"]["input"] tree = instance.fields["metadata"]["input"]
field_names = instance.fields["metadata"]["field_names"] field_names = instance.fields["metadata"]["field_names"]
...@@ -165,12 +199,14 @@ class SemanticMultitaskPredictor(predictor.Predictor): ...@@ -165,12 +199,14 @@ class SemanticMultitaskPredictor(predictor.Predictor):
model_path = path model_path = path
else: else:
try: try:
logger.debug("Downloading model.")
model_path = download.download_file(path) model_path = download.download_file(path)
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
raise e raise e
model = models.Model.from_archive(model_path) archive = models.load_archive(model_path)
model = archive.model
dataset_reader = allen_data.DatasetReader.from_params( dataset_reader = allen_data.DatasetReader.from_params(
models.load_archive(model_path).config["dataset_reader"]) archive.config["dataset_reader"])
return cls(model, dataset_reader, tokenizer) return cls(model, dataset_reader, tokenizer)
# sent_id = test-s1 # sent_id = test-s1
# text = Easy sentence. # text = Easy sentence.
1 Verylongwordwhichmustbetruncatedbythesystemto30 easy ADJ adj AdpType=Prep|Adp 1 amod _ _ 1 Verylongwordwhichmustbetruncatedbythesystemto30 easy ADJ adj AdpType=Prep|Adp 2 amod _ _
2 Sentence verylonglemmawhichmustbetruncatedbythesystemto30 NOUN nom Number=Sing 0 root _ _ 2 Sentence verylonglemmawhichmustbetruncatedbythesystemto30 NOUN nom Number=Sing 0 root _ _
3 . . PUNCT . _ 1 punct _ _ 3 . . PUNCT . _ 1 punct _ _
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment