Skip to content
Snippets Groups Projects
Commit 05aab344 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Minor fixes

parent b28ead01
1 merge request!46Merge COMBO 3.0 into master
......@@ -8,3 +8,4 @@ from .lemma import LemmatizerModel
from .combo_model import ComboModel
from .morpho import MorphologicalFeatures
from .model import Model
from .archival import *
\ No newline at end of file
......@@ -230,26 +230,22 @@ class COMBO(Predictor):
dataset_reader: DatasetReader):
return cls(model, dataset_reader, tokenizers.SpacyTokenizer())
# @classmethod
# def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(),
# batch_size: int = 1024,
# cuda_device: int = -1):
# util.import_module_and_submodules("combo.commands")
# util.import_module_and_submodules("combo.models")
# util.import_module_and_submodules("combo.training")
#
# if os.path.exists(path):
# model_path = path
# else:
# try:
# logger.debug("Downloading model.")
# model_path = download.download_file(path)
# except Exception as e:
# logger.error(e)
# raise e
#
# archive = models.load_archive(model_path, cuda_device=cuda_device)
# model = archive.model
# dataset_reader = DatasetReader.from_params(
# archive.config["dataset_reader"])
# return cls(model, dataset_reader, tokenizer, batch_size)
@classmethod
def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(),
batch_size: int = 1024,
cuda_device: int = -1):
if os.path.exists(path):
model_path = path
else:
try:
logger.debug("Downloading model.")
model_path = download.download_file(path)
except Exception as e:
logger.error(e)
raise e
archive = models.load_archive(model_path, cuda_device=cuda_device)
model = archive.model
dataset_reader = DatasetReader.from_params(
archive.config["dataset_reader"])
return cls(model, dataset_reader, tokenizer, batch_size)
......@@ -9,7 +9,6 @@ importlib-resources~=5.12.0
overrides~=7.3.1
torch~=2.0.0
torchtext~=0.15.1
lambo~=2.0.0
numpy~=1.24.1
pytorch-lightning~=2.0.01
requests~=2.28.2
......
import unittest
from combo.data import ConllDatasetReader
from torch.utils.data import DataLoader
class ConllDatasetReaderTest(unittest.TestCase):
......@@ -10,11 +9,6 @@ class ConllDatasetReaderTest(unittest.TestCase):
tokens = [token for token in reader('conll_test_file.txt')]
self.assertEqual(len(tokens), 6)
def test_read_all_tokens_data_loader(self):
reader = ConllDatasetReader(coding_scheme='IOB2')
loader = DataLoader(reader('conll_test_file.txt'), batch_size=16)
print(next(iter(loader)))
def test_tokenize_correct_tokens(self):
reader = ConllDatasetReader(coding_scheme='IOB2')
token = next(iter(reader('conll_test_file.txt')))
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment