From 8ed20add41c768019d660c4d74da978b47604ab6 Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Sat, 11 Nov 2023 14:50:59 +1100 Subject: [PATCH] Minor fixes --- combo/combo_model.py | 6 +++--- combo/data/tokenizers/sentence_splitter.py | 4 ++-- combo/data/vocabulary.py | 2 +- combo/default_model.py | 4 ++-- combo/modules/lemma.py | 2 +- combo/predict.py | 9 ++++----- tests/config/test_archive.py | 14 ++++++++------ .../token_indexers/test_single_id_token_indexer.py | 10 +++++----- 8 files changed, 26 insertions(+), 25 deletions(-) diff --git a/combo/combo_model.py b/combo/combo_model.py index d6ad22e..5372992 100644 --- a/combo/combo_model.py +++ b/combo/combo_model.py @@ -18,9 +18,9 @@ from combo.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder from combo.nn import RegularizerApplicator, base from combo.nn.utils import get_text_field_mask from combo.utils import metrics -from data import Instance -from data.batch import Batch -from data.dataset_loaders.dataset_loader import TensorDict +from combo.data import Instance +from combo.data.batch import Batch +from combo.data.dataset_loaders.dataset_loader import TensorDict from combo.nn import utils diff --git a/combo/data/tokenizers/sentence_splitter.py b/combo/data/tokenizers/sentence_splitter.py index 0313140..c0383c9 100644 --- a/combo/data/tokenizers/sentence_splitter.py +++ b/combo/data/tokenizers/sentence_splitter.py @@ -7,11 +7,11 @@ from typing import List, Dict, Any import spacy from combo.config import Registry -from combo.config.from_parameters import register_arguments +from combo.config.from_parameters import register_arguments, FromParameters from combo.utils.spacy import get_spacy_model -class SentenceSplitter: +class SentenceSplitter(FromParameters): """ A `SentenceSplitter` splits strings into sentences. """ diff --git a/combo/data/vocabulary.py b/combo/data/vocabulary.py index 73cb541..9f711aa 100644 --- a/combo/data/vocabulary.py +++ b/combo/data/vocabulary.py @@ -628,7 +628,7 @@ class Vocabulary(FromParameters): return self._vocab[namespace].get_stoi()[token] except KeyError: try: - return self._vocab[namespace].get_stoi()[token][self._oov_token] + return self._vocab[namespace].get_stoi()[self._oov_token] except KeyError: raise KeyError("Namespace %s doesn't contain token %s or default OOV token %s" % (namespace, repr(token), repr(self._oov_token))) diff --git a/combo/default_model.py b/combo/default_model.py index febd4fb..927966f 100644 --- a/combo/default_model.py +++ b/combo/default_model.py @@ -10,7 +10,7 @@ from combo.data.token_indexers import TokenConstPaddingCharactersIndexer, \ TokenFeatsIndexer, SingleIdTokenIndexer, PretrainedTransformerFixedMismatchedIndexer from combo.data.tokenizers import CharacterTokenizer from combo.data.vocabulary import Vocabulary -from combo_model import ComboModel +from combo.combo_model import ComboModel from combo.models.encoder import ComboEncoder, ComboStackedBidirectionalLSTM from combo.modules.dilated_cnn import DilatedCnnEncoder from combo.modules.lemma import LemmatizerModel @@ -22,7 +22,7 @@ from combo.nn.activations import ReLUActivation, TanhActivation, LinearActivatio from combo.modules import FeedForwardPredictor from combo.nn.base import Linear from combo.nn.regularizers.regularizers import L2Regularizer -from nn import RegularizerApplicator +from combo.nn import RegularizerApplicator def default_character_indexer(namespace=None, diff --git a/combo/modules/lemma.py b/combo/modules/lemma.py index 3f7fd24..c4c0495 100644 --- a/combo/modules/lemma.py +++ b/combo/modules/lemma.py @@ -13,7 +13,7 @@ from combo.nn.base import Predictor from combo.nn.activations import Activation from combo.nn.utils import masked_cross_entropy from combo.utils import ConfigurationError -from models.base import TimeDistributed +from combo.models.base import TimeDistributed @Registry.register('combo_lemma_predictor_from_vocab') diff --git a/combo/predict.py b/combo/predict.py index b1b5259..72e97f6 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -1,25 +1,24 @@ import logging import os import sys -from typing import List, Union, Dict, Any, Optional, Type +from typing import List, Union, Dict, Any import numpy as np import torch from overrides import overrides -from combo import data, models, common +from combo import data, models from combo.common import util from combo.config import Registry -from combo.config.from_parameters import register_arguments, resolve +from combo.config.from_parameters import register_arguments from combo.data import tokenizers, Instance, conllu2sentence, tokens2conllu, sentence2conllu from combo.data.dataset_loaders.dataset_loader import TensorDict from combo.data.dataset_readers.dataset_reader import DatasetReader - from combo.data.instance import JsonDict +from combo.default_model import default_ud_dataset_reader from combo.predictors import PredictorModule from combo.utils import download, graph from modules.model import Model -from combo.default_model import default_ud_dataset_reader logger = logging.getLogger(__name__) diff --git a/tests/config/test_archive.py b/tests/config/test_archive.py index b34bfd1..f8bae7a 100644 --- a/tests/config/test_archive.py +++ b/tests/config/test_archive.py @@ -1,16 +1,18 @@ +import os import unittest from tempfile import TemporaryDirectory -from combo.default_model import default_model, default_ud_dataset_reader, default_data_loader, default_vocabulary -from data.vocabulary import Vocabulary -from combo_model import ComboModel -from modules import archive -import os + +from combo.combo_model import ComboModel +from combo.data.vocabulary import Vocabulary +from combo.default_model import default_model +from combo.modules import archive TEMP_FILE_PATH = 'temp_serialization_dir' def _test_vocabulary() -> Vocabulary: - return Vocabulary.from_files(os.path.normpath(os.path.join(os.getcwd(), '../fixtures/train_vocabulary')), oov_token='_', padding_token='__PAD__') + return Vocabulary.from_files(os.path.normpath(os.path.join(os.getcwd(), '../fixtures/train_vocabulary')), + oov_token='_', padding_token='__PAD__') class ArchivalTest(unittest.TestCase): diff --git a/tests/data/token_indexers/test_single_id_token_indexer.py b/tests/data/token_indexers/test_single_id_token_indexer.py index 81c9867..1684c6b 100644 --- a/tests/data/token_indexers/test_single_id_token_indexer.py +++ b/tests/data/token_indexers/test_single_id_token_indexer.py @@ -9,20 +9,20 @@ class VocabularyTest(unittest.TestCase): def setUp(self): self.vocabulary = Vocabulary.from_files( - os.path.join(os.getcwd(), '../../fixtures/large_vocab'), + os.path.join(os.getcwd(), '../../fixtures/train_vocabulary'), oov_token='_', padding_token='__PAD__' ) - self.single_id_indexer = SingleIdTokenIndexer(namespace='tokens') + self.single_id_indexer = SingleIdTokenIndexer(namespace='token_characters') def test_get_index_to_token(self): token = Token(idx=0, text='w') - counter = {'tokens': {'w': 0}} + counter = {'token_characters': {'w': 0}} self.single_id_indexer.count_vocab_items(token, counter) - self.assertEqual(counter['tokens']['w'], 1) + self.assertEqual(counter['token_characters']['w'], 1) def test_tokens_to_indices(self): self.assertEqual( self.single_id_indexer.tokens_to_indices( [Token('w'), Token('nawet')], self.vocabulary), - {'tokens': [4, 87]}) + {'tokens': [11, 127]}) -- GitLab