Skip to content
Snippets Groups Projects
Commit 4d8f38ad authored by Adam Radziszewski's avatar Adam Radziszewski
Browse files

test script for iobber: comparing two outputs

parent 3fb07c54
Branches
No related merge requests found
#!/usr/bin/python
# -*- coding: utf-8 -*-
from optparse import OptionParser
import sys
import corpus2
from StringIO import StringIO
from collections import defaultdict as dd
descr = """%prog [options] OUTPUT1 OUTPUT2
Compares two iobber_txt outputs and counts differences.
"""
class SentRepr:
def __init__(self, sent, tagset, expand_optional):
self.toks = []
self.tag_reprs = []
self.chunk_reprs = []
# get tags
for tok in sent.tokens():
self.toks.append(tok)
tags = [lex.tag() for lex in tok.lexemes() if lex.is_disamb()]
if expand_optional:
# create multivalue tag wherever a value of optional attr is unspec
tags = [tagset.expand_optional_attrs(tag) for tag in tags]
# now expand multivalue tags to singular tags
newtags = []
for tag in tags:
newtags.extend(tagset.split_tag(tag))
tags = newtags
self.tag_reprs.append(set(tagset.tag_to_string(tag) for tag in tags))
asent = corpus2.AnnotatedSentence.wrap_sentence(sent)
chan_names = sorted(asent.all_channels())
# conv channels to IOB repr to have ready IOB repr to compare
for chan_name in chan_names:
asent.get_channel(chan_name).make_iob_from_segments()
for idx, tag_r in enumerate(self.tag_reprs):
iob_str = ''
for chan_name in chan_names:
chan = asent.get_channel(chan_name)
iob_str += '%s-%s-%d,' % (
chan_name, corpus2.to_string(chan.get_iob_at(idx)),
chan.is_head_at(idx))
self.chunk_reprs.append(iob_str)
def tokens(self):
return zip(self.toks, self.tag_reprs, self.chunk_reprs)
def text(tok_seq, respect_spaces, mark_boundaries = False):
"""Extracts text from a sequence of tokens. If respect_spaces, will append
spaces between tokens where no no-space markers present."""
buff = StringIO()
nonfirst = False
for item in tok_seq:
tok = item[0]
if nonfirst and respect_spaces and tok.after_space():
buff.write(' ')
if mark_boundaries:
buff.write('[')
buff.write(tok.orth_utf8().decode('utf-8'))
if mark_boundaries:
buff.write(']')
nonfirst = True
return buff.getvalue()
def tokens(rdr, tagset, expand_optional):
while True:
sent = rdr.get_next_sentence()
if not sent:
break
sent_repr = SentRepr(sent, tagset, expand_optional)
for item in sent_repr.tokens():
yield item
yield None # a guard at the end
def tok_seqs(rdr_here, rdr_there, respect_spaces, verbose_mode, debug_mode, tagset, expand_optional):
"""Generates subsequent aligned token sequences from the two readers.
Alignment is performed on the text level. Shortest token sequences with
matching text are generated.
"""
toks_here = tokens(rdr_here, tagset, expand_optional)
toks_there = tokens(rdr_there, tagset, expand_optional)
tok_here = toks_here.next()
tok_there = toks_there.next()
assert tok_here and tok_there, 'no input'
buff_here = [tok_here]
buff_there = [tok_there]
LIMIT = 30
num_iter = 0
while True:
num_iter += 1
if num_iter % 10000 == 0: print num_iter, 'iterations...'
if len(buff_here) > LIMIT or len(buff_there) > LIMIT:
raise IOError('limit exceeded')
text_here = text(buff_here, respect_spaces)
text_there = text(buff_there, respect_spaces)
if debug_mode:
print '%s (%d) v. %s(%d)' % (text_here, len(text_here),
text_there, len(text_there))
if len(text_here) == len(text_there):
# already aligned
if text_here != text_there:
# the same text length but different chars
# won't recover, flush it as it is
print 'WARNING: mismatch', text_here, '|', text_there
yield (buff_here, buff_there)
buff_here = []
buff_there = []
tok_here = toks_here.next()
tok_there = toks_there.next()
if not tok_here or not tok_there:
assert not tok_here and not tok_there, 'one ended prematurely'
break
buff_here = [tok_here]
buff_there = [tok_there]
elif len(text_here) < len(text_there):
# try fetching here
tok_here = toks_here.next()
assert tok_here, 'first ended prematurely'
buff_here.append(tok_here)
else: # len(text_here) > len(text_there):
# try fetching there
tok_there = toks_there.next()
assert tok_there, 'second ended prematurely'
buff_there.append(tok_there)
if buff_here or buff_there:
assert buff_here and buff_there
yield (buff_here, buff_there)
class Feat:
TAG_HIT = 'tag hit'
POS_HIT = 'POS hit'
LEMMA_HIT = 'lemma hit'
CHUNKS_HIT = 'chunks hit'
SEG_NOCHANGE = 'segmentation unchanged'
SEG_CHANGE = 'segmentation change'
class Metric:
TAG_ACCURACY = ([Feat.TAG_HIT, Feat.SEG_NOCHANGE], None) # lower bound for s correctness
POS_ACCURACY = ([Feat.POS_HIT, Feat.SEG_NOCHANGE], None) # POS lower bound for s correctness
LEMMA_ACCURACY = ([Feat.LEMMA_HIT, Feat.SEG_NOCHANGE], None) # lemma lower bound
CHUNK_ACCURACY = ([Feat.CHUNKS_HIT, Feat.SEG_NOCHANGE], None) # lower bound for chunk-tag acc
# percentage of tokens subjected to seg change
SEG_ACCURACY = ([Feat.SEG_NOCHANGE], None)
class TokComp:
"""Creates a tagger evaluation comparator. The comparator reads two
annotated texts: tagged text (output of a tagger being evaluated) and
reference tagging (the gold standard). Most of the figures reported concern
whether a reference token has got some sort of coverage in the tagger
output.
"""
def __init__(self, tagset, expand_optional, debug = False):
self.tagset = tagset
self.expand_optional = expand_optional
self.debug = debug
self.ref_toks = 0 # all tokens in ref corpus
self.tag_toks = 0 # all tokens in tagger output
self.tag_feat = dd(int) # feat frozenset -> count
def eat_ref_toks(self, feat_set, num_toks):
"""Classifies num_toks reference tokens as having the given set of
features. Will also increment the ref_toks counter."""
self.ref_toks += num_toks
self.tag_feat[frozenset(feat_set)] += num_toks
def cmp_toks(self, item1, item2):
"""Returns a set of features concerning strong and weak hits on tag and
POS level."""
hit_feats = set()
tok1_tags = item1[1]
tok2_tags = item2[1]
tok1_chu = item1[2]
tok2_chu = item2[2]
tok1_pos = set(t.split(':', 1)[0] for t in tok1_tags)
tok2_pos = set(t.split(':', 1)[0] for t in tok2_tags)
tok1_lem = set(unicode(lex.lemma()) for lex in item1[0].lexemes())
tok2_lem = set(unicode(lex.lemma()) for lex in item2[0].lexemes())
if tok1_pos == tok2_pos:
hit_feats.add(Feat.POS_HIT)
if tok1_tags == tok2_tags:
hit_feats.add(Feat.TAG_HIT)
if tok1_lem == tok2_lem:
hit_feats.add(Feat.LEMMA_HIT)
if tok1_chu == tok2_chu:
hit_feats.add(Feat.CHUNKS_HIT)
#print tok1_chu, tok2_chu
return hit_feats
def update(self, tag_seq, ref_seq):
self.tag_toks += len(tag_seq)
# initialise empty feat set for each ref token
pre_feat_sets = [set() for _ in ref_seq]
# now check for segmentation changes
# first variant: no segmentation mess
if len(tag_seq) == 1 and len(ref_seq) == 1:
# there is only one token, hence one feat set
# update it with hit feats
pre_feat_sets[0].add(Feat.SEG_NOCHANGE)
pre_feat_sets[0].update(self.cmp_toks(tag_seq[0], ref_seq[0]))
else:
if self.debug:
print 'SEGCHANGE\t%s\t%s' % (text(tag_seq, True, True), text(ref_seq, True, True))
# mark all as subjected to segmentation changes
for feats in pre_feat_sets: feats.add(Feat.SEG_CHANGE)
for feats in pre_feat_sets:
self.eat_ref_toks(feats, 1)
if self.debug:
print ' - ', ', '.join(sorted(feats))
def count_all(self): # TODO remove
"""Returns the number of all reference tokens."""
return sum(self.tag_feat.values())
def count_with(self, feats):
"""Returns the number of reference tokens having all the given
features."""
satisfying = [key for key in self.tag_feat if key.issuperset(set(feats))]
return sum(self.tag_feat[key] for key in satisfying)
def percentage_with(self, feats, wrt_to = None):
"""Returns the percentage of reference tokens that have the given
features. By default all the reference tokens are treated as the
denominator. If wrt_to given, will be used as a reference feat set."""
count = self.count_with(feats)
if wrt_to is not None:
denom = self.count_with(wrt_to)
else:
denom = self.ref_toks
if denom == 0:
return 0.0 # what else can we do? should be labelled N/A
return 100.0 * count / denom
def value_of(self, metric):
"""Calculates the value of the given metric, being a tuple of
features for tokens counted as hit and features for tokens counted
at all (or None if all the tokens should be counted). The first feats
should be a subset of the second one."""
return self.percentage_with(metric[0], metric[1])
def dump(self):
print '----'
print 'Tokens in reference data\t%d' % self.ref_toks
for m_name in dir(Metric):
if not m_name.startswith('_'):
metric = getattr(Metric, m_name)
print '%s\t%.4f%%' % (m_name, self.value_of(metric))
# calc upper bound
#upbound = self.value_of(Metric.WC_LOWER) + self.value_of(Metric.SEG_CHANGE)
#print 'WC_UPPER\t%.4f%%' % upbound
def go():
parser = OptionParser(usage=descr)
parser.add_option('-i', '--input-format', type='string', action='store',
dest='input_format', default='ccl',
help='set the input format; default: ccl')
parser.add_option('-t', '--tagset', type='string', action='store',
dest='tagset', default='nkjp',
help='set the tagset used in input; default: nkjp')
parser.add_option('-q', '--quiet', action='store_false', default=True, dest='verbose')
parser.add_option('-k', '--keep-optional', action='store_false',
default=True, dest='expand_optional',
help='do not expand unspecified optional attributes to multiple tags')
parser.add_option('-s', '--ignore-spaces', action='store_false',
default=True, dest='respect_spaces',
help='ignore spaces between tokens when comparing')
parser.add_option('-d', '--debug', action='store_true', dest='debug_mode')
(options, args) = parser.parse_args()
if len(args) != 2:
print 'You need to provide two output files'
print 'See --help for details.'
print
sys.exit(1)
tagset = corpus2.get_named_tagset(options.tagset)
num_folds = 1
fold_idx = 0
tag_fn = args[fold_idx] # filename of tagged fold @ fold_idx
ref_fn = args[fold_idx + num_folds] # ... reference fold @ fold_idx
if options.verbose:
print 'Comparing: %s v. %s' % (tag_fn, ref_fn)
tag_rdr = corpus2.TokenReader.create_path_reader(options.input_format, tagset, tag_fn)
ref_rdr = corpus2.TokenReader.create_path_reader(options.input_format, tagset, ref_fn)
res = TokComp(tagset, options.expand_optional, options.debug_mode)
for tag_seq, ref_seq in tok_seqs(
tag_rdr, ref_rdr, options.respect_spaces, options.verbose, options.debug_mode,
tagset, options.expand_optional):
res.update(tag_seq, ref_seq)
if options.verbose:
res.dump()
if __name__ == '__main__':
go()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment