diff --git a/libwccl/values/tset.cpp b/libwccl/values/tset.cpp index 79971c955931ba5358ad27fa7caa749d86d69ac1..499fb5ae9148f2b248045bcc0549f15b291271b5 100644 --- a/libwccl/values/tset.cpp +++ b/libwccl/values/tset.cpp @@ -1,5 +1,6 @@ #include <libwccl/values/tset.h> #include <libpwrutils/foreach.h> +#include <libpwrutils/bitset.h> #include <sstream> namespace Wccl { @@ -34,6 +35,12 @@ int TSet::categories_count(const Corpus2::Tagset& tagset) const return cats; } +int TSet::matching_categories(const Corpus2::Tag& tag) const +{ + const Corpus2::Tag& masked = tag_.get_masked(tag); + return PwrNlp::count_bits_set(masked.get_pos()) + PwrNlp::count_bits_set(masked.get_values()); +} + void TSet::insert_symbol(const Corpus2::Tagset& tagset, const std::string& s) { tag_.combine_with(tagset.parse_symbol(s)); diff --git a/libwccl/values/tset.h b/libwccl/values/tset.h index e0aa99ff9ff61fce6751ae2c758a803e3ceb926a..11f227932f65bee65b6286af079562f93fd202cd 100644 --- a/libwccl/values/tset.h +++ b/libwccl/values/tset.h @@ -89,6 +89,16 @@ public: */ int categories_count(const Corpus2::Tagset& tagset) const; + /** + * @return How many categories present in the supplied tag match with + * this symbol set. + * @warning The underlying assumption is that the supplied tag has at most + * 1 value per category. Otherwise the value will be incorrect. + * @note The symbol set may have partially defined categories. Only values + * present in this symbol set count when matching values in the tag. + */ + int matching_categories(const Corpus2::Tag& tag) const; + void combine_with(const Corpus2::Tag& other) { tag_.combine_with(other); } diff --git a/tests/values.cpp b/tests/values.cpp index 6bcaa81bed53a308f193a23273fc9df67041a2bd..f15247bb2a12808f1f7947bd8d292514100dcfa5 100644 --- a/tests/values.cpp +++ b/tests/values.cpp @@ -68,40 +68,61 @@ BOOST_AUTO_TEST_CASE(tset_ops) { TSet s1, s2; const Corpus2::Tagset& tagset = Corpus2::get_named_tagset("kipi"); + Corpus2::Tag subst_tag = tagset.parse_tag("subst:sg:nom:f", false)[0]; + Corpus2::Tag adj_tag = tagset.parse_tag("adj:pl:acc:m3:pos", false)[0]; + BOOST_CHECK(s1.equals(s2)); BOOST_CHECK(s1.is_subset_of(s2)); BOOST_CHECK(s2.is_subset_of(s1)); BOOST_CHECK(!s1.intersects(s2)); BOOST_CHECK_EQUAL(0, s1.categories_count(tagset)); + BOOST_CHECK_EQUAL(0, s1.matching_categories(subst_tag)); + BOOST_CHECK_EQUAL(0, s1.matching_categories(adj_tag)); s1.insert_symbol(tagset, "subst"); BOOST_CHECK_EQUAL(1, s1.categories_count(tagset)); + BOOST_CHECK_EQUAL(1, s1.matching_categories(subst_tag)); + BOOST_CHECK_EQUAL(0, s1.matching_categories(adj_tag)); BOOST_CHECK(!s1.equals(s2)); BOOST_CHECK(!s1.is_subset_of(s2)); BOOST_CHECK(s2.is_subset_of(s1)); BOOST_CHECK(!s1.intersects(s2)); s2.insert_symbol(tagset, "pl"); BOOST_CHECK_EQUAL(1, s2.categories_count(tagset)); + BOOST_CHECK_EQUAL(0, s2.matching_categories(subst_tag)); + BOOST_CHECK_EQUAL(1, s2.matching_categories(adj_tag)); BOOST_CHECK(!s1.equals(s2)); BOOST_CHECK(!s1.is_subset_of(s2)); BOOST_CHECK(!s2.is_subset_of(s1)); BOOST_CHECK(!s1.intersects(s2)); s2.insert_symbol(tagset, "subst"); BOOST_CHECK_EQUAL(2, s2.categories_count(tagset)); + BOOST_CHECK_EQUAL(1, s2.matching_categories(subst_tag)); + BOOST_CHECK_EQUAL(1, s2.matching_categories(adj_tag)); BOOST_CHECK(!s1.equals(s2)); BOOST_CHECK(s1.is_subset_of(s2)); BOOST_CHECK(!s2.is_subset_of(s1)); BOOST_CHECK(s1.intersects(s2)); s1.insert_symbol(tagset, "pl"); + BOOST_CHECK_EQUAL(2, s1.categories_count(tagset)); + BOOST_CHECK_EQUAL(1, s1.matching_categories(subst_tag)); + BOOST_CHECK_EQUAL(1, s1.matching_categories(adj_tag)); BOOST_CHECK(s1.equals(s2)); BOOST_CHECK(s1.is_subset_of(s2)); BOOST_CHECK(s2.is_subset_of(s1)); BOOST_CHECK(s1.intersects(s2)); s1.insert_symbol(tagset, "sg"); BOOST_CHECK_EQUAL(2, s1.categories_count(tagset)); + BOOST_CHECK_EQUAL(2, s1.matching_categories(subst_tag)); + BOOST_CHECK_EQUAL(1, s1.matching_categories(adj_tag)); s1.insert_symbol(tagset, "f"); BOOST_CHECK_EQUAL(3, s1.categories_count(tagset)); + BOOST_CHECK_EQUAL(3, s1.matching_categories(subst_tag)); + BOOST_CHECK_EQUAL(1, s1.matching_categories(adj_tag)); s1.insert_symbol(tagset, "adj"); BOOST_CHECK_EQUAL(3, s1.categories_count(tagset)); + BOOST_CHECK_EQUAL(3, s1.categories_count(tagset)); + BOOST_CHECK_EQUAL(2, s1.matching_categories(adj_tag)); + } BOOST_AUTO_TEST_CASE(position_ops)