Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
combo
Manage
Activity
Members
Labels
Plan
Issues
20
Issue boards
Milestones
Wiki
Redmine
Code
Merge requests
2
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Syntactic Tools
combo
Commits
826e57a7
Commit
826e57a7
authored
4 years ago
by
Mateusz Klimaszewski
Committed by
Mateusz Klimaszewski
4 years ago
Browse files
Options
Downloads
Patches
Plain Diff
Hotfix off by one in enhanced graphs.
parent
e295e563
Branches
Branches containing commit
Tags
Tags containing commit
2 merge requests
!9
Enhanced dependency parsing develop to master
,
!8
Enhanced dependency parsing
This commit is part of merge request
!8
. Comments created here will be created in the context of that merge request.
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
combo/models/graph_parser.py
+3
-3
3 additions, 3 deletions
combo/models/graph_parser.py
combo/predict.py
+7
-2
7 additions, 2 deletions
combo/predict.py
with
10 additions
and
5 deletions
combo/models/graph_parser.py
+
3
−
3
Edit
View file @
826e57a7
...
...
@@ -164,9 +164,9 @@ class GraphDependencyRelationModel(base.Predictor):
heads_true
:
torch
.
Tensor
,
mask
:
torch
.
BoolTensor
,
sample_weights
:
torch
.
Tensor
)
->
torch
.
Tensor
:
true
=
true
[
true
.
long
()
>
0
]
pred
=
pred
[
heads_true
.
long
()
==
1
]
correct_heads_mask
=
heads_true
.
long
()
==
1
true
=
true
[
correct_heads_mask
]
pred
=
pred
[
correct_heads_mask
]
loss
=
F
.
cross_entropy
(
pred
,
true
.
long
())
return
loss
.
sum
()
/
pred
.
size
(
0
)
...
...
This diff is collapsed.
Click to expand it.
combo/predict.py
+
7
−
2
Edit
View file @
826e57a7
...
...
@@ -197,8 +197,13 @@ class SemanticMultitaskPredictor(predictor.Predictor):
raise
NotImplementedError
(
f
"
Unknown field name
{
field_name
}
!
"
)
if
"
enhanced_head
"
in
predictions
and
predictions
[
"
enhanced_head
"
]:
graph
.
sdp_to_dag_deps
(
arc_scores
=
np
.
array
(
predictions
[
"
enhanced_head
"
]),
rel_scores
=
np
.
array
(
predictions
[
"
enhanced_deprel_prob
"
]),
# TODO off-by-one hotfix, refactor
h
=
np
.
array
(
predictions
[
"
enhanced_head
"
])
h
=
np
.
concatenate
((
h
[
-
1
:],
h
[:
-
1
]))
r
=
np
.
array
(
predictions
[
"
enhanced_deprel_prob
"
])
r
=
np
.
concatenate
((
r
[
-
1
:],
r
[:
-
1
]))
graph
.
sdp_to_dag_deps
(
arc_scores
=
h
,
rel_scores
=
r
,
tree_tokens
=
tree_tokens
,
root_idx
=
self
.
vocab
.
get_token_index
(
"
root
"
,
"
deprel_labels
"
),
vocab_index
=
self
.
vocab
.
get_index_to_token_vocabulary
(
"
deprel_labels
"
))
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment