From 971f61093535aa6731abbd5cb22be26ac3051e78 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Fri, 12 Jun 2020 14:16:41 +0200
Subject: [PATCH] Add batch prediction for jsons and data instances.
---
combo/models/model.py | 2 +-
combo/predict.py | 76 ++++++++++++++++++++++++++---------
tests/fixtures/example.conllu | 2 +-
3 files changed, 58 insertions(+), 22 deletions(-)
diff --git a/combo/models/model.py b/combo/models/model.py
index 72fc661..77b43e3 100644
--- a/combo/models/model.py
+++ b/combo/models/model.py
@@ -94,7 +94,7 @@ class SemanticMultitaskModel(allen_models.Model):
sample_weights=sample_weights)
lemma_output = self._optional(self.lemmatizer,
(encoder_emb[:, 1:], sentence.get("char").get("token_characters")
- if sentence.get("char") else None),
+ if sentence.get("char") else None),
mask=word_mask[:, 1:],
labels=lemma.get("char").get("token_characters") if lemma else None,
sample_weights=sample_weights)
diff --git a/combo/predict.py b/combo/predict.py
index 261bd6c..3e81eb5 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -1,13 +1,10 @@
import collections
-import errno
import logging
import os
import time
-from typing import List
+from typing import List, Union
import conllu
-import requests
-import tqdm
from allennlp import data as allen_data, common, models
from allennlp.common import util
from allennlp.data import tokenizers
@@ -29,19 +26,26 @@ class SemanticMultitaskPredictor(predictor.Predictor):
dataset_reader: allen_data.DatasetReader,
tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer()) -> None:
super().__init__(model, dataset_reader)
+ self.batch_size = 1000
self.vocab = model.vocab
self._dataset_reader.generate_labels = False
self._tokenizer = tokenizer
@overrides
def _json_to_instance(self, json_dict: common.JsonDict) -> allen_data.Instance:
- tokens = self._tokenizer.tokenize(json_dict["sentence"])
- tree = self._sentence_to_tree([t.text for t in tokens])
+ sentence = json_dict["sentence"]
+ if isinstance(sentence, str):
+ tokens = [t.text for t in self._tokenizer.tokenize(json_dict["sentence"])]
+ elif isinstance(sentence, list):
+ tokens = sentence
+ else:
+ raise ValueError("Input must be either string or list of strings.")
+ tree = self._sentence_to_tree(tokens)
return self._dataset_reader.text_to_instance(tree)
@overrides
def load_line(self, line: str) -> common.JsonDict:
- return {"sentence": line.replace("\n", " ").strip()}
+ return self._to_input_json(line.replace("\n", "").strip())
@overrides
def dump_line(self, outputs: common.JsonDict) -> str:
@@ -52,35 +56,61 @@ class SemanticMultitaskPredictor(predictor.Predictor):
else:
return str(outputs["tree"]) + "\n"
+ def predict(self, sentence: Union[str, List[str]]):
+ if isinstance(sentence, str):
+ return data.Sentence.from_json(self.predict_json({"sentence": sentence}))
+ elif isinstance(sentence, list):
+ sentences = []
+ for sentences_batch in util.lazy_groups_of(sentence, self.batch_size):
+ trees = self.predict_batch_json([self._to_input_json(s) for s in sentences_batch])
+ sentences.extend([data.Sentence.from_json(t) for t in trees])
+ return sentences
+ else:
+ raise ValueError("Input must be either string or list of strings.")
+
+ def __call__(self, sentence: Union[str, List[str]]):
+ return self.predict(sentence)
+
+ @overrides
+ def predict_batch_instance(self, instances: List[allen_data.Instance]) -> List[common.JsonDict]:
+ trees = []
+ predictions = super().predict_batch_instance(instances)
+ for prediction, instance in zip(predictions, instances):
+ tree_json = util.sanitize(self._predictions_as_tree(prediction, instance).serialize())
+ trees.append(collections.OrderedDict([
+ ("tree", tree_json),
+ ]))
+ return trees
+
@overrides
def predict_instance(self, instance: allen_data.Instance) -> common.JsonDict:
- start_time = time.time()
tree = self.predict_instance_as_tree(instance)
tree_json = util.sanitize(tree.serialize())
result = collections.OrderedDict([
("tree", tree_json),
])
- end_time = time.time()
- logger.info(f"Took {(end_time - start_time) * 1000.0} ms")
return result
- def predict(self, sentence: str):
- return data.Sentence.from_json(self.predict_json({"sentence": sentence}))
-
- def __call__(self, sentence: str):
- return self.predict(sentence)
+ @overrides
+ def predict_batch_json(self, inputs: List[common.JsonDict]) -> List[common.JsonDict]:
+ trees = []
+ instances = self._batch_json_to_instances(inputs)
+ predictions = super().predict_batch_json(inputs)
+ for prediction, instance in zip(predictions, instances):
+ tree_json = util.sanitize(self._predictions_as_tree(prediction, instance))
+ trees.append(collections.OrderedDict([
+ ("tree", tree_json),
+ ]))
+ return trees
@overrides
def predict_json(self, inputs: common.JsonDict) -> common.JsonDict:
- start_time = time.time()
instance = self._json_to_instance(inputs)
tree = self.predict_instance_as_tree(instance)
tree_json = util.sanitize(tree)
result = collections.OrderedDict([
("tree", tree_json),
])
- end_time = time.time()
- logger.info(f"Took {(end_time - start_time) * 1000.0} ms")
return result
def predict_instance_as_tree(self, instance: allen_data.Instance) -> conllu.TokenList:
@@ -97,6 +127,10 @@ class SemanticMultitaskPredictor(predictor.Predictor):
metadata=collections.OrderedDict()
)
+ @staticmethod
+ def _to_input_json(sentence: str):
+ return {"sentence": sentence}
+
def _predictions_as_tree(self, predictions, instance):
tree = instance.fields["metadata"]["input"]
field_names = instance.fields["metadata"]["field_names"]
@@ -165,12 +199,14 @@ class SemanticMultitaskPredictor(predictor.Predictor):
model_path = path
else:
try:
+ logger.debug("Downloading model.")
model_path = download.download_file(path)
except Exception as e:
logger.error(e)
raise e
- model = models.Model.from_archive(model_path)
+ archive = models.load_archive(model_path)
+ model = archive.model
dataset_reader = allen_data.DatasetReader.from_params(
- models.load_archive(model_path).config["dataset_reader"])
+ archive.config["dataset_reader"])
return cls(model, dataset_reader, tokenizer)
diff --git a/tests/fixtures/example.conllu b/tests/fixtures/example.conllu
index b58f0f3..039312e 100644
--- a/tests/fixtures/example.conllu
+++ b/tests/fixtures/example.conllu
@@ -1,5 +1,5 @@
# sent_id = test-s1
# text = Easy sentence.
-1 Verylongwordwhichmustbetruncatedbythesystemto30 easy ADJ adj AdpType=Prep|Adp 1 amod _ _
+1 Verylongwordwhichmustbetruncatedbythesystemto30 easy ADJ adj AdpType=Prep|Adp 2 amod _ _
2 Sentence verylonglemmawhichmustbetruncatedbythesystemto30 NOUN nom Number=Sing 0 root _ _
3 . . PUNCT . _ 1 punct _ _
--
GitLab