Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
combo
Manage
Activity
Members
Labels
Plan
Issues
20
Issue boards
Milestones
Wiki
Redmine
Code
Merge requests
2
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Syntactic Tools
combo
Commits
c9c0e3c7
Commit
c9c0e3c7
authored
4 years ago
by
Mateusz Klimaszewski
Browse files
Options
Downloads
Patches
Plain Diff
Simplified transformer word embedder and fix validation loss.
parent
847f3a90
Branches
Branches containing commit
Tags
Tags containing commit
2 merge requests
!13
Refactor merge develop to master
,
!12
Refactor
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
combo/models/embeddings.py
+4
-20
4 additions, 20 deletions
combo/models/embeddings.py
combo/models/parser.py
+1
-1
1 addition, 1 deletion
combo/models/parser.py
combo/training/trainer.py
+9
-4
9 additions, 4 deletions
combo/training/trainer.py
with
14 additions
and
25 deletions
combo/models/embeddings.py
+
4
−
20
View file @
c9c0e3c7
...
...
@@ -113,12 +113,10 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
freeze_transformer
:
bool
=
True
,
tokenizer_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
transformer_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
):
super
().
__init__
(
model_name
,
tokenizer_kwargs
=
tokenizer_kwargs
,
transformer_kwargs
=
transformer_kwargs
)
self
.
freeze_transformer
=
freeze_transformer
if
self
.
freeze_transformer
:
self
.
_matched_embedder
.
eval
()
for
param
in
self
.
_matched_embedder
.
parameters
():
param
.
requires_grad
=
False
super
().
__init__
(
model_name
,
train_parameters
=
not
freeze_transformer
,
tokenizer_kwargs
=
tokenizer_kwargs
,
transformer_kwargs
=
transformer_kwargs
)
if
projection_dim
:
self
.
projection_layer
=
base
.
Linear
(
in_features
=
super
().
get_output_dim
(),
out_features
=
projection_dim
,
...
...
@@ -148,20 +146,6 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
def
get_output_dim
(
self
):
return
self
.
output_dim
@overrides
def
train
(
self
,
mode
:
bool
):
if
self
.
freeze_transformer
:
self
.
projection_layer
.
train
(
mode
)
else
:
super
().
train
(
mode
)
@overrides
def
eval
(
self
):
if
self
.
freeze_transformer
:
self
.
projection_layer
.
eval
()
else
:
super
().
eval
()
@token_embedders.TokenEmbedder.register
(
"
feats_embedding
"
)
class
FeatsTokenEmbedder
(
token_embedders
.
Embedding
):
...
...
This diff is collapsed.
Click to expand it.
combo/models/parser.py
+
1
−
1
View file @
c9c0e3c7
...
...
@@ -158,7 +158,7 @@ class DependencyRelationModel(base.Predictor):
output
[
"
prediction
"
]
=
(
relation_prediction
.
argmax
(
-
1
)[:,
1
:],
head_output
[
"
prediction
"
])
else
:
# Mask root label whenever head is not 0.
relation_prediction_output
=
relation_prediction
[:,
1
:]
relation_prediction_output
=
relation_prediction
[:,
1
:]
.
clone
()
mask
=
(
head_output
[
"
prediction
"
]
==
0
)
vocab_size
=
relation_prediction_output
.
size
(
-
1
)
root_idx
=
torch
.
tensor
([
self
.
root_idx
],
device
=
device
)
...
...
This diff is collapsed.
Click to expand it.
combo/training/trainer.py
+
9
−
4
View file @
c9c0e3c7
...
...
@@ -230,22 +230,24 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
patience
:
int
=
None
,
validation_metric
:
str
=
"
-loss
"
,
num_epochs
:
int
=
20
,
cuda_device
:
int
=
-
1
,
cuda_device
:
Optional
[
Union
[
int
,
torch
.
device
]]
=
-
1
,
grad_norm
:
float
=
None
,
grad_clipping
:
float
=
None
,
distributed
:
bool
=
None
,
world_size
:
int
=
1
,
num_gradient_accumulation_steps
:
int
=
1
,
opt_level
:
Optional
[
str
]
=
None
,
use_amp
:
bool
=
False
,
optimizer
:
common
.
Lazy
[
optimizers
.
Optimizer
]
=
None
,
no_grad
:
List
[
str
]
=
None
,
optimizer
:
common
.
Lazy
[
optimizers
.
Optimizer
]
=
common
.
Lazy
(
optimizers
.
Optimizer
.
default
),
learning_rate_scheduler
:
common
.
Lazy
[
learning_rate_schedulers
.
LearningRateScheduler
]
=
None
,
momentum_scheduler
:
common
.
Lazy
[
momentum_schedulers
.
MomentumScheduler
]
=
None
,
tensorboard_writer
:
common
.
Lazy
[
allen_tensorboard_writer
.
TensorboardWriter
]
=
None
,
moving_average
:
common
.
Lazy
[
moving_average
.
MovingAverage
]
=
None
,
checkpointer
:
common
.
Lazy
[
training
.
Checkpointer
]
=
None
,
checkpointer
:
common
.
Lazy
[
training
.
Checkpointer
]
=
common
.
Lazy
(
training
.
Checkpointer
)
,
batch_callbacks
:
List
[
training
.
BatchCallback
]
=
None
,
epoch_callbacks
:
List
[
training
.
EpochCallback
]
=
None
,
end_callbacks
:
List
[
training
.
EpochCallback
]
=
None
,
trainer_callbacks
:
List
[
training
.
TrainerCallback
]
=
None
,
)
->
"
training.Trainer
"
:
if
tensorboard_writer
is
None
:
tensorboard_writer
=
common
.
Lazy
(
combo_tensorboard_writer
.
NullTensorboardWriter
)
...
...
@@ -265,6 +267,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
world_size
=
world_size
,
num_gradient_accumulation_steps
=
num_gradient_accumulation_steps
,
use_amp
=
use_amp
,
no_grad
=
no_grad
,
optimizer
=
optimizer
,
learning_rate_scheduler
=
learning_rate_scheduler
,
momentum_scheduler
=
momentum_scheduler
,
...
...
@@ -273,4 +276,6 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
checkpointer
=
checkpointer
,
batch_callbacks
=
batch_callbacks
,
epoch_callbacks
=
epoch_callbacks
,
end_callbacks
=
end_callbacks
,
trainer_callbacks
=
trainer_callbacks
,
)
This diff is collapsed.
Click to expand it.
Preview
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment