Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
combo
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Redmine
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Syntactic Tools
combo
Commits
5e79a44e
Commit
5e79a44e
authored
Mar 28, 2023
by
Maja Jabłońska
Browse files
Options
Downloads
Patches
Plain Diff
Add metrics and metrics_test
parent
89790343
Branches
Branches containing commit
Tags
Tags containing commit
1 merge request
!46
Merge COMBO 3.0 into master
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
combo/utils/metrics.py
+376
-5
376 additions, 5 deletions
combo/utils/metrics.py
tests/utils/test_metrics.py
+211
-0
211 additions, 0 deletions
tests/utils/test_metrics.py
with
587 additions
and
5 deletions
combo/utils/metrics.py
+
376
−
5
View file @
5e79a44e
from
typing
import
Optional
,
List
,
Dict
,
Iterable
import
torch
from
overrides
import
overrides
"""
Class Metric adapted from AllenNLP
https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/training/metrics/metric.py
"""
class
Metric
:
class
Metric
:
pass
"""
A very general abstract class representing a metric which can be
accumulated.
"""
supports_distributed
=
False
def
__call__
(
self
,
predictions
:
torch
.
Tensor
,
gold_labels
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
BoolTensor
]
):
"""
# Parameters
predictions : `torch.Tensor`, required.
A tensor of predictions.
gold_labels : `torch.Tensor`, required.
A tensor corresponding to some gold label to evaluate against.
mask : `torch.BoolTensor`, optional (default = `None`).
A mask can be passed, in order to deal with metrics which are
computed over potentially padded elements, such as sequence labels.
"""
raise
NotImplementedError
def
get_metric
(
self
,
reset
:
bool
):
"""
Compute and return the metric. Optionally also call `self.reset`.
"""
raise
NotImplementedError
def
reset
(
self
)
->
None
:
"""
Reset any accumulators or internal state.
"""
raise
NotImplementedError
@staticmethod
def
detach_tensors
(
*
tensors
:
torch
.
Tensor
)
->
Iterable
[
torch
.
Tensor
]:
"""
If you actually passed gradient-tracking Tensors to a Metric, there will be
a huge memory leak, because it will prevent garbage collection for the computation
graph. This method ensures the tensors are detached.
"""
# Check if it's actually a tensor in case something else was passed.
return
(
x
.
detach
()
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
for
x
in
tensors
)
"""
Adapted from COMBO
Author: Mateusz Klimaszewski
"""
class
LemmaAccuracy
(
Metric
):
class
LemmaAccuracy
(
Metric
):
pass
def
__init__
(
self
):
self
.
_correct_count
=
0.0
self
.
_total_count
=
0.0
self
.
correct_indices
=
torch
.
ones
([])
@overrides
def
__call__
(
self
,
predictions
:
torch
.
Tensor
,
gold_labels
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
BoolTensor
]
=
None
):
if
gold_labels
is
None
:
return
predictions
,
gold_labels
,
mask
=
self
.
detach_tensors
(
predictions
,
gold_labels
,
mask
)
# Some sanity checks.
if
gold_labels
.
size
()
!=
predictions
.
size
():
raise
ValueError
(
f
"
gold_labels must have shape == predictions.size() but
"
f
"
found tensor of shape:
{
gold_labels
.
size
()
}
"
)
if
mask
is
not
None
and
mask
.
size
()
not
in
[
predictions
.
size
()[:
-
1
],
predictions
.
size
()]:
raise
ValueError
(
f
"
mask must have shape in one of [predictions.size()[:-1], predictions.size()] but
"
f
"
found tensor of shape:
{
mask
.
size
()
}
"
)
if
mask
is
None
:
mask
=
predictions
.
new_ones
(
predictions
.
size
()[:
-
1
]).
bool
()
if
mask
.
dim
()
<
predictions
.
dim
():
mask
=
mask
.
unsqueeze
(
-
1
)
padding_mask
=
gold_labels
.
gt
(
0
)
correct
=
predictions
.
eq
(
gold_labels
)
*
padding_mask
correct
=
(
correct
.
int
().
sum
(
-
1
)
==
padding_mask
.
int
().
sum
(
-
1
))
*
mask
.
squeeze
(
-
1
)
correct
=
correct
.
float
()
self
.
correct_indices
=
correct
.
flatten
().
bool
()
self
.
_correct_count
+=
correct
.
sum
()
self
.
_total_count
+=
mask
.
sum
()
@overrides
def
get_metric
(
self
,
reset
:
bool
)
->
float
:
if
self
.
_total_count
>
0
:
accuracy
=
float
(
self
.
_correct_count
)
/
float
(
self
.
_total_count
)
else
:
accuracy
=
0.0
if
reset
:
self
.
reset
()
return
accuracy
@overrides
def
reset
(
self
)
->
None
:
self
.
_correct_count
=
0.0
self
.
_total_count
=
0.0
self
.
correct_indices
=
torch
.
ones
([])
class
SequenceBoolAccuracy
(
Metric
):
class
SequenceBoolAccuracy
(
Metric
):
pass
"""
BoolAccuracy implementation to handle sequences.
"""
def
__init__
(
self
,
prod_last_dim
:
bool
=
False
):
self
.
_correct_count
=
0.0
self
.
_total_count
=
0.0
self
.
prod_last_dim
=
prod_last_dim
self
.
correct_indices
=
torch
.
ones
([])
@overrides
def
__call__
(
self
,
predictions
:
torch
.
Tensor
,
gold_labels
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
BoolTensor
]
=
None
):
if
gold_labels
is
None
:
return
predictions
,
gold_labels
,
mask
=
self
.
detach_tensors
(
predictions
,
gold_labels
,
mask
)
# Some sanity checks.
if
gold_labels
.
size
()
!=
predictions
.
size
():
raise
ValueError
(
f
"
gold_labels must have shape == predictions.size() but
"
f
"
found tensor of shape:
{
gold_labels
.
size
()
}
"
)
if
mask
is
not
None
and
mask
.
size
()
not
in
[
predictions
.
size
()[:
-
1
],
predictions
.
size
()]:
raise
ValueError
(
f
"
mask must have shape in one of [predictions.size()[:-1], predictions.size()] but
"
f
"
found tensor of shape:
{
mask
.
size
()
}
"
)
if
mask
is
None
:
mask
=
predictions
.
new_ones
(
predictions
.
size
()[:
-
1
]).
bool
()
if
mask
.
dim
()
<
predictions
.
dim
():
mask
=
mask
.
unsqueeze
(
-
1
)
correct
=
predictions
.
eq
(
gold_labels
)
*
mask
if
self
.
prod_last_dim
:
correct
=
correct
.
prod
(
-
1
).
unsqueeze
(
-
1
)
correct
=
correct
.
float
()
self
.
correct_indices
=
correct
.
flatten
().
bool
()
self
.
_correct_count
+=
correct
.
sum
()
self
.
_total_count
+=
mask
.
sum
()
@overrides
def
get_metric
(
self
,
reset
:
bool
)
->
float
:
if
self
.
_total_count
>
0
:
accuracy
=
float
(
self
.
_correct_count
)
/
float
(
self
.
_total_count
)
else
:
accuracy
=
0.0
if
reset
:
self
.
reset
()
return
accuracy
@overrides
def
reset
(
self
)
->
None
:
self
.
_correct_count
=
0.0
self
.
_total_count
=
0.0
self
.
correct_indices
=
torch
.
ones
([])
class
AttachmentScores
(
Metric
):
class
AttachmentScores
(
Metric
):
pass
"""
Computes labeled and unlabeled attachment scores for a
dependency parse, as well as sentence level exact match
for both labeled and unlabeled trees. Note that the input
to this metric is the sampled predictions, not the distribution
itself.
# Parameters
ignore_classes : `List[int]`, optional (default = None)
A list of label ids to ignore when computing metrics.
"""
def
__init__
(
self
,
ignore_classes
:
List
[
int
]
=
None
)
->
None
:
self
.
_labeled_correct
=
0.0
self
.
_unlabeled_correct
=
0.0
self
.
_exact_labeled_correct
=
0.0
self
.
_exact_unlabeled_correct
=
0.0
self
.
_total_words
=
0.0
self
.
_total_sentences
=
0.0
self
.
correct_indices
=
torch
.
ones
([])
self
.
_ignore_classes
:
List
[
int
]
=
ignore_classes
or
[]
def
__call__
(
# type: ignore
self
,
predicted_indices
:
torch
.
Tensor
,
predicted_labels
:
torch
.
Tensor
,
gold_indices
:
torch
.
Tensor
,
gold_labels
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
BoolTensor
]
=
None
,
):
"""
# Parameters
predicted_indices : `torch.Tensor`, required.
A tensor of head index predictions of shape (batch_size, timesteps).
predicted_labels : `torch.Tensor`, required.
A tensor of arc label predictions of shape (batch_size, timesteps).
gold_indices : `torch.Tensor`, required.
A tensor of the same shape as `predicted_indices`.
gold_labels : `torch.Tensor`, required.
A tensor of the same shape as `predicted_labels`.
mask : `torch.BoolTensor`, optional (default = None).
A tensor of the same shape as `predicted_indices`.
"""
if
gold_labels
is
None
or
gold_indices
is
None
:
return
detached
=
self
.
detach_tensors
(
predicted_indices
,
predicted_labels
,
gold_indices
,
gold_labels
,
mask
)
predicted_indices
,
predicted_labels
,
gold_indices
,
gold_labels
,
mask
=
detached
if
mask
is
None
:
mask
=
torch
.
ones_like
(
predicted_indices
).
bool
()
predicted_indices
=
predicted_indices
.
long
()
predicted_labels
=
predicted_labels
.
long
()
gold_indices
=
gold_indices
.
long
()
gold_labels
=
gold_labels
.
long
()
# Multiply by a mask denoting locations of
# gold labels which we should ignore.
for
label
in
self
.
_ignore_classes
:
label_mask
=
gold_labels
.
eq
(
label
)
mask
=
mask
&
~
label_mask
correct_indices
=
predicted_indices
.
eq
(
gold_indices
).
long
()
*
mask
unlabeled_exact_match
=
(
correct_indices
+
~
mask
).
prod
(
dim
=-
1
)
if
len
(
correct_indices
.
size
())
>
2
:
unlabeled_exact_match
=
unlabeled_exact_match
.
prod
(
dim
=-
1
)
correct_labels
=
predicted_labels
.
eq
(
gold_labels
).
long
()
*
mask
correct_labels_and_indices
=
correct_indices
*
correct_labels
self
.
correct_indices
=
correct_labels_and_indices
.
flatten
()
labeled_exact_match
=
(
correct_labels_and_indices
+
~
mask
).
prod
(
dim
=-
1
)
if
len
(
correct_indices
.
size
())
>
2
:
labeled_exact_match
=
labeled_exact_match
.
prod
(
dim
=-
1
)
self
.
_unlabeled_correct
+=
correct_indices
.
sum
()
self
.
_exact_unlabeled_correct
+=
unlabeled_exact_match
.
sum
()
self
.
_labeled_correct
+=
correct_labels_and_indices
.
sum
()
self
.
_exact_labeled_correct
+=
labeled_exact_match
.
sum
()
self
.
_total_sentences
+=
correct_indices
.
size
(
0
)
self
.
_total_words
+=
correct_indices
.
numel
()
-
(
~
mask
).
sum
()
def
get_metric
(
self
,
reset
:
bool
=
False
):
"""
# Returns
The accumulated metrics as a dictionary.
"""
unlabeled_attachment_score
=
0.0
labeled_attachment_score
=
0.0
unlabeled_exact_match
=
0.0
labeled_exact_match
=
0.0
if
self
.
_total_words
>
0.0
:
unlabeled_attachment_score
=
float
(
self
.
_unlabeled_correct
)
/
float
(
self
.
_total_words
)
labeled_attachment_score
=
float
(
self
.
_labeled_correct
)
/
float
(
self
.
_total_words
)
if
self
.
_total_sentences
>
0
:
unlabeled_exact_match
=
float
(
self
.
_exact_unlabeled_correct
)
/
float
(
self
.
_total_sentences
)
labeled_exact_match
=
float
(
self
.
_exact_labeled_correct
)
/
float
(
self
.
_total_sentences
)
if
reset
:
self
.
reset
()
return
{
"
UAS
"
:
unlabeled_attachment_score
,
"
LAS
"
:
labeled_attachment_score
,
"
UEM
"
:
unlabeled_exact_match
,
"
LEM
"
:
labeled_exact_match
,
}
@overrides
def
reset
(
self
):
self
.
_labeled_correct
=
0.0
self
.
_unlabeled_correct
=
0.0
self
.
_exact_labeled_correct
=
0.0
self
.
_exact_unlabeled_correct
=
0.0
self
.
_total_words
=
0.0
self
.
_total_sentences
=
0.0
self
.
correct_indices
=
torch
.
ones
([])
class
SemanticMetrics
(
Metric
):
class
SemanticMetrics
(
Metric
):
pass
"""
Groups metrics for all predictions.
"""
def
__init__
(
self
)
->
None
:
self
.
upos_score
=
SequenceBoolAccuracy
()
self
.
xpos_score
=
SequenceBoolAccuracy
()
self
.
semrel_score
=
SequenceBoolAccuracy
()
self
.
feats_score
=
SequenceBoolAccuracy
(
prod_last_dim
=
True
)
self
.
lemma_score
=
LemmaAccuracy
()
self
.
attachment_scores
=
AttachmentScores
()
# Ignore PADDING and OOV
self
.
enhanced_attachment_scores
=
AttachmentScores
(
ignore_classes
=
[
0
,
1
])
self
.
em_score
=
0.0
def
__call__
(
# type: ignore
self
,
predictions
:
Dict
[
str
,
torch
.
Tensor
],
gold_labels
:
Dict
[
str
,
torch
.
Tensor
],
mask
:
torch
.
BoolTensor
):
self
.
upos_score
(
predictions
[
"
upostag
"
],
gold_labels
[
"
upostag
"
],
mask
)
self
.
xpos_score
(
predictions
[
"
xpostag
"
],
gold_labels
[
"
xpostag
"
],
mask
)
self
.
semrel_score
(
predictions
[
"
semrel
"
],
gold_labels
[
"
semrel
"
],
mask
)
self
.
feats_score
(
predictions
[
"
feats
"
],
gold_labels
[
"
feats
"
],
mask
)
self
.
lemma_score
(
predictions
[
"
lemma
"
],
gold_labels
[
"
lemma
"
],
mask
)
self
.
attachment_scores
(
predictions
[
"
head
"
],
predictions
[
"
deprel
"
],
gold_labels
[
"
head
"
],
gold_labels
[
"
deprel
"
],
mask
)
self
.
enhanced_attachment_scores
(
predictions
[
"
enhanced_head
"
],
predictions
[
"
enhanced_deprel
"
],
gold_labels
[
"
enhanced_head
"
],
gold_labels
[
"
enhanced_deprel
"
],
mask
=
None
)
enhanced_indices
=
(
self
.
enhanced_attachment_scores
.
correct_indices
.
reshape
(
mask
.
size
(
0
),
mask
.
size
(
1
)
+
1
,
-
1
)[:,
1
:,
1
:].
sum
(
-
1
).
reshape
(
-
1
).
bool
()
if
len
(
self
.
enhanced_attachment_scores
.
correct_indices
.
size
())
>
0
else
self
.
enhanced_attachment_scores
.
correct_indices
)
total
=
mask
.
sum
()
correct_indices
=
(
self
.
upos_score
.
correct_indices
*
self
.
xpos_score
.
correct_indices
*
self
.
semrel_score
.
correct_indices
*
self
.
feats_score
.
correct_indices
*
self
.
lemma_score
.
correct_indices
*
self
.
attachment_scores
.
correct_indices
*
enhanced_indices
)
*
mask
.
flatten
()
total
,
correct_indices
=
self
.
detach_tensors
(
total
,
correct_indices
.
float
().
sum
())
self
.
em_score
=
(
correct_indices
/
total
).
item
()
def
get_metric
(
self
,
reset
:
bool
)
->
Dict
[
str
,
float
]:
metrics_dict
=
{
"
UPOS_ACC
"
:
self
.
upos_score
.
get_metric
(
reset
),
"
XPOS_ACC
"
:
self
.
xpos_score
.
get_metric
(
reset
),
"
SEMREL_ACC
"
:
self
.
semrel_score
.
get_metric
(
reset
),
"
LEMMA_ACC
"
:
self
.
lemma_score
.
get_metric
(
reset
),
"
FEATS_ACC
"
:
self
.
feats_score
.
get_metric
(
reset
),
"
EM
"
:
self
.
em_score
}
metrics_dict
.
update
(
self
.
attachment_scores
.
get_metric
(
reset
))
enhanced_metrics
=
{
f
"
E
{
k
}
"
:
v
for
k
,
v
in
self
.
enhanced_attachment_scores
.
get_metric
(
reset
).
items
()}
metrics_dict
.
update
(
enhanced_metrics
)
return
metrics_dict
def
reset
(
self
)
->
None
:
self
.
upos_score
.
reset
()
self
.
xpos_score
.
reset
()
self
.
semrel_score
.
reset
()
self
.
lemma_score
.
reset
()
self
.
feats_score
.
reset
()
self
.
attachment_scores
.
reset
()
self
.
enhanced_attachment_scores
.
reset
()
self
.
em_score
=
0.0
This diff is collapsed.
Click to expand it.
tests/utils/test_metrics.py
0 → 100644
+
211
−
0
View file @
5e79a44e
"""
Metrics tests.
"""
import
unittest
import
torch
from
combo.utils
import
metrics
class
SemanticMetricsTest
(
unittest
.
TestCase
):
def
setUp
(
self
)
->
None
:
self
.
mask
:
torch
.
BoolTensor
=
torch
.
tensor
([
[
True
,
True
,
True
,
True
],
[
True
,
True
,
True
,
False
],
[
True
,
True
,
True
,
False
],
])
pred
=
torch
.
tensor
([
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
])
pred_seq
=
pred
.
reshape
(
3
,
4
,
1
)
gold
=
pred
.
clone
()
gold_seq
=
pred_seq
.
clone
()
self
.
upostag
,
self
.
upostag_l
=
((
"
upostag
"
,
x
)
for
x
in
[
pred
,
gold
])
self
.
xpostag
,
self
.
xpostag_l
=
((
"
xpostag
"
,
x
)
for
x
in
[
pred
,
gold
])
self
.
semrel
,
self
.
semrel_l
=
((
"
semrel
"
,
x
)
for
x
in
[
pred
,
gold
])
self
.
head
,
self
.
head_l
=
((
"
head
"
,
x
)
for
x
in
[
pred
,
gold
])
self
.
deprel
,
self
.
deprel_l
=
((
"
deprel
"
,
x
)
for
x
in
[
pred
,
gold
])
# TODO(mklimasz) Set up an example with size 3x5x5
self
.
enhanced_head
,
self
.
enhanced_head_l
=
((
"
enhanced_head
"
,
x
)
for
x
in
[
None
,
None
])
self
.
enhanced_deprel
,
self
.
enhanced_deprel_l
=
((
"
enhanced_deprel
"
,
x
)
for
x
in
[
None
,
None
])
self
.
feats
,
self
.
feats_l
=
((
"
feats
"
,
x
)
for
x
in
[
pred_seq
,
gold_seq
])
self
.
lemma
,
self
.
lemma_l
=
((
"
lemma
"
,
x
)
for
x
in
[
pred_seq
,
gold_seq
])
self
.
predictions
=
dict
(
[
self
.
upostag
,
self
.
xpostag
,
self
.
semrel
,
self
.
feats
,
self
.
lemma
,
self
.
head
,
self
.
deprel
,
self
.
enhanced_head
,
self
.
enhanced_deprel
])
self
.
gold_labels
=
dict
([
self
.
upostag_l
,
self
.
xpostag_l
,
self
.
semrel_l
,
self
.
feats_l
,
self
.
lemma_l
,
self
.
head_l
,
self
.
deprel_l
,
self
.
enhanced_head_l
,
self
.
enhanced_deprel_l
])
self
.
eps
=
1e-6
def
test_every_prediction_correct
(
self
):
# given
metric
=
metrics
.
SemanticMetrics
()
# when
metric
(
self
.
predictions
,
self
.
gold_labels
,
self
.
mask
)
# then
self
.
assertEqual
(
1.0
,
metric
.
em_score
)
def
test_missing_predictions_for_one_target
(
self
):
# given
metric
=
metrics
.
SemanticMetrics
()
self
.
predictions
[
"
upostag
"
]
=
None
self
.
gold_labels
[
"
upostag
"
]
=
None
# when
metric
(
self
.
predictions
,
self
.
gold_labels
,
self
.
mask
)
# then
self
.
assertEqual
(
1.0
,
metric
.
em_score
)
def
test_missing_predictions_for_two_targets
(
self
):
# given
metric
=
metrics
.
SemanticMetrics
()
self
.
predictions
[
"
upostag
"
]
=
None
self
.
gold_labels
[
"
upostag
"
]
=
None
self
.
predictions
[
"
lemma
"
]
=
None
self
.
gold_labels
[
"
lemma
"
]
=
None
# when
metric
(
self
.
predictions
,
self
.
gold_labels
,
self
.
mask
)
# then
self
.
assertEqual
(
1.0
,
metric
.
em_score
)
def
test_one_classification_in_one_target_is_wrong
(
self
):
# given
metric
=
metrics
.
SemanticMetrics
()
self
.
predictions
[
"
upostag
"
][
0
][
0
]
=
100
# when
metric
(
self
.
predictions
,
self
.
gold_labels
,
self
.
mask
)
# then
self
.
assertAlmostEqual
(
0.9
,
metric
.
em_score
,
delta
=
self
.
eps
)
def
test_classification_errors_and_target_without_predictions
(
self
):
# given
metric
=
metrics
.
SemanticMetrics
()
self
.
predictions
[
"
feats
"
]
=
None
self
.
gold_labels
[
"
feats
"
]
=
None
self
.
predictions
[
"
upostag
"
][
0
][
0
]
=
100
self
.
predictions
[
"
upostag
"
][
2
][
0
]
=
100
# should be ignored due to masking
self
.
predictions
[
"
upostag
"
][
1
][
3
]
=
100
# when
metric
(
self
.
predictions
,
self
.
gold_labels
,
self
.
mask
)
# then
self
.
assertAlmostEqual
(
0.8
,
metric
.
em_score
,
delta
=
self
.
eps
)
class
SequenceBoolAccuracyTest
(
unittest
.
TestCase
):
def
setUp
(
self
)
->
None
:
self
.
mask
:
torch
.
BoolTensor
=
torch
.
tensor
([
[
True
,
True
,
True
,
True
],
[
True
,
True
,
True
,
False
],
[
True
,
True
,
True
,
False
],
])
def
test_regular_classification_accuracy
(
self
):
# given
metric
=
metrics
.
SequenceBoolAccuracy
()
predictions
=
torch
.
tensor
([
[
1
,
1
,
0
,
8
],
[
1
,
2
,
3
,
4
],
[
9
,
4
,
3
,
9
],
])
gold_labels
=
torch
.
tensor
([
[
11
,
1
,
0
,
8
],
[
14
,
2
,
3
,
14
],
[
9
,
4
,
13
,
9
],
])
# when
metric
(
predictions
,
gold_labels
,
self
.
mask
)
# then
self
.
assertEqual
(
metric
.
_correct_count
.
item
(),
7
)
self
.
assertEqual
(
metric
.
_total_count
.
item
(),
10
)
def
test_feats_classification_accuracy
(
self
):
# given
metric
=
metrics
.
SequenceBoolAccuracy
(
prod_last_dim
=
True
)
# batch_size, sequence_length, classes
predictions
=
torch
.
tensor
([
[[
1
,
4
],
[
0
,
2
],
[
0
,
2
],
[
0
,
3
]],
[[
1
,
4
],
[
0
,
2
],
[
0
,
2
],
[
0
,
3
]],
[[
1
,
4
],
[
0
,
2
],
[
0
,
2
],
[
0
,
3
]],
])
gold_labels
=
torch
.
tensor
([
[[
1
,
14
],
[
0
,
2
],
[
0
,
2
],
[
0
,
3
]],
[[
11
,
4
],
[
0
,
2
],
[
0
,
2
],
[
10
,
3
]],
[[
1
,
4
],
[
0
,
2
],
[
10
,
12
],
[
0
,
3
]],
])
# when
metric
(
predictions
,
gold_labels
,
self
.
mask
)
# then
self
.
assertEqual
(
metric
.
_correct_count
.
item
(),
7
)
self
.
assertEqual
(
metric
.
_total_count
.
item
(),
10
)
class
LemmaAccuracyTest
(
unittest
.
TestCase
):
def
setUp
(
self
)
->
None
:
self
.
mask
:
torch
.
BoolTensor
=
torch
.
tensor
([
[
True
,
True
,
True
,
True
],
[
True
,
True
,
True
,
False
],
])
def
test_prediction_has_error_in_not_padded_place
(
self
):
# given
metric
=
metrics
.
LemmaAccuracy
()
predictions
=
torch
.
tensor
([
[[
1
,
1
,
1
],
[
1
,
1
,
1
],
[
2
,
2
,
0
],
[
1
,
1
,
4
],
],
[[
1
,
1
,
0
],
[
1
,
1000
,
0
],
[
1
,
1
,
0
],
[
1
,
1
,
0
],
],
])
gold_labels
=
torch
.
tensor
([
[[
1
,
1
,
1
],
[
1
,
1
,
1
],
[
2
,
2
,
0
],
[
1
,
1
,
4
],
],
[[
1
,
1
,
0
],
[
1
,
1
,
0
],
[
1
,
1
,
0
],
[
1
,
1
,
0
],
],
])
expected_correct_count
=
6
expected_total_count
=
7
expected_correct_indices
=
torch
.
tensor
([
1
,
1
,
1
,
1
,
1
,
0
,
1
,
0
])
# when
metric
(
predictions
,
gold_labels
,
self
.
mask
)
# then
self
.
assertEqual
(
metric
.
_correct_count
.
item
(),
expected_correct_count
)
self
.
assertEqual
(
metric
.
_total_count
.
item
(),
expected_total_count
)
self
.
assertTrue
(
torch
.
all
(
expected_correct_indices
.
eq
(
metric
.
correct_indices
)))
def
test_prediction_wrong_prediction_in_padding_should_be_ignored
(
self
):
# given
metric
=
metrics
.
LemmaAccuracy
()
predictions
=
torch
.
tensor
([
[[
1
,
1
,
1
],
[
1
,
1
,
1
],
[
2
,
2
,
0
],
[
1
,
1
,
4
],
],
[[
1
,
1
,
1000
],
[
1
,
1
,
0
],
[
1
,
1
,
0
],
[
1
,
1
,
0
],
],
])
gold_labels
=
torch
.
tensor
([
[[
1
,
1
,
1
],
[
1
,
1
,
1
],
[
2
,
2
,
0
],
[
1
,
1
,
4
],
],
[[
1
,
1
,
0
],
[
1
,
1
,
0
],
[
1
,
1
,
0
],
[
1
,
1
,
0
],
],
])
expected_correct_count
=
7
expected_total_count
=
7
expected_correct_indices
=
torch
.
tensor
([
1
,
1
,
1
,
1
,
1
,
1
,
1
,
0
])
# when
metric
(
predictions
,
gold_labels
,
self
.
mask
)
# then
self
.
assertEqual
(
expected_correct_count
,
metric
.
_correct_count
.
item
())
self
.
assertEqual
(
expected_total_count
,
metric
.
_total_count
.
item
())
self
.
assertTrue
(
torch
.
all
(
expected_correct_indices
.
eq
(
metric
.
correct_indices
)))
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment