Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
combo
Manage
Activity
Members
Labels
Plan
Issues
20
Issue boards
Milestones
Wiki
Redmine
Code
Merge requests
2
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Syntactic Tools
combo
Commits
5e79a44e
There was an error fetching the commit references. Please try again later.
Commit
5e79a44e
authored
2 years ago
by
Maja Jabłońska
Browse files
Options
Downloads
Patches
Plain Diff
Add metrics and metrics_test
parent
89790343
1 merge request
!46
Merge COMBO 3.0 into master
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
combo/utils/metrics.py
+376
-5
376 additions, 5 deletions
combo/utils/metrics.py
tests/utils/test_metrics.py
+211
-0
211 additions, 0 deletions
tests/utils/test_metrics.py
with
587 additions
and
5 deletions
combo/utils/metrics.py
+
376
−
5
View file @
5e79a44e
from
typing
import
Optional
,
List
,
Dict
,
Iterable
import
torch
from
overrides
import
overrides
"""
Class Metric adapted from AllenNLP
https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/training/metrics/metric.py
"""
class
Metric
:
class
Metric
:
pass
"""
A very general abstract class representing a metric which can be
accumulated.
"""
supports_distributed
=
False
def
__call__
(
self
,
predictions
:
torch
.
Tensor
,
gold_labels
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
BoolTensor
]
):
"""
# Parameters
predictions : `torch.Tensor`, required.
A tensor of predictions.
gold_labels : `torch.Tensor`, required.
A tensor corresponding to some gold label to evaluate against.
mask : `torch.BoolTensor`, optional (default = `None`).
A mask can be passed, in order to deal with metrics which are
computed over potentially padded elements, such as sequence labels.
"""
raise
NotImplementedError
def
get_metric
(
self
,
reset
:
bool
):
"""
Compute and return the metric. Optionally also call `self.reset`.
"""
raise
NotImplementedError
def
reset
(
self
)
->
None
:
"""
Reset any accumulators or internal state.
"""
raise
NotImplementedError
@staticmethod
def
detach_tensors
(
*
tensors
:
torch
.
Tensor
)
->
Iterable
[
torch
.
Tensor
]:
"""
If you actually passed gradient-tracking Tensors to a Metric, there will be
a huge memory leak, because it will prevent garbage collection for the computation
graph. This method ensures the tensors are detached.
"""
# Check if it's actually a tensor in case something else was passed.
return
(
x
.
detach
()
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
for
x
in
tensors
)
"""
Adapted from COMBO
Author: Mateusz Klimaszewski
"""
class
LemmaAccuracy
(
Metric
):
class
LemmaAccuracy
(
Metric
):
pass
def
__init__
(
self
):
self
.
_correct_count
=
0.0
self
.
_total_count
=
0.0
self
.
correct_indices
=
torch
.
ones
([])
@overrides
def
__call__
(
self
,
predictions
:
torch
.
Tensor
,
gold_labels
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
BoolTensor
]
=
None
):
if
gold_labels
is
None
:
return
predictions
,
gold_labels
,
mask
=
self
.
detach_tensors
(
predictions
,
gold_labels
,
mask
)
# Some sanity checks.
if
gold_labels
.
size
()
!=
predictions
.
size
():
raise
ValueError
(
f
"
gold_labels must have shape == predictions.size() but
"
f
"
found tensor of shape:
{
gold_labels
.
size
()
}
"
)
if
mask
is
not
None
and
mask
.
size
()
not
in
[
predictions
.
size
()[:
-
1
],
predictions
.
size
()]:
raise
ValueError
(
f
"
mask must have shape in one of [predictions.size()[:-1], predictions.size()] but
"
f
"
found tensor of shape:
{
mask
.
size
()
}
"
)
if
mask
is
None
:
mask
=
predictions
.
new_ones
(
predictions
.
size
()[:
-
1
]).
bool
()
if
mask
.
dim
()
<
predictions
.
dim
():
mask
=
mask
.
unsqueeze
(
-
1
)
padding_mask
=
gold_labels
.
gt
(
0
)
correct
=
predictions
.
eq
(
gold_labels
)
*
padding_mask
correct
=
(
correct
.
int
().
sum
(
-
1
)
==
padding_mask
.
int
().
sum
(
-
1
))
*
mask
.
squeeze
(
-
1
)
correct
=
correct
.
float
()
self
.
correct_indices
=
correct
.
flatten
().
bool
()
self
.
_correct_count
+=
correct
.
sum
()
self
.
_total_count
+=
mask
.
sum
()
@overrides
def
get_metric
(
self
,
reset
:
bool
)
->
float
:
if
self
.
_total_count
>
0
:
accuracy
=
float
(
self
.
_correct_count
)
/
float
(
self
.
_total_count
)
else
:
accuracy
=
0.0
if
reset
:
self
.
reset
()
return
accuracy
@overrides
def
reset
(
self
)
->
None
:
self
.
_correct_count
=
0.0
self
.
_total_count
=
0.0
self
.
correct_indices
=
torch
.
ones
([])
class
SequenceBoolAccuracy
(
Metric
):
class
SequenceBoolAccuracy
(
Metric
):
pass
"""
BoolAccuracy implementation to handle sequences.
"""
def
__init__
(
self
,
prod_last_dim
:
bool
=
False
):
self
.
_correct_count
=
0.0
self
.
_total_count
=
0.0
self
.
prod_last_dim
=
prod_last_dim
self
.
correct_indices
=
torch
.
ones
([])
@overrides
def
__call__
(
self
,
predictions
:
torch
.
Tensor
,
gold_labels
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
BoolTensor
]
=
None
):
if
gold_labels
is
None
:
return
predictions
,
gold_labels
,
mask
=
self
.
detach_tensors
(
predictions
,
gold_labels
,
mask
)
# Some sanity checks.
if
gold_labels
.
size
()
!=
predictions
.
size
():
raise
ValueError
(
f
"
gold_labels must have shape == predictions.size() but
"
f
"
found tensor of shape:
{
gold_labels
.
size
()
}
"
)
if
mask
is
not
None
and
mask
.
size
()
not
in
[
predictions
.
size
()[:
-
1
],
predictions
.
size
()]:
raise
ValueError
(
f
"
mask must have shape in one of [predictions.size()[:-1], predictions.size()] but
"
f
"
found tensor of shape:
{
mask
.
size
()
}
"
)
if
mask
is
None
:
mask
=
predictions
.
new_ones
(
predictions
.
size
()[:
-
1
]).
bool
()
if
mask
.
dim
()
<
predictions
.
dim
():
mask
=
mask
.
unsqueeze
(
-
1
)
correct
=
predictions
.
eq
(
gold_labels
)
*
mask
if
self
.
prod_last_dim
:
correct
=
correct
.
prod
(
-
1
).
unsqueeze
(
-
1
)
correct
=
correct
.
float
()
self
.
correct_indices
=
correct
.
flatten
().
bool
()
self
.
_correct_count
+=
correct
.
sum
()
self
.
_total_count
+=
mask
.
sum
()
@overrides
def
get_metric
(
self
,
reset
:
bool
)
->
float
:
if
self
.
_total_count
>
0
:
accuracy
=
float
(
self
.
_correct_count
)
/
float
(
self
.
_total_count
)
else
:
accuracy
=
0.0
if
reset
:
self
.
reset
()
return
accuracy
@overrides
def
reset
(
self
)
->
None
:
self
.
_correct_count
=
0.0
self
.
_total_count
=
0.0
self
.
correct_indices
=
torch
.
ones
([])
class
AttachmentScores
(
Metric
):
class
AttachmentScores
(
Metric
):
pass
"""
Computes labeled and unlabeled attachment scores for a
dependency parse, as well as sentence level exact match
for both labeled and unlabeled trees. Note that the input
to this metric is the sampled predictions, not the distribution
itself.
# Parameters
ignore_classes : `List[int]`, optional (default = None)
A list of label ids to ignore when computing metrics.
"""
def
__init__
(
self
,
ignore_classes
:
List
[
int
]
=
None
)
->
None
:
self
.
_labeled_correct
=
0.0
self
.
_unlabeled_correct
=
0.0
self
.
_exact_labeled_correct
=
0.0
self
.
_exact_unlabeled_correct
=
0.0
self
.
_total_words
=
0.0
self
.
_total_sentences
=
0.0
self
.
correct_indices
=
torch
.
ones
([])
self
.
_ignore_classes
:
List
[
int
]
=
ignore_classes
or
[]
def
__call__
(
# type: ignore
self
,
predicted_indices
:
torch
.
Tensor
,
predicted_labels
:
torch
.
Tensor
,
gold_indices
:
torch
.
Tensor
,
gold_labels
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
BoolTensor
]
=
None
,
):
"""
# Parameters
predicted_indices : `torch.Tensor`, required.
A tensor of head index predictions of shape (batch_size, timesteps).
predicted_labels : `torch.Tensor`, required.
A tensor of arc label predictions of shape (batch_size, timesteps).
gold_indices : `torch.Tensor`, required.
A tensor of the same shape as `predicted_indices`.
gold_labels : `torch.Tensor`, required.
A tensor of the same shape as `predicted_labels`.
mask : `torch.BoolTensor`, optional (default = None).
A tensor of the same shape as `predicted_indices`.
"""
if
gold_labels
is
None
or
gold_indices
is
None
:
return
detached
=
self
.
detach_tensors
(
predicted_indices
,
predicted_labels
,
gold_indices
,
gold_labels
,
mask
)
predicted_indices
,
predicted_labels
,
gold_indices
,
gold_labels
,
mask
=
detached
if
mask
is
None
:
mask
=
torch
.
ones_like
(
predicted_indices
).
bool
()
predicted_indices
=
predicted_indices
.
long
()
predicted_labels
=
predicted_labels
.
long
()
gold_indices
=
gold_indices
.
long
()
gold_labels
=
gold_labels
.
long
()
# Multiply by a mask denoting locations of
# gold labels which we should ignore.
for
label
in
self
.
_ignore_classes
:
label_mask
=
gold_labels
.
eq
(
label
)
mask
=
mask
&
~
label_mask
correct_indices
=
predicted_indices
.
eq
(
gold_indices
).
long
()
*
mask
unlabeled_exact_match
=
(
correct_indices
+
~
mask
).
prod
(
dim
=-
1
)
if
len
(
correct_indices
.
size
())
>
2
:
unlabeled_exact_match
=
unlabeled_exact_match
.
prod
(
dim
=-
1
)
correct_labels
=
predicted_labels
.
eq
(
gold_labels
).
long
()
*
mask
correct_labels_and_indices
=
correct_indices
*
correct_labels
self
.
correct_indices
=
correct_labels_and_indices
.
flatten
()
labeled_exact_match
=
(
correct_labels_and_indices
+
~
mask
).
prod
(
dim
=-
1
)
if
len
(
correct_indices
.
size
())
>
2
:
labeled_exact_match
=
labeled_exact_match
.
prod
(
dim
=-
1
)
self
.
_unlabeled_correct
+=
correct_indices
.
sum
()
self
.
_exact_unlabeled_correct
+=
unlabeled_exact_match
.
sum
()
self
.
_labeled_correct
+=
correct_labels_and_indices
.
sum
()
self
.
_exact_labeled_correct
+=
labeled_exact_match
.
sum
()
self
.
_total_sentences
+=
correct_indices
.
size
(
0
)
self
.
_total_words
+=
correct_indices
.
numel
()
-
(
~
mask
).
sum
()
def
get_metric
(
self
,
reset
:
bool
=
False
):
"""
# Returns
The accumulated metrics as a dictionary.
"""
unlabeled_attachment_score
=
0.0
labeled_attachment_score
=
0.0
unlabeled_exact_match
=
0.0
labeled_exact_match
=
0.0
if
self
.
_total_words
>
0.0
:
unlabeled_attachment_score
=
float
(
self
.
_unlabeled_correct
)
/
float
(
self
.
_total_words
)
labeled_attachment_score
=
float
(
self
.
_labeled_correct
)
/
float
(
self
.
_total_words
)
if
self
.
_total_sentences
>
0
:
unlabeled_exact_match
=
float
(
self
.
_exact_unlabeled_correct
)
/
float
(
self
.
_total_sentences
)
labeled_exact_match
=
float
(
self
.
_exact_labeled_correct
)
/
float
(
self
.
_total_sentences
)
if
reset
:
self
.
reset
()
return
{
"
UAS
"
:
unlabeled_attachment_score
,
"
LAS
"
:
labeled_attachment_score
,
"
UEM
"
:
unlabeled_exact_match
,
"
LEM
"
:
labeled_exact_match
,
}
@overrides
def
reset
(
self
):
self
.
_labeled_correct
=
0.0
self
.
_unlabeled_correct
=
0.0
self
.
_exact_labeled_correct
=
0.0
self
.
_exact_unlabeled_correct
=
0.0
self
.
_total_words
=
0.0
self
.
_total_sentences
=
0.0
self
.
correct_indices
=
torch
.
ones
([])
class
SemanticMetrics
(
Metric
):
class
SemanticMetrics
(
Metric
):
pass
"""
Groups metrics for all predictions.
"""
def
__init__
(
self
)
->
None
:
self
.
upos_score
=
SequenceBoolAccuracy
()
self
.
xpos_score
=
SequenceBoolAccuracy
()
self
.
semrel_score
=
SequenceBoolAccuracy
()
self
.
feats_score
=
SequenceBoolAccuracy
(
prod_last_dim
=
True
)
self
.
lemma_score
=
LemmaAccuracy
()
self
.
attachment_scores
=
AttachmentScores
()
# Ignore PADDING and OOV
self
.
enhanced_attachment_scores
=
AttachmentScores
(
ignore_classes
=
[
0
,
1
])
self
.
em_score
=
0.0
def
__call__
(
# type: ignore
self
,
predictions
:
Dict
[
str
,
torch
.
Tensor
],
gold_labels
:
Dict
[
str
,
torch
.
Tensor
],
mask
:
torch
.
BoolTensor
):
self
.
upos_score
(
predictions
[
"
upostag
"
],
gold_labels
[
"
upostag
"
],
mask
)
self
.
xpos_score
(
predictions
[
"
xpostag
"
],
gold_labels
[
"
xpostag
"
],
mask
)
self
.
semrel_score
(
predictions
[
"
semrel
"
],
gold_labels
[
"
semrel
"
],
mask
)
self
.
feats_score
(
predictions
[
"
feats
"
],
gold_labels
[
"
feats
"
],
mask
)
self
.
lemma_score
(
predictions
[
"
lemma
"
],
gold_labels
[
"
lemma
"
],
mask
)
self
.
attachment_scores
(
predictions
[
"
head
"
],
predictions
[
"
deprel
"
],
gold_labels
[
"
head
"
],
gold_labels
[
"
deprel
"
],
mask
)
self
.
enhanced_attachment_scores
(
predictions
[
"
enhanced_head
"
],
predictions
[
"
enhanced_deprel
"
],
gold_labels
[
"
enhanced_head
"
],
gold_labels
[
"
enhanced_deprel
"
],
mask
=
None
)
enhanced_indices
=
(
self
.
enhanced_attachment_scores
.
correct_indices
.
reshape
(
mask
.
size
(
0
),
mask
.
size
(
1
)
+
1
,
-
1
)[:,
1
:,
1
:].
sum
(
-
1
).
reshape
(
-
1
).
bool
()
if
len
(
self
.
enhanced_attachment_scores
.
correct_indices
.
size
())
>
0
else
self
.
enhanced_attachment_scores
.
correct_indices
)
total
=
mask
.
sum
()
correct_indices
=
(
self
.
upos_score
.
correct_indices
*
self
.
xpos_score
.
correct_indices
*
self
.
semrel_score
.
correct_indices
*
self
.
feats_score
.
correct_indices
*
self
.
lemma_score
.
correct_indices
*
self
.
attachment_scores
.
correct_indices
*
enhanced_indices
)
*
mask
.
flatten
()
total
,
correct_indices
=
self
.
detach_tensors
(
total
,
correct_indices
.
float
().
sum
())
self
.
em_score
=
(
correct_indices
/
total
).
item
()
def
get_metric
(
self
,
reset
:
bool
)
->
Dict
[
str
,
float
]:
metrics_dict
=
{
"
UPOS_ACC
"
:
self
.
upos_score
.
get_metric
(
reset
),
"
XPOS_ACC
"
:
self
.
xpos_score
.
get_metric
(
reset
),
"
SEMREL_ACC
"
:
self
.
semrel_score
.
get_metric
(
reset
),
"
LEMMA_ACC
"
:
self
.
lemma_score
.
get_metric
(
reset
),
"
FEATS_ACC
"
:
self
.
feats_score
.
get_metric
(
reset
),
"
EM
"
:
self
.
em_score
}
metrics_dict
.
update
(
self
.
attachment_scores
.
get_metric
(
reset
))
enhanced_metrics
=
{
f
"
E
{
k
}
"
:
v
for
k
,
v
in
self
.
enhanced_attachment_scores
.
get_metric
(
reset
).
items
()}
metrics_dict
.
update
(
enhanced_metrics
)
return
metrics_dict
def
reset
(
self
)
->
None
:
self
.
upos_score
.
reset
()
self
.
xpos_score
.
reset
()
self
.
semrel_score
.
reset
()
self
.
lemma_score
.
reset
()
self
.
feats_score
.
reset
()
self
.
attachment_scores
.
reset
()
self
.
enhanced_attachment_scores
.
reset
()
self
.
em_score
=
0.0
This diff is collapsed.
Click to expand it.
tests/utils/test_metrics.py
0 → 100644
+
211
−
0
View file @
5e79a44e
"""
Metrics tests.
"""
import
unittest
import
torch
from
combo.utils
import
metrics
class
SemanticMetricsTest
(
unittest
.
TestCase
):
def
setUp
(
self
)
->
None
:
self
.
mask
:
torch
.
BoolTensor
=
torch
.
tensor
([
[
True
,
True
,
True
,
True
],
[
True
,
True
,
True
,
False
],
[
True
,
True
,
True
,
False
],
])
pred
=
torch
.
tensor
([
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
[
0
,
1
,
2
,
3
],
])
pred_seq
=
pred
.
reshape
(
3
,
4
,
1
)
gold
=
pred
.
clone
()
gold_seq
=
pred_seq
.
clone
()
self
.
upostag
,
self
.
upostag_l
=
((
"
upostag
"
,
x
)
for
x
in
[
pred
,
gold
])
self
.
xpostag
,
self
.
xpostag_l
=
((
"
xpostag
"
,
x
)
for
x
in
[
pred
,
gold
])
self
.
semrel
,
self
.
semrel_l
=
((
"
semrel
"
,
x
)
for
x
in
[
pred
,
gold
])
self
.
head
,
self
.
head_l
=
((
"
head
"
,
x
)
for
x
in
[
pred
,
gold
])
self
.
deprel
,
self
.
deprel_l
=
((
"
deprel
"
,
x
)
for
x
in
[
pred
,
gold
])
# TODO(mklimasz) Set up an example with size 3x5x5
self
.
enhanced_head
,
self
.
enhanced_head_l
=
((
"
enhanced_head
"
,
x
)
for
x
in
[
None
,
None
])
self
.
enhanced_deprel
,
self
.
enhanced_deprel_l
=
((
"
enhanced_deprel
"
,
x
)
for
x
in
[
None
,
None
])
self
.
feats
,
self
.
feats_l
=
((
"
feats
"
,
x
)
for
x
in
[
pred_seq
,
gold_seq
])
self
.
lemma
,
self
.
lemma_l
=
((
"
lemma
"
,
x
)
for
x
in
[
pred_seq
,
gold_seq
])
self
.
predictions
=
dict
(
[
self
.
upostag
,
self
.
xpostag
,
self
.
semrel
,
self
.
feats
,
self
.
lemma
,
self
.
head
,
self
.
deprel
,
self
.
enhanced_head
,
self
.
enhanced_deprel
])
self
.
gold_labels
=
dict
([
self
.
upostag_l
,
self
.
xpostag_l
,
self
.
semrel_l
,
self
.
feats_l
,
self
.
lemma_l
,
self
.
head_l
,
self
.
deprel_l
,
self
.
enhanced_head_l
,
self
.
enhanced_deprel_l
])
self
.
eps
=
1e-6
def
test_every_prediction_correct
(
self
):
# given
metric
=
metrics
.
SemanticMetrics
()
# when
metric
(
self
.
predictions
,
self
.
gold_labels
,
self
.
mask
)
# then
self
.
assertEqual
(
1.0
,
metric
.
em_score
)
def
test_missing_predictions_for_one_target
(
self
):
# given
metric
=
metrics
.
SemanticMetrics
()
self
.
predictions
[
"
upostag
"
]
=
None
self
.
gold_labels
[
"
upostag
"
]
=
None
# when
metric
(
self
.
predictions
,
self
.
gold_labels
,
self
.
mask
)
# then
self
.
assertEqual
(
1.0
,
metric
.
em_score
)
def
test_missing_predictions_for_two_targets
(
self
):
# given
metric
=
metrics
.
SemanticMetrics
()
self
.
predictions
[
"
upostag
"
]
=
None
self
.
gold_labels
[
"
upostag
"
]
=
None
self
.
predictions
[
"
lemma
"
]
=
None
self
.
gold_labels
[
"
lemma
"
]
=
None
# when
metric
(
self
.
predictions
,
self
.
gold_labels
,
self
.
mask
)
# then
self
.
assertEqual
(
1.0
,
metric
.
em_score
)
def
test_one_classification_in_one_target_is_wrong
(
self
):
# given
metric
=
metrics
.
SemanticMetrics
()
self
.
predictions
[
"
upostag
"
][
0
][
0
]
=
100
# when
metric
(
self
.
predictions
,
self
.
gold_labels
,
self
.
mask
)
# then
self
.
assertAlmostEqual
(
0.9
,
metric
.
em_score
,
delta
=
self
.
eps
)
def
test_classification_errors_and_target_without_predictions
(
self
):
# given
metric
=
metrics
.
SemanticMetrics
()
self
.
predictions
[
"
feats
"
]
=
None
self
.
gold_labels
[
"
feats
"
]
=
None
self
.
predictions
[
"
upostag
"
][
0
][
0
]
=
100
self
.
predictions
[
"
upostag
"
][
2
][
0
]
=
100
# should be ignored due to masking
self
.
predictions
[
"
upostag
"
][
1
][
3
]
=
100
# when
metric
(
self
.
predictions
,
self
.
gold_labels
,
self
.
mask
)
# then
self
.
assertAlmostEqual
(
0.8
,
metric
.
em_score
,
delta
=
self
.
eps
)
class
SequenceBoolAccuracyTest
(
unittest
.
TestCase
):
def
setUp
(
self
)
->
None
:
self
.
mask
:
torch
.
BoolTensor
=
torch
.
tensor
([
[
True
,
True
,
True
,
True
],
[
True
,
True
,
True
,
False
],
[
True
,
True
,
True
,
False
],
])
def
test_regular_classification_accuracy
(
self
):
# given
metric
=
metrics
.
SequenceBoolAccuracy
()
predictions
=
torch
.
tensor
([
[
1
,
1
,
0
,
8
],
[
1
,
2
,
3
,
4
],
[
9
,
4
,
3
,
9
],
])
gold_labels
=
torch
.
tensor
([
[
11
,
1
,
0
,
8
],
[
14
,
2
,
3
,
14
],
[
9
,
4
,
13
,
9
],
])
# when
metric
(
predictions
,
gold_labels
,
self
.
mask
)
# then
self
.
assertEqual
(
metric
.
_correct_count
.
item
(),
7
)
self
.
assertEqual
(
metric
.
_total_count
.
item
(),
10
)
def
test_feats_classification_accuracy
(
self
):
# given
metric
=
metrics
.
SequenceBoolAccuracy
(
prod_last_dim
=
True
)
# batch_size, sequence_length, classes
predictions
=
torch
.
tensor
([
[[
1
,
4
],
[
0
,
2
],
[
0
,
2
],
[
0
,
3
]],
[[
1
,
4
],
[
0
,
2
],
[
0
,
2
],
[
0
,
3
]],
[[
1
,
4
],
[
0
,
2
],
[
0
,
2
],
[
0
,
3
]],
])
gold_labels
=
torch
.
tensor
([
[[
1
,
14
],
[
0
,
2
],
[
0
,
2
],
[
0
,
3
]],
[[
11
,
4
],
[
0
,
2
],
[
0
,
2
],
[
10
,
3
]],
[[
1
,
4
],
[
0
,
2
],
[
10
,
12
],
[
0
,
3
]],
])
# when
metric
(
predictions
,
gold_labels
,
self
.
mask
)
# then
self
.
assertEqual
(
metric
.
_correct_count
.
item
(),
7
)
self
.
assertEqual
(
metric
.
_total_count
.
item
(),
10
)
class
LemmaAccuracyTest
(
unittest
.
TestCase
):
def
setUp
(
self
)
->
None
:
self
.
mask
:
torch
.
BoolTensor
=
torch
.
tensor
([
[
True
,
True
,
True
,
True
],
[
True
,
True
,
True
,
False
],
])
def
test_prediction_has_error_in_not_padded_place
(
self
):
# given
metric
=
metrics
.
LemmaAccuracy
()
predictions
=
torch
.
tensor
([
[[
1
,
1
,
1
],
[
1
,
1
,
1
],
[
2
,
2
,
0
],
[
1
,
1
,
4
],
],
[[
1
,
1
,
0
],
[
1
,
1000
,
0
],
[
1
,
1
,
0
],
[
1
,
1
,
0
],
],
])
gold_labels
=
torch
.
tensor
([
[[
1
,
1
,
1
],
[
1
,
1
,
1
],
[
2
,
2
,
0
],
[
1
,
1
,
4
],
],
[[
1
,
1
,
0
],
[
1
,
1
,
0
],
[
1
,
1
,
0
],
[
1
,
1
,
0
],
],
])
expected_correct_count
=
6
expected_total_count
=
7
expected_correct_indices
=
torch
.
tensor
([
1
,
1
,
1
,
1
,
1
,
0
,
1
,
0
])
# when
metric
(
predictions
,
gold_labels
,
self
.
mask
)
# then
self
.
assertEqual
(
metric
.
_correct_count
.
item
(),
expected_correct_count
)
self
.
assertEqual
(
metric
.
_total_count
.
item
(),
expected_total_count
)
self
.
assertTrue
(
torch
.
all
(
expected_correct_indices
.
eq
(
metric
.
correct_indices
)))
def
test_prediction_wrong_prediction_in_padding_should_be_ignored
(
self
):
# given
metric
=
metrics
.
LemmaAccuracy
()
predictions
=
torch
.
tensor
([
[[
1
,
1
,
1
],
[
1
,
1
,
1
],
[
2
,
2
,
0
],
[
1
,
1
,
4
],
],
[[
1
,
1
,
1000
],
[
1
,
1
,
0
],
[
1
,
1
,
0
],
[
1
,
1
,
0
],
],
])
gold_labels
=
torch
.
tensor
([
[[
1
,
1
,
1
],
[
1
,
1
,
1
],
[
2
,
2
,
0
],
[
1
,
1
,
4
],
],
[[
1
,
1
,
0
],
[
1
,
1
,
0
],
[
1
,
1
,
0
],
[
1
,
1
,
0
],
],
])
expected_correct_count
=
7
expected_total_count
=
7
expected_correct_indices
=
torch
.
tensor
([
1
,
1
,
1
,
1
,
1
,
1
,
1
,
0
])
# when
metric
(
predictions
,
gold_labels
,
self
.
mask
)
# then
self
.
assertEqual
(
expected_correct_count
,
metric
.
_correct_count
.
item
())
self
.
assertEqual
(
expected_total_count
,
metric
.
_total_count
.
item
())
self
.
assertTrue
(
torch
.
all
(
expected_correct_indices
.
eq
(
metric
.
correct_indices
)))
This diff is collapsed.
Click to expand it.
Preview
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment