# -*- coding: utf-8 -*-

# Copyright (C) 2012 Adam Radziszewski. Part of IOBBER.
# This program is free software; you can redistribute and/or modify it
# under the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
# or FITNESS FOR A PARTICULAR PURPOSE.
#
# See the LICENCE, COPYING.LESSER and COPYING files for more details

__doc__ = """The actual chunker implementation."""

# SWIG bug workaround: loading multiple SWIG modules brought unwrapped
# swig::stop_iteration exceptions
import ctypes, sys
sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)

import corpus2
# TODO: get back to default dlopen policy?

import os, codecs
import ConfigParser
from operator import itemgetter as ig

import corpio, config, classify

def get_layers(conf):
	layers = [(k, v.split(',')) for (k, v) in conf.items(config.S_LAYERS)]
	for layer in layers:
		assert '-' not in layer, 'hyphens are not allowed in channel names'
	return layers

def is_input_tagged(conf):
	return conf.getboolean(config.S_GLOBAL, config.O_TAGGED)

class Stats:
	"""Statistics for reporting progress and diagnosis."""
	def __init__(self):
		self.clear()
	
	def clear(self):
		self.num_toks = 0
		self.num_sents = 0
	
	def dump(self):
		sys.stderr.write('Toks processed: %d\n' % self.num_toks)
		sys.stderr.write('Sents processed: %d\n' % self.num_sents)
		
	def maybe_report(self):
		if self.num_sents % 100 == 0:
			sys.stderr.write('%d toks...\n' % (self.num_toks))

class Chunker:
	"""The CRF-based chunker. The chunker may add annotations to multiple
	channels during one run, as specified in layer definitions.
	Layers are applied sequentially. A layer defines a set of channels
	that are dealt with at a time. The chunks defined in one layer are
	disjoint.
	A chunker is parametrised with an INI file, defining layers and settings
	and a WCCL file defing features to be used by the underlying classifier.
	A new chunker object should be called either load_model to become a
	functional chunker or train_and_save to infer chunking model from the
	training data. NOTE: after training the chunker object is still not
	ready for performance, i.e. the trained model is just saved to disk
	and requires loading.
	"""
	
	def __init__(self, config_path, data_dir, verbose = False):
		"""Creates a working chunker from given config (INI) filename.
		If config_path points to a config.ini file, there should be an
		accompanying file named config.ccl. Trained chunker model is sought in
		(or written to when training) data_dir. The model is basically a
		trained CRF classifier."""
		found_config_path = corpio.get_data(config_path)
		self.conf_dir, conf_fname = os.path.split(found_config_path)
		# models (trained classifiers)
		self.model_name, dummy = os.path.splitext(conf_fname)
		self.data_dir = corpio.get_data(data_dir)
		self.verbose = verbose
		# load the config file
		with open(found_config_path) as config_file:
			self.conf = ConfigParser.RawConfigParser()
			self.conf.readfp(config_file)
		self.tagset = corpio.get_tagset(self.conf)
		# the chunker may also operate on morphologically analysed
		# but not disambiguated input (no 'disamb' markers)
		self.is_input_tagged = is_input_tagged(self.conf)
		# uninitialised trained model
		# list of WCCL operators describing the features used
		# NOTE: dynamic lexicon generation is currently not supported
		# to make it possible, move op loading to load_model and train_and_save
		# layers -- list of layer names
		# layer_channels -- list of channel lists per layer
		self.layers, self.layer_channels = zip(*get_layers(self.conf))
		# list of per-channel op lists
		self.layer_ops = corpio.get_wccl_ops(
			self.conf, self.model_name,
			self.conf_dir, self.data_dir, self.layers)
		self.layer_models = None # layer_name -> trained classifier
		self.stats = Stats()
	
	def load_model(self):
		self.layer_models = {}
		for layer in self.layers:
			self.layer_models[layer] = classify.load(
				self.conf, self.model_name, self.data_dir, layer)
	
	def train_and_save(self, in_path, input_format):
		"""Trains the tagger and stores the model to files beginning with
		model_name."""
		self.layer_models = None # forget any previously trained model
		if self.verbose:
			sys.stderr.write('Generating training data...\n')
		# open files for storing training examples for each layer
		tr_files = classify.open_tr_files(
			self.model_name, self.data_dir, self.layers)
		
		# set-up the reader and gather feature values for subsequent sentences
		reader = corpio.get_reader(
			in_path, self.tagset, input_format, self.is_input_tagged)
		self.stats.clear()
		
		while True:
			sent = reader.get_next_sentence()
			if not sent:
				break # end of input
			# wrap the sentence as an AnnotatedSentence
			asent = corpus2.AnnotatedSentence.wrap_sentence(sent)
			
			# iterate over layers
			for layer_idx, layer in enumerate(self.layers):
				chans = self.layer_channels[layer_idx]
				for chan_name in chans:
					# ensure the channel is there and switch to IOB2 representation
					if not asent.has_channel(chan_name):
						asent.create_channel(chan_name)
					chan = asent.get_channel(chan_name)
					chan.make_iob_from_segments()
				# prepare WCCL context
				con = corpio.create_context(sent)
				# get file for storing training data
				tr_file = tr_files[layer]
				# iterate over each sentence token
				for tok_idx, tok in enumerate(sent.tokens()):
					con.set_position(tok_idx) # for WCCL ops
					feat_vals = [op.base_apply(con)
						.to_compact_string(self.tagset).decode('utf-8')
						for op in self.layer_ops[layer_idx]]
					# get IOB2 tags as strings, find non-O IOB2 tag or mark it as O
					# TODO: rename the to_string in corpus2 and fix it here
					non_O_chan = None
					non_O_tag = 'O'
					for chan_name in chans:
						chan = asent.get_channel(chan_name)
						there_iob = corpus2.to_string(chan.get_iob_at(tok_idx))
						if there_iob != 'O':
							if non_O_chan is not None:
								sys.stderr.write(
									'WARNING: overlapping phrases in sentence %s\n' % unicode(asent.id()))
							else:
								non_O_chan = chan_name
								non_O_tag = there_iob
					# B-NP, I-VP etc. or O
					class_label = 'O' if non_O_chan is None else '%s-%s' % (non_O_tag, non_O_chan)
					# generate training example and store to file
					classify.write_example(tr_file, feat_vals, class_label)
				classify.write_end_of_sent(tr_file)
			
			self.stats.num_sents += 1
			self.stats.num_toks += sent.tokens().size()
			if self.verbose: self.stats.maybe_report()
		
		classify.close_tr_files(tr_files)
		
		# train the classifier for each layer
		for layer in self.layers:
			if self.verbose:
				sys.stderr.write('Training classifier for %s... ' % layer)
			classify.train_and_save(
				self.conf, self.model_name,
				self.conf_dir, self.data_dir, layer)
			if self.verbose:
				sys.stderr.write('done!\n')
				self.stats.dump()
	
	def tag_sentence(self, sent):
		"""Chunks the given sentence."""
		# wrap the sentence as an AnnotatedSentence
		asent = corpus2.AnnotatedSentence.wrap_sentence(sent)
		
		# iterate over layers
		for layer_idx, layer in enumerate(self.layers):
			# get model for current layer
			model = self.layer_models[layer]
			if model is not None:
				chans = self.layer_channels[layer_idx]
				for chan_name in chans:
					# ensure the channel is there and switch to IOB2 representation
					if not asent.has_channel(chan_name):
						asent.create_channel(chan_name)
					chan = asent.get_channel(chan_name)
					chan.make_iob_from_segments()
				# prepare WCCL context and feed the sentence features
				con = corpio.create_context(sent)
				classify.open_sent(model)
				# iterate over tokens
				for tok_idx, tok in enumerate(sent.tokens()):
					con.set_position(tok_idx)
					feat_vals = [op.base_apply(con)
							.to_compact_string(self.tagset).decode('utf-8')
							for op in self.layer_ops[layer_idx]]
					classify.eat_token(model, feat_vals)
				classify.close_sent(model)
				for tok_idx, tok in enumerate(sent.tokens()):
					decsn = classify.classify_token(model, tok_idx)
					non_O_chan = None
					non_O_tag = 'O'
					if decsn != 'O':
						non_O_tag, non_O_chan = decsn.split('-')
					for chan_name in chans:
						chan = asent.get_channel(chan_name)
						# TODO: rename the from_string in corpus2 and fix it here
						tag_to_set = 'O' if chan_name != non_O_chan else non_O_tag
						chan.set_iob_at(tok_idx, corpus2.from_string(tag_to_set))
				# switch back to segments
				for chan_name in chans:
					chan = asent.get_channel(chan_name)
					chan.make_segments_from_iob()
		
		self.stats.num_sents += 1
		self.stats.num_toks += sent.tokens().size()
		if self.verbose: self.stats.maybe_report()
	
	def tag_input(self, in_path, out_path, input_format, output_format,
			preserve_pars):
		"""Chunks the input and writes processed input to out_path or stdout if
		out_path is None. Similarly, if in_path is None, will read stdin."""
		# set-up reader and writer and proceed with proper chunking
		reader = corpio.get_reader(
			in_path, self.tagset, input_format, self.is_input_tagged)
		writer = corpio.get_writer(out_path, self.tagset, output_format)
		
		self.stats.clear()
		
		if preserve_pars:
			while True:
				chunk = reader.get_next_chunk()
				if not chunk:
					break # end of input
				# process each sentence separately
				for sent in chunk.sentences():
					self.tag_sentence(sent)
				# save tagged input
				writer.write_chunk(chunk)
		else:
			while True:
				sent = reader.get_next_sentence()
				if not sent:
					break # end of input
				self.tag_sentence(sent)
				writer.write_sentence(sent)
		writer.finish()
		if self.verbose:
			self.stats.dump()
