diff --git a/combo/data/dataset_loaders/simple_data_loader.py b/combo/data/dataset_loaders/simple_data_loader.py index f2725895a393bb43fe20a5498e26b2e1c8b6e1e8..36690dad6809a1043d4b09f9cefd0bd8c49caea8 100644 --- a/combo/data/dataset_loaders/simple_data_loader.py +++ b/combo/data/dataset_loaders/simple_data_loader.py @@ -16,7 +16,7 @@ from combo.data.dataset_loaders import DefaultDataCollator from combo.data.dataset_readers import DatasetReader from combo.data.instance import Instance from combo.data.vocabulary import Vocabulary -import combo.nn.util as nn_util +import combo.nn.utils as nn_util from combo.data.dataset_loaders.dataset_loader import DataLoader, TensorDict diff --git a/combo/main.py b/combo/main.py index d12b5f471b55fa8fdef35041209ef7ce02ed5002..5b98f23e687fd261133f0f9bc74836a520add34f 100755 --- a/combo/main.py +++ b/combo/main.py @@ -5,7 +5,7 @@ from typing import Dict from absl import app from absl import flags -from combo.nn.base import Predictor +from combo.predictors import Predictor from combo.utils import checks logger = logging.getLogger(__name__) @@ -82,7 +82,6 @@ flags.DEFINE_enum(name="predictor_name", default="combo-spacy", def run(_): - print("COMBO") pass @@ -90,14 +89,6 @@ def _get_predictor() -> Predictor: # Check for GPU # allen_checks.check_for_gpu(FLAGS.cuda_device) checks.file_exists(FLAGS.model_path) - # load model from archive - # archive = models.load_archive( - # FLAGS.model_path, - # cuda_device=FLAGS.cuda_device, - # ) - # return predictors.Predictor.from_archive( - # archive, FLAGS.predictor_name - # ) return Predictor() diff --git a/combo/models/combo_model.py b/combo/models/combo_model.py index d2ada67d8fbe4a3933449828183871fa6c877a5d..249b24132813b17010a49c967fb281c3feadd22b 100644 --- a/combo/models/combo_model.py +++ b/combo/models/combo_model.py @@ -13,7 +13,7 @@ from combo.modules import TextFieldEmbedder from combo.modules.model import Model from combo.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder from combo.nn import RegularizerApplicator, base -from combo.nn.util import get_text_field_mask +from combo.nn.utils import get_text_field_mask from combo.utils import metrics diff --git a/combo/models/encoder.py b/combo/models/encoder.py index 12932adf579f2392b7c2ff09b2ae7482255e7ba1..b9cd26d8b4cf18a1598bfe148d464ff6cf71516d 100644 --- a/combo/models/encoder.py +++ b/combo/models/encoder.py @@ -6,14 +6,12 @@ and COMBO (Author: Mateusz Klimaszewski) from typing import Optional, Tuple, List import torch import torch.nn.utils.rnn as rnn -from overrides import overrides from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence from combo.config import FromParameters, Registry from combo.modules import input_variational_dropout from combo.modules.augmented_lstm import AugmentedLstm from combo.modules.input_variational_dropout import InputVariationalDropout -from combo.modules.module import Module from combo.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder from combo.utils import ConfigurationError diff --git a/combo/models/utils.py b/combo/models/utils.py deleted file mode 100644 index 8366a7b264064fd51473643eed05cb852285891a..0000000000000000000000000000000000000000 --- a/combo/models/utils.py +++ /dev/null @@ -1,27 +0,0 @@ -import torch -import torch.nn.functional as F - - -def masked_cross_entropy(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: - pred = pred + (mask.float().unsqueeze(-1) + 1e-45).log() - return F.cross_entropy(pred, true, reduction="none") * mask - - -""" -Adapted from AllenNLP -""" -def tiny_value_of_dtype(dtype: torch.dtype): - """ - Returns a moderately tiny value for a given PyTorch data type that is used to avoid numerical - issues such as division by zero. - This is different from `info_value_of_dtype(dtype).tiny` because it causes some NaN bugs. - Only supports floating point dtypes. - """ - if not dtype.is_floating_point: - raise TypeError("Only supports floating point dtypes.") - if dtype == torch.float or dtype == torch.double: - return 1e-13 - elif dtype == torch.half: - return 1e-4 - else: - raise TypeError("Does not support dtype " + str(dtype)) \ No newline at end of file diff --git a/combo/modules/augmented_lstm.py b/combo/modules/augmented_lstm.py index 2131c95fbd842cf1ffed382a3f02ba4224bab8f2..86e2fa6ebdabc3ecde092193089ea36b33da23b9 100644 --- a/combo/modules/augmented_lstm.py +++ b/combo/modules/augmented_lstm.py @@ -9,7 +9,7 @@ from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_ from combo.config import FromParameters, Registry from combo.modules.module import Module -from combo.nn.util import get_dropout_mask +from combo.nn.utils import get_dropout_mask from combo.nn.initializers import block_orthogonal from combo.utils import ConfigurationError diff --git a/combo/modules/encoder.py b/combo/modules/encoder.py index 4d053a8e95bd7f8a7898183a9ec3b214f2d2fb07..e213847cb15981cd2212f44c212d80be1c5c7ec0 100644 --- a/combo/modules/encoder.py +++ b/combo/modules/encoder.py @@ -8,7 +8,7 @@ import torch from torch.nn.utils.rnn import pack_padded_sequence, PackedSequence from combo.modules.module import Module -from combo.nn.util import get_lengths_from_binary_sequence_mask, sort_batch_by_length +from combo.nn.utils import get_lengths_from_binary_sequence_mask, sort_batch_by_length # We have two types here for the state, because storing the state in something # which is Iterable (like a tuple, below), is helpful for internal manipulation diff --git a/combo/modules/feedforward_predictor.py b/combo/modules/feedforward_predictor.py index 22c742d395d254d7232f877a677ba1f7db8cfff5..3c931d7d987f4fb1addf803edc5d110ac83520c5 100644 --- a/combo/modules/feedforward_predictor.py +++ b/combo/modules/feedforward_predictor.py @@ -1,6 +1,5 @@ from combo.data import Vocabulary -from combo.models.utils import masked_cross_entropy -from combo.nn import Activation +from combo.nn.utils import masked_cross_entropy from combo.nn.base import FeedForward from combo.predictors import Predictor import torch @@ -12,7 +11,7 @@ from combo.utils import ConfigurationError class FeedForwardPredictor(Predictor): """Feedforward predictor. Should be used on top of Seq2Seq encoder.""" - def __init__(self, feedforward_network: FeedForward): + def __init__(self, feedforward_network: "FeedForward"): super().__init__() self.feedforward_network = feedforward_network @@ -59,7 +58,7 @@ class FeedForwardPredictor(Predictor): input_dim: int, num_layers: int, hidden_dims: List[int], - activations: Union[Activation, List[Activation]], + activations: Union["Activation", List["Activation"]], dropout: Union[float, List[float]] = 0.0, ): if len(hidden_dims) + 1 != num_layers: diff --git a/combo/modules/model.py b/combo/modules/model.py index 25a637f1a7ed1544e0ed885b9fda1577d90674b3..ee520bb4d4292b02537a3ebf843b16c80fc3d1cd 100644 --- a/combo/modules/model.py +++ b/combo/modules/model.py @@ -18,7 +18,7 @@ from combo.data import Vocabulary, Instance from combo.data.batch import Batch from combo.data.dataset_loaders.dataset_loader import TensorDict from combo.modules.module import Module -from combo.nn import util, RegularizerApplicator +from combo.nn import utils, RegularizerApplicator from combo.utils import ConfigurationError logger = logging.getLogger(__name__) diff --git a/combo/modules/module.py b/combo/modules/module.py index 4ba1d1ed899df8921a4659b626a32476afedf21e..84311aca6de4b38be350b8ff6d53ae41c906ab9d 100644 --- a/combo/modules/module.py +++ b/combo/modules/module.py @@ -7,7 +7,7 @@ from typing import List, Optional, Tuple import pytorch_lightning as pl import torch -from combo.nn.util import ( +from combo.nn.utils import ( _check_incompatible_keys, _IncompatibleKeys, StateDictType, diff --git a/combo/modules/scalar_mix.py b/combo/modules/scalar_mix.py index b4e894e2b688deed4b68c47461f1dd9237913c9a..e31acf18ef1e176e2cd450f6df9849b3b564ee7c 100644 --- a/combo/modules/scalar_mix.py +++ b/combo/modules/scalar_mix.py @@ -8,7 +8,7 @@ from typing import List import torch from torch.nn import ParameterList, Parameter -from combo.nn import util +from combo.nn import utils from combo.utils import ConfigurationError diff --git a/combo/modules/token_embedders/embedding.py b/combo/modules/token_embedders/embedding.py index 7a628e1be9d3230f06e2d5b29c70c49e6cc3894f..84ee2c81212a192429232fb5f7b2adbfb27820d4 100644 --- a/combo/modules/token_embedders/embedding.py +++ b/combo/modules/token_embedders/embedding.py @@ -24,7 +24,7 @@ from cached_path import is_url_or_existing_file from combo.data.vocabulary import Vocabulary from combo.modules.time_distributed import TimeDistributed from combo.modules.token_embedders.token_embedder import TokenEmbedder -from combo.nn import util +from combo.nn import utils with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) diff --git a/combo/modules/token_embedders/pretrained_transformer_embedder.py b/combo/modules/token_embedders/pretrained_transformer_embedder.py index d32549096c019d7b8cf85690792b01bd42deb6ba..3867f2fa14f52a1d3e52a61c3f35a7a46cd76256 100644 --- a/combo/modules/token_embedders/pretrained_transformer_embedder.py +++ b/combo/modules/token_embedders/pretrained_transformer_embedder.py @@ -15,7 +15,7 @@ from combo.config import Registry from combo.data.tokenizers import PretrainedTransformerTokenizer from combo.modules.scalar_mix import ScalarMix from combo.modules.token_embedders.token_embedder import TokenEmbedder -from combo.nn.util import batched_index_select +from combo.nn.utils import batched_index_select from transformers import XLNetConfig logger = logging.getLogger(__name__) diff --git a/combo/modules/token_embedders/pretrained_transformer_mismatched_embedder.py b/combo/modules/token_embedders/pretrained_transformer_mismatched_embedder.py index d6f7a6f969b03de009c0abbd6fa36bc934114a13..11bfeac405b8abfa0eb7855061b67a7c1d8c207d 100644 --- a/combo/modules/token_embedders/pretrained_transformer_mismatched_embedder.py +++ b/combo/modules/token_embedders/pretrained_transformer_mismatched_embedder.py @@ -9,7 +9,7 @@ from typing import Optional, Dict, Any import torch from combo.modules.token_embedders import PretrainedTransformerEmbedder, TokenEmbedder -from combo.nn import util +from combo.nn import utils from combo.config import Registry from combo.utils import ConfigurationError diff --git a/combo/modules/token_embedders/token_embedder.py b/combo/modules/token_embedders/token_embedder.py index ed8b61981c3bc5e32714df96d9e3d55c08069e56..cf42e81edeaeacaef379ae7eabb47cb579da8ed5 100644 --- a/combo/modules/token_embedders/token_embedder.py +++ b/combo/modules/token_embedders/token_embedder.py @@ -2,12 +2,11 @@ from typing import Optional, List import torch from overrides import overrides -from torch import nn from torchtext.vocab import Vectors, GloVe, FastText, CharNGram from combo.config import FromParameters, Registry from combo.data import Vocabulary -from combo.models.utils import tiny_value_of_dtype +from combo.nn.utils import tiny_value_of_dtype from combo.modules.module import Module from combo.modules.time_distributed import TimeDistributed from combo.utils import ConfigurationError diff --git a/combo/nn/base.py b/combo/nn/base.py index c752066f06fef3e274549a687799ece43337adff..ff6763196ec756661f8e98ca84d41dc2f119a8e2 100644 --- a/combo/nn/base.py +++ b/combo/nn/base.py @@ -5,11 +5,9 @@ import torch.nn as nn from combo.config.registry import Registry from combo.config.from_parameters import FromParameters -from combo.nn.activations import Activation import combo.utils.checks as checks from combo.data.vocabulary import Vocabulary -from combo.models.utils import masked_cross_entropy -from combo.modules.module import Module +from combo.nn.utils import masked_cross_entropy from combo.predictors.predictor import Predictor @@ -18,7 +16,7 @@ class Linear(nn.Linear, FromParameters): def __init__(self, in_features: int, out_features: int, - activation: Activation = None, + activation: "Activation" = None, dropout_rate: Optional[float] = 0.0): super().__init__(in_features, out_features) self.activation = activation if activation else self.identity @@ -37,7 +35,7 @@ class Linear(nn.Linear, FromParameters): return x -class FeedForward(Module): +class FeedForward(torch.nn.Module): """ Modified copy of allennlp.modules.feedforward.FeedForward @@ -88,7 +86,7 @@ class FeedForward(Module): input_dim: int, num_layers: int, hidden_dims: Union[int, List[int]], - activations: List[Activation], # Change to Union[Activation, List[Activation]] + activations: List["Activation"], # Change to Union[Activation, List[Activation]] dropout: Union[float, List[float]] = 0.0, ) -> None: @@ -190,7 +188,7 @@ class FeedForwardPredictor(Predictor): input_dim: int, num_layers: int, hidden_dims: List[int], - activations: List[Activation], # TODO: change to Union[Activation, List[Activation]] + activations: List["Activation"], # TODO: change to Union[Activation, List[Activation]] dropout: Union[float, List[float]] = 0.0, ): if len(hidden_dims) + 1 != num_layers: diff --git a/combo/nn/util.py b/combo/nn/utils.py similarity index 98% rename from combo/nn/util.py rename to combo/nn/utils.py index e2537e24484a7d549c351cc44424c679db27f5a7..4333be0a1fc5736c87a27efb18426eaee1c5d2e7 100644 --- a/combo/nn/util.py +++ b/combo/nn/utils.py @@ -8,9 +8,13 @@ import torch from combo.common.util import int_to_device from combo.utils import ConfigurationError +import torch.nn.functional as F StateDictType = Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"] +def masked_cross_entropy(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: + pred = pred + (mask.float().unsqueeze(-1) + 1e-45).log() + return F.cross_entropy(pred, true, reduction="none") * mask def tiny_value_of_dtype(dtype: torch.dtype): """ diff --git a/combo/predictors/predictor_model.py b/combo/predictors/predictor_model.py index defcf42e5f27281204882b299f7bcf026ed1ceb7..45723c3442c66055cb7dda354ebba36389118420 100644 --- a/combo/predictors/predictor_model.py +++ b/combo/predictors/predictor_model.py @@ -25,7 +25,7 @@ from combo.data.dataset_loaders.dataset_loader import TensorDict from combo.data.dataset_readers.dataset_reader import DatasetReader from combo.data.instance import JsonDict, Instance from combo.modules.model import Model -from combo.nn import util +from combo.nn import utils logger = logging.getLogger(__name__)