Skip to content
Snippets Groups Projects

Punctuator v2

Merged
Michał Pogodarequested to merge
punctuator_v2 into master
7 open threads
9 files
+ 183
105
Compare changes
  • Side-by-side
  • Inline

Files

+ 56
22
@@ -63,8 +63,11 @@ def decode(tokens, labels_decoded, tokenizer):
return "".join(text_recovered)
def inference_masks(num_tokens: int, max_len: int, overlap: int) -> Tuple[List[List[bool]], List[List[bool]]]:
""" Splits text that is to long for predicting. The function provide list of masks for each prediction chunk
def inference_masks(
num_tokens: int, max_len: int, overlap: int
) -> Tuple[List[List[bool]], List[List[bool]]]:
"""Splits text that is to long for predicting. The function provide list
of masks for each prediction chunk
Args:
num_tokens (int): Number of tokens, including CLS & SEP
@@ -72,10 +75,14 @@ def inference_masks(num_tokens: int, max_len: int, overlap: int) -> Tuple[List[L
overlap (int): Ammout of overlapping between chunking windows
Returns:
Tuple[List[List[bool]], List[List[bool]]]: Masks for tokens provided for inference & for result of inference
Tuple[List[List[bool]], List[List[bool]]]: Masks for tokens provided
for inference & for result of inference
"""
if max_len >= num_tokens:
return [[True] * num_tokens], [[False] + [True] * (num_tokens - 2) + [False]]
return (
[[True] * num_tokens],
[[False] + [True] * (num_tokens - 2) + [False]],
)
# Account for CLS & SEP tokens
real_max_len = max_len - 2
@@ -88,22 +95,44 @@ def inference_masks(num_tokens: int, max_len: int, overlap: int) -> Tuple[List[L
for start_id in range(0, real_num_tokens, step_size):
stop = False
if start_id == 0:
entry = [True] + [True] * real_max_len + [False] * \
(real_num_tokens - real_max_len) + [True]
mask = [False] + [True] * \
(real_max_len - overlap) + [False] * (overlap + 1)
entry = (
[True]
+ [True] * real_max_len
+ [False] * (real_num_tokens - real_max_len)
+ [True]
)
mask = (
[False]
+ [True] * (real_max_len - overlap)
+ [False] * (overlap + 1)
)
elif start_id + real_max_len >= real_num_tokens:
offset_start = real_num_tokens - real_max_len
entry = [True] + [False] * \
(offset_start) + [True] * real_max_len + [True]
mask = [False] * (overlap + 1 + (start_id - offset_start)) + [True] * \
(real_max_len - overlap - (start_id - offset_start)) + [False]
entry = (
[True]
+ [False] * (offset_start)
+ [True] * real_max_len
+ [True]
)
mask = (
[False] * (overlap + 1 + (start_id - offset_start))
+ [True] * (real_max_len - overlap - (start_id - offset_start))
+ [False]
)
stop = True
else:
entry = [True] + [False] * start_id + [True] * real_max_len + \
[False] * (real_num_tokens - (start_id + real_max_len)) + [True]
mask = [False] * (overlap + 1) + [True] * \
(real_max_len - 2 * overlap) + [False] * (overlap + 1)
entry = (
[True]
+ [False] * start_id
+ [True] * real_max_len
+ [False] * (real_num_tokens - (start_id + real_max_len))
+ [True]
)
mask = (
[False] * (overlap + 1)
+ [True] * (real_max_len - 2 * overlap)
+ [False] * (overlap + 1)
)
masks.append(mask)
entries.append(entry)
@@ -114,8 +143,11 @@ def inference_masks(num_tokens: int, max_len: int, overlap: int) -> Tuple[List[L
return entries, masks
def combine_masks(num_tokens: int, max_len: int, overlap: int) -> List[List[bool]]:
"""Provides mask which tokens to take for each prediction. It makes sure that each token is only taken once & scored by best chunk.
def combine_masks(
num_tokens: int, max_len: int, overlap: int
) -> List[List[bool]]:
"""Provides mask which tokens to take for each prediction. It makes sure
that each token is only taken once & scored by best chunk.
Args:
num_tokens (int): Number of tokens, including CLS & SEP
@@ -135,10 +167,12 @@ def combine_masks(num_tokens: int, max_len: int, overlap: int) -> List[List[bool
stop = False
if start + max_len - 2 - overlap < num_tokens - 2:
entry = [False] + [False] * \
(start) + [True] * (max_len - 2 - overlap)
entry += [False] * (num_tokens - 2
- (start + max_len - 2 - overlap))
entry = (
[False] + [False] * (start) + [True] * (max_len - 2 - overlap)
)
entry += [False] * (
num_tokens - 2 - (start + max_len - 2 - overlap)
)
entry += [False]
else:
entry = [False] + [False] * (start)
Loading