Skip to content
Snippets Groups Projects
Commit af005af1 authored by Adam Radziszewski's avatar Adam Radziszewski
Browse files

overhaul the tagger-eval script

parent 26937cc9
Branches
No related tags found
No related merge requests found
......@@ -5,7 +5,7 @@ from optparse import OptionParser
import sys
import corpus2
from StringIO import StringIO
from collections import defaultdict as dd
descr = """%prog [options] TAGDFOLD1 ... REFFOLD1 ...
Evaluates tagging of tagged corpus consisting of possibly several folds using
......@@ -106,7 +106,41 @@ def tok_seqs(rdr_here, rdr_there, respect_spaces, verbose_mode, debug_mode):
assert buff_here and buff_there
yield (buff_here, buff_there)
class Feat:
WEAK_TAG_HIT = 'weak tag hit' # also includes strong hits
STRONG_TAG_HIT = 'strong tag hit'
WEAK_POS_HIT = 'weak pos hit' # also includes strong hits
STRONG_POS_HIT = 'strong pos hit'
ALLPUNC_HIT = 'weak tag allpunc hit' # heur1 only
PUNCAROUND_HIT = 'weak tag puncaround hit' # heur2 only
SEG_NOCHANGE = 'segmentation unchanged'
SEG_CHANGE = 'segmentation change'
KNOWN = 'known'
UNKNOWN = 'unknown'
class Metric:
# lower bounds for correctness, treating all segchanges as failures
WC_LOWER = ([Feat.WEAK_TAG_HIT, Feat.SEG_NOCHANGE], None) # lower bound for weak correctness
SC_LOWER = ([Feat.STRONG_TAG_HIT, Feat.SEG_NOCHANGE], None) # lower bound for strong correctness
# weak and strong corr, disregarding punct-only seg changes
WC = ([Feat.WEAK_TAG_HIT], None)
SC = ([Feat.STRONG_TAG_HIT], None)
# as above but metric for POS hits
POS_WC = ([Feat.WEAK_POS_HIT], None)
POS_SC = ([Feat.STRONG_POS_HIT], None)
# separate stats for known and unknown forms
KN_WC = ([Feat.WEAK_TAG_HIT, Feat.KNOWN], [Feat.KNOWN])
UNK_WC = ([Feat.WEAK_TAG_HIT, Feat.UNKNOWN], [Feat.UNKNOWN])
KN_SC = ([Feat.STRONG_TAG_HIT, Feat.KNOWN], [Feat.KNOWN])
UNK_SC = ([Feat.STRONG_TAG_HIT, Feat.UNKNOWN], [Feat.UNKNOWN])
KN_POS_SC = ([Feat.STRONG_POS_HIT, Feat.KNOWN], [Feat.KNOWN])
UNK_POS_SC = ([Feat.STRONG_POS_HIT, Feat.UNKNOWN], [Feat.UNKNOWN])
# heur recover
PUNCHIT_PUNCONLY = ([Feat.ALLPUNC_HIT], None)
PUNCHIT_AROUND = ([Feat.PUNCAROUND_HIT], None)
# percentage of tokens subjected to seg change
SEG_CHANGE = ([Feat.SEG_CHANGE], None)
SEG_NOCHANGE = ([Feat.SEG_NOCHANGE], None)
class TokComp:
"""Creates a tagger evaluation comparator. The comparator reads two
......@@ -137,39 +171,22 @@ class TokComp:
self.expand_optional = expand_optional
self.debug = debug
self.ref_toks = 0 # all tokens in ref corpus
self.ref_toks_amb = 0 # tokens subjected to segmentation ambiguities
# tokens not subjected to s.a. that contribute to weak correctness (WC)
self.ref_toks_noamb_weak_hit = 0
# not subjected to s.a. that contribute to strong correctness (SC)
self.ref_toks_noamb_strong_hit = 0
# not subjected to s.a. that contribute to SC on POS level
self.ref_toks_noamb_pos_strong_hit = 0 # exactly same sets of POSes
# tokens subjected to s.a. that contribute to WC
self.ref_toks_amb_weak_hit = 0
# tokens subjected to s.a. that contribute to SC
self.ref_toks_amb_strong_hit = 0
# tokens subjected to s.a. that contribute to SC on POS level
self.ref_toks_amb_pos_strong_hit = 0
# tokens subjected to s.a. that were weakly hit thanks to punc-only
self.ref_toks_amb_weak_punc_hit = 0
# tokens subjected to s.a. that were weakly hit thanks to punc-around
self.ref_toks_amb_weak_puncplus_hit = 0
# ref toks presumably unknown (unk tag among ref tok tags)
self.ref_toks_unk = 0
# ref toks that contribute to WC, amb + noamb
self.ref_toks_unk_weak_hit = 0
# ref toks that contribute to SC on POS level, amb + noamb
self.ref_toks_unk_pos_strong_hit = 0
self.tag_toks = 0 # all tokens in tagger output
self.tag_toks_amb = 0 # tokens in tagger output subjected to s.a.
self.tag_feat = dd(int) # feat frozenset -> count
def eat_ref_toks(self, feat_set, num_toks):
"""Classifies num_toks reference tokens as having the given set of
features. Will also increment the ref_toks counter."""
self.ref_toks += num_toks
self.tag_feat[frozenset(feat_set)] += num_toks
def is_punc(self, tok):
"""The only DISAMB tags are punctuation."""
tok_tags = set([self.tagset.tag_to_string(lex.tag()) for lex in tok.lexemes() if lex.is_disamb()])
return tok_tags == set([self.punc_tag])
def is_unk(self, tok):
"""There is an 'unknown word' interpretation."""
tok_tags = [self.tagset.tag_to_string(lex.tag()) for lex in tok.lexemes()]
return self.unk_tag in tok_tags
......@@ -192,179 +209,108 @@ class TokComp:
return set(self.tagset.tag_to_string(tag) for tag in tags)
def cmp_toks(self, tok1, tok2):
"""Returns a tuple: (hitlevel, poshitlevel), where hitlevel concerns
whole tag comparison, while poshitleve concerns POS only. Bot levels
are integers: 2 if both tokens have the same sets of disamb tags
(POS-es), 1 if they intersect, 0 otherwise."""
"""Returns a set of features concerning strong and weak hits on tag and
POS level."""
hit_feats = set()
tok1_tags = self.tagstrings_of_token(tok1)
tok2_tags = self.tagstrings_of_token(tok2)
tok1_pos = set(t.split(':', 1)[0] for t in tok1_tags)
tok2_pos = set(t.split(':', 1)[0] for t in tok2_tags)
pos_hit = (tok1_pos == tok2_pos)
taghitlevel, poshitlevel = 0, 0
if tok1_tags == tok2_tags:
taghitlevel = 2
elif tok1_tags.intersection(tok2_tags):
taghitlevel = 1
if tok1_pos == tok2_pos:
poshitlevel = 2
hit_feats.add(Feat.STRONG_POS_HIT)
hit_feats.add(Feat.WEAK_POS_HIT)
elif tok1_pos.intersection(tok2_pos):
poshitlevel = 1
hit_feats.add(Feat.WEAK_POS_HIT)
if tok1_tags == tok2_tags:
hit_feats.add(Feat.STRONG_TAG_HIT)
hit_feats.add(Feat.WEAK_TAG_HIT)
elif tok1_tags.intersection(tok2_tags):
hit_feats.add(Feat.WEAK_TAG_HIT)
return (taghitlevel, poshitlevel)
return hit_feats
def update(self, tag_seq, ref_seq):
self.tag_toks += len(tag_seq)
self.ref_toks += len(ref_seq)
#self.ref_toks += len(ref_seq) TODO
unk_tokens = sum(self.is_unk(ref_tok) for ref_tok in ref_seq)
self.ref_toks_unk += unk_tokens
# initialise empty feat set for each ref token
pre_feat_sets = [set() for _ in ref_seq]
# check if there are any "unknown words"
for tok, feats in zip(ref_seq, pre_feat_sets):
feats.add(Feat.UNKNOWN if self.is_unk(tok) else Feat.KNOWN)
# now check for segmentation changes
# first variant: no segmentation mess
if len(tag_seq) == 1 and len(ref_seq) == 1:
tagval, posval = self.cmp_toks(tag_seq[0], ref_seq[0])
if tagval > 0:
self.ref_toks_noamb_weak_hit += len(ref_seq)
self.ref_toks_unk_weak_hit += unk_tokens
if tagval == 2:
self.ref_toks_noamb_strong_hit += len(ref_seq)
if posval == 2:
self.ref_toks_noamb_pos_strong_hit += len(ref_seq)
self.ref_toks_unk_pos_strong_hit += unk_tokens
if self.debug: print '\t\tnormal', tagval, posval
# there is only one token, hence one feat set
# update it with hit feats
pre_feat_sets[0].add(Feat.SEG_NOCHANGE)
pre_feat_sets[0].update(self.cmp_toks(tag_seq[0], ref_seq[0]))
else:
self.ref_toks_amb += len(ref_seq)
self.tag_toks_amb += len(tag_seq)
# mark all as subjected to segmentation changes
for feats in pre_feat_sets: feats.add(Feat.SEG_CHANGE)
# check if all ref and tagged toks are punctuation
all_punc_ref = all(self.is_punc(tok) for tok in ref_seq)
all_punc_tag = all(self.is_punc(tok) for tok in tag_seq)
if all_punc_ref and all_punc_tag:
for feats in pre_feat_sets: feats.update([Feat.ALLPUNC_HIT, Feat.STRONG_POS_HIT, Feat.STRONG_TAG_HIT])
# second variant: PUNC v. PUNC gives hit
#print '-'.join(tokx.orth_utf8() for tokx in tag_seq)
#print '-'.join(tokx.orth_utf8() for tokx in ref_seq)
self.ref_toks_amb_weak_punc_hit += len(ref_seq)
self.ref_toks_amb_weak_hit += len(ref_seq)
self.ref_toks_amb_strong_hit += len(ref_seq)
self.ref_toks_amb_pos_strong_hit += len(ref_seq)
self.ref_toks_unk_weak_hit += unk_tokens # unlikely that unk_tokens > 0
self.ref_toks_unk_pos_strong_hit += unk_tokens # as above
if self.debug: print '\t\tpunc hit, ref len', len(ref_seq)
else:
nonpunc_ref = [tok for tok in ref_seq if not self.is_punc(tok)]
nonpunc_tag = [tok for tok in tag_seq if not self.is_punc(tok)]
tagval, posval = 0, 0
if len(nonpunc_ref) == len(nonpunc_tag) == 1:
tagval, posval = self.cmp_toks(nonpunc_tag[0], nonpunc_ref[0])
if tagval > 0:
self.ref_toks_amb_weak_hit += len(ref_seq)
self.ref_toks_amb_weak_puncplus_hit += len(ref_seq)
self.ref_toks_unk_weak_hit += unk_tokens
# perhaps third variant: both seqs have one non-punc token
# if the non-punc tokens match, will take the hit features
# for the whole ref
hit_feats = self.cmp_toks(nonpunc_tag[0], nonpunc_ref[0])
for feats in pre_feat_sets:
feats.update(hit_feats)
if hit_feats:
for feats in pre_feat_sets:
feats.add(Feat.PUNCAROUND_HIT)
if self.debug: print '\t\tpuncPLUS weak hit, ref len', len(ref_seq)
if tagval == 2:
self.ref_toks_amb_strong_hit += len(ref_seq)
if self.debug: print '\t\tpuncPLUS strong hit, ref len', len(ref_seq)
if tagval == 0:
# miss
if self.debug: print '\t\tMISS, ref len', len(ref_seq)
if posval == 2:
self.ref_toks_amb_pos_strong_hit += len(ref_seq)
self.ref_toks_unk_pos_strong_hit += unk_tokens
def weak_lower_bound(self):
"""Returns weak correctness percentage counting only hits where
segmentation did not change. That is, lower bound of real WC."""
return 100.0 * self.ref_toks_noamb_weak_hit / self.ref_toks
def strong_lower_bound(self):
"""Returns strong correctness percentage counting only hits where
segmentation did not change. That is, lower bound of real SC."""
return 100.0 * self.ref_toks_noamb_strong_hit / self.ref_toks
def weak_corr(self):
"""Returns weak correctness, counting changes in segmentation
as failure unless they concern punctuation (see above for two
rules)."""
all_weak_hits = self.ref_toks_amb_weak_hit + self.ref_toks_noamb_weak_hit
return 100.0 * all_weak_hits / self.ref_toks
def strong_corr(self):
"""As above but SC."""
all_strong_hits = self.ref_toks_amb_strong_hit + self.ref_toks_noamb_strong_hit
return 100.0 * all_strong_hits / self.ref_toks
def pos_strong_corr(self):
"""POS-only SC."""
all_pos_strong_hits = self.ref_toks_amb_pos_strong_hit + self.ref_toks_noamb_pos_strong_hit
return 100.0 * all_pos_strong_hits / self.ref_toks
def weak_upper_bound(self):
"""Upper bound for weak correctness, i.e. counting every reference
token subjected to segmentation change as hit."""
upper_weak_hits = self.ref_toks_noamb_weak_hit + self.ref_toks_amb
return 100.0 * upper_weak_hits / self.ref_toks
def strong_upper_bound(self):
"""Upper bound for SC."""
upper_strong_hits = self.ref_toks_noamb_strong_hit + self.ref_toks_amb
return 100.0 * upper_strong_hits / self.ref_toks
def unk_weak_corr(self):
return 100.0 * self.ref_toks_unk_weak_hit / self.ref_toks_unk
def unk_pos_strong_corr(self):
return 100.0 * self.ref_toks_unk_pos_strong_hit / self.ref_toks_unk
for feats in pre_feat_sets:
self.eat_ref_toks(feats, 1)
if self.debug:
print ' - ', ', '.join(sorted(feats))
def count_all(self): # TODO remove
"""Returns the number of all reference tokens."""
return sum(self.tag_feat.values())
def count_with(self, feats):
"""Returns the number of reference tokens having all the given
features."""
satisfying = [key for key in self.tag_feat if key.issuperset(set(feats))]
return sum(self.tag_feat[key] for key in satisfying)
def percentage_with(self, feats, wrt_to = None):
"""Returns the percentage of reference tokens that have the given
features. By default all the reference tokens are treated as the
denominator. If wrt_to given, will be used as a reference feat set."""
if wrt_to:
return 100.0 * self.count_with(feats) / self.count_with(wrt_to)
else:
return 100.0 * self.count_with(feats) / self.ref_toks
def value_of(self, metric):
"""Calculates the value of the given metric, being a tuple of
features for tokens counted as hit and features for tokens counted
at all (or None if all the tokens should be counted). The first feats
should be a subset of the second one."""
return self.percentage_with(metric[0], metric[1])
def dump(self):
print '----'
print 'REF-toks\t%d' % self.ref_toks
print 'REF-toks-unk\t%d\t%.4f%%' % (self.ref_toks_unk, 100.0 * self.ref_toks_unk / self.ref_toks)
print 'TAGGER-toks\t%d' % self.tag_toks
print 'REF-amb-toks\t%d\t%.4f%%' % (self.ref_toks_amb, 100.0 * self.ref_toks_amb / self.ref_toks)
print 'TAGGER-amb-toks\t%d\t%.4f%%' % (self.tag_toks_amb, 100.0 * self.tag_toks_amb / self.tag_toks)
print
print 'REF-weak-hits-noamb (lower bound)\t%d\t%.4f%%' % \
(self.ref_toks_noamb_weak_hit, self.weak_lower_bound())
print 'REF-strong-hits-noamb (lower bound)\t%d\t%.4f%%' % \
(self.ref_toks_noamb_strong_hit, self.strong_lower_bound())
print
print 'REF-weak-hits-amb (heur recover)\t%d\t%.4f%%' % \
(self.ref_toks_amb_weak_hit, \
100.0 * self.ref_toks_amb_weak_hit / self.ref_toks)
print 'REF-weak-punc-amb-hits (heur1 recover)\t%d\t%.4f%%' % \
(self.ref_toks_amb_weak_punc_hit, \
100.0 * self.ref_toks_amb_weak_punc_hit / self.ref_toks)
print 'REF-weak-puncPLUS-amb-hits (heur2 recover)\t%d\t%.4f%%' % \
(self.ref_toks_amb_weak_puncplus_hit, \
100.0 * self.ref_toks_amb_weak_puncplus_hit / self.ref_toks)
print 'REF-strong-POS-hits-amb (heur recover)\t%d\t%.4f%%' % \
(self.ref_toks_amb_pos_strong_hit, \
100.0 * self.ref_toks_amb_pos_strong_hit / self.ref_toks)
print
print 'REF-weak-hits-all (heur)\t%d\t%.4f%%' % \
(self.ref_toks_amb_weak_hit + self.ref_toks_noamb_weak_hit, \
self.weak_corr())
print 'REF-strong-hits-all (heur)\t%d\t%.4f%%' % \
(self.ref_toks_amb_strong_hit + self.ref_toks_noamb_strong_hit, \
self.strong_corr())
print 'REF-POS-strong-hits-all (heur)\t%d\t%.4f%%' % \
(self.ref_toks_amb_pos_strong_hit + self.ref_toks_noamb_pos_strong_hit, \
self.pos_strong_corr())
print
# all amb as hits
print 'REF-weak-hits-hitamb (upper bound)\t%d\t%.4f%%' % \
(self.ref_toks_noamb_weak_hit + self.ref_toks_amb, \
self.weak_upper_bound())
print 'REF-strong-hits-hitamb (upper bound)\t%d\t%.4f%%' % \
(self.ref_toks_noamb_strong_hit + self.ref_toks_amb, \
self.strong_upper_bound())
for m_name in dir(Metric):
if not m_name.startswith('_'):
metric = getattr(Metric, m_name)
print '%s\t%.4f%%' % (m_name, self.value_of(metric))
# calc upper bound
upbound = self.value_of(Metric.WC_LOWER) + self.value_of(Metric.SEG_CHANGE)
print 'WC_UPPER\t%.4f%%' % upbound
def go():
parser = OptionParser(usage=descr)
......@@ -423,12 +369,12 @@ def go():
res.update(tag_seq, ref_seq)
if options.verbose:
res.dump()
weak_lower_bound += res.weak_lower_bound()
weak_upper_bound += res.weak_upper_bound()
weak += res.weak_corr()
strong_pos += res.pos_strong_corr()
unk_weak += res.unk_weak_corr()
unk_strong_pos += res.unk_pos_strong_corr()
weak_lower_bound += res.value_of(Metric.WC_LOWER)
weak_upper_bound += res.value_of(Metric.WC_LOWER) + res.value_of(Metric.SEG_CHANGE)
weak += res.value_of(Metric.WC)
strong_pos += res.value_of(Metric.POS_SC)
unk_weak += res.value_of(Metric.UNK_WC)
unk_strong_pos += res.value_of(Metric.UNK_POS_SC)
print 'AVG weak corr lower bound\t%.4f%%' % (weak_lower_bound / num_folds)
print 'AVG weak corr upper bound\t%.4f%%' % (weak_upper_bound / num_folds)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment