Skip to content
Snippets Groups Projects
Commit 26937cc9 authored by Adam Radziszewski's avatar Adam Radziszewski
Browse files

fix tagger eval script and add OOV stats

parent 9a030bad
Branches
No related merge requests found
......@@ -21,6 +21,7 @@ changelog = """
* higher frac precision in output
* extract measures to functions for averaging
* averaging over folds
* separate stats for unknown forms
"""
def text(tok_seq, respect_spaces):
......@@ -153,6 +154,13 @@ class TokComp:
self.ref_toks_amb_weak_punc_hit = 0
# tokens subjected to s.a. that were weakly hit thanks to punc-around
self.ref_toks_amb_weak_puncplus_hit = 0
# ref toks presumably unknown (unk tag among ref tok tags)
self.ref_toks_unk = 0
# ref toks that contribute to WC, amb + noamb
self.ref_toks_unk_weak_hit = 0
# ref toks that contribute to SC on POS level, amb + noamb
self.ref_toks_unk_pos_strong_hit = 0
self.tag_toks = 0 # all tokens in tagger output
self.tag_toks_amb = 0 # tokens in tagger output subjected to s.a.
......@@ -161,9 +169,9 @@ class TokComp:
tok_tags = set([self.tagset.tag_to_string(lex.tag()) for lex in tok.lexemes() if lex.is_disamb()])
return tok_tags == set([self.punc_tag])
def is_unknown(self, tok):
def is_unk(self, tok):
tok_tags = [self.tagset.tag_to_string(lex.tag()) for lex in tok.lexemes()]
return unk_tag in tok_tags
return self.unk_tag in tok_tags
def tagstrings_of_token(self, tok):
"""Returns a set of strings, corresponding to disamb tags
......@@ -212,17 +220,20 @@ class TokComp:
self.tag_toks += len(tag_seq)
self.ref_toks += len(ref_seq)
unk_tokens = sum(self.is_unk(ref_tok) for ref_tok in
unk_tokens = sum(self.is_unk(ref_tok) for ref_tok in ref_seq)
self.ref_toks_unk += unk_tokens
# first variant: no segmentation mess
if len(tag_seq) == 1 and len(ref_seq) == 1:
tagval, posval = self.cmp_toks(tag_seq[0], ref_seq[0])
if tagval > 0:
self.ref_toks_noamb_weak_hit += len(ref_seq)
self.ref_toks_unk_weak_hit += unk_tokens
if tagval == 2:
self.ref_toks_noamb_strong_hit += len(ref_seq)
if posval == 2:
self.ref_toks_noamb_pos_strong_hit += len(ref_seq)
self.ref_toks_unk_pos_strong_hit += unk_tokens
if self.debug: print '\t\tnormal', tagval, posval
else:
......@@ -238,6 +249,8 @@ class TokComp:
self.ref_toks_amb_weak_hit += len(ref_seq)
self.ref_toks_amb_strong_hit += len(ref_seq)
self.ref_toks_amb_pos_strong_hit += len(ref_seq)
self.ref_toks_unk_weak_hit += unk_tokens # unlikely that unk_tokens > 0
self.ref_toks_unk_pos_strong_hit += unk_tokens # as above
if self.debug: print '\t\tpunc hit, ref len', len(ref_seq)
else:
nonpunc_ref = [tok for tok in ref_seq if not self.is_punc(tok)]
......@@ -249,6 +262,7 @@ class TokComp:
if tagval > 0:
self.ref_toks_amb_weak_hit += len(ref_seq)
self.ref_toks_amb_weak_puncplus_hit += len(ref_seq)
self.ref_toks_unk_weak_hit += unk_tokens
if self.debug: print '\t\tpuncPLUS weak hit, ref len', len(ref_seq)
if tagval == 2:
self.ref_toks_amb_strong_hit += len(ref_seq)
......@@ -258,6 +272,7 @@ class TokComp:
if self.debug: print '\t\tMISS, ref len', len(ref_seq)
if posval == 2:
self.ref_toks_amb_pos_strong_hit += len(ref_seq)
self.ref_toks_unk_pos_strong_hit += unk_tokens
def weak_lower_bound(self):
"""Returns weak correctness percentage counting only hits where
......@@ -296,10 +311,17 @@ class TokComp:
"""Upper bound for SC."""
upper_strong_hits = self.ref_toks_noamb_strong_hit + self.ref_toks_amb
return 100.0 * upper_strong_hits / self.ref_toks
def unk_weak_corr(self):
return 100.0 * self.ref_toks_unk_weak_hit / self.ref_toks_unk
def unk_pos_strong_corr(self):
return 100.0 * self.ref_toks_unk_pos_strong_hit / self.ref_toks_unk
def dump(self):
print '----'
print 'REF-toks\t%d' % self.ref_toks
print 'REF-toks-unk\t%d\t%.4f%%' % (self.ref_toks_unk, 100.0 * self.ref_toks_unk / self.ref_toks)
print 'TAGGER-toks\t%d' % self.tag_toks
print 'REF-amb-toks\t%d\t%.4f%%' % (self.ref_toks_amb, 100.0 * self.ref_toks_amb / self.ref_toks)
print 'TAGGER-amb-toks\t%d\t%.4f%%' % (self.tag_toks_amb, 100.0 * self.tag_toks_amb / self.tag_toks)
......@@ -356,6 +378,9 @@ def go():
parser.add_option('-p', '--punc-tag', type='string', action='store',
dest='punc_tag', default='interp',
help='set the tag used for punctuation; default: interp')
parser.add_option('-u', '--unk-tag', type='string', action='store',
dest='unk_tag', default='ign',
help='set the tag used for unknown forms; default: ign')
parser.add_option('-k', '--keep-optional', action='store_false',
default=True, dest='expand_optional',
help='do not expand unspecified optional attributes to multiple tags')
......@@ -380,6 +405,9 @@ def go():
weak_upper_bound = 0.0
weak = 0.0
strong_pos = 0.0
unk_weak = 0.0
unk_strong_pos = 0.0
for fold_idx in range(num_folds):
tag_fn = args[fold_idx] # filename of tagged fold @ fold_idx
......@@ -389,7 +417,7 @@ def go():
tag_rdr = corpus2.TokenReader.create_path_reader(options.input_format, tagset, tag_fn)
ref_rdr = corpus2.TokenReader.create_path_reader(options.input_format, tagset, ref_fn)
res = TokComp(tagset, options.punc_tag,
res = TokComp(tagset, options.punc_tag, options.unk_tag,
options.expand_optional, options.debug_mode)
for tag_seq, ref_seq in tok_seqs(tag_rdr, ref_rdr, options.respect_spaces, options.verbose, options.debug_mode):
res.update(tag_seq, ref_seq)
......@@ -399,11 +427,15 @@ def go():
weak_upper_bound += res.weak_upper_bound()
weak += res.weak_corr()
strong_pos += res.pos_strong_corr()
unk_weak += res.unk_weak_corr()
unk_strong_pos += res.unk_pos_strong_corr()
print 'AVG weak corr lower bound\t%.4f%%' % (weak_lower_bound / num_folds)
print 'AVG weak corr upper bound\t%.4f%%' % (weak_upper_bound / num_folds)
print 'AVG weak corr (heur)\t%.4f%%' % (weak / num_folds)
print 'AVG POS strong corr\t%.4f%%' % (strong_pos / num_folds)
print 'AVG UNK weak corr (heur)\t%.4f%%' % (unk_weak / num_folds)
print 'AVG UNK POS strong corr\t%.4f%%' % (unk_strong_pos / num_folds)
if __name__ == '__main__':
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment