diff --git a/combo/predict.py b/combo/predict.py index 3e81eb539ec5737e345b108619afe5528da058c4..310f48fd5f8c78804344926b655df29b3a3d630b 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -1,7 +1,6 @@ import collections import logging import os -import time from typing import List, Union import conllu @@ -24,68 +23,74 @@ class SemanticMultitaskPredictor(predictor.Predictor): def __init__(self, model: models.Model, dataset_reader: allen_data.DatasetReader, - tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer()) -> None: + tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer(), + batch_size: int = 500) -> None: super().__init__(model, dataset_reader) - self.batch_size = 1000 + self.batch_size = batch_size self.vocab = model.vocab self._dataset_reader.generate_labels = False self._tokenizer = tokenizer - @overrides - def _json_to_instance(self, json_dict: common.JsonDict) -> allen_data.Instance: - sentence = json_dict["sentence"] - if isinstance(sentence, str): - tokens = [t.text for t in self._tokenizer.tokenize(json_dict["sentence"])] - elif isinstance(sentence, list): - tokens = sentence - else: - raise ValueError("Input must be either string or list of strings.") - tree = self._sentence_to_tree(tokens) - return self._dataset_reader.text_to_instance(tree) + def __call__(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]): + """Depending on the input uses (or ignores) tokenizer. + When model isn't only text-based only List[data.Sentence] is possible input. - @overrides - def load_line(self, line: str) -> common.JsonDict: - return self._to_input_json(line.replace("\n", "").strip()) + * str - tokenizer is used + * List[str] - tokenizer is used for each string (treated as list of raw sentences) + * List[List[str]] - tokenizer isn't used (treated as list of tokenized sentences) + * List[data.Sentence] - tokenizer isn't used (treated as list of tokenized sentences) - @overrides - def dump_line(self, outputs: common.JsonDict) -> str: - # Check whether serialized (str) tree or token's list - # Serialized tree has already separators between lines - if type(outputs["tree"]) == str: - return str(outputs["tree"]) - else: - return str(outputs["tree"]) + "\n" + :param sentence: sentence(s) representation + :return: Sentence or List[Sentence] depending on the input + """ + return self.predict(sentence) - def predict(self, sentence: Union[str, List[str]]): + def predict(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]): if isinstance(sentence, str): return data.Sentence.from_json(self.predict_json({"sentence": sentence})) elif isinstance(sentence, list): - sentences = [] - for sentences_batch in util.lazy_groups_of(sentence, self.batch_size): - trees = self.predict_batch_json([self._to_input_json(s) for s in sentences_batch]) - sentences.extend([data.Sentence.from_json(t) for t in trees]) - return sentences + if len(sentence) == 0: + return [] + example = sentence[0] + if isinstance(example, str) or isinstance(example, list): + sentences = [] + for sentences_batch in util.lazy_groups_of(sentence, self.batch_size): + trees = self.predict_batch_json([self._to_input_json(s) for s in sentences_batch]) + sentences.extend([data.Sentence.from_json(t) for t in trees]) + return sentences + elif isinstance(example, data.Sentence): + sentences = [] + for sentences_batch in util.lazy_groups_of(sentence, self.batch_size): + trees = self.predict_batch_instance([self._to_input_instance(s) for s in sentences_batch], + serialize=False) + sentences.extend([data.Sentence.from_json(t) for t in trees]) + return sentences + else: + raise ValueError("List must have either sentences as str, List[str] or Sentence object.") else: raise ValueError("Input must be either string or list of strings.") - def __call__(self, sentence: Union[str, List[str]]): - return self.predict(sentence) - @overrides - def predict_batch_instance(self, instances: List[allen_data.Instance]) -> List[common.JsonDict]: + def predict_batch_instance(self, instances: List[allen_data.Instance], serialize: bool = True) -> List[ + common.JsonDict]: trees = [] predictions = super().predict_batch_instance(instances) for prediction, instance in zip(predictions, instances): - tree_json = util.sanitize(self._predictions_as_tree(prediction, instance).serialize()) + tree = self._predictions_as_tree(prediction, instance) + if serialize: + tree = tree.serialize() + tree_json = util.sanitize(tree) trees.append(collections.OrderedDict([ ("tree", tree_json), ])) return trees @overrides - def predict_instance(self, instance: allen_data.Instance) -> common.JsonDict: + def predict_instance(self, instance: allen_data.Instance, serialize: bool = True) -> common.JsonDict: tree = self.predict_instance_as_tree(instance) - tree_json = util.sanitize(tree.serialize()) + if serialize: + tree = tree.serialize() + tree_json = util.sanitize(tree) result = collections.OrderedDict([ ("tree", tree_json), ]) @@ -95,9 +100,10 @@ class SemanticMultitaskPredictor(predictor.Predictor): def predict_batch_json(self, inputs: List[common.JsonDict]) -> List[common.JsonDict]: trees = [] instances = self._batch_json_to_instances(inputs) - predictions = super().predict_batch_json(inputs) + predictions = self.predict_batch_instance(instances, serialize=False) for prediction, instance in zip(predictions, instances): - tree_json = util.sanitize(self._predictions_as_tree(prediction, instance)) + tree = self._predictions_as_tree(prediction, instance) + tree_json = util.sanitize(tree) trees.append(collections.OrderedDict([ ("tree", tree_json), ])) @@ -117,10 +123,35 @@ class SemanticMultitaskPredictor(predictor.Predictor): predictions = super().predict_instance(instance) return self._predictions_as_tree(predictions, instance) + @overrides + def _json_to_instance(self, json_dict: common.JsonDict) -> allen_data.Instance: + sentence = json_dict["sentence"] + if isinstance(sentence, str): + tokens = [t.text for t in self._tokenizer.tokenize(json_dict["sentence"])] + elif isinstance(sentence, list): + tokens = sentence + else: + raise ValueError("Input must be either string or list of strings.") + tree = self._sentence_to_tree(tokens) + return self._dataset_reader.text_to_instance(tree) + + @overrides + def load_line(self, line: str) -> common.JsonDict: + return self._to_input_json(line.replace("\n", "").strip()) + + @overrides + def dump_line(self, outputs: common.JsonDict) -> str: + # Check whether serialized (str) tree or token's list + # Serialized tree has already separators between lines + if type(outputs["tree"]) == str: + return str(outputs["tree"]) + else: + return str(outputs["tree"]) + "\n" + @staticmethod def _sentence_to_tree(sentence: List[str]): d = collections.OrderedDict - return conllu.TokenList( + return _TokenList( [d({"id": idx, "token": token}) for idx, token in enumerate(sentence)], @@ -131,6 +162,10 @@ class SemanticMultitaskPredictor(predictor.Predictor): def _to_input_json(sentence: str): return {"sentence": sentence} + def _to_input_instance(self, sentence: data.Sentence) -> allen_data.Instance: + tree = _TokenList([t.__dict__ for t in sentence.tokens]) + return self._dataset_reader.text_to_instance(tree) + def _predictions_as_tree(self, predictions, instance): tree = instance.fields["metadata"]["input"] field_names = instance.fields["metadata"]["field_names"] @@ -190,7 +225,7 @@ class SemanticMultitaskPredictor(predictor.Predictor): return cls(model, dataset_reader, tokenizers.SpacyTokenizer()) @classmethod - def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer()): + def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), batch_size: int = 500): util.import_module_and_submodules("combo.commands") util.import_module_and_submodules("combo.models") util.import_module_and_submodules("combo.training") @@ -209,4 +244,11 @@ class SemanticMultitaskPredictor(predictor.Predictor): model = archive.model dataset_reader = allen_data.DatasetReader.from_params( archive.config["dataset_reader"]) - return cls(model, dataset_reader, tokenizer) + return cls(model, dataset_reader, tokenizer, batch_size) + + +class _TokenList(conllu.TokenList): + + @overrides + def __repr__(self): + return 'TokenList<' + ', '.join(token['token'] for token in self) + '>' diff --git a/tests/fixtures/example.conllu b/tests/fixtures/example.conllu index 039312e5ce6036083fe4992c51312149bff10547..1125392e17d71f9db09c5236e1fb27ac0968d410 100644 --- a/tests/fixtures/example.conllu +++ b/tests/fixtures/example.conllu @@ -3,3 +3,4 @@ 1 Verylongwordwhichmustbetruncatedbythesystemto30 easy ADJ adj AdpType=Prep|Adp 2 amod _ _ 2 Sentence verylonglemmawhichmustbetruncatedbythesystemto30 NOUN nom Number=Sing 0 root _ _ 3 . . PUNCT . _ 1 punct _ _ + diff --git a/tests/fixtures/model.tar.gz b/tests/fixtures/model.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..ca7cf263d1a2e7b85e4a4810a8d993712d53ea6c Binary files /dev/null and b/tests/fixtures/model.tar.gz differ diff --git a/tests/test_predict.py b/tests/test_predict.py new file mode 100644 index 0000000000000000000000000000000000000000..e536561f9cfe68e5afa3d339d491230a73c63429 --- /dev/null +++ b/tests/test_predict.py @@ -0,0 +1,47 @@ +import os +import pathlib +import shutil +import unittest +from unittest import mock + +import combo.data as data +import combo.predict as predict + + +class PredictionTest(unittest.TestCase): + PROJECT_ROOT = (pathlib.Path(__file__).parent / "..").resolve() + MODULE_ROOT = PROJECT_ROOT / "combo" + TESTS_ROOT = PROJECT_ROOT / "tests" + FIXTURES_ROOT = TESTS_ROOT / "fixtures" + + def setUp(self) -> None: + def _cleanup_archive_dir_without_logging(path: str): + if os.path.exists(path): + shutil.rmtree(path) + + self.patcher = mock.patch( + "allennlp.models.archival._cleanup_archive_dir", _cleanup_archive_dir_without_logging + ) + self.mock_cleanup_archive_dir = self.patcher.start() + + def test_prediction_are_equal_given_the_same_input_in_different_form(self): + # given + raw_sentence = "Test." + raw_sentence_collection = ["Test."] + tokenized_sentence_collection = [["Test", "."]] + wrapped_tokenized_sentence = [data.Sentence(tokens=[ + data.Token(id=0, token="Test"), + data.Token(id=1, token=".") + ])] + nlp = predict.SemanticMultitaskPredictor.from_pretrained(os.path.join(self.FIXTURES_ROOT, "model.tar.gz")) + + # when + results = [ + nlp(raw_sentence), + nlp(raw_sentence_collection)[0], + nlp(tokenized_sentence_collection)[0], + nlp(wrapped_tokenized_sentence)[0] + ] + + # then + self.assertTrue(all(x == results[0] for x in results))