Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
combo
Manage
Activity
Members
Labels
Plan
Issues
20
Issue boards
Milestones
Wiki
Redmine
Code
Merge requests
2
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Syntactic Tools
combo
Commits
847f3a90
Commit
847f3a90
authored
4 years ago
by
Mateusz Klimaszewski
Browse files
Options
Downloads
Patches
Plain Diff
Refactor prediction loop.
parent
888e0f11
Branches
Branches containing commit
Tags
Tags containing commit
2 merge requests
!13
Refactor merge develop to master
,
!12
Refactor
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
combo/predict.py
+51
-49
51 additions, 49 deletions
combo/predict.py
with
51 additions
and
49 deletions
combo/predict.py
+
51
−
49
View file @
847f3a90
...
...
@@ -24,7 +24,7 @@ class COMBO(predictor.Predictor):
model
:
models
.
Model
,
dataset_reader
:
allen_data
.
DatasetReader
,
tokenizer
:
allen_data
.
Tokenizer
=
tokenizers
.
WhitespaceTokenizer
(),
batch_size
:
int
=
32
,
batch_size
:
int
=
1024
,
line_to_conllu
:
bool
=
True
)
->
None
:
super
().
__init__
(
model
,
dataset_reader
)
self
.
batch_size
=
batch_size
...
...
@@ -140,54 +140,56 @@ class COMBO(predictor.Predictor):
tree
=
instance
.
fields
[
"
metadata
"
][
"
input
"
]
field_names
=
instance
.
fields
[
"
metadata
"
][
"
field_names
"
]
tree_tokens
=
[
t
for
t
in
tree
if
isinstance
(
t
[
"
id
"
],
int
)]
for
idx
,
token
in
enumerate
(
tree_tokens
):
for
field_name
in
field_names
:
if
field_name
in
predictions
:
if
field_name
in
[
"
xpostag
"
,
"
upostag
"
,
"
semrel
"
,
"
deprel
"
]:
value
=
self
.
vocab
.
get_token_from_index
(
predictions
[
field_name
][
idx
],
field_name
+
"
_labels
"
)
token
[
field_name
]
=
value
elif
field_name
in
[
"
head
"
]:
token
[
field_name
]
=
int
(
predictions
[
field_name
][
idx
])
elif
field_name
==
"
deps
"
:
# Handled after every other decoding
continue
elif
field_name
in
[
"
feats
"
]:
slices
=
self
.
_model
.
morphological_feat
.
slices
features
=
[]
prediction
=
predictions
[
field_name
][
idx
]
for
(
cat
,
cat_indices
),
pred_idx
in
zip
(
slices
.
items
(),
prediction
):
if
cat
not
in
[
"
__PAD__
"
,
"
_
"
]:
value
=
self
.
vocab
.
get_token_from_index
(
cat_indices
[
pred_idx
],
field_name
+
"
_labels
"
)
# Exclude auxiliary values
if
"
=None
"
not
in
value
:
features
.
append
(
value
)
if
len
(
features
)
==
0
:
field_value
=
"
_
"
else
:
lowercase_features
=
[
f
.
lower
()
for
f
in
features
]
arg_indices
=
sorted
(
range
(
len
(
lowercase_features
)),
key
=
lowercase_features
.
__getitem__
)
field_value
=
"
|
"
.
join
(
np
.
array
(
features
)[
arg_indices
].
tolist
())
token
[
field_name
]
=
field_value
elif
field_name
==
"
lemma
"
:
prediction
=
predictions
[
field_name
][
idx
]
word_chars
=
[]
for
char_idx
in
prediction
[
1
:
-
1
]:
pred_char
=
self
.
vocab
.
get_token_from_index
(
char_idx
,
"
lemma_characters
"
)
if
pred_char
==
"
__END__
"
:
break
elif
pred_char
==
"
__PAD__
"
:
continue
elif
"
_
"
in
pred_char
:
pred_char
=
"
?
"
word_chars
.
append
(
pred_char
)
token
[
field_name
]
=
""
.
join
(
word_chars
)
for
field_name
in
field_names
:
if
field_name
not
in
predictions
:
continue
field_predictions
=
predictions
[
field_name
]
for
idx
,
token
in
enumerate
(
tree_tokens
):
if
field_name
in
{
"
xpostag
"
,
"
upostag
"
,
"
semrel
"
,
"
deprel
"
}:
value
=
self
.
vocab
.
get_token_from_index
(
field_predictions
[
idx
],
field_name
+
"
_labels
"
)
token
[
field_name
]
=
value
elif
field_name
==
"
head
"
:
token
[
field_name
]
=
int
(
field_predictions
[
idx
])
elif
field_name
==
"
deps
"
:
# Handled after every other decoding
continue
elif
field_name
==
"
feats
"
:
slices
=
self
.
_model
.
morphological_feat
.
slices
features
=
[]
prediction
=
field_predictions
[
idx
]
for
(
cat
,
cat_indices
),
pred_idx
in
zip
(
slices
.
items
(),
prediction
):
if
cat
not
in
[
"
__PAD__
"
,
"
_
"
]:
value
=
self
.
vocab
.
get_token_from_index
(
cat_indices
[
pred_idx
],
field_name
+
"
_labels
"
)
# Exclude auxiliary values
if
"
=None
"
not
in
value
:
features
.
append
(
value
)
if
len
(
features
)
==
0
:
field_value
=
"
_
"
else
:
raise
NotImplementedError
(
f
"
Unknown field name
{
field_name
}
!
"
)
lowercase_features
=
[
f
.
lower
()
for
f
in
features
]
arg_indices
=
sorted
(
range
(
len
(
lowercase_features
)),
key
=
lowercase_features
.
__getitem__
)
field_value
=
"
|
"
.
join
(
np
.
array
(
features
)[
arg_indices
].
tolist
())
token
[
field_name
]
=
field_value
elif
field_name
==
"
lemma
"
:
prediction
=
field_predictions
[
idx
]
word_chars
=
[]
for
char_idx
in
prediction
[
1
:
-
1
]:
pred_char
=
self
.
vocab
.
get_token_from_index
(
char_idx
,
"
lemma_characters
"
)
if
pred_char
==
"
__END__
"
:
break
elif
pred_char
==
"
__PAD__
"
:
continue
elif
"
_
"
in
pred_char
:
pred_char
=
"
?
"
word_chars
.
append
(
pred_char
)
token
[
field_name
]
=
""
.
join
(
word_chars
)
else
:
raise
NotImplementedError
(
f
"
Unknown field name
{
field_name
}
!
"
)
if
"
enhanced_head
"
in
predictions
and
predictions
[
"
enhanced_head
"
]:
# TODO off-by-one hotfix, refactor
...
...
@@ -212,7 +214,7 @@ class COMBO(predictor.Predictor):
@classmethod
def
from_pretrained
(
cls
,
path
:
str
,
tokenizer
=
tokenizers
.
SpacyTokenizer
(),
batch_size
:
int
=
32
,
batch_size
:
int
=
1024
,
cuda_device
:
int
=
-
1
):
util
.
import_module_and_submodules
(
"
combo.commands
"
)
util
.
import_module_and_submodules
(
"
combo.models
"
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment