diff --git a/combo/data/api.py b/combo/data/api.py index 525f6d5f9e3e51d6b4ad35618dea7e8e8a0b3734..b0763b6b0871208db3e3268b9bf324e373a1e7d8 100644 --- a/combo/data/api.py +++ b/combo/data/api.py @@ -1,30 +1,65 @@ -from typing import Optional, List - +import collections from dataclasses import dataclass, field +from typing import Optional, List, Dict, Any, Union, Tuple + +import conllu +from dataclasses_json import dataclass_json +from overrides import overrides +@dataclass_json @dataclass class Token: + id: Optional[Union[int, Tuple]] = None token: Optional[str] = None - id: Optional[int] = None lemma: Optional[str] = None upostag: Optional[str] = None xpostag: Optional[str] = None + feats: Optional[str] = None head: Optional[int] = None deprel: Optional[str] = None - feats: Optional[str] = None - - @classmethod - def from_json(cls, json): - return cls(**json) + deps: Optional[str] = None + misc: Optional[str] = None +@dataclass_json @dataclass class Sentence: tokens: List[Token] = field(default_factory=list) - embedding: List[float] = field(default_factory=list) + sentence_embedding: List[float] = field(default_factory=list) + metadata: Dict[str, Any] = field(default_factory=collections.OrderedDict) + + +class _TokenList(conllu.TokenList): + + @overrides + def __repr__(self): + return 'TokenList<' + ', '.join(token['token'] for token in self) + '>' + + +def sentence2conllu(sentence: Sentence) -> conllu.TokenList: + tokens = [collections.OrderedDict(t.to_dict()) for t in sentence.tokens] + # Range tokens must be tuple not list, this is conllu library requirement + for t in tokens: + if type(t["id"]) == list: + t["id"] = tuple(t["id"]) + return _TokenList(tokens=tokens, + metadata=sentence.metadata) + + +def tokens2conllu(tokens: List[str]) -> conllu.TokenList: + return _TokenList( + [collections.OrderedDict({"id": idx, "token": token}) for + idx, token + in enumerate(tokens, start=1)], + metadata=collections.OrderedDict() + ) + - @classmethod - def from_json(cls, json): - return cls(tokens=[Token.from_json(t) for t in json["tree"]], - embedding=json.get("sentence_embedding", [])) +def conllu2sentence(conllu_sentence: conllu.TokenList, + sentence_embedding: List[float]) -> Sentence: + return Sentence( + tokens=[Token.from_dict(t) for t in conllu_sentence.tokens], + sentence_embedding=sentence_embedding, + metadata=conllu_sentence.metadata + ) diff --git a/combo/main.py b/combo/main.py index 918bb9407083e9babe3c53ee1671b412fba5fe74..4dc0056adbef82521c4058655bd32c9e14211aea 100644 --- a/combo/main.py +++ b/combo/main.py @@ -143,10 +143,13 @@ def run(_): file.writelines(predictor.predict_instance_as_tree(tree).serialize()) else: use_dataset_reader = ".conllu" in FLAGS.input_file.lower() + predictor = _get_predictor() + if use_dataset_reader: + predictor.line_to_conllu = True if FLAGS.silent: logging.getLogger("allennlp.common.params").disabled = True manager = allen_predict._PredictManager( - _get_predictor(), + predictor, FLAGS.input_file, FLAGS.output_file, FLAGS.batch_size, diff --git a/combo/predict.py b/combo/predict.py index 2e5693f75a635915d4e7ed4f7c063bcda21fa8e2..0ee80a91fcb8e47829428eaed3b1399909754119 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -1,7 +1,6 @@ -import collections import logging import os -from typing import List, Union +from typing import List, Union, Tuple import conllu from allennlp import data as allen_data, common, models @@ -11,6 +10,7 @@ from allennlp.predictors import predictor from overrides import overrides from combo import data +from combo.data import sentence2conllu, tokens2conllu, conllu2sentence from combo.utils import download logger = logging.getLogger(__name__) @@ -24,13 +24,15 @@ class SemanticMultitaskPredictor(predictor.Predictor): model: models.Model, dataset_reader: allen_data.DatasetReader, tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer(), - batch_size: int = 500) -> None: + batch_size: int = 500, + line_to_conllu: bool = False) -> None: super().__init__(model, dataset_reader) self.batch_size = batch_size self.vocab = model.vocab self._dataset_reader.generate_labels = False self._dataset_reader.lazy = True self._tokenizer = tokenizer + self.line_to_conllu = line_to_conllu def __call__(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]): """Depending on the input uses (or ignores) tokenizer. @@ -48,7 +50,7 @@ class SemanticMultitaskPredictor(predictor.Predictor): def predict(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]): if isinstance(sentence, str): - return data.Sentence.from_json(self.predict_json({"sentence": sentence})) + return data.Sentence.from_dict(self.predict_json({"sentence": sentence})) elif isinstance(sentence, list): if len(sentence) == 0: return [] @@ -56,15 +58,14 @@ class SemanticMultitaskPredictor(predictor.Predictor): if isinstance(example, str) or isinstance(example, list): sentences = [] for sentences_batch in util.lazy_groups_of(sentence, self.batch_size): - trees = self.predict_batch_json([self._to_input_json(s) for s in sentences_batch]) - sentences.extend([data.Sentence.from_json(t) for t in trees]) + sentences_batch = self.predict_batch_json([self._to_input_json(s) for s in sentences_batch]) + sentences.extend(sentences_batch) return sentences elif isinstance(example, data.Sentence): sentences = [] for sentences_batch in util.lazy_groups_of(sentence, self.batch_size): - trees = self.predict_batch_instance([self._to_input_instance(s) for s in sentences_batch], - serialize=False) - sentences.extend([data.Sentence.from_json(t) for t in trees]) + sentences_batch = self.predict_batch_instance([self._to_input_instance(s) for s in sentences_batch]) + sentences.extend(sentences_batch) return sentences else: raise ValueError("List must have either sentences as str, List[str] or Sentence object.") @@ -72,58 +73,38 @@ class SemanticMultitaskPredictor(predictor.Predictor): raise ValueError("Input must be either string or list of strings.") @overrides - def predict_batch_instance(self, instances: List[allen_data.Instance], serialize: bool = True - ) -> List[common.JsonDict]: - trees = [] + def predict_batch_instance(self, instances: List[allen_data.Instance]) -> List[data.Sentence]: + sentences = [] predictions = super().predict_batch_instance(instances) for prediction, instance in zip(predictions, instances): - tree = self._predictions_as_tree(prediction, instance) - if serialize: - tree = self._serialize(tree) - tree_json = util.sanitize(tree) - trees.append(collections.OrderedDict([ - ("tree", tree_json), - ])) - return trees + tree, sentence_embedding = self.predict_instance_as_tree(instance) + sentence = conllu2sentence(tree, sentence_embedding) + sentences.append(sentence) + return sentences @overrides - def predict_instance(self, instance: allen_data.Instance, serialize: bool = True) -> common.JsonDict: - tree = self.predict_instance_as_tree(instance) - sentence_embedding = tree.metadata.get("sentence_embedding", []) - if serialize: - tree = self._serialize(tree) - tree_json = util.sanitize(tree) - result = collections.OrderedDict([ - ("tree", tree_json), - ("sentence_embedding", sentence_embedding) - ]) - return result - - @overrides - def predict_batch_json(self, inputs: List[common.JsonDict]) -> List[common.JsonDict]: - trees = [] + def predict_batch_json(self, inputs: List[common.JsonDict]) -> List[data.Sentence]: + sentences = [] instances = self._batch_json_to_instances(inputs) - predictions = self.predict_batch_instance(instances, serialize=False) + predictions = self.predict_batch_instance(instances) for prediction, instance in zip(predictions, instances): - tree = self._predictions_as_tree(prediction, instance) - tree_json = util.sanitize(tree) - trees.append(collections.OrderedDict([ - ("tree", tree_json), - ])) - return trees + tree, sentence_embedding = self.predict_instance_as_tree(instance) + sentence = conllu2sentence(tree, sentence_embedding) + sentences.append(sentence) + return sentences + + @overrides + def predict_instance(self, instance: allen_data.Instance, serialize: bool = True) -> data.Sentence: + tree, sentence_embedding = self.predict_instance_as_tree(instance) + return conllu2sentence(tree, sentence_embedding) @overrides - def predict_json(self, inputs: common.JsonDict) -> common.JsonDict: + def predict_json(self, inputs: common.JsonDict) -> data.Sentence: instance = self._json_to_instance(inputs) - tree = self.predict_instance_as_tree(instance) - tree_json = util.sanitize(tree) - result = collections.OrderedDict([ - ("tree", tree_json), - ]) - result["sentence_embedding"] = tree.metadata.get("sentence_embedding", []) - return result + tree, sentence_embedding = self.predict_instance_as_tree(instance) + return conllu2sentence(tree, sentence_embedding) - def predict_instance_as_tree(self, instance: allen_data.Instance) -> conllu.TokenList: + def predict_instance_as_tree(self, instance: allen_data.Instance) -> Tuple[conllu.TokenList, List[float]]: predictions = super().predict_instance(instance) return self._predictions_as_tree(predictions, instance) @@ -136,39 +117,27 @@ class SemanticMultitaskPredictor(predictor.Predictor): tokens = sentence else: raise ValueError("Input must be either string or list of strings.") - tree = self._sentence_to_tree(tokens) - return self._dataset_reader.text_to_instance(tree) + return self._dataset_reader.text_to_instance(tokens2conllu(tokens)) @overrides def load_line(self, line: str) -> common.JsonDict: return self._to_input_json(line.replace("\n", "").strip()) @overrides - def dump_line(self, outputs: common.JsonDict) -> str: + def dump_line(self, outputs: data.Sentence) -> str: # Check whether serialized (str) tree or token's list # Serialized tree has already separators between lines - if type(outputs["tree"]) == str: - return str(outputs["tree"]) + if self.line_to_conllu: + return sentence2conllu(outputs).serialize() else: - return str(outputs["tree"]) + "\n" - - @staticmethod - def _sentence_to_tree(sentence: List[str]): - d = collections.OrderedDict - return _TokenList( - [d({"id": idx, "token": token}) for - idx, token - in enumerate(sentence)], - metadata=collections.OrderedDict() - ) + return outputs.to_json() @staticmethod def _to_input_json(sentence: str): return {"sentence": sentence} def _to_input_instance(self, sentence: data.Sentence) -> allen_data.Instance: - tree = _TokenList([t.__dict__ for t in sentence.tokens]) - return self._dataset_reader.text_to_instance(tree) + return self._dataset_reader.text_to_instance(sentence2conllu(sentence)) def _predictions_as_tree(self, predictions, instance): tree = instance.fields["metadata"]["input"] @@ -219,15 +188,7 @@ class SemanticMultitaskPredictor(predictor.Predictor): else: raise NotImplementedError(f"Unknown field name {field_name}!") - if self._dataset_reader and "sent" in self._dataset_reader._targets: - tree.metadata["sentence_embedding"] = predictions["sentence_embedding"] - return tree - - @staticmethod - def _serialize(tree: conllu.TokenList): - if "sentence_embedding" in tree.metadata: - tree.metadata["sentence_embedding"] = str(tree.metadata["sentence_embedding"]) - return tree.serialize() + return tree, predictions["sentence_embedding"] @classmethod def with_spacy_tokenizer(cls, model: models.Model, @@ -257,10 +218,3 @@ class SemanticMultitaskPredictor(predictor.Predictor): dataset_reader = allen_data.DatasetReader.from_params( archive.config["dataset_reader"]) return cls(model, dataset_reader, tokenizer, batch_size) - - -class _TokenList(conllu.TokenList): - - @overrides - def __repr__(self): - return 'TokenList<' + ', '.join(token['token'] for token in self) + '>' diff --git a/setup.py b/setup.py index 03c533a39e4aab6fd5aadc1040642644c08f7995..d3aac1aba248e7bbb76c19e84ae89532c1aa9cf7 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,7 @@ REQUIREMENTS = [ 'absl-py==0.9.0', 'allennlp==1.0.0', 'conllu==2.3.2', + 'dataclasses-json==0.5.2', 'joblib==0.14.1', 'jsonnet==0.15.0', 'requests==2.23.0', diff --git a/tests/test_predict.py b/tests/test_predict.py index e536561f9cfe68e5afa3d339d491230a73c63429..42bc4936ad627971fe7a07f9c405fead8590633f 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -30,9 +30,10 @@ class PredictionTest(unittest.TestCase): raw_sentence_collection = ["Test."] tokenized_sentence_collection = [["Test", "."]] wrapped_tokenized_sentence = [data.Sentence(tokens=[ - data.Token(id=0, token="Test"), - data.Token(id=1, token=".") + data.Token(id=1, token="Test"), + data.Token(id=2, token=".") ])] + api_wrapped_tokenized_sentence = [data.conllu2sentence(data.tokens2conllu(["Test", "."]), [])] nlp = predict.SemanticMultitaskPredictor.from_pretrained(os.path.join(self.FIXTURES_ROOT, "model.tar.gz")) # when @@ -40,7 +41,8 @@ class PredictionTest(unittest.TestCase): nlp(raw_sentence), nlp(raw_sentence_collection)[0], nlp(tokenized_sentence_collection)[0], - nlp(wrapped_tokenized_sentence)[0] + nlp(wrapped_tokenized_sentence)[0], + nlp(api_wrapped_tokenized_sentence)[0] ] # then