Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
combo
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Redmine
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Syntactic Tools
combo
Commits
69f9be60
Commit
69f9be60
authored
Apr 6, 2023
by
Maja Jabłońska
Browse files
Options
Downloads
Patches
Plain Diff
Add HeadPredictionModel
parent
74797922
No related branches found
No related tags found
1 merge request
!46
Merge COMBO 3.0 into master
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
combo/models/parser.py
+219
-5
219 additions, 5 deletions
combo/models/parser.py
with
219 additions
and
5 deletions
combo/models/parser.py
+
219
−
5
View file @
69f9be60
from
combo.models.base
import
Predictor
"""
Adapted from COMBO
Author: Mateusz Klimaszewski
"""
from
typing
import
Tuple
,
Dict
,
Optional
,
Union
,
List
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
class
HeadPredictionModel
(
Predictor
):
pass
from
combo
import
data
from
combo.models
import
base
,
utils
from
combo.nn
import
chu_liu_edmonds
class
DependencyRelationModel
(
Predictor
):
pass
class
HeadPredictionModel
(
base
.
Predictor
):
"""
Head prediction model.
"""
def
__init__
(
self
,
head_projection_layer
:
base
.
Linear
,
dependency_projection_layer
:
base
.
Linear
,
cycle_loss_n
:
int
=
0
):
super
().
__init__
()
self
.
head_projection_layer
=
head_projection_layer
self
.
dependency_projection_layer
=
dependency_projection_layer
self
.
cycle_loss_n
=
cycle_loss_n
def
forward
(
self
,
x
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]],
mask
:
Optional
[
torch
.
BoolTensor
]
=
None
,
labels
:
Optional
[
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]]
=
None
,
sample_weights
:
Optional
[
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]]
=
None
)
->
Dict
[
str
,
torch
.
Tensor
]:
if
mask
is
None
:
mask
=
x
.
new_ones
(
x
.
size
()[
-
1
])
head_arc_emb
=
self
.
head_projection_layer
(
x
)
dep_arc_emb
=
self
.
dependency_projection_layer
(
x
)
x
=
dep_arc_emb
.
bmm
(
head_arc_emb
.
transpose
(
2
,
1
))
if
self
.
training
:
pred
=
x
.
argmax
(
-
1
)
else
:
pred
=
[]
# Adding non existing in mask ROOT to lengths
lengths
=
mask
.
data
.
sum
(
dim
=
1
).
long
().
cpu
().
numpy
()
+
1
for
idx
,
length
in
enumerate
(
lengths
):
probs
=
x
[
idx
,
:].
softmax
(
dim
=-
1
).
cpu
().
numpy
()
# We do not want any word to be parent of the root node (ROOT, 0).
# Also setting it to -1 instead of 0 fixes edge case where softmax made all
# but ROOT prediction to EXACTLY 0.0 and it might cause in many ROOT -> word edges)
probs
[:,
0
]
=
-
1
heads
,
_
=
chu_liu_edmonds
.
decode_mst
(
probs
.
T
,
length
=
length
,
has_labels
=
False
)
heads
[
0
]
=
0
pred
.
append
(
heads
)
pred
=
torch
.
from_numpy
(
np
.
stack
(
pred
)).
to
(
x
.
device
)
output
=
{
"
prediction
"
:
pred
[:,
1
:],
"
probability
"
:
x
}
if
labels
is
not
None
:
if
sample_weights
is
None
:
sample_weights
=
labels
.
new_ones
([
mask
.
size
(
0
)])
output
[
"
loss
"
],
output
[
"
cycle_loss
"
]
=
self
.
_loss
(
x
,
labels
,
mask
,
sample_weights
)
return
output
def
_cycle_loss
(
self
,
pred
:
torch
.
Tensor
):
BATCH_SIZE
,
_
,
_
=
pred
.
size
()
loss
=
pred
.
new_zeros
(
BATCH_SIZE
)
# Index from 1: as using non __ROOT__ tokens
pred
=
pred
.
softmax
(
-
1
)[:,
1
:,
1
:]
x
=
pred
for
i
in
range
(
self
.
cycle_loss_n
):
loss
+=
self
.
_batch_trace
(
x
)
# Don't multiple on last iteration
if
i
<
self
.
cycle_loss_n
-
1
:
x
=
x
.
bmm
(
pred
)
return
loss
@staticmethod
def
_batch_trace
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
len
(
x
.
size
())
==
3
BATCH_SIZE
,
N
,
M
=
x
.
size
()
assert
N
==
M
identity
=
x
.
new_tensor
(
torch
.
eye
(
N
))
identity
=
identity
.
reshape
((
1
,
N
,
N
))
batch_identity
=
identity
.
repeat
(
BATCH_SIZE
,
1
,
1
)
return
(
x
*
batch_identity
).
sum
((
-
1
,
-
2
))
def
_loss
(
self
,
pred
:
torch
.
Tensor
,
true
:
torch
.
Tensor
,
mask
:
torch
.
BoolTensor
,
sample_weights
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
BATCH_SIZE
,
N
,
M
=
pred
.
size
()
assert
N
==
M
SENTENCE_LENGTH
=
N
valid_positions
=
mask
.
sum
()
result
=
[]
# Ignore first pred dimension as it is ROOT token prediction
for
i
in
range
(
SENTENCE_LENGTH
-
1
):
pred_i
=
pred
[:,
i
+
1
,
:].
reshape
(
BATCH_SIZE
,
SENTENCE_LENGTH
)
true_i
=
true
[:,
i
].
reshape
(
-
1
)
mask_i
=
mask
[:,
i
]
cross_entropy_loss
=
utils
.
masked_cross_entropy
(
pred_i
,
true_i
,
mask_i
)
result
.
append
(
cross_entropy_loss
)
cycle_loss
=
self
.
_cycle_loss
(
pred
)
loss
=
torch
.
stack
(
result
).
transpose
(
1
,
0
)
*
sample_weights
.
unsqueeze
(
-
1
)
return
loss
.
sum
()
/
valid_positions
+
cycle_loss
.
mean
(),
cycle_loss
.
mean
()
@base.Predictor.register
(
"
combo_dependency_parsing_from_vocab
"
,
constructor
=
"
from_vocab
"
)
class
DependencyRelationModel
(
base
.
Predictor
):
"""
Dependency relation parsing model.
"""
def
__init__
(
self
,
root_idx
:
int
,
head_predictor
:
HeadPredictionModel
,
head_projection_layer
:
base
.
Linear
,
dependency_projection_layer
:
base
.
Linear
,
relation_prediction_layer
:
base
.
Linear
):
super
().
__init__
()
self
.
root_idx
=
root_idx
self
.
head_predictor
=
head_predictor
self
.
head_projection_layer
=
head_projection_layer
self
.
dependency_projection_layer
=
dependency_projection_layer
self
.
relation_prediction_layer
=
relation_prediction_layer
def
forward
(
self
,
x
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]],
mask
:
Optional
[
torch
.
BoolTensor
]
=
None
,
labels
:
Optional
[
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]]
=
None
,
sample_weights
:
Optional
[
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]]
=
None
)
->
Dict
[
str
,
torch
.
Tensor
]:
device
=
x
.
device
if
mask
is
not
None
:
mask
=
mask
[:,
1
:]
relations_labels
,
head_labels
=
None
,
None
if
labels
is
not
None
and
labels
[
0
]
is
not
None
:
relations_labels
,
head_labels
=
labels
if
mask
is
None
:
mask
=
head_labels
.
new_ones
(
head_labels
.
size
())
head_output
=
self
.
head_predictor
(
x
,
mask
,
head_labels
,
sample_weights
)
head_pred
=
head_output
[
"
probability
"
]
head_pred_soft
=
F
.
softmax
(
head_pred
,
dim
=-
1
)
head_rel_emb
=
self
.
head_projection_layer
(
x
)
dep_rel_emb
=
self
.
dependency_projection_layer
(
x
)
dep_rel_pred
=
head_pred_soft
.
bmm
(
head_rel_emb
)
dep_rel_pred
=
torch
.
cat
((
dep_rel_pred
,
dep_rel_emb
),
dim
=-
1
)
relation_prediction
=
self
.
relation_prediction_layer
(
dep_rel_pred
)
output
=
head_output
output
[
"
embedding
"
]
=
dep_rel_pred
if
self
.
training
:
output
[
"
prediction
"
]
=
(
relation_prediction
.
argmax
(
-
1
)[:,
1
:],
head_output
[
"
prediction
"
])
else
:
# Mask root label whenever head is not 0.
relation_prediction_output
=
relation_prediction
[:,
1
:].
clone
()
mask
=
(
head_output
[
"
prediction
"
]
==
0
)
vocab_size
=
relation_prediction_output
.
size
(
-
1
)
root_idx
=
torch
.
tensor
([
self
.
root_idx
],
device
=
device
)
relation_prediction_output
[
mask
]
=
(
relation_prediction_output
.
masked_select
(
mask
.
unsqueeze
(
-
1
))
.
reshape
(
-
1
,
vocab_size
)
.
index_fill
(
-
1
,
root_idx
,
10e10
))
relation_prediction_output
[
~
mask
]
=
(
relation_prediction_output
.
masked_select
(
~
(
mask
.
unsqueeze
(
-
1
)))
.
reshape
(
-
1
,
vocab_size
)
.
index_fill
(
-
1
,
root_idx
,
-
10e10
))
output
[
"
prediction
"
]
=
(
relation_prediction_output
.
argmax
(
-
1
),
head_output
[
"
prediction
"
])
if
labels
is
not
None
and
labels
[
0
]
is
not
None
:
if
sample_weights
is
None
:
sample_weights
=
labels
.
new_ones
([
mask
.
size
(
0
)])
loss
=
self
.
_loss
(
relation_prediction
[:,
1
:],
relations_labels
,
mask
,
sample_weights
)
output
[
"
loss
"
]
=
(
loss
,
head_output
[
"
loss
"
])
return
output
@staticmethod
def
_loss
(
pred
:
torch
.
Tensor
,
true
:
torch
.
Tensor
,
mask
:
torch
.
BoolTensor
,
sample_weights
:
torch
.
Tensor
)
->
torch
.
Tensor
:
valid_positions
=
mask
.
sum
()
BATCH_SIZE
,
_
,
DEPENDENCY_RELATIONS
=
pred
.
size
()
pred
=
pred
.
reshape
(
-
1
,
DEPENDENCY_RELATIONS
)
true
=
true
.
reshape
(
-
1
)
mask
=
mask
.
reshape
(
-
1
)
loss
=
utils
.
masked_cross_entropy
(
pred
,
true
,
mask
)
loss
=
loss
.
reshape
(
BATCH_SIZE
,
-
1
)
*
sample_weights
.
unsqueeze
(
-
1
)
return
loss
.
sum
()
/
valid_positions
@classmethod
def
from_vocab
(
cls
,
vocab
:
data
.
Vocabulary
,
vocab_namespace
:
str
,
head_predictor
:
HeadPredictionModel
,
head_projection_layer
:
base
.
Linear
,
dependency_projection_layer
:
base
.
Linear
):
"""
Creates parser combining model configuration and vocabulary data.
"""
assert
vocab_namespace
in
vocab
.
get_namespaces
()
relation_prediction_layer
=
base
.
Linear
(
in_features
=
head_projection_layer
.
get_output_dim
()
+
dependency_projection_layer
.
get_output_dim
(),
out_features
=
vocab
.
get_vocab_size
(
vocab_namespace
)
)
return
cls
(
head_predictor
=
head_predictor
,
head_projection_layer
=
head_projection_layer
,
dependency_projection_layer
=
dependency_projection_layer
,
relation_prediction_layer
=
relation_prediction_layer
,
root_idx
=
vocab
.
get_token_index
(
"
root
"
,
vocab_namespace
)
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment