Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
combo
Manage
Activity
Members
Labels
Plan
Issues
20
Issue boards
Milestones
Wiki
Redmine
Code
Merge requests
2
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Syntactic Tools
combo
Commits
4f9448a2
There was an error fetching the commit references. Please try again later.
Commit
4f9448a2
authored
2 years ago
by
Maja Jabłońska
Committed by
Martyna Wiącek
2 years ago
Browse files
Options
Downloads
Patches
Plain Diff
Clean up base.py
parent
8c4b7eaf
1 merge request
!46
Merge COMBO 3.0 into master
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
combo/models/base.py
+8
-14
8 additions, 14 deletions
combo/models/base.py
with
8 additions
and
14 deletions
combo/models/base.py
+
8
−
14
View file @
4f9448a2
...
...
@@ -5,10 +5,7 @@ import torch.nn as nn
import
utils
import
combo.models.combo_nn
as
combo_nn
import
combo.utils.checks
as
checks
class
Model
:
pass
from
combo
import
data
class
Predictor
(
nn
.
Module
):
...
...
@@ -21,7 +18,6 @@ class Predictor(nn.Module):
class
Linear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
...
...
@@ -91,12 +87,12 @@ class FeedForward(torch.nn.Module):
"""
def
__init__
(
self
,
input_dim
:
int
,
num_layers
:
int
,
hidden_dims
:
Union
[
int
,
List
[
int
]],
activations
:
Union
[
combo_nn
.
Activation
,
List
[
combo_nn
.
Activation
]],
dropout
:
Union
[
float
,
List
[
float
]]
=
0.0
,
self
,
input_dim
:
int
,
num_layers
:
int
,
hidden_dims
:
Union
[
int
,
List
[
int
]],
activations
:
Union
[
combo_nn
.
Activation
,
List
[
combo_nn
.
Activation
]],
dropout
:
Union
[
float
,
List
[
float
]]
=
0.0
,
)
->
None
:
super
().
__init__
()
...
...
@@ -140,14 +136,13 @@ class FeedForward(torch.nn.Module):
output
=
inputs
feature_maps
=
[]
for
layer
,
activation
,
dropout
in
zip
(
self
.
_linear_layers
,
self
.
_activations
,
self
.
_dropout
self
.
_linear_layers
,
self
.
_activations
,
self
.
_dropout
):
feature_maps
.
append
(
output
)
output
=
dropout
(
activation
(
layer
(
output
)))
return
output
,
feature_maps
class
FeedForwardPredictor
(
Predictor
):
"""
Feedforward predictor. Should be used on top of Seq2Seq encoder.
"""
...
...
@@ -216,4 +211,3 @@ class FeedForwardPredictor(Predictor):
hidden_dims
=
hidden_dims
,
activations
=
activations
,
dropout
=
dropout
))
This diff is collapsed.
Click to expand it.
Preview
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment