diff --git a/cclutils/_annotations.py b/cclutils/_annotations.py index 916dfbdaf81c812dc41c3f5ee3be00f4bd687d89..470efe6daf41f0ac86934793a4a730b138b36aba 100644 --- a/cclutils/_annotations.py +++ b/cclutils/_annotations.py @@ -13,6 +13,7 @@ __all__ = [ 'get_annotations', 'get_annotation', 'set_annotation_for_token', + 'set_annotation_for_tokens', 'is_head_of' ] @@ -195,12 +196,34 @@ def set_annotation_for_token(sentence, token, key, value=None, set_head=False): value (int, bool): annotation number (convertible to integer) """ + idx = 0 if set_head else None + set_annotation_for_tokens(sentence, [token], key, value=value, head_index=idx) + + +def set_annotation_for_tokens(sentence, tokens, key, value=None, head_index=None): + """ + Set annotation for a group of tokens from same sentence. Tokens are treated + as representation of single expression and thus they get same annotation + number (value). + + Args: + sentence (Corpus2.Sentence) + tokens (list of Corpus2.Token) + key (str): a name for annotation channel + value (int, bool): annotation number (convertible to integer) + head_index(int): index of token from passed list (counting starts from 0), + which will be marked as a head of annotation. If not + given, then head will not be set. + + """ + if not tokens or not isinstance(tokens, list): + raise ValueError(f"List of tokens not given or invalid format: " + f"{tokens} (type: {type(tokens)})") ann_sentence = annotate_sentence(sentence) if key not in ann_sentence.all_channels(): ann_sentence.create_channel(key) channel = ann_sentence.get_channel(key) - token_index = _find_token(sentence, token) if value is not None: try: segment = int(value) @@ -208,9 +231,16 @@ def set_annotation_for_token(sentence, token, key, value=None, set_head=False): raise Exception("Wrong value type - should be an integer.") else: segment = channel.get_new_segment_index() - channel.set_segment_at(token_index, segment) - if set_head: - channel.set_head_at(token_index, True) + + for i, token in enumerate(tokens): + token_index = _find_token(sentence, token) + channel.set_segment_at(token_index, segment) + if head_index is not None: + if head_index < 0 or head_index >= len(tokens): + raise ValueError(f"head_index ({head_index}) does not match " + f"passed list of tokens.") + if i == head_index: + channel.set_head_at(token_index, True) def is_head_of(sentence, token, key):