From 05aab34446a2a86ecfdb1355bf1ea11551fb7840 Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Thu, 17 Aug 2023 16:40:29 +0200
Subject: [PATCH] Minor fixes

---
 combo/models/__init__.py              |  1 +
 combo/predict.py                      | 42 ++++++++++++---------------
 requirements.txt                      |  1 -
 tests/data/data_readers/test_conll.py |  6 ----
 4 files changed, 20 insertions(+), 30 deletions(-)

diff --git a/combo/models/__init__.py b/combo/models/__init__.py
index 66122cf..8cc3df5 100644
--- a/combo/models/__init__.py
+++ b/combo/models/__init__.py
@@ -8,3 +8,4 @@ from .lemma import LemmatizerModel
 from .combo_model import ComboModel
 from .morpho import MorphologicalFeatures
 from .model import Model
+from .archival import *
\ No newline at end of file
diff --git a/combo/predict.py b/combo/predict.py
index aeef5cb..aa547e4 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -230,26 +230,22 @@ class COMBO(Predictor):
                              dataset_reader: DatasetReader):
         return cls(model, dataset_reader, tokenizers.SpacyTokenizer())
 
-    # @classmethod
-    # def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(),
-    #                     batch_size: int = 1024,
-    #                     cuda_device: int = -1):
-    #     util.import_module_and_submodules("combo.commands")
-    #     util.import_module_and_submodules("combo.models")
-    #     util.import_module_and_submodules("combo.training")
-    #
-    #     if os.path.exists(path):
-    #         model_path = path
-    #     else:
-    #         try:
-    #             logger.debug("Downloading model.")
-    #             model_path = download.download_file(path)
-    #         except Exception as e:
-    #             logger.error(e)
-    #             raise e
-    #
-    #     archive = models.load_archive(model_path, cuda_device=cuda_device)
-    #     model = archive.model
-    #     dataset_reader = DatasetReader.from_params(
-    #         archive.config["dataset_reader"])
-    #     return cls(model, dataset_reader, tokenizer, batch_size)
+    @classmethod
+    def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(),
+                        batch_size: int = 1024,
+                        cuda_device: int = -1):
+        if os.path.exists(path):
+            model_path = path
+        else:
+            try:
+                logger.debug("Downloading model.")
+                model_path = download.download_file(path)
+            except Exception as e:
+                logger.error(e)
+                raise e
+
+        archive = models.load_archive(model_path, cuda_device=cuda_device)
+        model = archive.model
+        dataset_reader = DatasetReader.from_params(
+            archive.config["dataset_reader"])
+        return cls(model, dataset_reader, tokenizer, batch_size)
diff --git a/requirements.txt b/requirements.txt
index 1a928d8..f7222b4 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -9,7 +9,6 @@ importlib-resources~=5.12.0
 overrides~=7.3.1
 torch~=2.0.0
 torchtext~=0.15.1
-lambo~=2.0.0
 numpy~=1.24.1
 pytorch-lightning~=2.0.01
 requests~=2.28.2
diff --git a/tests/data/data_readers/test_conll.py b/tests/data/data_readers/test_conll.py
index 134d0a9..f516a24 100644
--- a/tests/data/data_readers/test_conll.py
+++ b/tests/data/data_readers/test_conll.py
@@ -1,7 +1,6 @@
 import unittest
 
 from combo.data import ConllDatasetReader
-from torch.utils.data import DataLoader
 
 
 class ConllDatasetReaderTest(unittest.TestCase):
@@ -10,11 +9,6 @@ class ConllDatasetReaderTest(unittest.TestCase):
         tokens = [token for token in reader('conll_test_file.txt')]
         self.assertEqual(len(tokens), 6)
 
-    def test_read_all_tokens_data_loader(self):
-        reader = ConllDatasetReader(coding_scheme='IOB2')
-        loader = DataLoader(reader('conll_test_file.txt'), batch_size=16)
-        print(next(iter(loader)))
-
     def test_tokenize_correct_tokens(self):
         reader = ConllDatasetReader(coding_scheme='IOB2')
         token = next(iter(reader('conll_test_file.txt')))
-- 
GitLab