From 05aab34446a2a86ecfdb1355bf1ea11551fb7840 Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Thu, 17 Aug 2023 16:40:29 +0200 Subject: [PATCH] Minor fixes --- combo/models/__init__.py | 1 + combo/predict.py | 42 ++++++++++++--------------- requirements.txt | 1 - tests/data/data_readers/test_conll.py | 6 ---- 4 files changed, 20 insertions(+), 30 deletions(-) diff --git a/combo/models/__init__.py b/combo/models/__init__.py index 66122cf..8cc3df5 100644 --- a/combo/models/__init__.py +++ b/combo/models/__init__.py @@ -8,3 +8,4 @@ from .lemma import LemmatizerModel from .combo_model import ComboModel from .morpho import MorphologicalFeatures from .model import Model +from .archival import * \ No newline at end of file diff --git a/combo/predict.py b/combo/predict.py index aeef5cb..aa547e4 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -230,26 +230,22 @@ class COMBO(Predictor): dataset_reader: DatasetReader): return cls(model, dataset_reader, tokenizers.SpacyTokenizer()) - # @classmethod - # def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), - # batch_size: int = 1024, - # cuda_device: int = -1): - # util.import_module_and_submodules("combo.commands") - # util.import_module_and_submodules("combo.models") - # util.import_module_and_submodules("combo.training") - # - # if os.path.exists(path): - # model_path = path - # else: - # try: - # logger.debug("Downloading model.") - # model_path = download.download_file(path) - # except Exception as e: - # logger.error(e) - # raise e - # - # archive = models.load_archive(model_path, cuda_device=cuda_device) - # model = archive.model - # dataset_reader = DatasetReader.from_params( - # archive.config["dataset_reader"]) - # return cls(model, dataset_reader, tokenizer, batch_size) + @classmethod + def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), + batch_size: int = 1024, + cuda_device: int = -1): + if os.path.exists(path): + model_path = path + else: + try: + logger.debug("Downloading model.") + model_path = download.download_file(path) + except Exception as e: + logger.error(e) + raise e + + archive = models.load_archive(model_path, cuda_device=cuda_device) + model = archive.model + dataset_reader = DatasetReader.from_params( + archive.config["dataset_reader"]) + return cls(model, dataset_reader, tokenizer, batch_size) diff --git a/requirements.txt b/requirements.txt index 1a928d8..f7222b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,6 @@ importlib-resources~=5.12.0 overrides~=7.3.1 torch~=2.0.0 torchtext~=0.15.1 -lambo~=2.0.0 numpy~=1.24.1 pytorch-lightning~=2.0.01 requests~=2.28.2 diff --git a/tests/data/data_readers/test_conll.py b/tests/data/data_readers/test_conll.py index 134d0a9..f516a24 100644 --- a/tests/data/data_readers/test_conll.py +++ b/tests/data/data_readers/test_conll.py @@ -1,7 +1,6 @@ import unittest from combo.data import ConllDatasetReader -from torch.utils.data import DataLoader class ConllDatasetReaderTest(unittest.TestCase): @@ -10,11 +9,6 @@ class ConllDatasetReaderTest(unittest.TestCase): tokens = [token for token in reader('conll_test_file.txt')] self.assertEqual(len(tokens), 6) - def test_read_all_tokens_data_loader(self): - reader = ConllDatasetReader(coding_scheme='IOB2') - loader = DataLoader(reader('conll_test_file.txt'), batch_size=16) - print(next(iter(loader))) - def test_tokenize_correct_tokens(self): reader = ConllDatasetReader(coding_scheme='IOB2') token = next(iter(reader('conll_test_file.txt'))) -- GitLab