diff --git a/punctuator/punctuator.py b/punctuator/punctuator.py index 1bea097cce60a7c77d6034cd76c0b26cda703e57..72041e06b7491f140d9cc1c452b73a1800227542 100644 --- a/punctuator/punctuator.py +++ b/punctuator/punctuator.py @@ -35,6 +35,8 @@ def decode(tokens, labels_decoded, tokenizer, bpe=False): for label, token in zip(labels_decoded, tokens): if bpe: token_str = tokenizer.decode(token) + if token_str.startswith(" "): + token_str = token_str[1:] else: token_str = tokenizer.convert_ids_to_tokens([token])[0] if token_str == "[PAD]": @@ -43,8 +45,7 @@ def decode(tokens, labels_decoded, tokenizer, bpe=False): word.append(token_str.replace("##", "")) else: if len(word) > 0: - if not bpe or word_end != ' ': - word.append(word_end) + word.append(word_end) text_recovered.append("".join(word)) word = [] if label.startswith("__ALL_UPPER__"):