Skip to content
Snippets Groups Projects
Commit 5b9d1cbc authored by Pawel Orlowicz's avatar Pawel Orlowicz
Browse files

relation_eval refactored

parent 924f54e5
Branches
No related merge requests found
#!/usr/bin/python
# -*- coding: utf-8 -*-
# Copyright (C) 2012 Paweł Orłowicz.
# This program is free software; you can redistribute and/or modify it
......@@ -31,7 +32,7 @@ from optparse import OptionParser
import sys
import corpus2
class RelStats :
class RelStats:
def __init__(self):
self.both_hits = 0
self.head_hits = 0
......@@ -39,7 +40,7 @@ class RelStats :
self.any_hits = 0
#helper method to get annotation vector from annotated sentence
def get_channel_annotations(self, ann_sent, dir_point) :
def get_channel_annotations(self, ann_sent, dir_point):
chann_name = dir_point.channel_name()
annotation_number = dir_point.annotation_number() - 1
channel = ann_sent.get_channel(chann_name)
......@@ -47,7 +48,7 @@ class RelStats :
return ann_vec[annotation_number]
#helper method to get list of tokens' indices
def get_indices(self, annotated_sentence, direction_point) :
def get_indices(self, annotated_sentence, direction_point):
ann_chann = self.get_channel_annotations(annotated_sentence, direction_point)
indices = ann_chann.indices
#loop to unwrap Integer objects from ann_chann.indices
......@@ -58,34 +59,38 @@ class RelStats :
return inds
#helper to get index of the chunk's head
def get_head_index(self, annotated_sentence, direction_point) :
def get_head_index(self, annotated_sentence, direction_point):
ann_chann = self.get_channel_annotations(annotated_sentence, direction_point)
head_index = ann_chann.head_index
return head_index
#returns values of hits from one direction point of relation
def verify_relation(self, ref_ann_sent, dir_point_ref, target_ann_sent, dir_point_target) :
def verify_relation(self, ref_ann_sent, dir_point_ref, target_ann_sent, dir_point_target):
both, head, chun = 0,0,0
#if indices from ref chunk and target chunks equals (tokens are the same) then chun hits
if self.get_indices(ref_ann_sent, dir_point_ref) == self.get_indices(target_ann_sent, dir_point_target) :
chun += 1
#if chun hits and head indices match then head hits
if self.get_head_index(ref_ann_sent, dir_point_ref) == self.get_head_index(target_ann_sent, dir_point_target) :
head +=1
if self.get_indices(ref_ann_sent, dir_point_ref) == self.get_indices(target_ann_sent, dir_point_target):
chun = 1
# if chun hits and head indices match then head hits
# if self.get_head_index(ref_ann_sent, dir_point_ref) == self.get_head_index(target_ann_sent, dir_point_target):
# head =1
#if indices are different (chunks consists of different sets of words) but heads match then head hits
elif self.get_head_index(ref_ann_sent, dir_point_ref) == self.get_head_index(target_ann_sent, dir_point_target) :
head += 1
if self.get_head_index(ref_ann_sent, dir_point_ref) == self.get_head_index(target_ann_sent, dir_point_target):
head = 1
if chun == 1 and head == 1:
both = 1
return both,chun,head
#if there was a hit on both sides of relation (dir_from, dir_to) then update counters
def update_stats(self, both, chun, head) :
if chun == 2 :
def update_stats(self, both, chun, head):
if chun == 2:
self.chun_hits+=1
if head == 2 :
if head == 2:
self.head_hits += 1
if chun == 2 and head == 2 :
if chun == 2 and head == 2:
self.both_hits += 1
if chun == 2 or head == 2:
if both > 0 and chun+head > 2:
self.any_hits+=1
if both == 0 and chun+head > 1:
self.any_hits+=1
def print_stats(self,ref_rels_count, target_rels_count, stat_mode):
......@@ -114,35 +119,35 @@ class RelStats :
print ('Head match:\t')
print '%.2f\t%.2f\t%.2f' % (p, r, f)
def compare(rel1, rel2) :
def compare(rel1, rel2):
dp1_from = rel1.rel_from()
dp2_from = rel2.rel_from()
dp1_to = rel1.rel_to()
dp2_to = rel2.rel_to()
if cmp(dp1_from.sentence_id(), dp2_from.sentence_id()) < 0 :
if cmp(dp1_from.sentence_id(), dp2_from.sentence_id()) < 0:
return -1
elif cmp(dp1_from.sentence_id(), dp2_from.sentence_id()) > 0 :
elif cmp(dp1_from.sentence_id(), dp2_from.sentence_id()) > 0:
return 1
if cmp(dp1_from.channel_name(), dp2_from.channel_name()) < 0 :
if cmp(dp1_from.channel_name(), dp2_from.channel_name()) < 0:
return -1
elif cmp(dp1_from.channel_name(), dp2_from.channel_name()) > 0 :
elif cmp(dp1_from.channel_name(), dp2_from.channel_name()) > 0:
return 1
if cmp(dp1_from.annotation_number(), dp2_from.annotation_number()) < 0:
return -1
elif cmp(dp1_from.annotation_number(), dp2_from.annotation_number()) > 0 :
elif cmp(dp1_from.annotation_number(), dp2_from.annotation_number()) > 0:
return 1
if cmp(dp1_to.sentence_id(), dp2_to.sentence_id()) < 0 :
if cmp(dp1_to.sentence_id(), dp2_to.sentence_id()) < 0:
return -1
elif cmp(dp1_to.sentence_id(), dp2_to.sentence_id()) > 0 :
elif cmp(dp1_to.sentence_id(), dp2_to.sentence_id()) > 0:
return 1
if cmp(dp1_to.channel_name(), dp2_to.channel_name()) < 0 :
if cmp(dp1_to.channel_name(), dp2_to.channel_name()) < 0:
return -1
elif cmp(dp1_to.channel_name(), dp2_to.channel_name()) > 0 :
elif cmp(dp1_to.channel_name(), dp2_to.channel_name()) > 0:
return 1
if cmp(dp1_to.annotation_number(), dp2_to.annotation_number()) < 0:
return -1
elif cmp(dp1_to.annotation_number(), dp2_to.annotation_number()) > 0 :
elif cmp(dp1_to.annotation_number(), dp2_to.annotation_number()) > 0:
return 1
if rel1.rel_name() < rel2.rel_name():
......@@ -169,7 +174,6 @@ def go():
sys.exit(1)
batch_ref, batch_target, rel_name = args
rel_stats = RelStats()
corpus_type = "document"
......@@ -182,7 +186,7 @@ def go():
target_file = open(batch_target, "r")
line_ref = ref_file.readline()
line_target = target_file.readline()
while line_ref and line_target :
while line_ref and line_target:
line_ref = line_ref.strip()
ref_ccl_filename, ref_rel_filename = line_ref.split(";")
......@@ -190,7 +194,6 @@ def go():
line_target = line_target.strip()
target_ccl_filename, target_rel_filename = line_target.split(";")
ref_ccl_rdr = corpus2.CclRelReader(tagset, ref_ccl_filename, ref_rel_filename)
target_ccl_rdr = corpus2.CclRelReader(tagset, target_ccl_filename, target_rel_filename)
......@@ -205,15 +208,14 @@ def go():
ref_sents = dict([ (s.id(), corpus2.AnnotatedSentence.wrap_sentence(s)) for c in ref_doc.paragraphs() for s in c.sentences()])
target_sents = dict([ (s.id(), corpus2.AnnotatedSentence.wrap_sentence(s)) for c in target_doc.paragraphs() for s in c.sentences()])
for pattern in ref_rels :
for pattern in ref_rels:
t = filter(lambda x : (compare(x, pattern) == 0) , target_rels)
if len(t) > 0 :
if len(t) > 0:
t = t[0]
r = pattern
both, chun, head = 0,0,0
for dir_point_ref, dir_point_target in zip([r.rel_from(), r.rel_to()], [t.rel_from(), t.rel_to()]) :
for dir_point_ref, dir_point_target in zip([r.rel_from(), r.rel_to()], [t.rel_from(), t.rel_to()]):
ref_ann_sent = ref_sents[dir_point_ref.sentence_id()]
target_ann_sent = target_sents[dir_point_target.sentence_id()]
b,c,h = rel_stats.verify_relation(ref_ann_sent, dir_point_ref, target_ann_sent, dir_point_target)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment