Skip to content
Snippets Groups Projects
Commit 4f9448a2 authored by Maja Jabłońska's avatar Maja Jabłońska Committed by Martyna Wiącek
Browse files

Clean up base.py

parent 8c4b7eaf
1 merge request!46Merge COMBO 3.0 into master
......@@ -5,10 +5,7 @@ import torch.nn as nn
import utils
import combo.models.combo_nn as combo_nn
import combo.utils.checks as checks
class Model:
pass
from combo import data
class Predictor(nn.Module):
......@@ -21,7 +18,6 @@ class Predictor(nn.Module):
class Linear(nn.Linear):
def __init__(self,
in_features: int,
out_features: int,
......@@ -91,12 +87,12 @@ class FeedForward(torch.nn.Module):
"""
def __init__(
self,
input_dim: int,
num_layers: int,
hidden_dims: Union[int, List[int]],
activations: Union[combo_nn.Activation, List[combo_nn.Activation]],
dropout: Union[float, List[float]] = 0.0,
self,
input_dim: int,
num_layers: int,
hidden_dims: Union[int, List[int]],
activations: Union[combo_nn.Activation, List[combo_nn.Activation]],
dropout: Union[float, List[float]] = 0.0,
) -> None:
super().__init__()
......@@ -140,14 +136,13 @@ class FeedForward(torch.nn.Module):
output = inputs
feature_maps = []
for layer, activation, dropout in zip(
self._linear_layers, self._activations, self._dropout
self._linear_layers, self._activations, self._dropout
):
feature_maps.append(output)
output = dropout(activation(layer(output)))
return output, feature_maps
class FeedForwardPredictor(Predictor):
"""Feedforward predictor. Should be used on top of Seq2Seq encoder."""
......@@ -216,4 +211,3 @@ class FeedForwardPredictor(Predictor):
hidden_dims=hidden_dims,
activations=activations,
dropout=dropout))
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment