#ifndef LIBWCCL_OPS_FUNCTIONS_SETOPS_H
#define LIBWCCL_OPS_FUNCTIONS_SETOPS_H

#include <libwccl/ops/functions/listoperator.h>
#include <libwccl/values/tset.h>
#include <libwccl/values/strset.h>

namespace Wccl {

template<class T>
class SetListOperator : public ListOperator<T>
{
	BOOST_MPL_ASSERT(( boost::mpl::count<boost::mpl::list<StrSet, TSet>, T> ));
public:
	SetListOperator(const boost::shared_ptr<typename SetListOperator<T>::TFunctionPtrVector>& expressions)
		: ListOperator<T>(expressions)
	{
	}
};

/**
 * Set operation: union
 */
template<class T>
class SetUnion : public SetListOperator<T>
{
public:
	SetUnion(const boost::shared_ptr<typename SetListOperator<T>::TFunctionPtrVector>& expressions)
		: SetListOperator<T>(expressions)
	{
	}

	/**
	 * @returns Name of the function: "union".
	 */
	std::string raw_name() const {
		return "union";
	}

protected:
	FunctionBase::BaseRetValPtr apply_internal(const FunExecContext& context) const;
};

/**
 * Set operation: intersection
 */
template<class T>
class SetIntersection : public SetListOperator<T>
{
public:
	SetIntersection(const boost::shared_ptr<typename SetListOperator<T>::TFunctionPtrVector>& expressions)
		: SetListOperator<T>(expressions)
	{
	}

	/**
	 * @returns Name of the function: "intersection".
	 */
	std::string raw_name() const {
		return "intersection";
	}

protected:
	FunctionBase::BaseRetValPtr apply_internal(const FunExecContext& context) const;
};

//
// ----- Implementation -----
//

template <> inline
FunctionBase::BaseRetValPtr SetUnion<TSet>::apply_internal(const FunExecContext& context) const {
	Corpus2::Tag out;
	foreach(boost::shared_ptr< Function<TSet> > expression, *expressions_) {
		Corpus2::Tag s = expression->apply(context)->get_value();
		out.combine_with(s);
	}
	return boost::make_shared<TSet>(out);
}

template <> inline
FunctionBase::BaseRetValPtr SetUnion<StrSet>::apply_internal(const FunExecContext& context) const {
	if (expressions_->size() == 1) return expressions_->front()->apply(context);
	boost::shared_ptr<StrSet> out = boost::make_shared<StrSet>();
	if (expressions_->empty()) return out;
	const boost::shared_ptr<const StrSet>& set1 = (*expressions_)[0]->apply(context);
	const boost::shared_ptr<const StrSet>& set2 = (*expressions_)[1]->apply(context);
	std::set_union(set1->contents().begin(), set1->contents().end(),
		set2->contents().begin(), set2->contents().end(),
		std::inserter(out->contents(), out->contents().begin()));

	for (size_t i = 2; i < expressions_->size(); ++i) {
		const boost::shared_ptr<const StrSet>& seti = (*expressions_)[i]->apply(context);
		foreach (const UnicodeString& s, seti->contents()) {
			out->insert(s);
		}
	}
	return out;
}

template <> inline
FunctionBase::BaseRetValPtr SetIntersection<TSet>::apply_internal(const FunExecContext& context) const {
	Corpus2::Tag out;
	if (!expressions_->empty()) {
		out = (*expressions_)[0]->apply(context)->get_value();
		for (size_t i = 1; i < expressions_->size(); ++i) {
			Corpus2::Tag s = (*expressions_)[i]->apply(context)->get_value();
			out.mask_with(s);
		}
	}
	return boost::make_shared<TSet>(out);
}

template <> inline
FunctionBase::BaseRetValPtr SetIntersection<StrSet>::apply_internal(const FunExecContext& context) const {
	if (expressions_->size() == 1) return expressions_->front()->apply(context);
	boost::shared_ptr<StrSet> out = boost::make_shared<StrSet>();
	if (expressions_->empty()) return out;
	const boost::shared_ptr<const StrSet>& set1 = (*expressions_)[0]->apply(context);
	const boost::shared_ptr<const StrSet>& set2 = (*expressions_)[1]->apply(context);
	std::set_intersection(set1->contents().begin(), set1->contents().end(),
		set2->contents().begin(), set2->contents().end(),
		std::inserter(out->contents(), out->contents().begin()));

	for (size_t i = 2; i < expressions_->size(); ++i) {
		boost::shared_ptr<StrSet> out2 = boost::make_shared<StrSet>();
		const boost::shared_ptr<const StrSet>& seti = (*expressions_)[i]->apply(context);
		std::set_intersection(seti->contents().begin(), seti->contents().end(),
			out->contents().begin(), out->contents().end(),
			std::inserter(out->contents(), out2->contents().begin()));
		out->contents().swap(out2->contents());
	}
	return out;
}

} /* end ns Wccl */

#endif // LIBWCCL_OPS_FUNCTIONS_SETOPS_H