Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
combo
Manage
Activity
Members
Labels
Plan
Issues
20
Issue boards
Milestones
Wiki
Redmine
Code
Merge requests
2
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Syntactic Tools
combo
Merge requests
!12
Refactor
Code
Review changes
Check out branch
Download
Patches
Plain diff
Merged
Refactor
refactor
into
develop
Overview
0
Commits
4
Pipelines
0
Changes
5
Merged
Mateusz Klimaszewski
requested to merge
refactor
into
develop
4 years ago
Overview
0
Commits
4
Pipelines
0
Changes
5
Expand
0
0
Merge request reports
Compare
develop
develop (base)
and
latest version
latest version
8fbff648
4 commits,
4 years ago
5 files
+
67
−
75
Expand all files
Inline
Compare changes
Side-by-side
Inline
Show whitespace changes
Show one file at a time
Files
5
Search (e.g. *.vue) (Ctrl+P)
combo/models/embeddings.py
+
5
−
21
Options
@@ -107,18 +107,16 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
def
__init__
(
self
,
model_name
:
str
,
projection_dim
:
int
,
projection_dim
:
int
=
0
,
projection_activation
:
Optional
[
allen_nn
.
Activation
]
=
lambda
x
:
x
,
projection_dropout_rate
:
Optional
[
float
]
=
0.0
,
freeze_transformer
:
bool
=
True
,
tokenizer_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
transformer_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
):
super
().
__init__
(
model_name
,
tokenizer_kwargs
=
tokenizer_kwargs
,
transformer_kwargs
=
transformer_kwargs
)
self
.
freeze_transformer
=
freeze_transformer
if
self
.
freeze_transformer
:
self
.
_matched_embedder
.
eval
()
for
param
in
self
.
_matched_embedder
.
parameters
():
param
.
requires_grad
=
False
super
().
__init__
(
model_name
,
train_parameters
=
not
freeze_transformer
,
tokenizer_kwargs
=
tokenizer_kwargs
,
transformer_kwargs
=
transformer_kwargs
)
if
projection_dim
:
self
.
projection_layer
=
base
.
Linear
(
in_features
=
super
().
get_output_dim
(),
out_features
=
projection_dim
,
@@ -148,20 +146,6 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
def
get_output_dim
(
self
):
return
self
.
output_dim
@overrides
def
train
(
self
,
mode
:
bool
):
if
self
.
freeze_transformer
:
self
.
projection_layer
.
train
(
mode
)
else
:
super
().
train
(
mode
)
@overrides
def
eval
(
self
):
if
self
.
freeze_transformer
:
self
.
projection_layer
.
eval
()
else
:
super
().
eval
()
@token_embedders.TokenEmbedder.register
(
"
feats_embedding
"
)
class
FeatsTokenEmbedder
(
token_embedders
.
Embedding
):