From 5b9d1cbc2abdd904386a3d72796b5e2cff4b7529 Mon Sep 17 00:00:00 2001 From: Pawel Orlowicz <porlowicz@gmail.com> Date: Fri, 7 Dec 2012 14:22:57 +0100 Subject: [PATCH] relation_eval refactored --- utils/relation_eval.py | 72 ++++++++++++++++++++++-------------------- 1 file changed, 37 insertions(+), 35 deletions(-) diff --git a/utils/relation_eval.py b/utils/relation_eval.py index 4067051..0daa15a 100755 --- a/utils/relation_eval.py +++ b/utils/relation_eval.py @@ -1,4 +1,5 @@ #!/usr/bin/python +# -*- coding: utf-8 -*- # Copyright (C) 2012 Paweł Orłowicz. # This program is free software; you can redistribute and/or modify it @@ -31,7 +32,7 @@ from optparse import OptionParser import sys import corpus2 -class RelStats : +class RelStats: def __init__(self): self.both_hits = 0 self.head_hits = 0 @@ -39,7 +40,7 @@ class RelStats : self.any_hits = 0 #helper method to get annotation vector from annotated sentence - def get_channel_annotations(self, ann_sent, dir_point) : + def get_channel_annotations(self, ann_sent, dir_point): chann_name = dir_point.channel_name() annotation_number = dir_point.annotation_number() - 1 channel = ann_sent.get_channel(chann_name) @@ -47,7 +48,7 @@ class RelStats : return ann_vec[annotation_number] #helper method to get list of tokens' indices - def get_indices(self, annotated_sentence, direction_point) : + def get_indices(self, annotated_sentence, direction_point): ann_chann = self.get_channel_annotations(annotated_sentence, direction_point) indices = ann_chann.indices #loop to unwrap Integer objects from ann_chann.indices @@ -58,34 +59,38 @@ class RelStats : return inds #helper to get index of the chunk's head - def get_head_index(self, annotated_sentence, direction_point) : + def get_head_index(self, annotated_sentence, direction_point): ann_chann = self.get_channel_annotations(annotated_sentence, direction_point) head_index = ann_chann.head_index return head_index #returns values of hits from one direction point of relation - def verify_relation(self, ref_ann_sent, dir_point_ref, target_ann_sent, dir_point_target) : + def verify_relation(self, ref_ann_sent, dir_point_ref, target_ann_sent, dir_point_target): both, head, chun = 0,0,0 #if indices from ref chunk and target chunks equals (tokens are the same) then chun hits - if self.get_indices(ref_ann_sent, dir_point_ref) == self.get_indices(target_ann_sent, dir_point_target) : - chun += 1 - #if chun hits and head indices match then head hits - if self.get_head_index(ref_ann_sent, dir_point_ref) == self.get_head_index(target_ann_sent, dir_point_target) : - head +=1 + if self.get_indices(ref_ann_sent, dir_point_ref) == self.get_indices(target_ann_sent, dir_point_target): + chun = 1 +# if chun hits and head indices match then head hits +# if self.get_head_index(ref_ann_sent, dir_point_ref) == self.get_head_index(target_ann_sent, dir_point_target): +# head =1 #if indices are different (chunks consists of different sets of words) but heads match then head hits - elif self.get_head_index(ref_ann_sent, dir_point_ref) == self.get_head_index(target_ann_sent, dir_point_target) : - head += 1 + if self.get_head_index(ref_ann_sent, dir_point_ref) == self.get_head_index(target_ann_sent, dir_point_target): + head = 1 + if chun == 1 and head == 1: + both = 1 return both,chun,head #if there was a hit on both sides of relation (dir_from, dir_to) then update counters - def update_stats(self, both, chun, head) : - if chun == 2 : + def update_stats(self, both, chun, head): + if chun == 2: self.chun_hits+=1 - if head == 2 : + if head == 2: self.head_hits += 1 - if chun == 2 and head == 2 : + if chun == 2 and head == 2: self.both_hits += 1 - if chun == 2 or head == 2: + if both > 0 and chun+head > 2: + self.any_hits+=1 + if both == 0 and chun+head > 1: self.any_hits+=1 def print_stats(self,ref_rels_count, target_rels_count, stat_mode): @@ -114,35 +119,35 @@ class RelStats : print ('Head match:\t') print '%.2f\t%.2f\t%.2f' % (p, r, f) -def compare(rel1, rel2) : +def compare(rel1, rel2): dp1_from = rel1.rel_from() dp2_from = rel2.rel_from() dp1_to = rel1.rel_to() dp2_to = rel2.rel_to() - if cmp(dp1_from.sentence_id(), dp2_from.sentence_id()) < 0 : + if cmp(dp1_from.sentence_id(), dp2_from.sentence_id()) < 0: return -1 - elif cmp(dp1_from.sentence_id(), dp2_from.sentence_id()) > 0 : + elif cmp(dp1_from.sentence_id(), dp2_from.sentence_id()) > 0: return 1 - if cmp(dp1_from.channel_name(), dp2_from.channel_name()) < 0 : + if cmp(dp1_from.channel_name(), dp2_from.channel_name()) < 0: return -1 - elif cmp(dp1_from.channel_name(), dp2_from.channel_name()) > 0 : + elif cmp(dp1_from.channel_name(), dp2_from.channel_name()) > 0: return 1 if cmp(dp1_from.annotation_number(), dp2_from.annotation_number()) < 0: return -1 - elif cmp(dp1_from.annotation_number(), dp2_from.annotation_number()) > 0 : + elif cmp(dp1_from.annotation_number(), dp2_from.annotation_number()) > 0: return 1 - if cmp(dp1_to.sentence_id(), dp2_to.sentence_id()) < 0 : + if cmp(dp1_to.sentence_id(), dp2_to.sentence_id()) < 0: return -1 - elif cmp(dp1_to.sentence_id(), dp2_to.sentence_id()) > 0 : + elif cmp(dp1_to.sentence_id(), dp2_to.sentence_id()) > 0: return 1 - if cmp(dp1_to.channel_name(), dp2_to.channel_name()) < 0 : + if cmp(dp1_to.channel_name(), dp2_to.channel_name()) < 0: return -1 - elif cmp(dp1_to.channel_name(), dp2_to.channel_name()) > 0 : + elif cmp(dp1_to.channel_name(), dp2_to.channel_name()) > 0: return 1 if cmp(dp1_to.annotation_number(), dp2_to.annotation_number()) < 0: return -1 - elif cmp(dp1_to.annotation_number(), dp2_to.annotation_number()) > 0 : + elif cmp(dp1_to.annotation_number(), dp2_to.annotation_number()) > 0: return 1 if rel1.rel_name() < rel2.rel_name(): @@ -169,7 +174,6 @@ def go(): sys.exit(1) batch_ref, batch_target, rel_name = args - rel_stats = RelStats() corpus_type = "document" @@ -182,7 +186,7 @@ def go(): target_file = open(batch_target, "r") line_ref = ref_file.readline() line_target = target_file.readline() - while line_ref and line_target : + while line_ref and line_target: line_ref = line_ref.strip() ref_ccl_filename, ref_rel_filename = line_ref.split(";") @@ -190,7 +194,6 @@ def go(): line_target = line_target.strip() target_ccl_filename, target_rel_filename = line_target.split(";") - ref_ccl_rdr = corpus2.CclRelReader(tagset, ref_ccl_filename, ref_rel_filename) target_ccl_rdr = corpus2.CclRelReader(tagset, target_ccl_filename, target_rel_filename) @@ -205,15 +208,14 @@ def go(): ref_sents = dict([ (s.id(), corpus2.AnnotatedSentence.wrap_sentence(s)) for c in ref_doc.paragraphs() for s in c.sentences()]) target_sents = dict([ (s.id(), corpus2.AnnotatedSentence.wrap_sentence(s)) for c in target_doc.paragraphs() for s in c.sentences()]) - - for pattern in ref_rels : + for pattern in ref_rels: t = filter(lambda x : (compare(x, pattern) == 0) , target_rels) - if len(t) > 0 : + if len(t) > 0: t = t[0] r = pattern both, chun, head = 0,0,0 - for dir_point_ref, dir_point_target in zip([r.rel_from(), r.rel_to()], [t.rel_from(), t.rel_to()]) : + for dir_point_ref, dir_point_target in zip([r.rel_from(), r.rel_to()], [t.rel_from(), t.rel_to()]): ref_ann_sent = ref_sents[dir_point_ref.sentence_id()] target_ann_sent = target_sents[dir_point_target.sentence_id()] b,c,h = rel_stats.verify_relation(ref_ann_sent, dir_point_ref, target_ann_sent, dir_point_target) -- GitLab