From b28ead01d245b810d0e4da605b2e7fc799abc5da Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Thu, 10 Aug 2023 14:02:58 +0200 Subject: [PATCH] Add PretrainedTransformerTokenizer and PretrainedTransformerIndexer --- combo/common/__init__.py | 2 + combo/common/cached_transformers.py | 28 +++ combo/common/logging.py | 55 +++++ combo/common/tqdm.py | 99 ++++++++ combo/common/util.py | 14 ++ combo/data/dataset_readers/conll.py | 4 +- .../universal_dependencies_dataset_reader.py | 2 +- combo/data/fields/adjacency_field.py | 10 +- combo/data/fields/field.py | 2 +- combo/data/fields/label_field.py | 12 +- combo/data/fields/sequence_label_field.py | 10 +- .../data/fields/sequence_multilabel_field.py | 8 +- combo/data/token_embedders/__init__.py | 0 combo/data/token_embedders/embedding.py | 216 ++++++++++++++++++ combo/data/token_indexers/__init__.py | 3 + ...ed_transformer_fixed_mismatched_indexer.py | 63 ----- .../pretrained_transformer_indexer.py | 52 ++++- ...etrained_transformer_mismatched_indexer.py | 6 +- .../token_indexers/single_id_token_indexer.py | 47 ++-- .../token_characters_indexer.py | 4 +- .../pretrained_transformer_tokenizer.py | 123 ++++------ combo/data/tokenizers/token.py | 11 +- combo/data/vocabulary.py | 59 ++++- combo/utils/__init__.py | 3 +- combo/utils/file_utils.py | 6 + combo/utils/typing.py | 2 + 26 files changed, 634 insertions(+), 207 deletions(-) create mode 100644 combo/common/cached_transformers.py create mode 100644 combo/common/logging.py create mode 100644 combo/common/tqdm.py create mode 100644 combo/data/token_embedders/__init__.py create mode 100644 combo/data/token_embedders/embedding.py create mode 100644 combo/utils/typing.py diff --git a/combo/common/__init__.py b/combo/common/__init__.py index e69de29..211d539 100644 --- a/combo/common/__init__.py +++ b/combo/common/__init__.py @@ -0,0 +1,2 @@ +from .logging import * +from .tqdm import * diff --git a/combo/common/cached_transformers.py b/combo/common/cached_transformers.py new file mode 100644 index 0000000..627879c --- /dev/null +++ b/combo/common/cached_transformers.py @@ -0,0 +1,28 @@ +""" +Adapted from AllenNLP +https://github.com/allenai/allennlp/blob/main/allennlp/common/cached_transformers.py +""" +import logging +import transformers +from typing import Dict, Tuple + +from combo.common.util import hash_object + + +logger = logging.getLogger(__name__) + +_tokenizer_cache: Dict[Tuple[str, str], transformers.PreTrainedTokenizer] = {} + +def get_tokenizer(model_name: str, **kwargs) -> transformers.PreTrainedTokenizer: + + cache_key = (model_name, hash_object(kwargs)) + + global _tokenizer_cache + tokenizer = _tokenizer_cache.get(cache_key, None) + if tokenizer is None: + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_name, + **kwargs, + ) + _tokenizer_cache[cache_key] = tokenizer + return tokenizer diff --git a/combo/common/logging.py b/combo/common/logging.py new file mode 100644 index 0000000..3cb652e --- /dev/null +++ b/combo/common/logging.py @@ -0,0 +1,55 @@ +""" +Adapted from AllenNLP +https://github.com/allenai/allennlp/blob/main/allennlp/common/logging.py +""" + +import logging + + +class AllenNlpLogger(logging.Logger): + """ + A custom subclass of 'logging.Logger' that keeps a set of messages to + implement {debug,info,etc.}_once() methods. + """ + + def __init__(self, name): + super().__init__(name) + self._seen_msgs = set() + + def debug_once(self, msg, *args, **kwargs): + if msg not in self._seen_msgs: + self.debug(msg, *args, **kwargs) + self._seen_msgs.add(msg) + + def info_once(self, msg, *args, **kwargs): + if msg not in self._seen_msgs: + self.info(msg, *args, **kwargs) + self._seen_msgs.add(msg) + + def warning_once(self, msg, *args, **kwargs): + if msg not in self._seen_msgs: + self.warning(msg, *args, **kwargs) + self._seen_msgs.add(msg) + + def error_once(self, msg, *args, **kwargs): + if msg not in self._seen_msgs: + self.error(msg, *args, **kwargs) + self._seen_msgs.add(msg) + + def critical_once(self, msg, *args, **kwargs): + if msg not in self._seen_msgs: + self.critical(msg, *args, **kwargs) + self._seen_msgs.add(msg) + + +logging.setLoggerClass(AllenNlpLogger) +logger = logging.getLogger(__name__) + + +FILE_FRIENDLY_LOGGING: bool = False +""" +If this flag is set to `True`, we add newlines to tqdm output, even on an interactive terminal, and we slow +down tqdm's output to only once every 10 seconds. + +By default, it is set to `False`. +""" diff --git a/combo/common/tqdm.py b/combo/common/tqdm.py new file mode 100644 index 0000000..bb17281 --- /dev/null +++ b/combo/common/tqdm.py @@ -0,0 +1,99 @@ +""" +Adapted from AllenNLP +https://github.com/allenai/allennlp/blob/main/allennlp/common/tqdm.py +""" + +import logging +import sys +from time import time +from typing import Optional +import combo.common.logging as common_logging + +try: + SHELL = str(type(get_ipython())) # type:ignore # noqa: F821 +except: # noqa: E722 + SHELL = "" + + +if "zmqshell.ZMQInteractiveShell" in SHELL: + from tqdm import tqdm_notebook as _tqdm +else: + from tqdm import tqdm as _tqdm + +# This is necessary to stop tqdm from hanging +# when exceptions are raised inside iterators. +# It should have been fixed in 4.2.1, but it still +# occurs. +# TODO(Mark): Remove this once tqdm cleans up after itself properly. +# https://github.com/tqdm/tqdm/issues/469 +_tqdm.monitor_interval = 0 + + +logger = logging.getLogger("tqdm") +logger.propagate = False + + +def replace_cr_with_newline(message: str) -> str: + """ + TQDM and requests use carriage returns to get the training line to update for each batch + without adding more lines to the terminal output. Displaying those in a file won't work + correctly, so we'll just make sure that each batch shows up on its one line. + """ + # In addition to carriage returns, nested progress-bars will contain extra new-line + # characters and this special control sequence which tells the terminal to move the + # cursor one line up. + message = message.replace("\r", "").replace("\n", "").replace("[A", "") + if message and message[-1] != "\n": + message += "\n" + return message + + +class TqdmToLogsWriter(object): + def __init__(self): + self.last_message_written_time = 0.0 + + def write(self, message): + file_friendly_message: Optional[str] = None + if common_logging.FILE_FRIENDLY_LOGGING: + file_friendly_message = replace_cr_with_newline(message) + if file_friendly_message.strip(): + sys.stderr.write(file_friendly_message) + else: + sys.stderr.write(message) + + # Every 10 seconds we also log the message. + now = time() + if now - self.last_message_written_time >= 10 or "100%" in message: + if file_friendly_message is None: + file_friendly_message = replace_cr_with_newline(message) + for message in file_friendly_message.split("\n"): + message = message.strip() + if len(message) > 0: + logger.info(message) + self.last_message_written_time = now + + def flush(self): + sys.stderr.flush() + + +class Tqdm: + @staticmethod + def tqdm(*args, **kwargs): + # Use a slower interval when FILE_FRIENDLY_LOGGING is set. + default_mininterval = 2.0 if common_logging.FILE_FRIENDLY_LOGGING else 0.1 + + new_kwargs = { + "file": TqdmToLogsWriter(), + "mininterval": default_mininterval, + **kwargs, + } + + return _tqdm(*args, **new_kwargs) + + @staticmethod + def set_lock(lock): + _tqdm.set_lock(lock) + + @staticmethod + def get_lock(): + return _tqdm.get_lock() diff --git a/combo/common/util.py b/combo/common/util.py index 7cb5c88..64d440a 100644 --- a/combo/common/util.py +++ b/combo/common/util.py @@ -4,10 +4,24 @@ from itertools import islice import numpy import spacy import torch +import hashlib +import io +import base58 +import dill A = TypeVar("A") +def hash_object(o: Any) -> str: + """Adapted from AllenNLP""" + """Returns a character hash code of arbitrary Python objects.""" + m = hashlib.blake2b() + with io.BytesIO() as buffer: + dill.dump(o, buffer) + m.update(buffer.getbuffer()) + return base58.b58encode(m.digest()).decode() + + def int_to_device(device: Union[int, torch.device]) -> torch.device: if isinstance(device, torch.device): return device diff --git a/combo/data/dataset_readers/conll.py b/combo/data/dataset_readers/conll.py index da08e1e..d8c2c5d 100644 --- a/combo/data/dataset_readers/conll.py +++ b/combo/data/dataset_readers/conll.py @@ -59,7 +59,7 @@ class ConllDatasetReader(DatasetReader): feature_labels : `Sequence[str]`, optional (default=`()`) These labels will be loaded as features into the corresponding instance fields: `pos` -> `pos_tags`, `chunk` -> `chunk_tags`, `ner` -> `ner_tags` - Each will have its own namespace : `pos_tags`, `chunk_tags`, `ner_tags`. + Each will have its own _namespace : `pos_tags`, `chunk_tags`, `ner_tags`. If you want to use one of the tags as a `feature` in your model, it should be specified here. convert_to_coding_scheme : `Optional[str]`, optional (default=`None`) @@ -73,7 +73,7 @@ class ConllDatasetReader(DatasetReader): Specifies the coding scheme of the input file. Valid options are `IOB1` and `IOB2`. label_namespace : `str`, optional (default=`labels`) - Specifies the namespace for the chosen `tag_label`. + Specifies the _namespace for the chosen `tag_label`. """ _VALID_LABELS = {"ner", "pos", "chunk"} diff --git a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py index 3a503d4..33bad6d 100644 --- a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py +++ b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py @@ -179,7 +179,7 @@ class UniversalDependenciesDatasetReader(DatasetReader, ABC): indices=enhanced_heads, sequence_field=text_field_deps, labels=enhanced_deprels, - # Label namespace matches regular tree parsing. + # Label _namespace matches regular tree parsing. label_namespace="enhanced_deprel_labels", padding_value=0, ) diff --git a/combo/data/fields/adjacency_field.py b/combo/data/fields/adjacency_field.py index e4bc4ab..e12a837 100644 --- a/combo/data/fields/adjacency_field.py +++ b/combo/data/fields/adjacency_field.py @@ -34,7 +34,7 @@ class AdjacencyField(Field[torch.Tensor]): labels : `List[str]`, optional, (default = `None`) Optional labels for the edges of the adjacency matrix. label_namespace : `str`, optional (default=`'labels'`) - The namespace to use for converting tag strings into integers. We convert tag strings to + The _namespace to use for converting tag strings into integers. We convert tag strings to integers for you, and this parameter tells the `Vocabulary` object which mapping from strings to integers to use (so that "O" as a tag doesn't get the same idx as "O" as a word). padding_value : `int`, optional (default = `-1`) @@ -50,10 +50,10 @@ class AdjacencyField(Field[torch.Tensor]): "_indexed_labels", ] - # It is possible that users want to use this field with a namespace which uses OOV/PAD tokens. + # It is possible that users want to use this field with a _namespace which uses OOV/PAD tokens. # This warning will be repeated for every instantiation of this class (i.e for every data # instance), spewing a lot of warnings so this class variable is used to only log a single - # warning per namespace. + # warning per _namespace. _already_warned_namespaces: Set[str] = set() def __init__( @@ -95,7 +95,7 @@ class AdjacencyField(Field[torch.Tensor]): if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")): if label_namespace not in self._already_warned_namespaces: logger.warning( - "Your label namespace was '%s'. We recommend you use a namespace " + "Your label _namespace was '%s'. We recommend you use a _namespace " "ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by " "default to your vocabulary. See documentation for " "`non_padded_namespaces` parameter in Vocabulary.", @@ -146,7 +146,7 @@ class AdjacencyField(Field[torch.Tensor]): return ( f"AdjacencyField of length {length}\n" f"\t\twith indices:\n {formatted_indices}\n" - f"\t\tand labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'." + f"\t\tand labels:\n {formatted_labels} \t\tin _namespace: '{self._label_namespace}'." ) def __len__(self): diff --git a/combo/data/fields/field.py b/combo/data/fields/field.py index 952cf97..2e890f8 100644 --- a/combo/data/fields/field.py +++ b/combo/data/fields/field.py @@ -48,7 +48,7 @@ class Field(Generic[DataArray]): be represented as a combination of word ids and character ids, and you don't want words and characters to share the same vocabulary - "a" as a word should get a different idx from "a" as a character, and the vocabulary sizes of words and characters are very different. - Because of this, the first key in the `counter` object is a `namespace`, like "tokens", + Because of this, the first key in the `counter` object is a `_namespace`, like "tokens", "token_characters", "tags", or "labels", and the second key is the actual vocabulary item. """ pass diff --git a/combo/data/fields/label_field.py b/combo/data/fields/label_field.py index fabe242..68f6088 100644 --- a/combo/data/fields/label_field.py +++ b/combo/data/fields/label_field.py @@ -26,9 +26,9 @@ class LabelField(Field[torch.Tensor]): # Parameters label : `Union[str, int]` label_namespace : `str`, optional (default=`"labels"`) - The namespace to use for converting label strings into integers. We map label strings to + The _namespace to use for converting label strings into integers. We map label strings to integers for you (e.g., "entailment" and "contradiction" get converted to 0, 1, ...), - and this namespace tells the `Vocabulary` object which mapping from strings to integers + and this _namespace tells the `Vocabulary` object which mapping from strings to integers to use (so "entailment" as a label doesn't get the same integer idx as "entailment" as a word). If you have multiple different label fields in your data, you should make sure you use different namespaces for each one, always using the suffix "labels" (e.g., @@ -41,9 +41,9 @@ class LabelField(Field[torch.Tensor]): __slots__ = ["label", "_label_namespace", "_label_id", "_skip_indexing"] # Most often, you probably don't want to have OOV/PAD tokens with a LabelField, so we warn you - # about it when you pick a namespace that will getting these tokens by default. It is + # about it when you pick a _namespace that will getting these tokens by default. It is # possible, however, that you _do_ actually want OOV/PAD tokens with this Field. This class - # variable is used to make sure that we only log a single warning for this per namespace, and + # variable is used to make sure that we only log a single warning for this per _namespace, and # not every time you create one of these Field objects. _already_warned_namespaces: Set[str] = set() @@ -73,7 +73,7 @@ class LabelField(Field[torch.Tensor]): if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")): if label_namespace not in self._already_warned_namespaces: logger.warning( - "Your label namespace was '%s'. We recommend you use a namespace " + "Your label _namespace was '%s'. We recommend you use a _namespace " "ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by " "default to your vocabulary. See documentation for " "`non_padded_namespaces` parameter in Vocabulary.", @@ -105,7 +105,7 @@ class LabelField(Field[torch.Tensor]): return self.label def __str__(self) -> str: - return f"LabelField with label: {self.label} in namespace: '{self._label_namespace}'." + return f"LabelField with label: {self.label} in _namespace: '{self._label_namespace}'." def __len__(self): return 1 diff --git a/combo/data/fields/sequence_label_field.py b/combo/data/fields/sequence_label_field.py index 99c032e..60d30c6 100644 --- a/combo/data/fields/sequence_label_field.py +++ b/combo/data/fields/sequence_label_field.py @@ -35,7 +35,7 @@ class SequenceLabelField(Field[torch.Tensor]): A field containing the sequence that this `SequenceLabelField` is labeling. Most often, this is a `TextField`, for tagging individual tokens in a sentence. label_namespace : `str`, optional (default=`'labels'`) - The namespace to use for converting tag strings into integers. We convert tag strings to + The _namespace to use for converting tag strings into integers. We convert tag strings to integers for you, and this parameter tells the `Vocabulary` object which mapping from strings to integers to use (so that "O" as a tag doesn't get the same idx as "O" as a word). """ @@ -48,10 +48,10 @@ class SequenceLabelField(Field[torch.Tensor]): "_skip_indexing", ] - # It is possible that users want to use this field with a namespace which uses OOV/PAD tokens. + # It is possible that users want to use this field with a _namespace which uses OOV/PAD tokens. # This warning will be repeated for every instantiation of this class (i.e for every data # instance), spewing a lot of warnings so this class variable is used to only log a single - # warning per namespace. + # warning per _namespace. _already_warned_namespaces: Set[str] = set() def __init__( @@ -87,7 +87,7 @@ class SequenceLabelField(Field[torch.Tensor]): if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")): if label_namespace not in self._already_warned_namespaces: logger.warning( - "Your label namespace was '%s'. We recommend you use a namespace " + "Your label _namespace was '%s'. We recommend you use a _namespace " "ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by " "default to your vocabulary. See documentation for " "`non_padded_namespaces` parameter in Vocabulary.", @@ -144,7 +144,7 @@ class SequenceLabelField(Field[torch.Tensor]): ) return ( f"SequenceLabelField of length {length} with " - f"labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'." + f"labels:\n {formatted_labels} \t\tin _namespace: '{self._label_namespace}'." ) def human_readable_repr(self) -> Union[List[str], List[int]]: diff --git a/combo/data/fields/sequence_multilabel_field.py b/combo/data/fields/sequence_multilabel_field.py index d0c19d8..4938d43 100644 --- a/combo/data/fields/sequence_multilabel_field.py +++ b/combo/data/fields/sequence_multilabel_field.py @@ -44,9 +44,9 @@ class SequenceMultiLabelField(Field[torch.Tensor]): A field containing the sequence that this `SequenceMultiLabelField` is labeling. Most often, this is a `TextField`, for tagging individual tokens in a sentence. label_namespace : `str`, optional (default="labels") - The namespace to use for converting label strings into integers. We map label strings to + The _namespace to use for converting label strings into integers. We map label strings to integers for you (e.g., "entailment" and "contradiction" get converted to 0, 1, ...), - and this namespace tells the `Vocabulary` object which mapping from strings to integers + and this _namespace tells the `Vocabulary` object which mapping from strings to integers to use (so "entailment" as a label doesn't get the same integer idx as "entailment" as a word). If you have multiple different label fields in your data, you should make sure you use different namespaces for each one, always using the suffix "labels" (e.g., @@ -86,7 +86,7 @@ class SequenceMultiLabelField(Field[torch.Tensor]): if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")): if label_namespace not in self._already_warned_namespaces: logger.warning( - "Your label namespace was '%s'. We recommend you use a namespace " + "Your label _namespace was '%s'. We recommend you use a _namespace " "ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by " "default to your vocabulary. See documentation for " "`non_padded_namespaces` parameter in Vocabulary.", @@ -144,5 +144,5 @@ class SequenceMultiLabelField(Field[torch.Tensor]): ) return ( f"SequenceMultiLabelField of length {length} with " - f"labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'." + f"labels:\n {formatted_labels} \t\tin _namespace: '{self._label_namespace}'." ) diff --git a/combo/data/token_embedders/__init__.py b/combo/data/token_embedders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/combo/data/token_embedders/embedding.py b/combo/data/token_embedders/embedding.py new file mode 100644 index 0000000..ae17b56 --- /dev/null +++ b/combo/data/token_embedders/embedding.py @@ -0,0 +1,216 @@ +""" +Adapted from AllenNLP +https://github.com/allenai/allennlp/blob/main/allennlp/modules/token_embedders/embedding.py +""" +import io +import itertools +import logging +import re +import tarfile +import zipfile +from typing import Any, Iterator, Tuple, NamedTuple, Optional, BinaryIO, Sequence + +from combo.utils.typing import cast +from combo.utils.file_utils import cached_path, get_file_extension + +logger = logging.getLogger(__name__) + + +class EmbeddingsFileURI(NamedTuple): + main_file_uri: str + path_inside_archive: Optional[str] = None + + +def format_embeddings_file_uri( + main_file_path_or_url: str, path_inside_archive: Optional[str] = None +) -> str: + if path_inside_archive: + return "({})#{}".format(main_file_path_or_url, path_inside_archive) + return main_file_path_or_url + + +def parse_embeddings_file_uri(uri: str) -> "EmbeddingsFileURI": + match = re.fullmatch(r"\((.*)\)#(.*)", uri) + if match: + fields = cast(Tuple[str, str], match.groups()) + return EmbeddingsFileURI(*fields) + else: + return EmbeddingsFileURI(uri, None) + + +class EmbeddingsTextFile(Iterator[str]): + """ + Utility class for opening embeddings text files. Handles various compression formats, + as well as context management. + + # Parameters + + file_uri : `str` + It can be: + + * a file system path or a URL of an eventually compressed text file or a zip/tar archive + containing a single file. + * URI of the type `(archive_path_or_url)#file_path_inside_archive` if the text file + is contained in a multi-file archive. + + encoding : `str` + cache_dir : `str` + """ + + DEFAULT_ENCODING = "utf-8" + + def __init__( + self, file_uri: str, encoding: str = DEFAULT_ENCODING, cache_dir: str = None + ) -> None: + + self.uri = file_uri + self._encoding = encoding + self._cache_dir = cache_dir + self._archive_handle: Any = None # only if the file is inside an archive + + main_file_uri, path_inside_archive = parse_embeddings_file_uri(file_uri) + main_file_local_path = cached_path(main_file_uri, cache_dir=cache_dir) + + if zipfile.is_zipfile(main_file_local_path): # ZIP archive + self._open_inside_zip(main_file_uri, path_inside_archive) + + elif tarfile.is_tarfile(main_file_local_path): # TAR archive + self._open_inside_tar(main_file_uri, path_inside_archive) + + else: # all the other supported formats, including uncompressed files + if path_inside_archive: + raise ValueError("Unsupported archive format: %s" + main_file_uri) + + # All the python packages for compressed files share the same interface of io.open + extension = get_file_extension(main_file_uri) + + # Some systems don't have support for all of these libraries, so we import them only + # when necessary. + package = None + if extension in [".txt", ".vec"]: + package = io + elif extension == ".gz": + import gzip + + package = gzip + elif extension == ".bz2": + import bz2 + + package = bz2 + elif extension == ".xz": + import lzma + + package = lzma + + if package is None: + logger.warning( + 'The embeddings file has an unknown file extension "%s". ' + "We will assume the file is an (uncompressed) text file", + extension, + ) + package = io + + self._handle = package.open( # type: ignore + main_file_local_path, "rt", encoding=encoding + ) + + # To use this with tqdm we'd like to know the number of tokens. It's possible that the + # first line of the embeddings file contains this: if it does, we want to start iteration + # from the 2nd line, otherwise we want to start from the 1st. + # Unfortunately, once we read the first line, we cannot move back the file iterator + # because the underlying file may be "not seekable"; we use itertools.chain instead. + first_line = next(self._handle) # this moves the iterator forward + self.num_tokens = EmbeddingsTextFile._get_num_tokens_from_first_line(first_line) + if self.num_tokens: + # the first line is a header line: start iterating from the 2nd line + self._iterator = self._handle + else: + # the first line is not a header line: start iterating from the 1st line + self._iterator = itertools.chain([first_line], self._handle) + + def _open_inside_zip(self, archive_path: str, member_path: Optional[str] = None) -> None: + cached_archive_path = cached_path(archive_path, cache_dir=self._cache_dir) + archive = zipfile.ZipFile(cached_archive_path, "r") + if member_path is None: + members_list = archive.namelist() + member_path = self._get_the_only_file_in_the_archive(members_list, archive_path) + member_path = cast(str, member_path) + member_file = cast(BinaryIO, archive.open(member_path, "r")) + self._handle = io.TextIOWrapper(member_file, encoding=self._encoding) + self._archive_handle = archive + + def _open_inside_tar(self, archive_path: str, member_path: Optional[str] = None) -> None: + cached_archive_path = cached_path(archive_path, cache_dir=self._cache_dir) + archive = tarfile.open(cached_archive_path, "r") + if member_path is None: + members_list = archive.getnames() + member_path = self._get_the_only_file_in_the_archive(members_list, archive_path) + member_path = cast(str, member_path) + member = archive.getmember(member_path) # raises exception if not present + member_file = cast(BinaryIO, archive.extractfile(member)) + self._handle = io.TextIOWrapper(member_file, encoding=self._encoding) + self._archive_handle = archive + + def read(self) -> str: + return "".join(self._iterator) + + def readline(self) -> str: + return next(self._iterator) + + def close(self) -> None: + self._handle.close() + if self._archive_handle: + self._archive_handle.close() + + def __enter__(self) -> "EmbeddingsTextFile": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() + + def __iter__(self) -> "EmbeddingsTextFile": + return self + + def __next__(self) -> str: + return next(self._iterator) + + def __len__(self) -> Optional[int]: + if self.num_tokens: + return self.num_tokens + raise AttributeError( + "an object of type EmbeddingsTextFile implements `__len__` only if the underlying " + "text file declares the number of tokens (i.e. the number of lines following)" + "in the first line. That is not the case of this particular instance." + ) + + @staticmethod + def _get_the_only_file_in_the_archive(members_list: Sequence[str], archive_path: str) -> str: + if len(members_list) > 1: + raise ValueError( + "The archive %s contains multiple files, so you must select " + "one of the files inside providing a uri of the type: %s." + % ( + archive_path, + format_embeddings_file_uri("path_or_url_to_archive", "path_inside_archive"), + ) + ) + return members_list[0] + + @staticmethod + def _get_num_tokens_from_first_line(line: str) -> Optional[int]: + """This function takes in input a string and if it contains 1 or 2 integers, it assumes the + largest one it the number of tokens. Returns None if the line doesn't match that pattern.""" + fields = line.split(" ") + if 1 <= len(fields) <= 2: + try: + int_fields = [int(x) for x in fields] + except ValueError: + return None + else: + num_tokens = max(int_fields) + logger.info( + "Recognized a header line in the embedding file with number of tokens: %d", + num_tokens, + ) + return num_tokens + return None diff --git a/combo/data/token_indexers/__init__.py b/combo/data/token_indexers/__init__.py index 6e993b9..75df1b3 100644 --- a/combo/data/token_indexers/__init__.py +++ b/combo/data/token_indexers/__init__.py @@ -1,3 +1,6 @@ from .token_indexer import IndexedTokenList, TokenIndexer from .token_features_indexer import TokenFeatsIndexer from .single_id_token_indexer import SingleIdTokenIndexer +from .pretrained_transformer_indexer import PretrainedTransformerIndexer +from .pretrained_transformer_mismatched_indexer import PretrainedTransformerMismatchedIndexer +from .pretrained_transformer_fixed_mismatched_indexer import PretrainedTransformerFixedMismatchedIndexer diff --git a/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py index 6366a1e..af7bb0a 100644 --- a/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py +++ b/combo/data/token_indexers/pretrained_transformer_fixed_mismatched_indexer.py @@ -8,11 +8,9 @@ from typing import Optional, Dict, Any, List, Tuple from overrides import overrides from combo.data import Vocabulary -from combo.data.tokenizers import Token from combo.data.token_indexers import IndexedTokenList from combo.data.token_indexers.pretrained_transformer_indexer import PretrainedTransformerIndexer from combo.data.token_indexers.pretrained_transformer_mismatched_indexer import PretrainedTransformerMismatchedIndexer -from combo.data.tokenizers.pretrained_transformer_tokenizer import PretrainedTransformerTokenizer class PretrainedTransformerFixedMismatchedIndexer(PretrainedTransformerMismatchedIndexer): @@ -63,64 +61,3 @@ class PretrainedTransformerFixedMismatchedIndexer(PretrainedTransformerMismatche } return self._matched_indexer._postprocess_output(output) - - -class PretrainedTransformerIndexer(PretrainedTransformerIndexer): - - def __init__( - self, - model_name: str, - namespace: str = "tags", - max_length: int = None, - tokenizer_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> None: - super().__init__(model_name, namespace, max_length, tokenizer_kwargs, **kwargs) - self._namespace = namespace - self._allennlp_tokenizer = PretrainedTransformerTokenizer( - model_name, tokenizer_kwargs=tokenizer_kwargs - ) - self._tokenizer = self._allennlp_tokenizer.tokenizer - self._added_to_vocabulary = False - - self._num_added_start_tokens = len(self._allennlp_tokenizer.single_sequence_start_tokens) - self._num_added_end_tokens = len(self._allennlp_tokenizer.single_sequence_end_tokens) - - self._max_length = max_length - if self._max_length is not None: - num_added_tokens = len(self._allennlp_tokenizer.tokenize("a")) - 1 - self._effective_max_length = ( # we need to take into account special tokens - self._max_length - num_added_tokens - ) - if self._effective_max_length <= 0: - raise ValueError( - "max_length needs to be greater than the number of special tokens inserted." - ) - - -class PretrainedTransformerTokenizer(PretrainedTransformerTokenizer): - - def _intra_word_tokenize( - self, string_tokens: List[str] - ) -> Tuple[List[Token], List[Optional[Tuple[int, int]]]]: - tokens: List[Token] = [] - offsets: List[Optional[Tuple[int, int]]] = [] - for token_string in string_tokens: - wordpieces = self.tokenizer.encode_plus( - token_string, - add_special_tokens=False, - return_tensors=None, - return_offsets_mapping=False, - return_attention_mask=False, - ) - wp_ids = wordpieces["input_ids"] - - if len(wp_ids) > 0: - offsets.append((len(tokens), len(tokens) + len(wp_ids) - 1)) - tokens.extend( - Token(text=wp_text, text_id=wp_id) - for wp_id, wp_text in zip(wp_ids, self.tokenizer.convert_ids_to_tokens(wp_ids)) - ) - else: - offsets.append(None) - return tokens, offsets diff --git a/combo/data/token_indexers/pretrained_transformer_indexer.py b/combo/data/token_indexers/pretrained_transformer_indexer.py index 1c5302c..6580f4b 100644 --- a/combo/data/token_indexers/pretrained_transformer_indexer.py +++ b/combo/data/token_indexers/pretrained_transformer_indexer.py @@ -27,10 +27,10 @@ class PretrainedTransformerIndexer(TokenIndexer): # Parameters model_name : `str` The name of the `transformers` model to use. - namespace : `str`, optional (default=`tags`) - We will add the tokens in the pytorch_transformer vocabulary to this vocabulary namespace. + _namespace : `str`, optional (default=`tags`) + We will add the tokens in the pytorch_transformer vocabulary to this vocabulary _namespace. We use a somewhat confusing default value of `tags` so that we do not add padding or UNK - tokens to this namespace, which would break on loading because we wouldn't find our default + tokens to this _namespace, which would break on loading because we wouldn't find our default OOV token. max_length : `int`, optional (default = `None`) If not None, split the document into segments of this many tokens (including special tokens) @@ -48,13 +48,17 @@ class PretrainedTransformerIndexer(TokenIndexer): model_name: str, namespace: str = "tags", max_length: int = None, + add_special_tokens: bool = False, tokenizer_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ) -> None: super().__init__(**kwargs) self._namespace = namespace + self._model_name = model_name + self._add_special_tokens = add_special_tokens + self._tokenizer_kwargs = tokenizer_kwargs self._allennlp_tokenizer = PretrainedTransformerTokenizer( - model_name, tokenizer_kwargs=tokenizer_kwargs + model_name, add_special_tokens=add_special_tokens, tokenizer_kwargs=tokenizer_kwargs ) self._tokenizer = self._allennlp_tokenizer.tokenizer self._added_to_vocabulary = False @@ -75,7 +79,7 @@ class PretrainedTransformerIndexer(TokenIndexer): def _add_encoding_to_vocabulary_if_needed(self, vocab: Vocabulary) -> None: """ - Copies tokens from ```transformers``` model's vocab to the specified namespace. + Copies tokens from ```transformers``` model's vocab to the specified _namespace. """ if self._added_to_vocabulary: return @@ -236,6 +240,35 @@ class PretrainedTransformerIndexer(TokenIndexer): tensor_dict[key] = tensor return tensor_dict + def _intra_word_tokenize( + self, string_tokens: List[str] + ) -> Tuple[List[Token], List[Optional[Tuple[int, int]]]]: + """ + Adapted from COMBO + Authors: Mateusz Klimaszewski, Lukasz Pszenny + """ + tokens: List[Token] = [] + offsets: List[Optional[Tuple[int, int]]] = [] + for token_string in string_tokens: + wordpieces = self.tokenizer.encode_plus( + token_string, + add_special_tokens=False, + return_tensors=None, + return_offsets_mapping=False, + return_attention_mask=False, + ) + wp_ids = wordpieces["input_ids"] + + if len(wp_ids) > 0: + offsets.append((len(tokens), len(tokens) + len(wp_ids) - 1)) + tokens.extend( + Token(text=wp_text, text_id=wp_id) + for wp_id, wp_text in zip(wp_ids, self._tokenizer.convert_ids_to_tokens(wp_ids)) + ) + else: + offsets.append(None) + return tokens, offsets + def __eq__(self, other): if isinstance(other, PretrainedTransformerIndexer): for key in self.__dict__: @@ -247,3 +280,12 @@ class PretrainedTransformerIndexer(TokenIndexer): return False return True return NotImplemented + + def _to_params(self) -> Dict[str, Any]: + return { + "namespace": self._namespace, + "model_name": self._model_name, + "add_special_tokens": self._add_special_tokens, + "max_length": self._max_length, + "tokenizer_kwargs": self._tokenizer_kwargs, + } diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py index f510517..6d660e7 100644 --- a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py +++ b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py @@ -29,10 +29,10 @@ class PretrainedTransformerMismatchedIndexer(TokenIndexer): # Parameters model_name : `str` The name of the `transformers` model to use. - namespace : `str`, optional (default=`tags`) - We will add the tokens in the pytorch_transformer vocabulary to this vocabulary namespace. + _namespace : `str`, optional (default=`tags`) + We will add the tokens in the pytorch_transformer vocabulary to this vocabulary _namespace. We use a somewhat confusing default value of `tags` so that we do not add padding or UNK - tokens to this namespace, which would break on loading because we wouldn't find our default + tokens to this _namespace, which would break on loading because we wouldn't find our default OOV token. max_length : `int`, optional (default = `None`) If positive, split the document into segments of this many tokens (including special tokens) diff --git a/combo/data/token_indexers/single_id_token_indexer.py b/combo/data/token_indexers/single_id_token_indexer.py index 576ae8a..143c786 100644 --- a/combo/data/token_indexers/single_id_token_indexer.py +++ b/combo/data/token_indexers/single_id_token_indexer.py @@ -12,17 +12,18 @@ from combo.data.token_indexers import TokenIndexer, IndexedTokenList _DEFAULT_VALUE = "THIS IS A REALLY UNLIKELY VALUE THAT HAS TO BE A STRING" + class SingleIdTokenIndexer(TokenIndexer): """ This :class:`TokenIndexer` represents tokens as single integers. Registered as a `TokenIndexer` with name "single_id". # Parameters - namespace : `Optional[str]`, optional (default=`"tokens"`) - We will use this namespace in the :class:`Vocabulary` to map strings to indices. If you + _namespace : `Optional[str]`, optional (default=`"tokens"`) + We will use this _namespace in the :class:`Vocabulary` to map strings to indices. If you explicitly pass in `None` here, we will skip indexing and vocabulary lookups. This means that the `feature_name` you use must correspond to an integer value (like `text_id`, for instance, which gets set by some tokenizers, such as when using byte encoding). - lowercase_tokens : `bool`, optional (default=`False`) + _lowercase_tokens : `bool`, optional (default=`False`) If `True`, we will call `token.lower()` before getting an index for the token from the vocabulary. start_tokens : `List[str]`, optional (default=`None`) @@ -33,7 +34,7 @@ class SingleIdTokenIndexer(TokenIndexer): We will use the :class:`Token` attribute with this name as input. This is potentially useful, e.g., for using NER tags instead of (or in addition to) surface forms as your inputs (passing `ent_type_` here would do that). If you use a non-default value here, you almost - certainly want to also change the `namespace` parameter, and you might want to give a + certainly want to also change the `_namespace` parameter, and you might want to give a `default_value`. default_value : `str`, optional When you want to use a non-default `feature_name`, you sometimes want to have a default @@ -46,18 +47,18 @@ class SingleIdTokenIndexer(TokenIndexer): """ def __init__( - self, - namespace: Optional[str] = "tokens", - lowercase_tokens: bool = False, - start_tokens: List[str] = None, - end_tokens: List[str] = None, - feature_name: str = "text", - default_value: str = _DEFAULT_VALUE, - token_min_padding_length: int = 0, + self, + namespace: Optional[str] = "tokens", + lowercase_tokens: bool = False, + start_tokens: List[str] = None, + end_tokens: List[str] = None, + feature_name: str = "text", + default_value: str = _DEFAULT_VALUE, + token_min_padding_length: int = 0, ) -> None: super().__init__(token_min_padding_length) - self.namespace = namespace - self.lowercase_tokens = lowercase_tokens + self._namespace = namespace + self._lowercase_tokens = lowercase_tokens self._start_tokens = [Token(st) for st in (start_tokens or [])] self._end_tokens = [Token(et) for et in (end_tokens or [])] @@ -65,26 +66,26 @@ class SingleIdTokenIndexer(TokenIndexer): self._default_value = default_value def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]): - if self.namespace is not None: + if self._namespace is not None: text = self._get_feature_value(token) - if self.lowercase_tokens: + if self._lowercase_tokens: text = text.lower() - counter[self.namespace][text] += 1 + counter[self._namespace][text] += 1 def tokens_to_indices( - self, tokens: List[Token], vocabulary: Vocabulary + self, tokens: List[Token], vocabulary: Vocabulary ) -> Dict[str, List[int]]: indices: List[int] = [] for token in itertools.chain(self._start_tokens, tokens, self._end_tokens): text = self._get_feature_value(token) - if self.namespace is None: + if self._namespace is None: # We could have a check here that `text` is an int; not sure it's worth it. indices.append(text) # type: ignore else: - if self.lowercase_tokens: + if self._lowercase_tokens: text = text.lower() - indices.append(vocabulary.get_token_index(text, self.namespace)) + indices.append(vocabulary.get_token_index(text, self._namespace)) return {"tokens": indices} @@ -106,8 +107,8 @@ class SingleIdTokenIndexer(TokenIndexer): def _to_params(self) -> Dict[str, Any]: return { - "namespace": self.namespace, - "lowercase_tokens": self.lowercase_tokens, + "namespace": self._namespace, + "lowercase_tokens": self._lowercase_tokens, "start_tokens": [t.text for t in self._start_tokens], "end_tokens": [t.text for t in self._end_tokens], "feature_name": self._feature_name, diff --git a/combo/data/token_indexers/token_characters_indexer.py b/combo/data/token_indexers/token_characters_indexer.py index 30197be..1227f2e 100644 --- a/combo/data/token_indexers/token_characters_indexer.py +++ b/combo/data/token_indexers/token_characters_indexer.py @@ -24,8 +24,8 @@ class TokenCharactersIndexer(TokenIndexer): # Parameters - namespace : `str`, optional (default=`token_characters`) - We will use this namespace in the :class:`Vocabulary` to map the characters in each token + _namespace : `str`, optional (default=`token_characters`) + We will use this _namespace in the :class:`Vocabulary` to map the characters in each token to indices. character_tokenizer : `CharacterTokenizer`, optional (default=`CharacterTokenizer()`) We use a :class:`CharacterTokenizer` to handle splitting tokens into characters, as it has diff --git a/combo/data/tokenizers/pretrained_transformer_tokenizer.py b/combo/data/tokenizers/pretrained_transformer_tokenizer.py index 093bfca..6672460 100644 --- a/combo/data/tokenizers/pretrained_transformer_tokenizer.py +++ b/combo/data/tokenizers/pretrained_transformer_tokenizer.py @@ -5,12 +5,8 @@ https://github.com/allenai/allennlp/blob/main/allennlp/data/tokenizers/pretraine import copy import dataclasses -import hashlib -import io import logging from typing import Any, Dict, List, Optional, Tuple, Iterable -import base58 -import dill from transformers import PreTrainedTokenizer, AutoTokenizer @@ -20,27 +16,6 @@ from combo.utils import sanitize_wordpiece logger = logging.getLogger(__name__) -def hash_object(o: Any) -> str: - """Returns a character hash code of arbitrary Python objects.""" - m = hashlib.blake2b() - with io.BytesIO() as buffer: - dill.dump(o, buffer) - m.update(buffer.getbuffer()) - return base58.b58encode(m.digest()).decode() - -def get_tokenizer(model_name: str, **kwargs) -> PreTrainedTokenizer: - cache_key = (model_name, hash_object(kwargs)) - - global _tokenizer_cache - tokenizer = _tokenizer_cache.get(cache_key, None) - if tokenizer is None: - tokenizer = AutoTokenizer.from_pretrained( - model_name, - **kwargs, - ) - _tokenizer_cache[cache_key] = tokenizer - return tokenizer - class PretrainedTransformerTokenizer(Tokenizer): """ @@ -72,13 +47,15 @@ class PretrainedTransformerTokenizer(Tokenizer): """ # noqa: E501 def __init__( - self, - model_name: str, - add_special_tokens: bool = True, - max_length: Optional[int] = None, - tokenizer_kwargs: Optional[Dict[str, Any]] = None, - verification_tokens: Optional[Tuple[str, str]] = None, + self, + model_name: str, + add_special_tokens: bool = True, + max_length: Optional[int] = None, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + verification_tokens: Optional[Tuple[str, str]] = None, ) -> None: + from combo.common.cached_transformers import get_tokenizer + if tokenizer_kwargs is None: tokenizer_kwargs = {} else: @@ -110,12 +87,13 @@ class PretrainedTransformerTokenizer(Tokenizer): self._reverse_engineer_special_tokens(token_a, token_b, model_name, tokenizer_kwargs) def _reverse_engineer_special_tokens( - self, - token_a: str, - token_b: str, - model_name: str, - tokenizer_kwargs: Optional[Dict[str, Any]], + self, + token_a: str, + token_b: str, + model_name: str, + tokenizer_kwargs: Optional[Dict[str, Any]], ): + from combo.common.cached_transformers import get_tokenizer # storing the special tokens self.sequence_pair_start_tokens = [] self.sequence_pair_mid_tokens = [] @@ -157,15 +135,15 @@ class PretrainedTransformerTokenizer(Tokenizer): seen_dummy_a = False seen_dummy_b = False for token_id, token_type_id in zip( - dummy_output["input_ids"], dummy_output["token_type_ids"] + dummy_output["input_ids"], dummy_output["token_type_ids"] ): if token_id == dummy_a: if seen_dummy_a or seen_dummy_b: # seeing a twice or b before a raise ValueError("Cannot auto-determine the number of special tokens added.") seen_dummy_a = True assert ( - self.sequence_pair_first_token_type_id is None - or self.sequence_pair_first_token_type_id == token_type_id + self.sequence_pair_first_token_type_id is None + or self.sequence_pair_first_token_type_id == token_type_id ), "multiple different token type ids found for the first sequence" self.sequence_pair_first_token_type_id = token_type_id continue @@ -175,8 +153,8 @@ class PretrainedTransformerTokenizer(Tokenizer): raise ValueError("Cannot auto-determine the number of special tokens added.") seen_dummy_b = True assert ( - self.sequence_pair_second_token_type_id is None - or self.sequence_pair_second_token_type_id == token_type_id + self.sequence_pair_second_token_type_id is None + or self.sequence_pair_second_token_type_id == token_type_id ), "multiple different token type ids found for the second sequence" self.sequence_pair_second_token_type_id = token_type_id continue @@ -194,10 +172,10 @@ class PretrainedTransformerTokenizer(Tokenizer): self.sequence_pair_end_tokens.append(token) assert ( - len(self.sequence_pair_start_tokens) - + len(self.sequence_pair_mid_tokens) - + len(self.sequence_pair_end_tokens) - ) == self.tokenizer.num_special_tokens_to_add(pair=True) + len(self.sequence_pair_start_tokens) + + len(self.sequence_pair_mid_tokens) + + len(self.sequence_pair_end_tokens) + ) == self.tokenizer.num_special_tokens_to_add(pair=True) # Reverse-engineer the tokenizer for one sequence dummy_output = tokenizer_with_special_tokens.encode_plus( @@ -214,15 +192,15 @@ class PretrainedTransformerTokenizer(Tokenizer): seen_dummy_a = False for token_id, token_type_id in zip( - dummy_output["input_ids"], dummy_output["token_type_ids"] + dummy_output["input_ids"], dummy_output["token_type_ids"] ): if token_id == dummy_a: if seen_dummy_a: raise ValueError("Cannot auto-determine the number of special tokens added.") seen_dummy_a = True assert ( - self.single_sequence_token_type_id is None - or self.single_sequence_token_type_id == token_type_id + self.single_sequence_token_type_id is None + or self.single_sequence_token_type_id == token_type_id ), "multiple different token type ids found for the sequence" self.single_sequence_token_type_id = token_type_id continue @@ -238,8 +216,8 @@ class PretrainedTransformerTokenizer(Tokenizer): self.single_sequence_end_tokens.append(token) assert ( - len(self.single_sequence_start_tokens) + len(self.single_sequence_end_tokens) - ) == self.tokenizer.num_special_tokens_to_add(pair=False) + len(self.single_sequence_start_tokens) + len(self.single_sequence_end_tokens) + ) == self.tokenizer.num_special_tokens_to_add(pair=False) @staticmethod def tokenizer_lowercases(tokenizer: PreTrainedTokenizer) -> bool: @@ -284,7 +262,7 @@ class PretrainedTransformerTokenizer(Tokenizer): tokens = [] for token_id, token_type_id, special_token_mask, offsets in zip( - token_ids, token_type_ids, special_tokens_mask, token_offsets + token_ids, token_type_ids, special_tokens_mask, token_offsets ): # In `special_tokens_mask`, 1s indicate special tokens and 0s indicate regular tokens. # NOTE: in transformers v3.4.0 (and probably older versions) the docstring @@ -304,15 +282,14 @@ class PretrainedTransformerTokenizer(Tokenizer): text=self.tokenizer.convert_ids_to_tokens(token_id, skip_special_tokens=False), text_id=token_id, type_id=token_type_id, - idx=start, - idx_end=end, + idx=(start, end) ) ) return tokens def _estimate_character_indices( - self, text: str, token_ids: List[int] + self, text: str, token_ids: List[int] ) -> List[Optional[Tuple[int, int]]]: """ The huggingface tokenizers produce tokens that may or may not be slices from the @@ -373,7 +350,7 @@ class PretrainedTransformerTokenizer(Tokenizer): return token_offsets def _intra_word_tokenize( - self, string_tokens: List[str] + self, string_tokens: List[str] ) -> Tuple[List[Token], List[Optional[Tuple[int, int]]]]: tokens: List[Token] = [] offsets: List[Optional[Tuple[int, int]]] = [] @@ -399,7 +376,7 @@ class PretrainedTransformerTokenizer(Tokenizer): @staticmethod def _increment_offsets( - offsets: Iterable[Optional[Tuple[int, int]]], increment: int + offsets: Iterable[Optional[Tuple[int, int]]], increment: int ) -> List[Optional[Tuple[int, int]]]: return [ None if offset is None else (offset[0] + increment, offset[1] + increment) @@ -407,7 +384,7 @@ class PretrainedTransformerTokenizer(Tokenizer): ] def intra_word_tokenize( - self, string_tokens: List[str] + self, string_tokens: List[str] ) -> Tuple[List[Token], List[Optional[Tuple[int, int]]]]: """ Tokenizes each word into wordpieces separately and returns the wordpiece IDs. @@ -421,7 +398,7 @@ class PretrainedTransformerTokenizer(Tokenizer): return tokens, offsets def intra_word_tokenize_sentence_pair( - self, string_tokens_a: List[str], string_tokens_b: List[str] + self, string_tokens_a: List[str], string_tokens_b: List[str] ) -> Tuple[List[Token], List[Optional[Tuple[int, int]]], List[Optional[Tuple[int, int]]]]: """ Tokenizes each word into wordpieces separately and returns the wordpiece IDs. @@ -434,9 +411,9 @@ class PretrainedTransformerTokenizer(Tokenizer): offsets_b = self._increment_offsets( offsets_b, ( - len(self.sequence_pair_start_tokens) - + len(tokens_a) - + len(self.sequence_pair_mid_tokens) + len(self.sequence_pair_start_tokens) + + len(tokens_a) + + len(self.sequence_pair_mid_tokens) ), ) tokens_a = self.add_special_tokens(tokens_a, tokens_b) @@ -445,7 +422,7 @@ class PretrainedTransformerTokenizer(Tokenizer): return tokens_a, offsets_a, offsets_b def add_special_tokens( - self, tokens1: List[Token], tokens2: Optional[List[Token]] = None + self, tokens1: List[Token], tokens2: Optional[List[Token]] = None ) -> List[Token]: def with_new_type_id(tokens: List[Token], type_id: int) -> List[Token]: return [dataclasses.replace(t, type_id=type_id) for t in tokens] @@ -455,17 +432,17 @@ class PretrainedTransformerTokenizer(Tokenizer): if tokens2 is None: return ( - self.single_sequence_start_tokens - + with_new_type_id(tokens1, self.single_sequence_token_type_id) # type: ignore - + self.single_sequence_end_tokens + self.single_sequence_start_tokens + + with_new_type_id(tokens1, self.single_sequence_token_type_id) # type: ignore + + self.single_sequence_end_tokens ) else: return ( - self.sequence_pair_start_tokens - + with_new_type_id(tokens1, self.sequence_pair_first_token_type_id) # type: ignore - + self.sequence_pair_mid_tokens - + with_new_type_id(tokens2, self.sequence_pair_second_token_type_id) # type: ignore - + self.sequence_pair_end_tokens + self.sequence_pair_start_tokens + + with_new_type_id(tokens1, self.sequence_pair_first_token_type_id) # type: ignore + + self.sequence_pair_mid_tokens + + with_new_type_id(tokens2, self.sequence_pair_second_token_type_id) # type: ignore + + self.sequence_pair_end_tokens ) def num_special_tokens_for_sequence(self) -> int: @@ -473,9 +450,9 @@ class PretrainedTransformerTokenizer(Tokenizer): def num_special_tokens_for_pair(self) -> int: return ( - len(self.sequence_pair_start_tokens) - + len(self.sequence_pair_mid_tokens) - + len(self.sequence_pair_end_tokens) + len(self.sequence_pair_start_tokens) + + len(self.sequence_pair_mid_tokens) + + len(self.sequence_pair_end_tokens) ) def _to_params(self) -> Dict[str, Any]: diff --git a/combo/data/tokenizers/token.py b/combo/data/tokenizers/token.py index d8f2590..b282a72 100644 --- a/combo/data/tokenizers/token.py +++ b/combo/data/tokenizers/token.py @@ -31,7 +31,8 @@ class Token: "subwords", "semrel", "embeddings", - "text_id" + "text_id", + "type_id" ] text: Optional[str] @@ -49,6 +50,7 @@ class Token: semrel: Optional[str] embeddings: Dict[str, List[float]] text_id: Optional[int] + type_id: Optional[int] def __init__(self, text: str = None, @@ -65,7 +67,8 @@ class Token: subwords: List[str] = None, semrel: str = None, embeddings: Dict[str, List[float]] = None, - text_id: int = None) -> None: + text_id: int = None, + type_id: int = None,) -> None: _assert_none_or_type(text, str) self.text = text @@ -89,6 +92,7 @@ class Token: self.embeddings = embeddings self.text_id = text_id + self.type_id = type_id def __str__(self): return self.text @@ -112,5 +116,6 @@ class Token: f"(subwords: {','.join(self.subwords)})" f"(semrel: {self.semrel}) " f"(embeddings: {self.embeddings}) " - f"(text_id: {self.text_id})" + f"(text_id: {self.text_id}) " + f"(type_id: {self.type_id}) " ) diff --git a/combo/data/vocabulary.py b/combo/data/vocabulary.py index a2e53b8..345c666 100644 --- a/combo/data/vocabulary.py +++ b/combo/data/vocabulary.py @@ -9,6 +9,10 @@ from torchtext.vocab import vocab as torchtext_vocab import logging from filelock import FileLock +from transformers import PreTrainedTokenizer + +from combo.common import Tqdm +from combo.data.token_embedders.embedding import EmbeddingsTextFile logger = logging.Logger(__name__) @@ -21,7 +25,7 @@ DEFAULT_NAMESPACE = "tokens" def match_namespace(pattern: str, namespace: str): if not isinstance(pattern, str): - raise ValueError("Pattern and namespace must be string types, got %s and %s." % + raise ValueError("Pattern and _namespace must be string types, got %s and %s." % (type(pattern), type(namespace))) if pattern == namespace: return True @@ -30,6 +34,23 @@ def match_namespace(pattern: str, namespace: str): return False +def _read_pretrained_tokens(embeddings_file_uri: str) -> List[str]: + # Moving this import to the top breaks everything (cycling import, I guess) + + logger.info("Reading pretrained tokens from: %s", embeddings_file_uri) + tokens: List[str] = [] + with EmbeddingsTextFile(embeddings_file_uri) as embeddings_file: + for line_number, line in enumerate(Tqdm.tqdm(embeddings_file), start=1): + token_end = line.find(" ") + if token_end >= 0: + token = line[:token_end] + tokens.append(token) + else: + line_begin = line[:20] + "..." if len(line) > 20 else line + logger.warning("Skipping line number %d: %s", line_number, line_begin) + return tokens + + class _NamespaceDependentDefaultDict(defaultdict[str, TorchtextVocab]): def __init__(self, non_padded_namespaces: Iterable[str], @@ -41,7 +62,7 @@ class _NamespaceDependentDefaultDict(defaultdict[str, TorchtextVocab]): super().__init__() def __missing__(self, namespace: str): - # Non-padded namespace + # Non-padded _namespace if any([match_namespace(npn, namespace) for npn in self._non_padded_namespaces]): value = torchtext_vocab(OrderedDict([])) else: @@ -128,7 +149,7 @@ class Vocabulary: def save_to_files(self, directory: str) -> None: """ Persist this Vocabulary to files, so it can be reloaded later. - Each namespace corresponds to one file. + Each _namespace corresponds to one file. Adapred from https://github.com/allenai/allennlp/blob/main/allennlp/data/vocabulary.py # Parameters @@ -149,7 +170,7 @@ class Vocabulary: print(namespace_str, file=namespace_file) for namespace, vocab in self._vocab.items(): - # Each namespace gets written to its own file, in index order. + # Each _namespace gets written to its own file, in index order. with codecs.open( os.path.join(directory, namespace + ".txt"), "w", "utf-8" ) as token_file: @@ -165,11 +186,11 @@ class Vocabulary: def add_token_to_namespace(self, token: str, namespace: str = DEFAULT_NAMESPACE): """ - Add the token if not present and return the index even if token was already in the namespace. + Add the token if not present and return the index even if token was already in the _namespace. :param token: token to be added - :param namespace: namespace to add the token to - :return: index of the token in the namespace + :param namespace: _namespace to add the token to + :return: index of the token in the _namespace """ if not isinstance(token, str): @@ -179,11 +200,11 @@ class Vocabulary: def add_tokens_to_namespace(self, tokens: List[str], namespace: str = DEFAULT_NAMESPACE): """ - Add the token if not present and return the index even if token was already in the namespace. + Add the token if not present and return the index even if token was already in the _namespace. :param tokens: tokens to be added - :param namespace: namespace to add the token to - :return: index of the token in the namespace + :param namespace: _namespace to add the token to + :return: index of the token in the _namespace """ if not isinstance(tokens, List): @@ -193,6 +214,24 @@ class Vocabulary: for token in tokens: self._vocab[namespace].append_token(token) + def add_transformer_vocab( + self, tokenizer: PreTrainedTokenizer, namespace: str = "tokens" + ) -> None: + """ + Copies tokens from a transformer tokenizer's vocab into the given namespace. + """ + try: + vocab_items = tokenizer.get_vocab().items() + except NotImplementedError: + vocab_items = ( + (tokenizer.convert_ids_to_tokens(idx), idx) for idx in range(tokenizer.vocab_size) + ) + + for word, idx in vocab_items: + self._vocab[namespace].insert_token(token=word, index=idx) + + self._non_padded_namespaces.add(namespace) + def get_index_to_token_vocabulary(self, namespace: str = DEFAULT_NAMESPACE) -> Dict[int, str]: if not isinstance(namespace, str): raise ValueError( diff --git a/combo/utils/__init__.py b/combo/utils/__init__.py index 791484e..b36a4d9 100644 --- a/combo/utils/__init__.py +++ b/combo/utils/__init__.py @@ -1,3 +1,4 @@ from .checks import * from .sequence import * -from .exceptions import * \ No newline at end of file +from .exceptions import * +from .typing import * \ No newline at end of file diff --git a/combo/utils/file_utils.py b/combo/utils/file_utils.py index 96fe63b..fc41d46 100644 --- a/combo/utils/file_utils.py +++ b/combo/utils/file_utils.py @@ -13,6 +13,12 @@ CACHE_ROOT = Path(os.getenv("COMBO_CACHE_ROOT", Path.home() / ".combo")) CACHE_DIRECTORY = str(CACHE_ROOT / "cache") +def get_file_extension(path: str, dot=True, lower: bool = True): + ext = os.path.splitext(path)[1] + ext = ext if dot else ext[1:] + return ext.lower() if lower else ext + + def cached_path( url_or_filename: Union[str, PathLike], cache_dir: Union[str, Path] = None, diff --git a/combo/utils/typing.py b/combo/utils/typing.py new file mode 100644 index 0000000..ba566b0 --- /dev/null +++ b/combo/utils/typing.py @@ -0,0 +1,2 @@ +def cast(typ, val): + return val -- GitLab