Skip to content
Snippets Groups Projects
Commit 28eee417 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Add python api wrappers over sentence and tokens.

parent 3f7e1a00
No related branches found
No related tags found
No related merge requests found
......@@ -67,7 +67,7 @@ Works for models where input was text-based only.
Interactive testing in console (load model and just type sentence in console).
```bash
combo --mode predict --model_path your_model_tar_gz --input_file "-"
combo --mode predict --model_path your_model_tar_gz --input_file "-" --nosilent
```
### Raw text
Works for models where input was text-based only.
......@@ -91,7 +91,7 @@ import combo.predict as predict
model_path = "your_model.tar.gz"
nlp = predict.SemanticMultitaskPredictor.from_pretrained(model_path)
parsed_tree = nlp("Sentence to parse.")["tree"]
sentence = nlp("Sentence to parse.")
```
## Configuration
......
from .samplers import TokenCountBatchSampler
from .token_indexers import TokenCharactersIndexer
from .api import *
from typing import Optional, List
from dataclasses import dataclass, field
@dataclass
class Token:
token: Optional[str] = None
id: Optional[int] = None
lemma: Optional[str] = None
upostag: Optional[str] = None
xpostag: Optional[str] = None
head: Optional[int] = None
deprel: Optional[str] = None
feats: Optional[str] = None
@staticmethod
def from_json(json):
return Token(**json)
@dataclass
class Sentence:
tokens: List[Token] = field(default_factory=list)
embedding: List[float] = field(default_factory=list)
@staticmethod
def from_json(json):
return Sentence(tokens=[Token.from_json(t) for t in json['tree']],
embedding=json.get('sentence_embedding', []))
......@@ -86,7 +86,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
@overrides
def text_to_instance(self, tree: conllu.TokenList) -> allen_data.Instance:
fields_: Dict[str, allen_data.Field] = {}
tokens = [Token(t['token'],
tokens = [_Token(t['token'],
pos_=t.get('upostag'),
tag_=t.get('xpostag'),
lemma_=t.get('lemma'),
......@@ -233,5 +233,5 @@ def get_slices_if_not_provided(vocab: allen_data.Vocabulary):
@dataclass
class Token(allen_data.Token):
class _Token(allen_data.Token):
feats_: Optional[str] = None
......@@ -10,6 +10,8 @@ from allennlp.data import tokenizers
from allennlp.predictors import predictor
from overrides import overrides
from combo import data
logger = logging.getLogger(__name__)
......@@ -58,7 +60,7 @@ class SemanticMultitaskPredictor(predictor.Predictor):
return result
def predict(self, sentence: str):
return self.predict_json({'sentence': sentence})
return data.Sentence.from_json(self.predict_json({'sentence': sentence}))
def __call__(self, sentence: str):
return self.predict(sentence)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment