diff --git a/combo/combo_model.py b/combo/combo_model.py index d6ad22ef455558476e36f4f745fdcafb0a1b2ed9..5372992e106e5e16df5b2d0087fdd497ed2c3453 100644 --- a/combo/combo_model.py +++ b/combo/combo_model.py @@ -18,9 +18,9 @@ from combo.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder from combo.nn import RegularizerApplicator, base from combo.nn.utils import get_text_field_mask from combo.utils import metrics -from data import Instance -from data.batch import Batch -from data.dataset_loaders.dataset_loader import TensorDict +from combo.data import Instance +from combo.data.batch import Batch +from combo.data.dataset_loaders.dataset_loader import TensorDict from combo.nn import utils diff --git a/combo/data/tokenizers/sentence_splitter.py b/combo/data/tokenizers/sentence_splitter.py index 0313140f242ce7db88e60e1ce97349b9142f8a1c..c0383c994e992e678b905d0fb7aa034ba8cee9d3 100644 --- a/combo/data/tokenizers/sentence_splitter.py +++ b/combo/data/tokenizers/sentence_splitter.py @@ -7,11 +7,11 @@ from typing import List, Dict, Any import spacy from combo.config import Registry -from combo.config.from_parameters import register_arguments +from combo.config.from_parameters import register_arguments, FromParameters from combo.utils.spacy import get_spacy_model -class SentenceSplitter: +class SentenceSplitter(FromParameters): """ A `SentenceSplitter` splits strings into sentences. """ diff --git a/combo/data/vocabulary.py b/combo/data/vocabulary.py index 73cb5411def457728b939e3c9b99331bd14a709d..9f711aacc4c2fd56d3267e2b8e7f242c6f9b909b 100644 --- a/combo/data/vocabulary.py +++ b/combo/data/vocabulary.py @@ -628,7 +628,7 @@ class Vocabulary(FromParameters): return self._vocab[namespace].get_stoi()[token] except KeyError: try: - return self._vocab[namespace].get_stoi()[token][self._oov_token] + return self._vocab[namespace].get_stoi()[self._oov_token] except KeyError: raise KeyError("Namespace %s doesn't contain token %s or default OOV token %s" % (namespace, repr(token), repr(self._oov_token))) diff --git a/combo/default_model.py b/combo/default_model.py index febd4fbcf3a73915325ea9ef61a35f63c61f3a1e..927966f175f6e9df47504a479a2d12b95f6239ef 100644 --- a/combo/default_model.py +++ b/combo/default_model.py @@ -10,7 +10,7 @@ from combo.data.token_indexers import TokenConstPaddingCharactersIndexer, \ TokenFeatsIndexer, SingleIdTokenIndexer, PretrainedTransformerFixedMismatchedIndexer from combo.data.tokenizers import CharacterTokenizer from combo.data.vocabulary import Vocabulary -from combo_model import ComboModel +from combo.combo_model import ComboModel from combo.models.encoder import ComboEncoder, ComboStackedBidirectionalLSTM from combo.modules.dilated_cnn import DilatedCnnEncoder from combo.modules.lemma import LemmatizerModel @@ -22,7 +22,7 @@ from combo.nn.activations import ReLUActivation, TanhActivation, LinearActivatio from combo.modules import FeedForwardPredictor from combo.nn.base import Linear from combo.nn.regularizers.regularizers import L2Regularizer -from nn import RegularizerApplicator +from combo.nn import RegularizerApplicator def default_character_indexer(namespace=None, diff --git a/combo/modules/lemma.py b/combo/modules/lemma.py index 3f7fd24bd2c1b4a2097ed121f2e995e4ebd51871..c4c0495da97fe18f369b108312eb8ebc454293d4 100644 --- a/combo/modules/lemma.py +++ b/combo/modules/lemma.py @@ -13,7 +13,7 @@ from combo.nn.base import Predictor from combo.nn.activations import Activation from combo.nn.utils import masked_cross_entropy from combo.utils import ConfigurationError -from models.base import TimeDistributed +from combo.models.base import TimeDistributed @Registry.register('combo_lemma_predictor_from_vocab') diff --git a/combo/predict.py b/combo/predict.py index b1b52597dd5fac3880085fab80c4d8fadc9f4d1e..72e97f671ef2a999fee08a438722db954fc4bf0d 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -1,25 +1,24 @@ import logging import os import sys -from typing import List, Union, Dict, Any, Optional, Type +from typing import List, Union, Dict, Any import numpy as np import torch from overrides import overrides -from combo import data, models, common +from combo import data, models from combo.common import util from combo.config import Registry -from combo.config.from_parameters import register_arguments, resolve +from combo.config.from_parameters import register_arguments from combo.data import tokenizers, Instance, conllu2sentence, tokens2conllu, sentence2conllu from combo.data.dataset_loaders.dataset_loader import TensorDict from combo.data.dataset_readers.dataset_reader import DatasetReader - from combo.data.instance import JsonDict +from combo.default_model import default_ud_dataset_reader from combo.predictors import PredictorModule from combo.utils import download, graph from modules.model import Model -from combo.default_model import default_ud_dataset_reader logger = logging.getLogger(__name__) diff --git a/tests/config/test_archive.py b/tests/config/test_archive.py index b34bfd1fdb18759049d9d0cc38108359b56006f5..f8bae7a29c9687c6ff78a3dfff5d2638aac87ffc 100644 --- a/tests/config/test_archive.py +++ b/tests/config/test_archive.py @@ -1,16 +1,18 @@ +import os import unittest from tempfile import TemporaryDirectory -from combo.default_model import default_model, default_ud_dataset_reader, default_data_loader, default_vocabulary -from data.vocabulary import Vocabulary -from combo_model import ComboModel -from modules import archive -import os + +from combo.combo_model import ComboModel +from combo.data.vocabulary import Vocabulary +from combo.default_model import default_model +from combo.modules import archive TEMP_FILE_PATH = 'temp_serialization_dir' def _test_vocabulary() -> Vocabulary: - return Vocabulary.from_files(os.path.normpath(os.path.join(os.getcwd(), '../fixtures/train_vocabulary')), oov_token='_', padding_token='__PAD__') + return Vocabulary.from_files(os.path.normpath(os.path.join(os.getcwd(), '../fixtures/train_vocabulary')), + oov_token='_', padding_token='__PAD__') class ArchivalTest(unittest.TestCase): diff --git a/tests/data/token_indexers/test_single_id_token_indexer.py b/tests/data/token_indexers/test_single_id_token_indexer.py index 81c9867f9f5382347d7abb4dff866dbe91b99cb7..1684c6b85086562d3f9ead4d6e806518e927d466 100644 --- a/tests/data/token_indexers/test_single_id_token_indexer.py +++ b/tests/data/token_indexers/test_single_id_token_indexer.py @@ -9,20 +9,20 @@ class VocabularyTest(unittest.TestCase): def setUp(self): self.vocabulary = Vocabulary.from_files( - os.path.join(os.getcwd(), '../../fixtures/large_vocab'), + os.path.join(os.getcwd(), '../../fixtures/train_vocabulary'), oov_token='_', padding_token='__PAD__' ) - self.single_id_indexer = SingleIdTokenIndexer(namespace='tokens') + self.single_id_indexer = SingleIdTokenIndexer(namespace='token_characters') def test_get_index_to_token(self): token = Token(idx=0, text='w') - counter = {'tokens': {'w': 0}} + counter = {'token_characters': {'w': 0}} self.single_id_indexer.count_vocab_items(token, counter) - self.assertEqual(counter['tokens']['w'], 1) + self.assertEqual(counter['token_characters']['w'], 1) def test_tokens_to_indices(self): self.assertEqual( self.single_id_indexer.tokens_to_indices( [Token('w'), Token('nawet')], self.vocabulary), - {'tokens': [4, 87]}) + {'tokens': [11, 127]})