From 4f9448a2abdfb6b783b0e73a91ef2662f667afd2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Maja=20Jab=C5=82o=C5=84ska?= <majajjablonska@gmail.com>
Date: Sun, 2 Apr 2023 12:52:41 +0200
Subject: [PATCH] Clean up base.py

---
 combo/models/base.py | 22 ++++++++--------------
 1 file changed, 8 insertions(+), 14 deletions(-)

diff --git a/combo/models/base.py b/combo/models/base.py
index 016ef85..7d6cce9 100644
--- a/combo/models/base.py
+++ b/combo/models/base.py
@@ -5,10 +5,7 @@ import torch.nn as nn
 import utils
 import combo.models.combo_nn as combo_nn
 import combo.utils.checks as checks
-
-
-class Model:
-    pass
+from combo import data
 
 
 class Predictor(nn.Module):
@@ -21,7 +18,6 @@ class Predictor(nn.Module):
 
 
 class Linear(nn.Linear):
-
     def __init__(self,
                  in_features: int,
                  out_features: int,
@@ -91,12 +87,12 @@ class FeedForward(torch.nn.Module):
     """
 
     def __init__(
-        self,
-        input_dim: int,
-        num_layers: int,
-        hidden_dims: Union[int, List[int]],
-        activations: Union[combo_nn.Activation, List[combo_nn.Activation]],
-        dropout: Union[float, List[float]] = 0.0,
+            self,
+            input_dim: int,
+            num_layers: int,
+            hidden_dims: Union[int, List[int]],
+            activations: Union[combo_nn.Activation, List[combo_nn.Activation]],
+            dropout: Union[float, List[float]] = 0.0,
     ) -> None:
 
         super().__init__()
@@ -140,14 +136,13 @@ class FeedForward(torch.nn.Module):
         output = inputs
         feature_maps = []
         for layer, activation, dropout in zip(
-            self._linear_layers, self._activations, self._dropout
+                self._linear_layers, self._activations, self._dropout
         ):
             feature_maps.append(output)
             output = dropout(activation(layer(output)))
         return output, feature_maps
 
 
-
 class FeedForwardPredictor(Predictor):
     """Feedforward predictor. Should be used on top of Seq2Seq encoder."""
 
@@ -216,4 +211,3 @@ class FeedForwardPredictor(Predictor):
             hidden_dims=hidden_dims,
             activations=activations,
             dropout=dropout))
-
-- 
GitLab