From 5aef689563be61a0fb97826750d74643d58888c3 Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Mon, 2 Oct 2023 00:06:39 +1100 Subject: [PATCH] Remove circular dependencies --- .../dataset_loaders/simple_data_loader.py | 2 +- combo/main.py | 11 +------- combo/models/combo_model.py | 2 +- combo/models/encoder.py | 2 -- combo/models/utils.py | 27 ------------------- combo/modules/augmented_lstm.py | 2 +- combo/modules/encoder.py | 2 +- combo/modules/feedforward_predictor.py | 7 +++-- combo/modules/model.py | 2 +- combo/modules/module.py | 2 +- combo/modules/scalar_mix.py | 2 +- combo/modules/token_embedders/embedding.py | 2 +- .../pretrained_transformer_embedder.py | 2 +- ...trained_transformer_mismatched_embedder.py | 2 +- .../modules/token_embedders/token_embedder.py | 3 +-- combo/nn/base.py | 12 ++++----- combo/nn/{util.py => utils.py} | 4 +++ combo/predictors/predictor_model.py | 2 +- 18 files changed, 25 insertions(+), 63 deletions(-) delete mode 100644 combo/models/utils.py rename combo/nn/{util.py => utils.py} (98%) diff --git a/combo/data/dataset_loaders/simple_data_loader.py b/combo/data/dataset_loaders/simple_data_loader.py index f272589..36690da 100644 --- a/combo/data/dataset_loaders/simple_data_loader.py +++ b/combo/data/dataset_loaders/simple_data_loader.py @@ -16,7 +16,7 @@ from combo.data.dataset_loaders import DefaultDataCollator from combo.data.dataset_readers import DatasetReader from combo.data.instance import Instance from combo.data.vocabulary import Vocabulary -import combo.nn.util as nn_util +import combo.nn.utils as nn_util from combo.data.dataset_loaders.dataset_loader import DataLoader, TensorDict diff --git a/combo/main.py b/combo/main.py index d12b5f4..5b98f23 100755 --- a/combo/main.py +++ b/combo/main.py @@ -5,7 +5,7 @@ from typing import Dict from absl import app from absl import flags -from combo.nn.base import Predictor +from combo.predictors import Predictor from combo.utils import checks logger = logging.getLogger(__name__) @@ -82,7 +82,6 @@ flags.DEFINE_enum(name="predictor_name", default="combo-spacy", def run(_): - print("COMBO") pass @@ -90,14 +89,6 @@ def _get_predictor() -> Predictor: # Check for GPU # allen_checks.check_for_gpu(FLAGS.cuda_device) checks.file_exists(FLAGS.model_path) - # load model from archive - # archive = models.load_archive( - # FLAGS.model_path, - # cuda_device=FLAGS.cuda_device, - # ) - # return predictors.Predictor.from_archive( - # archive, FLAGS.predictor_name - # ) return Predictor() diff --git a/combo/models/combo_model.py b/combo/models/combo_model.py index d2ada67..249b241 100644 --- a/combo/models/combo_model.py +++ b/combo/models/combo_model.py @@ -13,7 +13,7 @@ from combo.modules import TextFieldEmbedder from combo.modules.model import Model from combo.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder from combo.nn import RegularizerApplicator, base -from combo.nn.util import get_text_field_mask +from combo.nn.utils import get_text_field_mask from combo.utils import metrics diff --git a/combo/models/encoder.py b/combo/models/encoder.py index 12932ad..b9cd26d 100644 --- a/combo/models/encoder.py +++ b/combo/models/encoder.py @@ -6,14 +6,12 @@ and COMBO (Author: Mateusz Klimaszewski) from typing import Optional, Tuple, List import torch import torch.nn.utils.rnn as rnn -from overrides import overrides from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence from combo.config import FromParameters, Registry from combo.modules import input_variational_dropout from combo.modules.augmented_lstm import AugmentedLstm from combo.modules.input_variational_dropout import InputVariationalDropout -from combo.modules.module import Module from combo.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder from combo.utils import ConfigurationError diff --git a/combo/models/utils.py b/combo/models/utils.py deleted file mode 100644 index 8366a7b..0000000 --- a/combo/models/utils.py +++ /dev/null @@ -1,27 +0,0 @@ -import torch -import torch.nn.functional as F - - -def masked_cross_entropy(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: - pred = pred + (mask.float().unsqueeze(-1) + 1e-45).log() - return F.cross_entropy(pred, true, reduction="none") * mask - - -""" -Adapted from AllenNLP -""" -def tiny_value_of_dtype(dtype: torch.dtype): - """ - Returns a moderately tiny value for a given PyTorch data type that is used to avoid numerical - issues such as division by zero. - This is different from `info_value_of_dtype(dtype).tiny` because it causes some NaN bugs. - Only supports floating point dtypes. - """ - if not dtype.is_floating_point: - raise TypeError("Only supports floating point dtypes.") - if dtype == torch.float or dtype == torch.double: - return 1e-13 - elif dtype == torch.half: - return 1e-4 - else: - raise TypeError("Does not support dtype " + str(dtype)) \ No newline at end of file diff --git a/combo/modules/augmented_lstm.py b/combo/modules/augmented_lstm.py index 2131c95..86e2fa6 100644 --- a/combo/modules/augmented_lstm.py +++ b/combo/modules/augmented_lstm.py @@ -9,7 +9,7 @@ from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_ from combo.config import FromParameters, Registry from combo.modules.module import Module -from combo.nn.util import get_dropout_mask +from combo.nn.utils import get_dropout_mask from combo.nn.initializers import block_orthogonal from combo.utils import ConfigurationError diff --git a/combo/modules/encoder.py b/combo/modules/encoder.py index 4d053a8..e213847 100644 --- a/combo/modules/encoder.py +++ b/combo/modules/encoder.py @@ -8,7 +8,7 @@ import torch from torch.nn.utils.rnn import pack_padded_sequence, PackedSequence from combo.modules.module import Module -from combo.nn.util import get_lengths_from_binary_sequence_mask, sort_batch_by_length +from combo.nn.utils import get_lengths_from_binary_sequence_mask, sort_batch_by_length # We have two types here for the state, because storing the state in something # which is Iterable (like a tuple, below), is helpful for internal manipulation diff --git a/combo/modules/feedforward_predictor.py b/combo/modules/feedforward_predictor.py index 22c742d..3c931d7 100644 --- a/combo/modules/feedforward_predictor.py +++ b/combo/modules/feedforward_predictor.py @@ -1,6 +1,5 @@ from combo.data import Vocabulary -from combo.models.utils import masked_cross_entropy -from combo.nn import Activation +from combo.nn.utils import masked_cross_entropy from combo.nn.base import FeedForward from combo.predictors import Predictor import torch @@ -12,7 +11,7 @@ from combo.utils import ConfigurationError class FeedForwardPredictor(Predictor): """Feedforward predictor. Should be used on top of Seq2Seq encoder.""" - def __init__(self, feedforward_network: FeedForward): + def __init__(self, feedforward_network: "FeedForward"): super().__init__() self.feedforward_network = feedforward_network @@ -59,7 +58,7 @@ class FeedForwardPredictor(Predictor): input_dim: int, num_layers: int, hidden_dims: List[int], - activations: Union[Activation, List[Activation]], + activations: Union["Activation", List["Activation"]], dropout: Union[float, List[float]] = 0.0, ): if len(hidden_dims) + 1 != num_layers: diff --git a/combo/modules/model.py b/combo/modules/model.py index 25a637f..ee520bb 100644 --- a/combo/modules/model.py +++ b/combo/modules/model.py @@ -18,7 +18,7 @@ from combo.data import Vocabulary, Instance from combo.data.batch import Batch from combo.data.dataset_loaders.dataset_loader import TensorDict from combo.modules.module import Module -from combo.nn import util, RegularizerApplicator +from combo.nn import utils, RegularizerApplicator from combo.utils import ConfigurationError logger = logging.getLogger(__name__) diff --git a/combo/modules/module.py b/combo/modules/module.py index 4ba1d1e..84311ac 100644 --- a/combo/modules/module.py +++ b/combo/modules/module.py @@ -7,7 +7,7 @@ from typing import List, Optional, Tuple import pytorch_lightning as pl import torch -from combo.nn.util import ( +from combo.nn.utils import ( _check_incompatible_keys, _IncompatibleKeys, StateDictType, diff --git a/combo/modules/scalar_mix.py b/combo/modules/scalar_mix.py index b4e894e..e31acf1 100644 --- a/combo/modules/scalar_mix.py +++ b/combo/modules/scalar_mix.py @@ -8,7 +8,7 @@ from typing import List import torch from torch.nn import ParameterList, Parameter -from combo.nn import util +from combo.nn import utils from combo.utils import ConfigurationError diff --git a/combo/modules/token_embedders/embedding.py b/combo/modules/token_embedders/embedding.py index 7a628e1..84ee2c8 100644 --- a/combo/modules/token_embedders/embedding.py +++ b/combo/modules/token_embedders/embedding.py @@ -24,7 +24,7 @@ from cached_path import is_url_or_existing_file from combo.data.vocabulary import Vocabulary from combo.modules.time_distributed import TimeDistributed from combo.modules.token_embedders.token_embedder import TokenEmbedder -from combo.nn import util +from combo.nn import utils with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) diff --git a/combo/modules/token_embedders/pretrained_transformer_embedder.py b/combo/modules/token_embedders/pretrained_transformer_embedder.py index d325490..3867f2f 100644 --- a/combo/modules/token_embedders/pretrained_transformer_embedder.py +++ b/combo/modules/token_embedders/pretrained_transformer_embedder.py @@ -15,7 +15,7 @@ from combo.config import Registry from combo.data.tokenizers import PretrainedTransformerTokenizer from combo.modules.scalar_mix import ScalarMix from combo.modules.token_embedders.token_embedder import TokenEmbedder -from combo.nn.util import batched_index_select +from combo.nn.utils import batched_index_select from transformers import XLNetConfig logger = logging.getLogger(__name__) diff --git a/combo/modules/token_embedders/pretrained_transformer_mismatched_embedder.py b/combo/modules/token_embedders/pretrained_transformer_mismatched_embedder.py index d6f7a6f..11bfeac 100644 --- a/combo/modules/token_embedders/pretrained_transformer_mismatched_embedder.py +++ b/combo/modules/token_embedders/pretrained_transformer_mismatched_embedder.py @@ -9,7 +9,7 @@ from typing import Optional, Dict, Any import torch from combo.modules.token_embedders import PretrainedTransformerEmbedder, TokenEmbedder -from combo.nn import util +from combo.nn import utils from combo.config import Registry from combo.utils import ConfigurationError diff --git a/combo/modules/token_embedders/token_embedder.py b/combo/modules/token_embedders/token_embedder.py index ed8b619..cf42e81 100644 --- a/combo/modules/token_embedders/token_embedder.py +++ b/combo/modules/token_embedders/token_embedder.py @@ -2,12 +2,11 @@ from typing import Optional, List import torch from overrides import overrides -from torch import nn from torchtext.vocab import Vectors, GloVe, FastText, CharNGram from combo.config import FromParameters, Registry from combo.data import Vocabulary -from combo.models.utils import tiny_value_of_dtype +from combo.nn.utils import tiny_value_of_dtype from combo.modules.module import Module from combo.modules.time_distributed import TimeDistributed from combo.utils import ConfigurationError diff --git a/combo/nn/base.py b/combo/nn/base.py index c752066..ff67631 100644 --- a/combo/nn/base.py +++ b/combo/nn/base.py @@ -5,11 +5,9 @@ import torch.nn as nn from combo.config.registry import Registry from combo.config.from_parameters import FromParameters -from combo.nn.activations import Activation import combo.utils.checks as checks from combo.data.vocabulary import Vocabulary -from combo.models.utils import masked_cross_entropy -from combo.modules.module import Module +from combo.nn.utils import masked_cross_entropy from combo.predictors.predictor import Predictor @@ -18,7 +16,7 @@ class Linear(nn.Linear, FromParameters): def __init__(self, in_features: int, out_features: int, - activation: Activation = None, + activation: "Activation" = None, dropout_rate: Optional[float] = 0.0): super().__init__(in_features, out_features) self.activation = activation if activation else self.identity @@ -37,7 +35,7 @@ class Linear(nn.Linear, FromParameters): return x -class FeedForward(Module): +class FeedForward(torch.nn.Module): """ Modified copy of allennlp.modules.feedforward.FeedForward @@ -88,7 +86,7 @@ class FeedForward(Module): input_dim: int, num_layers: int, hidden_dims: Union[int, List[int]], - activations: List[Activation], # Change to Union[Activation, List[Activation]] + activations: List["Activation"], # Change to Union[Activation, List[Activation]] dropout: Union[float, List[float]] = 0.0, ) -> None: @@ -190,7 +188,7 @@ class FeedForwardPredictor(Predictor): input_dim: int, num_layers: int, hidden_dims: List[int], - activations: List[Activation], # TODO: change to Union[Activation, List[Activation]] + activations: List["Activation"], # TODO: change to Union[Activation, List[Activation]] dropout: Union[float, List[float]] = 0.0, ): if len(hidden_dims) + 1 != num_layers: diff --git a/combo/nn/util.py b/combo/nn/utils.py similarity index 98% rename from combo/nn/util.py rename to combo/nn/utils.py index e2537e2..4333be0 100644 --- a/combo/nn/util.py +++ b/combo/nn/utils.py @@ -8,9 +8,13 @@ import torch from combo.common.util import int_to_device from combo.utils import ConfigurationError +import torch.nn.functional as F StateDictType = Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"] +def masked_cross_entropy(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: + pred = pred + (mask.float().unsqueeze(-1) + 1e-45).log() + return F.cross_entropy(pred, true, reduction="none") * mask def tiny_value_of_dtype(dtype: torch.dtype): """ diff --git a/combo/predictors/predictor_model.py b/combo/predictors/predictor_model.py index defcf42..45723c3 100644 --- a/combo/predictors/predictor_model.py +++ b/combo/predictors/predictor_model.py @@ -25,7 +25,7 @@ from combo.data.dataset_loaders.dataset_loader import TensorDict from combo.data.dataset_readers.dataset_reader import DatasetReader from combo.data.instance import JsonDict, Instance from combo.modules.model import Model -from combo.nn import util +from combo.nn import utils logger = logging.getLogger(__name__) -- GitLab