diff --git a/combo/models/__init__.py b/combo/models/__init__.py index 66122cf32038cc71b431c46134df74cac80ac354..8cc3df5c3039bee577217dc8f78eb5c9f21bcfe7 100644 --- a/combo/models/__init__.py +++ b/combo/models/__init__.py @@ -8,3 +8,4 @@ from .lemma import LemmatizerModel from .combo_model import ComboModel from .morpho import MorphologicalFeatures from .model import Model +from .archival import * \ No newline at end of file diff --git a/combo/predict.py b/combo/predict.py index aeef5cb3cc5310894bf02f59c3d1a8c610461b28..aa547e428004ff97ff893a7c91c81d8a7e3f6e6e 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -230,26 +230,22 @@ class COMBO(Predictor): dataset_reader: DatasetReader): return cls(model, dataset_reader, tokenizers.SpacyTokenizer()) - # @classmethod - # def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), - # batch_size: int = 1024, - # cuda_device: int = -1): - # util.import_module_and_submodules("combo.commands") - # util.import_module_and_submodules("combo.models") - # util.import_module_and_submodules("combo.training") - # - # if os.path.exists(path): - # model_path = path - # else: - # try: - # logger.debug("Downloading model.") - # model_path = download.download_file(path) - # except Exception as e: - # logger.error(e) - # raise e - # - # archive = models.load_archive(model_path, cuda_device=cuda_device) - # model = archive.model - # dataset_reader = DatasetReader.from_params( - # archive.config["dataset_reader"]) - # return cls(model, dataset_reader, tokenizer, batch_size) + @classmethod + def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), + batch_size: int = 1024, + cuda_device: int = -1): + if os.path.exists(path): + model_path = path + else: + try: + logger.debug("Downloading model.") + model_path = download.download_file(path) + except Exception as e: + logger.error(e) + raise e + + archive = models.load_archive(model_path, cuda_device=cuda_device) + model = archive.model + dataset_reader = DatasetReader.from_params( + archive.config["dataset_reader"]) + return cls(model, dataset_reader, tokenizer, batch_size) diff --git a/requirements.txt b/requirements.txt index 1a928d8172ef71919fef3ff5d25cbf55db360741..f7222b4321a04fe45bdf1f8ad96336675d5451b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,6 @@ importlib-resources~=5.12.0 overrides~=7.3.1 torch~=2.0.0 torchtext~=0.15.1 -lambo~=2.0.0 numpy~=1.24.1 pytorch-lightning~=2.0.01 requests~=2.28.2 diff --git a/tests/data/data_readers/test_conll.py b/tests/data/data_readers/test_conll.py index 134d0a94236a8f1d8bb98eec34dd898f536a095f..f516a24abd554a5366f0417fe30a0a809a43f971 100644 --- a/tests/data/data_readers/test_conll.py +++ b/tests/data/data_readers/test_conll.py @@ -1,7 +1,6 @@ import unittest from combo.data import ConllDatasetReader -from torch.utils.data import DataLoader class ConllDatasetReaderTest(unittest.TestCase): @@ -10,11 +9,6 @@ class ConllDatasetReaderTest(unittest.TestCase): tokens = [token for token in reader('conll_test_file.txt')] self.assertEqual(len(tokens), 6) - def test_read_all_tokens_data_loader(self): - reader = ConllDatasetReader(coding_scheme='IOB2') - loader = DataLoader(reader('conll_test_file.txt'), batch_size=16) - print(next(iter(loader))) - def test_tokenize_correct_tokens(self): reader = ConllDatasetReader(coding_scheme='IOB2') token = next(iter(reader('conll_test_file.txt')))