Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
combo
Manage
Activity
Members
Labels
Plan
Issues
20
Issue boards
Milestones
Wiki
Redmine
Code
Merge requests
2
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Syntactic Tools
combo
Commits
038c9705
Commit
038c9705
authored
4 years ago
by
Mateusz Klimaszewski
Browse files
Options
Downloads
Patches
Plain Diff
Fix batch predictions for DEPS.
parent
ee349a12
No related merge requests found
Pipeline
#2955
passed with stage
in 4 minutes and 58 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
combo/predict.py
+5
-4
5 additions, 4 deletions
combo/predict.py
scripts/predict_iwpt21.py
+3
-0
3 additions, 0 deletions
scripts/predict_iwpt21.py
with
8 additions
and
4 deletions
combo/predict.py
+
5
−
4
View file @
038c9705
...
@@ -194,14 +194,15 @@ class COMBO(predictor.Predictor):
...
@@ -194,14 +194,15 @@ class COMBO(predictor.Predictor):
if
"
enhanced_head
"
in
predictions
and
predictions
[
"
enhanced_head
"
]:
if
"
enhanced_head
"
in
predictions
and
predictions
[
"
enhanced_head
"
]:
# TODO off-by-one hotfix, refactor
# TODO off-by-one hotfix, refactor
h
=
np
.
array
(
predictions
[
"
enhanced_head
"
])
sentence_length
=
len
(
tree_tokens
)
h
=
np
.
array
(
predictions
[
"
enhanced_head
"
])[:
sentence_length
,
:
sentence_length
]
h
=
np
.
concatenate
((
h
[
-
1
:],
h
[:
-
1
]))
h
=
np
.
concatenate
((
h
[
-
1
:],
h
[:
-
1
]))
r
=
np
.
array
(
predictions
[
"
enhanced_deprel_prob
"
])
r
=
np
.
array
(
predictions
[
"
enhanced_deprel_prob
"
])
[:
sentence_length
,
:
sentence_length
,
:]
r
=
np
.
concatenate
((
r
[
-
1
:],
r
[:
-
1
]))
r
=
np
.
concatenate
((
r
[
-
1
:],
r
[:
-
1
]))
graph
.
graph_and_tree_merge
(
graph
.
graph_and_tree_merge
(
tree_arc_scores
=
predictions
[
"
head
"
],
tree_arc_scores
=
predictions
[
"
head
"
]
[:
sentence_length
]
,
tree_rel_scores
=
predictions
[
"
deprel
"
],
tree_rel_scores
=
predictions
[
"
deprel
"
]
[:
sentence_length
]
,
graph_arc_scores
=
h
,
graph_arc_scores
=
h
,
graph_rel_scores
=
r
,
graph_rel_scores
=
r
,
idx2label
=
self
.
vocab
.
get_index_to_token_vocabulary
(
"
deprel_labels
"
),
idx2label
=
self
.
vocab
.
get_index_to_token_vocabulary
(
"
deprel_labels
"
),
...
...
This diff is collapsed.
Click to expand it.
scripts/predict_iwpt21.py
+
3
−
0
View file @
038c9705
...
@@ -36,6 +36,8 @@ flags.DEFINE_integer(name="cuda_device", default=-1,
...
@@ -36,6 +36,8 @@ flags.DEFINE_integer(name="cuda_device", default=-1,
help
=
"
Cuda device id (-1 for cpu).
"
)
help
=
"
Cuda device id (-1 for cpu).
"
)
flags
.
DEFINE_boolean
(
name
=
"
expect_prefix
"
,
default
=
True
,
flags
.
DEFINE_boolean
(
name
=
"
expect_prefix
"
,
default
=
True
,
help
=
"
Whether to expect allennlp prefix.
"
)
help
=
"
Whether to expect allennlp prefix.
"
)
flags
.
DEFINE_integer
(
name
=
"
batch_size
"
,
default
=
32
,
help
=
"
Batch size.
"
)
def
run
(
_
):
def
run
(
_
):
...
@@ -68,6 +70,7 @@ def run(_):
...
@@ -68,6 +70,7 @@ def run(_):
--input_file
{
test_file
}
--input_file
{
test_file
}
--output_file
{
output_pred
}
--output_file
{
output_pred
}
--cuda_device
{
FLAGS
.
cuda_device
}
--cuda_device
{
FLAGS
.
cuda_device
}
--batch_size
{
FLAGS
.
batch_size
}
--silent
--silent
"""
"""
utils
.
execute_command
(
command
)
utils
.
execute_command
(
command
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment