diff --git a/README.md b/README.md index 4ca38ce0e5701c105cc45e739b361195af5a93d6..aaf04f1bdc25e1ad3c138909a21b18ac62488cf8 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,28 @@ pip install -e . import combo ``` +Pretrained model usage: + +```python +from combo.predict import COMBO +c = COMBO.from_pretrained("model") +prediction = c("To jest przykładowe zdanie.") + +print("{:15} {:15} {:10} {:10} {:10}".format('TOKEN', 'LEMMA', 'UPOS', 'HEAD', 'DEPREL')) +for token in prediction.tokens: + print("{:15} {:15} {:10} {:10} {:10}".format(token.text, token.lemma, token.upostag, token.head, token.deprel)) +``` + +Example output: +``` +TOKEN LEMMA UPOS HEAD DEPREL +To to AUX 4 cop +jest być AUX 4 cop +przykładowe przykładowy ADJ 4 amod +zdanie zdanie NOUN 0 root +. . PUNCT 4 punct +``` + ## Use COMBO CLI The minimal training example (make sure to download some conllu training and validation files) diff --git a/combo/modules/model.py b/combo/modules/model.py index c87cc0ad71423d45e1b631d58c54d3df56ea235e..846545942b71d9af7c7e2ad3e678dcfe2e3bf0bb 100644 --- a/combo/modules/model.py +++ b/combo/modules/model.py @@ -356,7 +356,6 @@ class Model(Module, pl.LightningModule, FromParameters): model_params = config.get("model") model_params['parameters']['vocabulary'] = vocab_params - print(vocab_params) # The experiment config tells us how to _train_ a model, including where to get pre-trained # embeddings/weights from. We're now _loading_ the model, so those weights will already be diff --git a/combo/predict.py b/combo/predict.py index c97b26a9e78bb31e4fa6d16b1de358815e7344da..b1b52597dd5fac3880085fab80c4d8fadc9f4d1e 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -7,7 +7,7 @@ import numpy as np import torch from overrides import overrides -from combo import data +from combo import data, models, common from combo.common import util from combo.config import Registry from combo.config.from_parameters import register_arguments, resolve @@ -18,8 +18,8 @@ from combo.data.dataset_readers.dataset_reader import DatasetReader from combo.data.instance import JsonDict from combo.predictors import PredictorModule from combo.utils import download, graph -from combo.modules.archival import load_archive -from combo.modules.model import Model +from modules.model import Model +from combo.default_model import default_ud_dataset_reader logger = logging.getLogger(__name__) @@ -245,7 +245,7 @@ class COMBO(PredictorModule): return tree, predictions["sentence_embedding"], embeddings @classmethod - def with_spacy_tokenizer(cls, model: Model, + def with_spacy_tokenizer(cls, model: models.Model, dataset_reader: DatasetReader): return cls(model, dataset_reader, tokenizers.SpacyTokenizer()) @@ -263,9 +263,7 @@ class COMBO(PredictorModule): logger.error(e) raise e - archive = load_archive(model_path, cuda_device=cuda_device) + archive = models.load_archive(model_path, cuda_device=cuda_device) model = archive.model - dataset_reader = resolve( - archive.config["dataset_reader"] - ) + dataset_reader = default_ud_dataset_reader() return cls(model, dataset_reader, tokenizer, batch_size) diff --git a/combo/utils/download.py b/combo/utils/download.py index ff5ed9b5e8ef9424a2dbfb64c0ac90985e4a3d98..53fd7413dba8d91c86bb0fa5094c90015120c18b 100644 --- a/combo/utils/download.py +++ b/combo/utils/download.py @@ -9,15 +9,22 @@ from requests import adapters, exceptions logger = logging.getLogger(__name__) -_URL = "http://s3.clarin-pl.eu/dspace/combo/prototype/{model}.tar.gz" +DATA_TO_PATH = { + "enhanced" : "iwpt_2020", + "iwpt2021" : "iwpt_2021", + "ud25" : "ud_25", + "ud27" : "ud_27", + "ud29" : "ud_29"} +_PROTOTYPE_URL = "http://s3.clarin-pl.eu/dspace/combo/prototype/{model}.tar.gz" +_URL = "http://s3.clarin-pl.eu/dspace/combo/{data}/{model}.tar.gz" _HOME_DIR = os.getenv("HOME", os.curdir) _CACHE_DIR = os.getenv("COMBO_DIR", os.path.join(_HOME_DIR, ".combo")) def download_file(model_name, force=False): _make_cache_dir() - url = _URL.format(model=model_name) - print('URL', url) + data = model_name.split("-")[-1] + url = _URL.format(model=model_name, data=DATA_TO_PATH[data]) local_filename = url.split("/")[-1] location = os.path.join(_CACHE_DIR, local_filename) if os.path.exists(location) and not force: