Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
What's new
7
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Open sidebar
Syntactic Tools
combo
Commits
a3774ab6
Commit
a3774ab6
authored
Apr 30, 2021
by
Mateusz Klimaszewski
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Tree and graph merging algorithm.
parent
3123cced
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
92 additions
and
5 deletions
+92
-5
combo/predict.py
combo/predict.py
+16
-5
combo/utils/graph.py
combo/utils/graph.py
+76
-0
No files found.
combo/predict.py
View file @
a3774ab6
...
...
@@ -201,11 +201,22 @@ class COMBO(predictor.Predictor):
h
=
np
.
concatenate
((
h
[
-
1
:],
h
[:
-
1
]))
r
=
np
.
array
(
predictions
[
"enhanced_deprel_prob"
])
r
=
np
.
concatenate
((
r
[
-
1
:],
r
[:
-
1
]))
graph
.
sdp_to_dag_deps
(
arc_scores
=
h
,
rel_scores
=
r
,
tree_tokens
=
tree_tokens
,
root_idx
=
self
.
vocab
.
get_token_index
(
"root"
,
"deprel_labels"
),
vocab_index
=
self
.
vocab
.
get_index_to_token_vocabulary
(
"deprel_labels"
))
graph
.
graph_and_tree_merge
(
tree_arc_scores
=
predictions
[
"head"
],
tree_rel_scores
=
predictions
[
"deprel"
],
graph_arc_scores
=
h
,
graph_rel_scores
=
r
,
idx2label
=
self
.
vocab
.
get_index_to_token_vocabulary
(
"deprel_labels"
),
label2idx
=
self
.
vocab
.
get_token_to_index_vocabulary
(
"deprel_labels"
),
tokens
=
tree_tokens
)
# graph.sdp_to_dag_deps(arc_scores=h,
# rel_scores=r,
# tree_tokens=tree_tokens,
# root_idx=self.vocab.get_token_index("root", "deprel_labels"),
# vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels"))
empty_tokens
=
graph
.
restore_collapse_edges
(
tree_tokens
)
tree
.
tokens
.
extend
(
empty_tokens
)
...
...
combo/utils/graph.py
View file @
a3774ab6
...
...
@@ -3,6 +3,82 @@ from typing import List
import
numpy
as
np
_ACL_REL_CL
=
"acl:relcl"
def
graph_and_tree_merge
(
tree_arc_scores
,
tree_rel_scores
,
graph_arc_scores
,
graph_rel_scores
,
label2idx
,
idx2label
,
tokens
):
graph_arc_scores
=
np
.
copy
(
graph_arc_scores
)
# Exclude self-loops, in-place operation.
np
.
fill_diagonal
(
graph_arc_scores
,
0
)
# Connection to root will be handled by tree.
graph_arc_scores
[:,
0
]
=
False
# The same with labels.
root_idx
=
label2idx
[
"root"
]
graph_rel_scores
[:,
:,
root_idx
]
=
-
float
(
'inf'
)
graph_rel_pred
=
graph_rel_scores
.
argmax
(
-
1
)
# Add tree edges to graph
tree_heads
=
[
0
]
+
tree_arc_scores
graph
=
[[]
for
_
in
range
(
len
(
tree_heads
))]
labeled_graph
=
[[]
for
_
in
range
(
len
(
tree_heads
))]
for
d
,
h
in
enumerate
(
tree_heads
):
if
not
d
:
continue
label
=
idx2label
[
tree_rel_scores
[
d
-
1
]]
if
label
!=
_ACL_REL_CL
:
graph
[
h
].
append
(
d
)
labeled_graph
[
h
].
append
((
d
,
label
))
# Debug only
# Extract graph edges
graph_edges
=
np
.
argwhere
(
graph_arc_scores
)
# Add graph edges which aren't creating a cycle
for
(
d
,
h
)
in
graph_edges
:
if
not
d
or
not
h
or
d
in
graph
[
h
]:
continue
try
:
path
=
next
(
_dfs
(
graph
,
d
,
h
))
except
StopIteration
:
# There is not path from d to h
label
=
idx2label
[
graph_rel_pred
[
d
][
h
]]
if
label
!=
_ACL_REL_CL
:
graph
[
h
].
append
(
d
)
labeled_graph
[
h
].
append
((
d
,
label
))
# Add 'acl:relcl' without checking for cycles.
for
d
,
h
in
enumerate
(
tree_heads
):
if
not
d
:
continue
label
=
idx2label
[
tree_rel_scores
[
d
-
1
]]
if
label
==
_ACL_REL_CL
:
graph
[
h
].
append
(
d
)
labeled_graph
[
h
].
append
((
d
,
label
))
assert
len
(
labeled_graph
[
0
])
==
1
d
=
graph
[
0
][
0
]
graph
[
d
].
append
(
0
)
labeled_graph
[
d
].
append
((
0
,
"root"
))
parse_graph
=
[[]
for
_
in
range
(
len
(
tree_heads
))]
for
h
in
range
(
len
(
tree_heads
)):
for
d
,
label
in
labeled_graph
[
h
]:
parse_graph
[
d
].
append
((
h
,
label
))
parse_graph
[
d
]
=
sorted
(
parse_graph
[
d
])
for
i
,
g
in
enumerate
(
parse_graph
):
heads
=
[
x
[
0
]
for
x
in
g
]
rels
=
[
x
[
1
]
for
x
in
g
]
deps
=
'|'
.
join
(
f
'
{
h
}
:
{
r
}
'
for
h
,
r
in
zip
(
heads
,
rels
))
tokens
[
i
-
1
][
"deps"
]
=
deps
return
def
sdp_to_dag_deps
(
arc_scores
,
rel_scores
,
tree_tokens
:
List
,
root_idx
=
0
,
vocab_index
=
None
)
->
None
:
# adding ROOT
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment