From 4f9448a2abdfb6b783b0e73a91ef2662f667afd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maja=20Jab=C5=82o=C5=84ska?= <majajjablonska@gmail.com> Date: Sun, 2 Apr 2023 12:52:41 +0200 Subject: [PATCH] Clean up base.py --- combo/models/base.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/combo/models/base.py b/combo/models/base.py index 016ef85..7d6cce9 100644 --- a/combo/models/base.py +++ b/combo/models/base.py @@ -5,10 +5,7 @@ import torch.nn as nn import utils import combo.models.combo_nn as combo_nn import combo.utils.checks as checks - - -class Model: - pass +from combo import data class Predictor(nn.Module): @@ -21,7 +18,6 @@ class Predictor(nn.Module): class Linear(nn.Linear): - def __init__(self, in_features: int, out_features: int, @@ -91,12 +87,12 @@ class FeedForward(torch.nn.Module): """ def __init__( - self, - input_dim: int, - num_layers: int, - hidden_dims: Union[int, List[int]], - activations: Union[combo_nn.Activation, List[combo_nn.Activation]], - dropout: Union[float, List[float]] = 0.0, + self, + input_dim: int, + num_layers: int, + hidden_dims: Union[int, List[int]], + activations: Union[combo_nn.Activation, List[combo_nn.Activation]], + dropout: Union[float, List[float]] = 0.0, ) -> None: super().__init__() @@ -140,14 +136,13 @@ class FeedForward(torch.nn.Module): output = inputs feature_maps = [] for layer, activation, dropout in zip( - self._linear_layers, self._activations, self._dropout + self._linear_layers, self._activations, self._dropout ): feature_maps.append(output) output = dropout(activation(layer(output))) return output, feature_maps - class FeedForwardPredictor(Predictor): """Feedforward predictor. Should be used on top of Seq2Seq encoder.""" @@ -216,4 +211,3 @@ class FeedForwardPredictor(Predictor): hidden_dims=hidden_dims, activations=activations, dropout=dropout)) - -- GitLab