Skip to content
Snippets Groups Projects
Commit 8ed20add authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Minor fixes

parent 9c99e508
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
...@@ -18,9 +18,9 @@ from combo.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder ...@@ -18,9 +18,9 @@ from combo.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
from combo.nn import RegularizerApplicator, base from combo.nn import RegularizerApplicator, base
from combo.nn.utils import get_text_field_mask from combo.nn.utils import get_text_field_mask
from combo.utils import metrics from combo.utils import metrics
from data import Instance from combo.data import Instance
from data.batch import Batch from combo.data.batch import Batch
from data.dataset_loaders.dataset_loader import TensorDict from combo.data.dataset_loaders.dataset_loader import TensorDict
from combo.nn import utils from combo.nn import utils
......
...@@ -7,11 +7,11 @@ from typing import List, Dict, Any ...@@ -7,11 +7,11 @@ from typing import List, Dict, Any
import spacy import spacy
from combo.config import Registry from combo.config import Registry
from combo.config.from_parameters import register_arguments from combo.config.from_parameters import register_arguments, FromParameters
from combo.utils.spacy import get_spacy_model from combo.utils.spacy import get_spacy_model
class SentenceSplitter: class SentenceSplitter(FromParameters):
""" """
A `SentenceSplitter` splits strings into sentences. A `SentenceSplitter` splits strings into sentences.
""" """
......
...@@ -628,7 +628,7 @@ class Vocabulary(FromParameters): ...@@ -628,7 +628,7 @@ class Vocabulary(FromParameters):
return self._vocab[namespace].get_stoi()[token] return self._vocab[namespace].get_stoi()[token]
except KeyError: except KeyError:
try: try:
return self._vocab[namespace].get_stoi()[token][self._oov_token] return self._vocab[namespace].get_stoi()[self._oov_token]
except KeyError: except KeyError:
raise KeyError("Namespace %s doesn't contain token %s or default OOV token %s" % raise KeyError("Namespace %s doesn't contain token %s or default OOV token %s" %
(namespace, repr(token), repr(self._oov_token))) (namespace, repr(token), repr(self._oov_token)))
......
...@@ -10,7 +10,7 @@ from combo.data.token_indexers import TokenConstPaddingCharactersIndexer, \ ...@@ -10,7 +10,7 @@ from combo.data.token_indexers import TokenConstPaddingCharactersIndexer, \
TokenFeatsIndexer, SingleIdTokenIndexer, PretrainedTransformerFixedMismatchedIndexer TokenFeatsIndexer, SingleIdTokenIndexer, PretrainedTransformerFixedMismatchedIndexer
from combo.data.tokenizers import CharacterTokenizer from combo.data.tokenizers import CharacterTokenizer
from combo.data.vocabulary import Vocabulary from combo.data.vocabulary import Vocabulary
from combo_model import ComboModel from combo.combo_model import ComboModel
from combo.models.encoder import ComboEncoder, ComboStackedBidirectionalLSTM from combo.models.encoder import ComboEncoder, ComboStackedBidirectionalLSTM
from combo.modules.dilated_cnn import DilatedCnnEncoder from combo.modules.dilated_cnn import DilatedCnnEncoder
from combo.modules.lemma import LemmatizerModel from combo.modules.lemma import LemmatizerModel
...@@ -22,7 +22,7 @@ from combo.nn.activations import ReLUActivation, TanhActivation, LinearActivatio ...@@ -22,7 +22,7 @@ from combo.nn.activations import ReLUActivation, TanhActivation, LinearActivatio
from combo.modules import FeedForwardPredictor from combo.modules import FeedForwardPredictor
from combo.nn.base import Linear from combo.nn.base import Linear
from combo.nn.regularizers.regularizers import L2Regularizer from combo.nn.regularizers.regularizers import L2Regularizer
from nn import RegularizerApplicator from combo.nn import RegularizerApplicator
def default_character_indexer(namespace=None, def default_character_indexer(namespace=None,
......
...@@ -13,7 +13,7 @@ from combo.nn.base import Predictor ...@@ -13,7 +13,7 @@ from combo.nn.base import Predictor
from combo.nn.activations import Activation from combo.nn.activations import Activation
from combo.nn.utils import masked_cross_entropy from combo.nn.utils import masked_cross_entropy
from combo.utils import ConfigurationError from combo.utils import ConfigurationError
from models.base import TimeDistributed from combo.models.base import TimeDistributed
@Registry.register('combo_lemma_predictor_from_vocab') @Registry.register('combo_lemma_predictor_from_vocab')
......
import logging import logging
import os import os
import sys import sys
from typing import List, Union, Dict, Any, Optional, Type from typing import List, Union, Dict, Any
import numpy as np import numpy as np
import torch import torch
from overrides import overrides from overrides import overrides
from combo import data, models, common from combo import data, models
from combo.common import util from combo.common import util
from combo.config import Registry from combo.config import Registry
from combo.config.from_parameters import register_arguments, resolve from combo.config.from_parameters import register_arguments
from combo.data import tokenizers, Instance, conllu2sentence, tokens2conllu, sentence2conllu from combo.data import tokenizers, Instance, conllu2sentence, tokens2conllu, sentence2conllu
from combo.data.dataset_loaders.dataset_loader import TensorDict from combo.data.dataset_loaders.dataset_loader import TensorDict
from combo.data.dataset_readers.dataset_reader import DatasetReader from combo.data.dataset_readers.dataset_reader import DatasetReader
from combo.data.instance import JsonDict from combo.data.instance import JsonDict
from combo.default_model import default_ud_dataset_reader
from combo.predictors import PredictorModule from combo.predictors import PredictorModule
from combo.utils import download, graph from combo.utils import download, graph
from modules.model import Model from modules.model import Model
from combo.default_model import default_ud_dataset_reader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
import os
import unittest import unittest
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from combo.default_model import default_model, default_ud_dataset_reader, default_data_loader, default_vocabulary
from data.vocabulary import Vocabulary from combo.combo_model import ComboModel
from combo_model import ComboModel from combo.data.vocabulary import Vocabulary
from modules import archive from combo.default_model import default_model
import os from combo.modules import archive
TEMP_FILE_PATH = 'temp_serialization_dir' TEMP_FILE_PATH = 'temp_serialization_dir'
def _test_vocabulary() -> Vocabulary: def _test_vocabulary() -> Vocabulary:
return Vocabulary.from_files(os.path.normpath(os.path.join(os.getcwd(), '../fixtures/train_vocabulary')), oov_token='_', padding_token='__PAD__') return Vocabulary.from_files(os.path.normpath(os.path.join(os.getcwd(), '../fixtures/train_vocabulary')),
oov_token='_', padding_token='__PAD__')
class ArchivalTest(unittest.TestCase): class ArchivalTest(unittest.TestCase):
......
...@@ -9,20 +9,20 @@ class VocabularyTest(unittest.TestCase): ...@@ -9,20 +9,20 @@ class VocabularyTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.vocabulary = Vocabulary.from_files( self.vocabulary = Vocabulary.from_files(
os.path.join(os.getcwd(), '../../fixtures/large_vocab'), os.path.join(os.getcwd(), '../../fixtures/train_vocabulary'),
oov_token='_', oov_token='_',
padding_token='__PAD__' padding_token='__PAD__'
) )
self.single_id_indexer = SingleIdTokenIndexer(namespace='tokens') self.single_id_indexer = SingleIdTokenIndexer(namespace='token_characters')
def test_get_index_to_token(self): def test_get_index_to_token(self):
token = Token(idx=0, text='w') token = Token(idx=0, text='w')
counter = {'tokens': {'w': 0}} counter = {'token_characters': {'w': 0}}
self.single_id_indexer.count_vocab_items(token, counter) self.single_id_indexer.count_vocab_items(token, counter)
self.assertEqual(counter['tokens']['w'], 1) self.assertEqual(counter['token_characters']['w'], 1)
def test_tokens_to_indices(self): def test_tokens_to_indices(self):
self.assertEqual( self.assertEqual(
self.single_id_indexer.tokens_to_indices( self.single_id_indexer.tokens_to_indices(
[Token('w'), Token('nawet')], self.vocabulary), [Token('w'), Token('nawet')], self.vocabulary),
{'tokens': [4, 87]}) {'tokens': [11, 127]})
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment