diff --git a/README.md b/README.md index 9b88650adcda7a8d7635e5b058e4320286e42109..f25edd21049195dc60f6a66de5fad3d5e339744b 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ Works for models where input was text-based only. Interactive testing in console (load model and just type sentence in console). ```bash -combo --mode predict --model_path your_model_tar_gz --input_file "-" +combo --mode predict --model_path your_model_tar_gz --input_file "-" --nosilent ``` ### Raw text Works for models where input was text-based only. @@ -91,7 +91,7 @@ import combo.predict as predict model_path = "your_model.tar.gz" nlp = predict.SemanticMultitaskPredictor.from_pretrained(model_path) -parsed_tree = nlp("Sentence to parse.")["tree"] +sentence = nlp("Sentence to parse.") ``` ## Configuration diff --git a/combo/data/__init__.py b/combo/data/__init__.py index 7abffa2b3206b29c93e14f86ae76101a6a859cf3..f5973ab11e1a74eeca4f5239ba537538b1e200f1 100644 --- a/combo/data/__init__.py +++ b/combo/data/__init__.py @@ -1,2 +1,3 @@ from .samplers import TokenCountBatchSampler from .token_indexers import TokenCharactersIndexer +from .api import * diff --git a/combo/data/api.py b/combo/data/api.py new file mode 100644 index 0000000000000000000000000000000000000000..62313b38c0ff2bbfe14ea08dfd26d0dd8ffa45ef --- /dev/null +++ b/combo/data/api.py @@ -0,0 +1,31 @@ +from typing import Optional, List + +from dataclasses import dataclass, field + + +@dataclass +class Token: + token: Optional[str] = None + id: Optional[int] = None + lemma: Optional[str] = None + upostag: Optional[str] = None + xpostag: Optional[str] = None + head: Optional[int] = None + deprel: Optional[str] = None + feats: Optional[str] = None + + @staticmethod + def from_json(json): + return Token(**json) + + +@dataclass +class Sentence: + tokens: List[Token] = field(default_factory=list) + embedding: List[float] = field(default_factory=list) + + @staticmethod + def from_json(json): + return Sentence(tokens=[Token.from_json(t) for t in json['tree']], + embedding=json.get('sentence_embedding', [])) + diff --git a/combo/data/dataset.py b/combo/data/dataset.py index ab5ce2aecb974a37a62248304281dadf4da8c9f2..d381569cea273dff43dd4ddb37e533b43878b1bb 100644 --- a/combo/data/dataset.py +++ b/combo/data/dataset.py @@ -86,7 +86,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): @overrides def text_to_instance(self, tree: conllu.TokenList) -> allen_data.Instance: fields_: Dict[str, allen_data.Field] = {} - tokens = [Token(t['token'], + tokens = [_Token(t['token'], pos_=t.get('upostag'), tag_=t.get('xpostag'), lemma_=t.get('lemma'), @@ -233,5 +233,5 @@ def get_slices_if_not_provided(vocab: allen_data.Vocabulary): @dataclass -class Token(allen_data.Token): +class _Token(allen_data.Token): feats_: Optional[str] = None diff --git a/combo/predict.py b/combo/predict.py index b30da436572f93edde9548dd40f5acff946c0fd5..e346ab9f7e7ef94bdf52bf5a22999e9788c96727 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -10,6 +10,8 @@ from allennlp.data import tokenizers from allennlp.predictors import predictor from overrides import overrides +from combo import data + logger = logging.getLogger(__name__) @@ -58,7 +60,7 @@ class SemanticMultitaskPredictor(predictor.Predictor): return result def predict(self, sentence: str): - return self.predict_json({'sentence': sentence}) + return data.Sentence.from_json(self.predict_json({'sentence': sentence})) def __call__(self, sentence: str): return self.predict(sentence)