diff --git a/utils/tagger-eval.py b/utils/tagger-eval.py index e6209b8228c08ee0d8d9cdeb4d2a2f5c44bdcf50..508e690cc06ccdbac0b83f4fde7cd1407dee1ee9 100755 --- a/utils/tagger-eval.py +++ b/utils/tagger-eval.py @@ -5,7 +5,7 @@ from optparse import OptionParser import sys import corpus2 from StringIO import StringIO - +from collections import defaultdict as dd descr = """%prog [options] TAGDFOLD1 ... REFFOLD1 ... Evaluates tagging of tagged corpus consisting of possibly several folds using @@ -106,7 +106,41 @@ def tok_seqs(rdr_here, rdr_there, respect_spaces, verbose_mode, debug_mode): assert buff_here and buff_there yield (buff_here, buff_there) +class Feat: + WEAK_TAG_HIT = 'weak tag hit' # also includes strong hits + STRONG_TAG_HIT = 'strong tag hit' + WEAK_POS_HIT = 'weak pos hit' # also includes strong hits + STRONG_POS_HIT = 'strong pos hit' + ALLPUNC_HIT = 'weak tag allpunc hit' # heur1 only + PUNCAROUND_HIT = 'weak tag puncaround hit' # heur2 only + SEG_NOCHANGE = 'segmentation unchanged' + SEG_CHANGE = 'segmentation change' + KNOWN = 'known' + UNKNOWN = 'unknown' +class Metric: + # lower bounds for correctness, treating all segchanges as failures + WC_LOWER = ([Feat.WEAK_TAG_HIT, Feat.SEG_NOCHANGE], None) # lower bound for weak correctness + SC_LOWER = ([Feat.STRONG_TAG_HIT, Feat.SEG_NOCHANGE], None) # lower bound for strong correctness + # weak and strong corr, disregarding punct-only seg changes + WC = ([Feat.WEAK_TAG_HIT], None) + SC = ([Feat.STRONG_TAG_HIT], None) + # as above but metric for POS hits + POS_WC = ([Feat.WEAK_POS_HIT], None) + POS_SC = ([Feat.STRONG_POS_HIT], None) + # separate stats for known and unknown forms + KN_WC = ([Feat.WEAK_TAG_HIT, Feat.KNOWN], [Feat.KNOWN]) + UNK_WC = ([Feat.WEAK_TAG_HIT, Feat.UNKNOWN], [Feat.UNKNOWN]) + KN_SC = ([Feat.STRONG_TAG_HIT, Feat.KNOWN], [Feat.KNOWN]) + UNK_SC = ([Feat.STRONG_TAG_HIT, Feat.UNKNOWN], [Feat.UNKNOWN]) + KN_POS_SC = ([Feat.STRONG_POS_HIT, Feat.KNOWN], [Feat.KNOWN]) + UNK_POS_SC = ([Feat.STRONG_POS_HIT, Feat.UNKNOWN], [Feat.UNKNOWN]) + # heur recover + PUNCHIT_PUNCONLY = ([Feat.ALLPUNC_HIT], None) + PUNCHIT_AROUND = ([Feat.PUNCAROUND_HIT], None) + # percentage of tokens subjected to seg change + SEG_CHANGE = ([Feat.SEG_CHANGE], None) + SEG_NOCHANGE = ([Feat.SEG_NOCHANGE], None) class TokComp: """Creates a tagger evaluation comparator. The comparator reads two @@ -137,39 +171,22 @@ class TokComp: self.expand_optional = expand_optional self.debug = debug self.ref_toks = 0 # all tokens in ref corpus - self.ref_toks_amb = 0 # tokens subjected to segmentation ambiguities - # tokens not subjected to s.a. that contribute to weak correctness (WC) - self.ref_toks_noamb_weak_hit = 0 - # not subjected to s.a. that contribute to strong correctness (SC) - self.ref_toks_noamb_strong_hit = 0 - # not subjected to s.a. that contribute to SC on POS level - self.ref_toks_noamb_pos_strong_hit = 0 # exactly same sets of POSes - # tokens subjected to s.a. that contribute to WC - self.ref_toks_amb_weak_hit = 0 - # tokens subjected to s.a. that contribute to SC - self.ref_toks_amb_strong_hit = 0 - # tokens subjected to s.a. that contribute to SC on POS level - self.ref_toks_amb_pos_strong_hit = 0 - # tokens subjected to s.a. that were weakly hit thanks to punc-only - self.ref_toks_amb_weak_punc_hit = 0 - # tokens subjected to s.a. that were weakly hit thanks to punc-around - self.ref_toks_amb_weak_puncplus_hit = 0 - # ref toks presumably unknown (unk tag among ref tok tags) - self.ref_toks_unk = 0 - # ref toks that contribute to WC, amb + noamb - self.ref_toks_unk_weak_hit = 0 - # ref toks that contribute to SC on POS level, amb + noamb - self.ref_toks_unk_pos_strong_hit = 0 - - self.tag_toks = 0 # all tokens in tagger output - self.tag_toks_amb = 0 # tokens in tagger output subjected to s.a. + self.tag_feat = dd(int) # feat frozenset -> count + + def eat_ref_toks(self, feat_set, num_toks): + """Classifies num_toks reference tokens as having the given set of + features. Will also increment the ref_toks counter.""" + self.ref_toks += num_toks + self.tag_feat[frozenset(feat_set)] += num_toks def is_punc(self, tok): + """The only DISAMB tags are punctuation.""" tok_tags = set([self.tagset.tag_to_string(lex.tag()) for lex in tok.lexemes() if lex.is_disamb()]) return tok_tags == set([self.punc_tag]) def is_unk(self, tok): + """There is an 'unknown word' interpretation.""" tok_tags = [self.tagset.tag_to_string(lex.tag()) for lex in tok.lexemes()] return self.unk_tag in tok_tags @@ -192,179 +209,108 @@ class TokComp: return set(self.tagset.tag_to_string(tag) for tag in tags) def cmp_toks(self, tok1, tok2): - """Returns a tuple: (hitlevel, poshitlevel), where hitlevel concerns - whole tag comparison, while poshitleve concerns POS only. Bot levels - are integers: 2 if both tokens have the same sets of disamb tags - (POS-es), 1 if they intersect, 0 otherwise.""" + """Returns a set of features concerning strong and weak hits on tag and + POS level.""" + hit_feats = set() tok1_tags = self.tagstrings_of_token(tok1) tok2_tags = self.tagstrings_of_token(tok2) tok1_pos = set(t.split(':', 1)[0] for t in tok1_tags) tok2_pos = set(t.split(':', 1)[0] for t in tok2_tags) - pos_hit = (tok1_pos == tok2_pos) - - taghitlevel, poshitlevel = 0, 0 - - if tok1_tags == tok2_tags: - taghitlevel = 2 - elif tok1_tags.intersection(tok2_tags): - taghitlevel = 1 - if tok1_pos == tok2_pos: - poshitlevel = 2 + hit_feats.add(Feat.STRONG_POS_HIT) + hit_feats.add(Feat.WEAK_POS_HIT) elif tok1_pos.intersection(tok2_pos): - poshitlevel = 1 + hit_feats.add(Feat.WEAK_POS_HIT) + if tok1_tags == tok2_tags: + hit_feats.add(Feat.STRONG_TAG_HIT) + hit_feats.add(Feat.WEAK_TAG_HIT) + elif tok1_tags.intersection(tok2_tags): + hit_feats.add(Feat.WEAK_TAG_HIT) - return (taghitlevel, poshitlevel) + return hit_feats def update(self, tag_seq, ref_seq): self.tag_toks += len(tag_seq) - self.ref_toks += len(ref_seq) + #self.ref_toks += len(ref_seq) TODO - unk_tokens = sum(self.is_unk(ref_tok) for ref_tok in ref_seq) - self.ref_toks_unk += unk_tokens + # initialise empty feat set for each ref token + pre_feat_sets = [set() for _ in ref_seq] + # check if there are any "unknown words" + for tok, feats in zip(ref_seq, pre_feat_sets): + feats.add(Feat.UNKNOWN if self.is_unk(tok) else Feat.KNOWN) + # now check for segmentation changes # first variant: no segmentation mess if len(tag_seq) == 1 and len(ref_seq) == 1: - tagval, posval = self.cmp_toks(tag_seq[0], ref_seq[0]) - if tagval > 0: - self.ref_toks_noamb_weak_hit += len(ref_seq) - self.ref_toks_unk_weak_hit += unk_tokens - if tagval == 2: - self.ref_toks_noamb_strong_hit += len(ref_seq) - if posval == 2: - self.ref_toks_noamb_pos_strong_hit += len(ref_seq) - self.ref_toks_unk_pos_strong_hit += unk_tokens - - if self.debug: print '\t\tnormal', tagval, posval + # there is only one token, hence one feat set + # update it with hit feats + pre_feat_sets[0].add(Feat.SEG_NOCHANGE) + pre_feat_sets[0].update(self.cmp_toks(tag_seq[0], ref_seq[0])) else: - self.ref_toks_amb += len(ref_seq) - self.tag_toks_amb += len(tag_seq) + # mark all as subjected to segmentation changes + for feats in pre_feat_sets: feats.add(Feat.SEG_CHANGE) + # check if all ref and tagged toks are punctuation all_punc_ref = all(self.is_punc(tok) for tok in ref_seq) all_punc_tag = all(self.is_punc(tok) for tok in tag_seq) if all_punc_ref and all_punc_tag: + for feats in pre_feat_sets: feats.update([Feat.ALLPUNC_HIT, Feat.STRONG_POS_HIT, Feat.STRONG_TAG_HIT]) # second variant: PUNC v. PUNC gives hit - #print '-'.join(tokx.orth_utf8() for tokx in tag_seq) - #print '-'.join(tokx.orth_utf8() for tokx in ref_seq) - self.ref_toks_amb_weak_punc_hit += len(ref_seq) - self.ref_toks_amb_weak_hit += len(ref_seq) - self.ref_toks_amb_strong_hit += len(ref_seq) - self.ref_toks_amb_pos_strong_hit += len(ref_seq) - self.ref_toks_unk_weak_hit += unk_tokens # unlikely that unk_tokens > 0 - self.ref_toks_unk_pos_strong_hit += unk_tokens # as above if self.debug: print '\t\tpunc hit, ref len', len(ref_seq) else: nonpunc_ref = [tok for tok in ref_seq if not self.is_punc(tok)] nonpunc_tag = [tok for tok in tag_seq if not self.is_punc(tok)] - - tagval, posval = 0, 0 if len(nonpunc_ref) == len(nonpunc_tag) == 1: - tagval, posval = self.cmp_toks(nonpunc_tag[0], nonpunc_ref[0]) - if tagval > 0: - self.ref_toks_amb_weak_hit += len(ref_seq) - self.ref_toks_amb_weak_puncplus_hit += len(ref_seq) - self.ref_toks_unk_weak_hit += unk_tokens - if self.debug: print '\t\tpuncPLUS weak hit, ref len', len(ref_seq) - if tagval == 2: - self.ref_toks_amb_strong_hit += len(ref_seq) - if self.debug: print '\t\tpuncPLUS strong hit, ref len', len(ref_seq) - if tagval == 0: - # miss - if self.debug: print '\t\tMISS, ref len', len(ref_seq) - if posval == 2: - self.ref_toks_amb_pos_strong_hit += len(ref_seq) - self.ref_toks_unk_pos_strong_hit += unk_tokens - - def weak_lower_bound(self): - """Returns weak correctness percentage counting only hits where - segmentation did not change. That is, lower bound of real WC.""" - return 100.0 * self.ref_toks_noamb_weak_hit / self.ref_toks - - def strong_lower_bound(self): - """Returns strong correctness percentage counting only hits where - segmentation did not change. That is, lower bound of real SC.""" - return 100.0 * self.ref_toks_noamb_strong_hit / self.ref_toks + # perhaps third variant: both seqs have one non-punc token + # if the non-punc tokens match, will take the hit features + # for the whole ref + hit_feats = self.cmp_toks(nonpunc_tag[0], nonpunc_ref[0]) + for feats in pre_feat_sets: + feats.update(hit_feats) + if hit_feats: + for feats in pre_feat_sets: + feats.add(Feat.PUNCAROUND_HIT) + if self.debug: print '\t\tpuncPLUS weak hit, ref len', len(ref_seq) + for feats in pre_feat_sets: + self.eat_ref_toks(feats, 1) + if self.debug: + print ' - ', ', '.join(sorted(feats)) - def weak_corr(self): - """Returns weak correctness, counting changes in segmentation - as failure unless they concern punctuation (see above for two - rules).""" - all_weak_hits = self.ref_toks_amb_weak_hit + self.ref_toks_noamb_weak_hit - return 100.0 * all_weak_hits / self.ref_toks + def count_all(self): # TODO remove + """Returns the number of all reference tokens.""" + return sum(self.tag_feat.values()) - def strong_corr(self): - """As above but SC.""" - all_strong_hits = self.ref_toks_amb_strong_hit + self.ref_toks_noamb_strong_hit - return 100.0 * all_strong_hits / self.ref_toks + def count_with(self, feats): + """Returns the number of reference tokens having all the given + features.""" + satisfying = [key for key in self.tag_feat if key.issuperset(set(feats))] + return sum(self.tag_feat[key] for key in satisfying) - def pos_strong_corr(self): - """POS-only SC.""" - all_pos_strong_hits = self.ref_toks_amb_pos_strong_hit + self.ref_toks_noamb_pos_strong_hit - return 100.0 * all_pos_strong_hits / self.ref_toks + def percentage_with(self, feats, wrt_to = None): + """Returns the percentage of reference tokens that have the given + features. By default all the reference tokens are treated as the + denominator. If wrt_to given, will be used as a reference feat set.""" + if wrt_to: + return 100.0 * self.count_with(feats) / self.count_with(wrt_to) + else: + return 100.0 * self.count_with(feats) / self.ref_toks - def weak_upper_bound(self): - """Upper bound for weak correctness, i.e. counting every reference - token subjected to segmentation change as hit.""" - upper_weak_hits = self.ref_toks_noamb_weak_hit + self.ref_toks_amb - return 100.0 * upper_weak_hits / self.ref_toks - - def strong_upper_bound(self): - """Upper bound for SC.""" - upper_strong_hits = self.ref_toks_noamb_strong_hit + self.ref_toks_amb - return 100.0 * upper_strong_hits / self.ref_toks - - def unk_weak_corr(self): - return 100.0 * self.ref_toks_unk_weak_hit / self.ref_toks_unk - - def unk_pos_strong_corr(self): - return 100.0 * self.ref_toks_unk_pos_strong_hit / self.ref_toks_unk + def value_of(self, metric): + """Calculates the value of the given metric, being a tuple of + features for tokens counted as hit and features for tokens counted + at all (or None if all the tokens should be counted). The first feats + should be a subset of the second one.""" + return self.percentage_with(metric[0], metric[1]) def dump(self): print '----' print 'REF-toks\t%d' % self.ref_toks - print 'REF-toks-unk\t%d\t%.4f%%' % (self.ref_toks_unk, 100.0 * self.ref_toks_unk / self.ref_toks) - print 'TAGGER-toks\t%d' % self.tag_toks - print 'REF-amb-toks\t%d\t%.4f%%' % (self.ref_toks_amb, 100.0 * self.ref_toks_amb / self.ref_toks) - print 'TAGGER-amb-toks\t%d\t%.4f%%' % (self.tag_toks_amb, 100.0 * self.tag_toks_amb / self.tag_toks) - print - print 'REF-weak-hits-noamb (lower bound)\t%d\t%.4f%%' % \ - (self.ref_toks_noamb_weak_hit, self.weak_lower_bound()) - print 'REF-strong-hits-noamb (lower bound)\t%d\t%.4f%%' % \ - (self.ref_toks_noamb_strong_hit, self.strong_lower_bound()) - print - print 'REF-weak-hits-amb (heur recover)\t%d\t%.4f%%' % \ - (self.ref_toks_amb_weak_hit, \ - 100.0 * self.ref_toks_amb_weak_hit / self.ref_toks) - print 'REF-weak-punc-amb-hits (heur1 recover)\t%d\t%.4f%%' % \ - (self.ref_toks_amb_weak_punc_hit, \ - 100.0 * self.ref_toks_amb_weak_punc_hit / self.ref_toks) - print 'REF-weak-puncPLUS-amb-hits (heur2 recover)\t%d\t%.4f%%' % \ - (self.ref_toks_amb_weak_puncplus_hit, \ - 100.0 * self.ref_toks_amb_weak_puncplus_hit / self.ref_toks) - print 'REF-strong-POS-hits-amb (heur recover)\t%d\t%.4f%%' % \ - (self.ref_toks_amb_pos_strong_hit, \ - 100.0 * self.ref_toks_amb_pos_strong_hit / self.ref_toks) - print - - print 'REF-weak-hits-all (heur)\t%d\t%.4f%%' % \ - (self.ref_toks_amb_weak_hit + self.ref_toks_noamb_weak_hit, \ - self.weak_corr()) - - print 'REF-strong-hits-all (heur)\t%d\t%.4f%%' % \ - (self.ref_toks_amb_strong_hit + self.ref_toks_noamb_strong_hit, \ - self.strong_corr()) - - print 'REF-POS-strong-hits-all (heur)\t%d\t%.4f%%' % \ - (self.ref_toks_amb_pos_strong_hit + self.ref_toks_noamb_pos_strong_hit, \ - self.pos_strong_corr()) - print - # all amb as hits - print 'REF-weak-hits-hitamb (upper bound)\t%d\t%.4f%%' % \ - (self.ref_toks_noamb_weak_hit + self.ref_toks_amb, \ - self.weak_upper_bound()) - - print 'REF-strong-hits-hitamb (upper bound)\t%d\t%.4f%%' % \ - (self.ref_toks_noamb_strong_hit + self.ref_toks_amb, \ - self.strong_upper_bound()) + for m_name in dir(Metric): + if not m_name.startswith('_'): + metric = getattr(Metric, m_name) + print '%s\t%.4f%%' % (m_name, self.value_of(metric)) + # calc upper bound + upbound = self.value_of(Metric.WC_LOWER) + self.value_of(Metric.SEG_CHANGE) + print 'WC_UPPER\t%.4f%%' % upbound def go(): parser = OptionParser(usage=descr) @@ -423,12 +369,12 @@ def go(): res.update(tag_seq, ref_seq) if options.verbose: res.dump() - weak_lower_bound += res.weak_lower_bound() - weak_upper_bound += res.weak_upper_bound() - weak += res.weak_corr() - strong_pos += res.pos_strong_corr() - unk_weak += res.unk_weak_corr() - unk_strong_pos += res.unk_pos_strong_corr() + weak_lower_bound += res.value_of(Metric.WC_LOWER) + weak_upper_bound += res.value_of(Metric.WC_LOWER) + res.value_of(Metric.SEG_CHANGE) + weak += res.value_of(Metric.WC) + strong_pos += res.value_of(Metric.POS_SC) + unk_weak += res.value_of(Metric.UNK_WC) + unk_strong_pos += res.value_of(Metric.UNK_POS_SC) print 'AVG weak corr lower bound\t%.4f%%' % (weak_lower_bound / num_folds) print 'AVG weak corr upper bound\t%.4f%%' % (weak_upper_bound / num_folds)