diff --git a/combo/models/model.py b/combo/models/model.py index 72fc6617c65e381e6bdbf11a72a2afd792cf580f..77b43e3c1a95e09b15c310409af0090f097d47fa 100644 --- a/combo/models/model.py +++ b/combo/models/model.py @@ -94,7 +94,7 @@ class SemanticMultitaskModel(allen_models.Model): sample_weights=sample_weights) lemma_output = self._optional(self.lemmatizer, (encoder_emb[:, 1:], sentence.get("char").get("token_characters") - if sentence.get("char") else None), + if sentence.get("char") else None), mask=word_mask[:, 1:], labels=lemma.get("char").get("token_characters") if lemma else None, sample_weights=sample_weights) diff --git a/combo/predict.py b/combo/predict.py index 261bd6c0801ba21489d54f67db6316ce795a3a93..3e81eb539ec5737e345b108619afe5528da058c4 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -1,13 +1,10 @@ import collections -import errno import logging import os import time -from typing import List +from typing import List, Union import conllu -import requests -import tqdm from allennlp import data as allen_data, common, models from allennlp.common import util from allennlp.data import tokenizers @@ -29,19 +26,26 @@ class SemanticMultitaskPredictor(predictor.Predictor): dataset_reader: allen_data.DatasetReader, tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer()) -> None: super().__init__(model, dataset_reader) + self.batch_size = 1000 self.vocab = model.vocab self._dataset_reader.generate_labels = False self._tokenizer = tokenizer @overrides def _json_to_instance(self, json_dict: common.JsonDict) -> allen_data.Instance: - tokens = self._tokenizer.tokenize(json_dict["sentence"]) - tree = self._sentence_to_tree([t.text for t in tokens]) + sentence = json_dict["sentence"] + if isinstance(sentence, str): + tokens = [t.text for t in self._tokenizer.tokenize(json_dict["sentence"])] + elif isinstance(sentence, list): + tokens = sentence + else: + raise ValueError("Input must be either string or list of strings.") + tree = self._sentence_to_tree(tokens) return self._dataset_reader.text_to_instance(tree) @overrides def load_line(self, line: str) -> common.JsonDict: - return {"sentence": line.replace("\n", " ").strip()} + return self._to_input_json(line.replace("\n", "").strip()) @overrides def dump_line(self, outputs: common.JsonDict) -> str: @@ -52,35 +56,61 @@ class SemanticMultitaskPredictor(predictor.Predictor): else: return str(outputs["tree"]) + "\n" + def predict(self, sentence: Union[str, List[str]]): + if isinstance(sentence, str): + return data.Sentence.from_json(self.predict_json({"sentence": sentence})) + elif isinstance(sentence, list): + sentences = [] + for sentences_batch in util.lazy_groups_of(sentence, self.batch_size): + trees = self.predict_batch_json([self._to_input_json(s) for s in sentences_batch]) + sentences.extend([data.Sentence.from_json(t) for t in trees]) + return sentences + else: + raise ValueError("Input must be either string or list of strings.") + + def __call__(self, sentence: Union[str, List[str]]): + return self.predict(sentence) + + @overrides + def predict_batch_instance(self, instances: List[allen_data.Instance]) -> List[common.JsonDict]: + trees = [] + predictions = super().predict_batch_instance(instances) + for prediction, instance in zip(predictions, instances): + tree_json = util.sanitize(self._predictions_as_tree(prediction, instance).serialize()) + trees.append(collections.OrderedDict([ + ("tree", tree_json), + ])) + return trees + @overrides def predict_instance(self, instance: allen_data.Instance) -> common.JsonDict: - start_time = time.time() tree = self.predict_instance_as_tree(instance) tree_json = util.sanitize(tree.serialize()) result = collections.OrderedDict([ ("tree", tree_json), ]) - end_time = time.time() - logger.info(f"Took {(end_time - start_time) * 1000.0} ms") return result - def predict(self, sentence: str): - return data.Sentence.from_json(self.predict_json({"sentence": sentence})) - - def __call__(self, sentence: str): - return self.predict(sentence) + @overrides + def predict_batch_json(self, inputs: List[common.JsonDict]) -> List[common.JsonDict]: + trees = [] + instances = self._batch_json_to_instances(inputs) + predictions = super().predict_batch_json(inputs) + for prediction, instance in zip(predictions, instances): + tree_json = util.sanitize(self._predictions_as_tree(prediction, instance)) + trees.append(collections.OrderedDict([ + ("tree", tree_json), + ])) + return trees @overrides def predict_json(self, inputs: common.JsonDict) -> common.JsonDict: - start_time = time.time() instance = self._json_to_instance(inputs) tree = self.predict_instance_as_tree(instance) tree_json = util.sanitize(tree) result = collections.OrderedDict([ ("tree", tree_json), ]) - end_time = time.time() - logger.info(f"Took {(end_time - start_time) * 1000.0} ms") return result def predict_instance_as_tree(self, instance: allen_data.Instance) -> conllu.TokenList: @@ -97,6 +127,10 @@ class SemanticMultitaskPredictor(predictor.Predictor): metadata=collections.OrderedDict() ) + @staticmethod + def _to_input_json(sentence: str): + return {"sentence": sentence} + def _predictions_as_tree(self, predictions, instance): tree = instance.fields["metadata"]["input"] field_names = instance.fields["metadata"]["field_names"] @@ -165,12 +199,14 @@ class SemanticMultitaskPredictor(predictor.Predictor): model_path = path else: try: + logger.debug("Downloading model.") model_path = download.download_file(path) except Exception as e: logger.error(e) raise e - model = models.Model.from_archive(model_path) + archive = models.load_archive(model_path) + model = archive.model dataset_reader = allen_data.DatasetReader.from_params( - models.load_archive(model_path).config["dataset_reader"]) + archive.config["dataset_reader"]) return cls(model, dataset_reader, tokenizer) diff --git a/tests/fixtures/example.conllu b/tests/fixtures/example.conllu index b58f0f33ebe554d7febcc8019544f49a59cd58cc..039312e5ce6036083fe4992c51312149bff10547 100644 --- a/tests/fixtures/example.conllu +++ b/tests/fixtures/example.conllu @@ -1,5 +1,5 @@ # sent_id = test-s1 # text = Easy sentence. -1 Verylongwordwhichmustbetruncatedbythesystemto30 easy ADJ adj AdpType=Prep|Adp 1 amod _ _ +1 Verylongwordwhichmustbetruncatedbythesystemto30 easy ADJ adj AdpType=Prep|Adp 2 amod _ _ 2 Sentence verylonglemmawhichmustbetruncatedbythesystemto30 NOUN nom Number=Sing 0 root _ _ 3 . . PUNCT . _ 1 punct _ _