Skip to content
Snippets Groups Projects
Select Git revision
  • 0f1104aece4bd2ef1c3c8083b5ec77f731b685eb
  • master default protected
  • develop protected
  • develop-0.7.x
  • develop-0.8.0
  • dev_czuk
  • loader
  • kgr10_roberta
  • 14-BiLSTM-CRF-RoBERTa
  • 12-handle-long-sequences
  • 13-flair-embeddings
  • BiLSTM
  • v0.7.0
  • v0.6.1
  • v0.5
  • v0.4.1
  • v0.3
17 results

conlleval.py

Blame
  • conlleval.py 7.37 KiB
    """
    Source: https://github.com/sighsmile/conlleval
    
    This script applies to IOB2 or IOBES tagging scheme.
    If you are using a different scheme, please convert to IOB2 or IOBES.
    
    IOB2:
    - B = begin, 
    - I = inside but not the first, 
    - O = outside
    
    e.g. 
    John   lives in New   York  City  .
    B-PER  O     O  B-LOC I-LOC I-LOC O
    
    IOBES:
    - B = begin, 
    - E = end, 
    - S = singleton, 
    - I = inside but not the first or the last, 
    - O = outside
    
    e.g.
    John   lives in New   York  City  .
    S-PER  O     O  B-LOC I-LOC E-LOC O
    
    prefix: IOBES
    chunk_type: PER, LOC, etc.
    """
    from __future__ import division, print_function, unicode_literals
    
    import sys
    from collections import defaultdict
    
    def split_tag(chunk_tag):
        """
        split chunk tag into IOBES prefix and chunk_type
        e.g. 
        B-PER -> (B, PER)
        O -> (O, None)
        """
        if chunk_tag == 'O':
            return ('O', None)
        return chunk_tag.split('-', maxsplit=1)
    
    def is_chunk_end(prev_tag, tag):
        """
        check if the previous chunk ended between the previous and current word
        e.g. 
        (B-PER, I-PER) -> False
        (B-LOC, O)  -> True
    
        Note: in case of contradicting tags, e.g. (B-PER, I-LOC)
        this is considered as (B-PER, B-LOC)
        """
        prefix1, chunk_type1 = split_tag(prev_tag)
        prefix2, chunk_type2 = split_tag(tag)
    
        if prefix1 == 'O':
            return False
        if prefix2 == 'O':
            return prefix1 != 'O'
    
        if chunk_type1 != chunk_type2:
            return True
    
        return prefix2 in ['B', 'S'] or prefix1 in ['E', 'S']
    
    def is_chunk_start(prev_tag, tag):
        """
        check if a new chunk started between the previous and current word
        """
        prefix1, chunk_type1 = split_tag(prev_tag)
        prefix2, chunk_type2 = split_tag(tag)
    
        if prefix2 == 'O':
            return False
        if prefix1 == 'O':
            return prefix2 != 'O'
    
        if chunk_type1 != chunk_type2:
            return True
    
        return prefix2 in ['B', 'S'] or prefix1 in ['E', 'S']
    
    
    def calc_metrics(tp, p, t, percent=True):
        """
        compute overall precision, recall and FB1 (default values are 0.0)
        if percent is True, return 100 * original decimal value
        """
        precision = tp / p if p else 0
        recall = tp / t if t else 0
        fb1 = 2 * precision * recall / (precision + recall) if precision + recall else 0
        if percent:
            return 100 * precision, 100 * recall, 100 * fb1
        else:
            return precision, recall, fb1
    
    
    def count_chunks(true_seqs, pred_seqs):
        """
        true_seqs: a list of true tags
        pred_seqs: a list of predicted tags
    
        return: 
        correct_chunks: a dict (counter), 
                        key = chunk types, 
                        value = number of correctly identified chunks per type
        true_chunks:    a dict, number of true chunks per type
        pred_chunks:    a dict, number of identified chunks per type
    
        correct_counts, true_counts, pred_counts: similar to above, but for tags
        """
        correct_chunks = defaultdict(int)
        true_chunks = defaultdict(int)
        pred_chunks = defaultdict(int)
    
        correct_counts = defaultdict(int)
        true_counts = defaultdict(int)
        pred_counts = defaultdict(int)
    
        prev_true_tag, prev_pred_tag = 'O', 'O'
        correct_chunk = None
    
        for true_tag, pred_tag in zip(true_seqs, pred_seqs):
            if true_tag == pred_tag:
                correct_counts[true_tag] += 1
            true_counts[true_tag] += 1
            pred_counts[pred_tag] += 1
    
            _, true_type = split_tag(true_tag)
            _, pred_type = split_tag(pred_tag)
    
            if correct_chunk is not None:
                true_end = is_chunk_end(prev_true_tag, true_tag)
                pred_end = is_chunk_end(prev_pred_tag, pred_tag)
    
                if pred_end and true_end:
                    correct_chunks[correct_chunk] += 1
                    correct_chunk = None
                elif pred_end != true_end or true_type != pred_type:
                    correct_chunk = None
    
            true_start = is_chunk_start(prev_true_tag, true_tag)
            pred_start = is_chunk_start(prev_pred_tag, pred_tag)
    
            if true_start and pred_start and true_type == pred_type:
                correct_chunk = true_type
            if true_start:
                true_chunks[true_type] += 1
            if pred_start:
                pred_chunks[pred_type] += 1
    
            prev_true_tag, prev_pred_tag = true_tag, pred_tag
        if correct_chunk is not None:
            correct_chunks[correct_chunk] += 1
    
        return (correct_chunks, true_chunks, pred_chunks, 
            correct_counts, true_counts, pred_counts)
    
    def get_result(correct_chunks, true_chunks, pred_chunks,
        correct_counts, true_counts, pred_counts, verbose=True):
        """
        if verbose, print overall performance, as well as preformance per chunk type;
        otherwise, simply return overall prec, rec, f1 scores
        """
        # sum counts
        sum_correct_chunks = sum(correct_chunks.values())
        sum_true_chunks = sum(true_chunks.values())
        sum_pred_chunks = sum(pred_chunks.values())
    
        sum_correct_counts = sum(correct_counts.values())
        sum_true_counts = sum(true_counts.values())
    
        nonO_correct_counts = sum(v for k, v in correct_counts.items() if k != 'O')
        nonO_true_counts = sum(v for k, v in true_counts.items() if k != 'O')
    
        chunk_types = sorted(list(set(list(true_chunks) + list(pred_chunks))))
    
        # compute overall precision, recall and FB1 (default values are 0.0)
        prec, rec, f1 = calc_metrics(sum_correct_chunks, sum_pred_chunks, sum_true_chunks)
        res = (prec, rec, f1)
        if not verbose:
            return res
    
        # print overall performance, and performance per chunk type
        
        print("processed %i tokens with %i phrases; " % (sum_true_counts, sum_true_chunks), end='')
        print("found: %i phrases; correct: %i.\n" % (sum_pred_chunks, sum_correct_chunks), end='')
            
        print("accuracy: %6.2f%%; (non-O)" % (100*nonO_correct_counts/nonO_true_counts))
        print("accuracy: %6.2f%%; " % (100*sum_correct_counts/sum_true_counts), end='')
        print("precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f" % (prec, rec, f1))
    
        # for each chunk type, compute precision, recall and FB1 (default values are 0.0)
        for t in chunk_types:
            prec, rec, f1 = calc_metrics(correct_chunks[t], pred_chunks[t], true_chunks[t])
            print("%17s: " %t , end='')
            print("precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f" %
                        (prec, rec, f1), end='')
            print("  %d" % pred_chunks[t])
    
        return res
        # you can generate LaTeX output for tables like in
        # http://cnts.uia.ac.be/conll2003/ner/example.tex
        # but I'm not implementing this
    
    def evaluate(true_seqs, pred_seqs, verbose=True):
        (correct_chunks, true_chunks, pred_chunks,
            correct_counts, true_counts, pred_counts) = count_chunks(true_seqs, pred_seqs)
        result = get_result(correct_chunks, true_chunks, pred_chunks,
            correct_counts, true_counts, pred_counts, verbose=verbose)
        return result
    
    def evaluate_conll_file(fileIterator):
        true_seqs, pred_seqs = [], []
        
        for line in fileIterator:
            cols = line.strip().split()
            # each non-empty line must contain >= 3 columns
            if not cols:
                true_seqs.append('O')
                pred_seqs.append('O')
            elif len(cols) < 3:
                raise IOError("conlleval: too few columns in line %s\n" % line)
            else:
                # extract tags from last 2 columns
                true_seqs.append(cols[-2])
                pred_seqs.append(cols[-1])
        return evaluate(true_seqs, pred_seqs)
    
    if __name__ == '__main__':
        """
        usage:     conlleval < file
        """
        evaluate_conll_file(sys.stdin)