Skip to content
Snippets Groups Projects
Commit 8ed20add authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Minor fixes

parent 9c99e508
1 merge request!46Merge COMBO 3.0 into master
......@@ -18,9 +18,9 @@ from combo.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
from combo.nn import RegularizerApplicator, base
from combo.nn.utils import get_text_field_mask
from combo.utils import metrics
from data import Instance
from data.batch import Batch
from data.dataset_loaders.dataset_loader import TensorDict
from combo.data import Instance
from combo.data.batch import Batch
from combo.data.dataset_loaders.dataset_loader import TensorDict
from combo.nn import utils
......
......@@ -7,11 +7,11 @@ from typing import List, Dict, Any
import spacy
from combo.config import Registry
from combo.config.from_parameters import register_arguments
from combo.config.from_parameters import register_arguments, FromParameters
from combo.utils.spacy import get_spacy_model
class SentenceSplitter:
class SentenceSplitter(FromParameters):
"""
A `SentenceSplitter` splits strings into sentences.
"""
......
......@@ -628,7 +628,7 @@ class Vocabulary(FromParameters):
return self._vocab[namespace].get_stoi()[token]
except KeyError:
try:
return self._vocab[namespace].get_stoi()[token][self._oov_token]
return self._vocab[namespace].get_stoi()[self._oov_token]
except KeyError:
raise KeyError("Namespace %s doesn't contain token %s or default OOV token %s" %
(namespace, repr(token), repr(self._oov_token)))
......
......@@ -10,7 +10,7 @@ from combo.data.token_indexers import TokenConstPaddingCharactersIndexer, \
TokenFeatsIndexer, SingleIdTokenIndexer, PretrainedTransformerFixedMismatchedIndexer
from combo.data.tokenizers import CharacterTokenizer
from combo.data.vocabulary import Vocabulary
from combo_model import ComboModel
from combo.combo_model import ComboModel
from combo.models.encoder import ComboEncoder, ComboStackedBidirectionalLSTM
from combo.modules.dilated_cnn import DilatedCnnEncoder
from combo.modules.lemma import LemmatizerModel
......@@ -22,7 +22,7 @@ from combo.nn.activations import ReLUActivation, TanhActivation, LinearActivatio
from combo.modules import FeedForwardPredictor
from combo.nn.base import Linear
from combo.nn.regularizers.regularizers import L2Regularizer
from nn import RegularizerApplicator
from combo.nn import RegularizerApplicator
def default_character_indexer(namespace=None,
......
......@@ -13,7 +13,7 @@ from combo.nn.base import Predictor
from combo.nn.activations import Activation
from combo.nn.utils import masked_cross_entropy
from combo.utils import ConfigurationError
from models.base import TimeDistributed
from combo.models.base import TimeDistributed
@Registry.register('combo_lemma_predictor_from_vocab')
......
import logging
import os
import sys
from typing import List, Union, Dict, Any, Optional, Type
from typing import List, Union, Dict, Any
import numpy as np
import torch
from overrides import overrides
from combo import data, models, common
from combo import data, models
from combo.common import util
from combo.config import Registry
from combo.config.from_parameters import register_arguments, resolve
from combo.config.from_parameters import register_arguments
from combo.data import tokenizers, Instance, conllu2sentence, tokens2conllu, sentence2conllu
from combo.data.dataset_loaders.dataset_loader import TensorDict
from combo.data.dataset_readers.dataset_reader import DatasetReader
from combo.data.instance import JsonDict
from combo.default_model import default_ud_dataset_reader
from combo.predictors import PredictorModule
from combo.utils import download, graph
from modules.model import Model
from combo.default_model import default_ud_dataset_reader
logger = logging.getLogger(__name__)
......
import os
import unittest
from tempfile import TemporaryDirectory
from combo.default_model import default_model, default_ud_dataset_reader, default_data_loader, default_vocabulary
from data.vocabulary import Vocabulary
from combo_model import ComboModel
from modules import archive
import os
from combo.combo_model import ComboModel
from combo.data.vocabulary import Vocabulary
from combo.default_model import default_model
from combo.modules import archive
TEMP_FILE_PATH = 'temp_serialization_dir'
def _test_vocabulary() -> Vocabulary:
return Vocabulary.from_files(os.path.normpath(os.path.join(os.getcwd(), '../fixtures/train_vocabulary')), oov_token='_', padding_token='__PAD__')
return Vocabulary.from_files(os.path.normpath(os.path.join(os.getcwd(), '../fixtures/train_vocabulary')),
oov_token='_', padding_token='__PAD__')
class ArchivalTest(unittest.TestCase):
......
......@@ -9,20 +9,20 @@ class VocabularyTest(unittest.TestCase):
def setUp(self):
self.vocabulary = Vocabulary.from_files(
os.path.join(os.getcwd(), '../../fixtures/large_vocab'),
os.path.join(os.getcwd(), '../../fixtures/train_vocabulary'),
oov_token='_',
padding_token='__PAD__'
)
self.single_id_indexer = SingleIdTokenIndexer(namespace='tokens')
self.single_id_indexer = SingleIdTokenIndexer(namespace='token_characters')
def test_get_index_to_token(self):
token = Token(idx=0, text='w')
counter = {'tokens': {'w': 0}}
counter = {'token_characters': {'w': 0}}
self.single_id_indexer.count_vocab_items(token, counter)
self.assertEqual(counter['tokens']['w'], 1)
self.assertEqual(counter['token_characters']['w'], 1)
def test_tokens_to_indices(self):
self.assertEqual(
self.single_id_indexer.tokens_to_indices(
[Token('w'), Token('nawet')], self.vocabulary),
{'tokens': [4, 87]})
{'tokens': [11, 127]})
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment