diff --git a/README.md b/README.md index 19847b8e5df935feb36bb9e5ec1e89cc3f1d35ed..c339bda6407f5e60609a75974125cd2276d5e794 100644 --- a/README.md +++ b/README.md @@ -18,9 +18,9 @@ python setup.py develop ``` Run the following lines in your Python console to make predictions with a pre-trained model: ```python -import combo.predict as predict +from combo.predict import COMBO -nlp = predict.SemanticMultitaskPredictor.from_pretrained("polish-herbert-base") +nlp = COMBO.from_pretrained("polish-herbert-base") sentence = nlp("Moje zdanie.") print(sentence.tokens) ``` diff --git a/combo/data/dataset.py b/combo/data/dataset.py index bb56ac33478fb36b5fc7736f79c6b6d68b4a2f59..48b68b14e592dc6f98e3f62e8b5c3bd23899cb4c 100644 --- a/combo/data/dataset.py +++ b/combo/data/dataset.py @@ -119,8 +119,8 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): label_namespace=target_name + "_labels") elif target_name == "deps": # Graphs require adding ROOT (AdjacencyField uses sequence length from TextField). - text_field_deps = copy.deepcopy(text_field) - text_field_deps.tokens.insert(0, _Token("ROOT")) + text_field_deps = allen_fields.TextField([_Token("ROOT")] + copy.deepcopy(tokens), + self._token_indexers) enhanced_heads: List[Tuple[int, int]] = [] enhanced_deprels: List[str] = [] for idx, t in enumerate(tree_tokens): diff --git a/combo/main.py b/combo/main.py index 374af69bdbc4584d3425c2814ce4873e74233004..17c960ac7caa84513692841abb989955c7925721 100644 --- a/combo/main.py +++ b/combo/main.py @@ -136,7 +136,7 @@ def run(_): params = common.Params.from_file(FLAGS.config_path, ext_vars=_get_ext_vars())["dataset_reader"] params.pop("type") dataset_reader = dataset.UniversalDependenciesDatasetReader.from_params(params) - predictor = predict.SemanticMultitaskPredictor( + predictor = predict.COMBO( model=model, dataset_reader=dataset_reader ) diff --git a/combo/models/model.py b/combo/models/model.py index 710f72cf8dae932fc8f2b5c92abbaf7639a52ec2..ad0df0e8be53d6470581fdda9f73279ea25327ed 100644 --- a/combo/models/model.py +++ b/combo/models/model.py @@ -126,10 +126,12 @@ class SemanticMultitaskModel(allen_models.Model): "deprel": relations_pred, "enhanced_head": enhanced_head_pred, "enhanced_deprel": enhanced_relations_pred, - "enhanced_deprel_prob": enhanced_parser_output["rel_probability"], "sentence_embedding": torch.max(encoder_emb[:, 1:], dim=1)[0], } + if "rel_probability" in enhanced_parser_output: + output["enhanced_deprel_prob"] = enhanced_parser_output["rel_probability"] + if self._has_labels([upostag, xpostag, lemma, feats, head, deprel, semrel]): # Feats mapping diff --git a/combo/predict.py b/combo/predict.py index c58db250c792ddd1562ab2d7aa4908af52c54191..8d3e2f93106f9dbf89435bf056b0146a8892d929 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) @predictor.Predictor.register("semantic-multitask-predictor") @predictor.Predictor.register("semantic-multitask-predictor-spacy", constructor="with_spacy_tokenizer") -class SemanticMultitaskPredictor(predictor.Predictor): +class COMBO(predictor.Predictor): def __init__(self, model: models.Model, diff --git a/config.graph.template.jsonnet b/config.graph.template.jsonnet index 6975aba449c53b8de135b4cad579df5ad259034c..bc8c46580f17f22924a9f68628d64ce7f1060d55 100644 --- a/config.graph.template.jsonnet +++ b/config.graph.template.jsonnet @@ -204,6 +204,8 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't type: "transformers_word_embeddings", model_name: pretrained_transformer_name, projection_dim: projected_embedding_dim, + tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert") + then {use_fast: false} else {}, } else { type: "embeddings_projected", embedding_dim: embedding_dim, diff --git a/docs/models.md b/docs/models.md index 485f7614cc0b65c6413f88f0139a7dd7cd8a1711..d4346ff2de196d2ad295eed98cfd218aaf071636 100644 --- a/docs/models.md +++ b/docs/models.md @@ -5,9 +5,9 @@ Pre-trained models are available [here](http://mozart.ipipan.waw.pl/~mklimaszews ## Automatic download Python `from_pretrained` method will download the pre-trained model if the provided name (without the extension .tar.gz) matches one of the names in [here](http://mozart.ipipan.waw.pl/~mklimaszewski/models/). ```python -import combo.predict as predict +from combo.predict import COMBO -nlp = predict.SemanticMultitaskPredictor.from_pretrained("polish-herbert-base") +nlp = COMBO.from_pretrained("polish-herbert-base") ``` Otherwise it looks for a model in local env. diff --git a/docs/prediction.md b/docs/prediction.md index 89cc74c27e8de8e4fafb44c60aea8ed260b67a3d..6de5d0e1892389ba5cd18c25b88947db3f717074 100644 --- a/docs/prediction.md +++ b/docs/prediction.md @@ -32,9 +32,19 @@ Use either `--predictor_name semantic-multitask-predictor` or `--predictor_name ## Python ```python -import combo.predict as predict +from combo.predict import COMBO model_path = "your_model.tar.gz" -nlp = predict.SemanticMultitaskPredictor.from_pretrained(model_path) +nlp = COMBO.from_pretrained(model_path) sentence = nlp("Sentence to parse.") ``` + +Using your own tokenization: +```python +from combo.predict import COMBO + +model_path = "your_model.tar.gz" +nlp = COMBO.from_pretrained(model_path) +tokenized_sentence = ["Sentence", "to", "parse", "."] +sentence = nlp([tokenized_sentence]) +``` diff --git a/tests/test_predict.py b/tests/test_predict.py index 2a56bd9baff30a6d34d3ad6bb16ce4ccdea71792..332ced3cfa010723fa51b77ce1742b8d976e1025 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -22,7 +22,7 @@ class PredictionTest(unittest.TestCase): data.Token(id=2, token=".") ])] api_wrapped_tokenized_sentence = [data.conllu2sentence(data.tokens2conllu(["Test", "."]), [])] - nlp = predict.SemanticMultitaskPredictor.from_pretrained(os.path.join(self.FIXTURES_ROOT, "model.tar.gz")) + nlp = predict.COMBO.from_pretrained(os.path.join(self.FIXTURES_ROOT, "model.tar.gz")) # when results = [