From 5b9d1cbc2abdd904386a3d72796b5e2cff4b7529 Mon Sep 17 00:00:00 2001
From: Pawel Orlowicz <porlowicz@gmail.com>
Date: Fri, 7 Dec 2012 14:22:57 +0100
Subject: [PATCH] relation_eval refactored

---
 utils/relation_eval.py | 72 ++++++++++++++++++++++--------------------
 1 file changed, 37 insertions(+), 35 deletions(-)

diff --git a/utils/relation_eval.py b/utils/relation_eval.py
index 4067051..0daa15a 100755
--- a/utils/relation_eval.py
+++ b/utils/relation_eval.py
@@ -1,4 +1,5 @@
 #!/usr/bin/python
+# -*- coding: utf-8 -*-
 
 # Copyright (C) 2012 Paweł Orłowicz.
 # This program is free software; you can redistribute and/or modify it
@@ -31,7 +32,7 @@ from optparse import OptionParser
 import sys
 import corpus2
 
-class RelStats : 
+class RelStats:
 	def __init__(self):
 		self.both_hits = 0
 		self.head_hits = 0
@@ -39,7 +40,7 @@ class RelStats :
 		self.any_hits = 0
 
 	#helper method to get annotation vector from annotated sentence
-	def get_channel_annotations(self, ann_sent, dir_point) : 
+	def get_channel_annotations(self, ann_sent, dir_point):
 		chann_name = dir_point.channel_name()
 		annotation_number = dir_point.annotation_number() - 1
 		channel = ann_sent.get_channel(chann_name)
@@ -47,7 +48,7 @@ class RelStats :
 		return ann_vec[annotation_number]
 
 	#helper method to get list of tokens' indices
-	def get_indices(self, annotated_sentence, direction_point) :
+	def get_indices(self, annotated_sentence, direction_point):
 		ann_chann  = self.get_channel_annotations(annotated_sentence, direction_point)
 		indices = ann_chann.indices
 		#loop to unwrap Integer objects from ann_chann.indices
@@ -58,34 +59,38 @@ class RelStats :
 		return inds
 
 	#helper to get index of the chunk's head
-	def get_head_index(self, annotated_sentence, direction_point) : 
+	def get_head_index(self, annotated_sentence, direction_point):
 		ann_chann = self.get_channel_annotations(annotated_sentence, direction_point)
 		head_index = ann_chann.head_index
 		return head_index
 
 	#returns values of hits from one direction point of relation
-	def verify_relation(self, ref_ann_sent, dir_point_ref, target_ann_sent, dir_point_target) : 
+	def verify_relation(self, ref_ann_sent, dir_point_ref, target_ann_sent, dir_point_target):
 		both, head, chun = 0,0,0
 		#if indices from ref chunk and target chunks equals (tokens are the same) then chun hits
-		if self.get_indices(ref_ann_sent, dir_point_ref) == self.get_indices(target_ann_sent, dir_point_target) : 
-			chun += 1
-			#if chun hits and head indices match then head hits
-			if self.get_head_index(ref_ann_sent, dir_point_ref) == self.get_head_index(target_ann_sent, dir_point_target) : 
-				head +=1
+		if self.get_indices(ref_ann_sent, dir_point_ref) == self.get_indices(target_ann_sent, dir_point_target):
+			chun = 1
+#			if chun hits and head indices match then head hits
+#			if self.get_head_index(ref_ann_sent, dir_point_ref) == self.get_head_index(target_ann_sent, dir_point_target): 
+#				head =1
 		#if indices are different (chunks consists of different sets of words) but heads match then head hits
-		elif self.get_head_index(ref_ann_sent, dir_point_ref) == self.get_head_index(target_ann_sent, dir_point_target) : 
-			head += 1
+		if self.get_head_index(ref_ann_sent, dir_point_ref) == self.get_head_index(target_ann_sent, dir_point_target):
+			head = 1
+		if chun == 1 and head == 1:
+			both = 1
 		return both,chun,head
 
 	#if there was a hit on both sides of relation (dir_from, dir_to) then update counters
-	def update_stats(self, both, chun, head) : 
-		if chun == 2 :
+	def update_stats(self, both, chun, head):
+		if chun == 2:
 			self.chun_hits+=1
-		if head == 2 :
+		if head == 2:
 			self.head_hits += 1
-		if chun == 2 and head == 2 :
+		if chun == 2 and head == 2:
 			self.both_hits += 1
-		if chun == 2 or head == 2:
+		if both > 0 and chun+head > 2:
+			self.any_hits+=1
+		if both == 0 and chun+head > 1:
 			self.any_hits+=1
 
 	def print_stats(self,ref_rels_count, target_rels_count, stat_mode):
@@ -114,35 +119,35 @@ class RelStats :
 			print ('Head match:\t')
 		print '%.2f\t%.2f\t%.2f' % (p, r, f) 
 
-def compare(rel1, rel2) : 
+def compare(rel1, rel2):
 	dp1_from = rel1.rel_from()
 	dp2_from = rel2.rel_from()
 	dp1_to = rel1.rel_to()
 	dp2_to = rel2.rel_to()
-	if cmp(dp1_from.sentence_id(), dp2_from.sentence_id()) < 0 :
+	if cmp(dp1_from.sentence_id(), dp2_from.sentence_id()) < 0:
 		return -1
-	elif cmp(dp1_from.sentence_id(), dp2_from.sentence_id()) > 0 :
+	elif cmp(dp1_from.sentence_id(), dp2_from.sentence_id()) > 0:
 		return 1
-	if cmp(dp1_from.channel_name(), dp2_from.channel_name()) < 0 : 
+	if cmp(dp1_from.channel_name(), dp2_from.channel_name()) < 0:
 		return -1
-	elif cmp(dp1_from.channel_name(), dp2_from.channel_name()) > 0 : 
+	elif cmp(dp1_from.channel_name(), dp2_from.channel_name()) > 0:
 		return 1
 	if cmp(dp1_from.annotation_number(), dp2_from.annotation_number()) < 0:
 		return -1
-	elif cmp(dp1_from.annotation_number(), dp2_from.annotation_number()) > 0 : 
+	elif cmp(dp1_from.annotation_number(), dp2_from.annotation_number()) > 0:
 		return 1
 
-	if cmp(dp1_to.sentence_id(), dp2_to.sentence_id()) < 0 :
+	if cmp(dp1_to.sentence_id(), dp2_to.sentence_id()) < 0:
 		return -1
-	elif cmp(dp1_to.sentence_id(), dp2_to.sentence_id()) > 0 :
+	elif cmp(dp1_to.sentence_id(), dp2_to.sentence_id()) > 0:
 		return 1
-	if cmp(dp1_to.channel_name(), dp2_to.channel_name()) < 0 : 
+	if cmp(dp1_to.channel_name(), dp2_to.channel_name()) < 0:
 		return -1
-	elif cmp(dp1_to.channel_name(), dp2_to.channel_name()) > 0 : 
+	elif cmp(dp1_to.channel_name(), dp2_to.channel_name()) > 0:
 		return 1
 	if cmp(dp1_to.annotation_number(), dp2_to.annotation_number()) < 0:
 		return -1
-	elif cmp(dp1_to.annotation_number(), dp2_to.annotation_number()) > 0 : 
+	elif cmp(dp1_to.annotation_number(), dp2_to.annotation_number()) > 0:
 		return 1
 
 	if rel1.rel_name() < rel2.rel_name():
@@ -169,7 +174,6 @@ def go():
 		sys.exit(1)
 
 	batch_ref, batch_target, rel_name = args
-
 	rel_stats = RelStats()
 
 	corpus_type = "document"
@@ -182,7 +186,7 @@ def go():
 	target_file = open(batch_target, "r")
 	line_ref = ref_file.readline()
 	line_target = target_file.readline()
-	while line_ref and line_target : 
+	while line_ref and line_target:
 
 		line_ref = line_ref.strip()
 		ref_ccl_filename, ref_rel_filename = line_ref.split(";")
@@ -190,7 +194,6 @@ def go():
 		line_target = line_target.strip()
 		target_ccl_filename, target_rel_filename = line_target.split(";")
 	
-
 		ref_ccl_rdr = corpus2.CclRelReader(tagset, ref_ccl_filename, ref_rel_filename)
 		target_ccl_rdr = corpus2.CclRelReader(tagset, target_ccl_filename, target_rel_filename)
 
@@ -205,15 +208,14 @@ def go():
 		ref_sents = dict([ (s.id(), corpus2.AnnotatedSentence.wrap_sentence(s)) for c in ref_doc.paragraphs() for s in c.sentences()])
 		target_sents = dict([ (s.id(), corpus2.AnnotatedSentence.wrap_sentence(s)) for c in target_doc.paragraphs() for s in c.sentences()])
 
-
-		for pattern in ref_rels	:
+		for pattern in ref_rels:
 			t = filter(lambda x : (compare(x, pattern) == 0) , target_rels)
-			if len(t) > 0 : 
+			if len(t) > 0:
 				t = t[0]
 				r = pattern
 			
 				both, chun, head = 0,0,0
-				for dir_point_ref, dir_point_target in zip([r.rel_from(), r.rel_to()], [t.rel_from(), t.rel_to()]) : 
+				for dir_point_ref, dir_point_target in zip([r.rel_from(), r.rel_to()], [t.rel_from(), t.rel_to()]):
 					ref_ann_sent = ref_sents[dir_point_ref.sentence_id()]
 					target_ann_sent = target_sents[dir_point_target.sentence_id()]
 					b,c,h = rel_stats.verify_relation(ref_ann_sent, dir_point_ref, target_ann_sent, dir_point_target)
-- 
GitLab