Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
combo
Manage
Activity
Members
Labels
Plan
Issues
20
Issue boards
Milestones
Wiki
Redmine
Code
Merge requests
2
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Syntactic Tools
combo
Commits
62ff9bbf
There was an error fetching the commit references. Please try again later.
Commit
62ff9bbf
authored
2 years ago
by
Maja Jabłońska
Browse files
Options
Downloads
Patches
Plain Diff
Add TextField from AllenNLP
parent
ebd2bd32
1 merge request
!46
Merge COMBO 3.0 into master
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
combo/data/fields/text_field.py
+220
-0
220 additions, 0 deletions
combo/data/fields/text_field.py
with
220 additions
and
0 deletions
combo/data/fields/text_field.py
0 → 100644
+
220
−
0
View file @
62ff9bbf
"""
Adapted from AllenNLP.
https://github.com/allenai/allennlp/blob/main/allennlp/data/fields/text_field.py
A `TextField` represents a string of text, the kind that you might want to represent with
standard word vectors, or pass through an LSTM.
"""
from
collections
import
defaultdict
from
copy
import
deepcopy
from
typing
import
Dict
,
List
,
Optional
,
Iterator
import
textwrap
from
spacy.tokens
import
Token
as
SpacyToken
import
torch
# There are two levels of dictionaries here: the top level is for the *key*, which aligns
# TokenIndexers with their corresponding TokenEmbedders. The bottom level is for the *objects*
# produced by a given TokenIndexer, which will be input to a particular TokenEmbedder's forward()
# method. We label these as tensors, because that's what they typically are, though they could in
# reality have arbitrary type.
from
combo.data
import
Vocabulary
from
combo.data.fields.sequence_field
import
SequenceField
from
combo.data.token_indexers
import
TokenIndexer
,
IndexedTokenList
from
combo.data.tokenizers
import
TokenizerToken
from
combo.utils
import
ConfigurationError
TextFieldTensors
=
Dict
[
str
,
Dict
[
str
,
torch
.
Tensor
]]
def
batch_tensor_dicts
(
tensor_dicts
:
List
[
Dict
[
str
,
torch
.
Tensor
]],
remove_trailing_dimension
:
bool
=
False
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""
Takes a list of tensor dictionaries, where each dictionary is assumed to have matching keys,
and returns a single dictionary with all tensors with the same key batched together.
# Parameters
tensor_dicts : `List[Dict[str, torch.Tensor]]`
The list of tensor dictionaries to batch.
remove_trailing_dimension : `bool`
If `True`, we will check for a trailing dimension of size 1 on the tensors that are being
batched, and remove it if we find it.
"""
key_to_tensors
:
Dict
[
str
,
List
[
torch
.
Tensor
]]
=
defaultdict
(
list
)
for
tensor_dict
in
tensor_dicts
:
for
key
,
tensor
in
tensor_dict
.
items
():
key_to_tensors
[
key
].
append
(
tensor
)
batched_tensors
=
{}
for
key
,
tensor_list
in
key_to_tensors
.
items
():
batched_tensor
=
torch
.
stack
(
tensor_list
)
if
remove_trailing_dimension
and
all
(
tensor
.
size
(
-
1
)
==
1
for
tensor
in
tensor_list
):
batched_tensor
=
batched_tensor
.
squeeze
(
-
1
)
batched_tensors
[
key
]
=
batched_tensor
return
batched_tensors
class
TextField
(
SequenceField
[
TextFieldTensors
]):
"""
This `Field` represents a list of string tokens. Before constructing this object, you need
to tokenize raw strings using a :class:`~allennlp.data.tokenizers.tokenizer.Tokenizer`.
Because string tokens can be represented as indexed arrays in a number of ways, we also take a
dictionary of :class:`~allennlp.data.token_indexers.token_indexer.TokenIndexer`
objects that will be used to convert the tokens into indices.
Each `TokenIndexer` could represent each token as a single ID, or a list of character IDs, or
something else.
This field will get converted into a dictionary of arrays, one for each `TokenIndexer`. A
`SingleIdTokenIndexer` produces an array of shape (num_tokens,), while a
`TokenCharactersIndexer` produces an array of shape (num_tokens, num_characters).
"""
__slots__
=
[
"
tokens
"
,
"
_token_indexers
"
,
"
_indexed_tokens
"
]
def
__init__
(
self
,
tokens
:
List
[
TokenizerToken
],
token_indexers
:
Optional
[
Dict
[
str
,
TokenIndexer
]]
=
None
)
->
None
:
self
.
tokens
=
tokens
self
.
_token_indexers
=
token_indexers
self
.
_indexed_tokens
:
Optional
[
Dict
[
str
,
IndexedTokenList
]]
=
None
if
not
all
(
isinstance
(
x
,
(
TokenizerToken
,
SpacyToken
))
for
x
in
tokens
):
raise
ConfigurationError
(
"
TextFields must be passed Tokens.
"
"
Found: {} with types {}.
"
.
format
(
tokens
,
[
type
(
x
)
for
x
in
tokens
])
)
@property
def
token_indexers
(
self
)
->
Dict
[
str
,
TokenIndexer
]:
if
self
.
_token_indexers
is
None
:
raise
ValueError
(
"
TextField
'
s token_indexers have not been set.
\n
"
"
Did you forget to call DatasetReader.apply_token_indexers(instance)
"
"
on your instance?
\n
"
"
If apply_token_indexers() is being called but
"
"
you
'
re still seeing this error, it may not be implemented correctly.
"
)
return
self
.
_token_indexers
@token_indexers.setter
def
token_indexers
(
self
,
token_indexers
:
Dict
[
str
,
TokenIndexer
])
->
None
:
self
.
_token_indexers
=
token_indexers
def
count_vocab_items
(
self
,
counter
:
Dict
[
str
,
Dict
[
str
,
int
]]):
for
indexer
in
self
.
token_indexers
.
values
():
for
token
in
self
.
tokens
:
indexer
.
count_vocab_items
(
token
,
counter
)
def
index
(
self
,
vocab
:
Vocabulary
):
self
.
_indexed_tokens
=
{}
for
indexer_name
,
indexer
in
self
.
token_indexers
.
items
():
self
.
_indexed_tokens
[
indexer_name
]
=
indexer
.
tokens_to_indices
(
self
.
tokens
,
vocab
)
def
get_padding_lengths
(
self
)
->
Dict
[
str
,
int
]:
"""
The `TextField` has a list of `Tokens`, and each `Token` gets converted into arrays by
(potentially) several `TokenIndexers`. This method gets the max length (over tokens)
associated with each of these arrays.
"""
if
self
.
_indexed_tokens
is
None
:
raise
ConfigurationError
(
"
You must call .index(vocabulary) on a field before determining padding lengths.
"
)
padding_lengths
=
{}
for
indexer_name
,
indexer
in
self
.
token_indexers
.
items
():
indexer_lengths
=
indexer
.
get_padding_lengths
(
self
.
_indexed_tokens
[
indexer_name
])
for
key
,
length
in
indexer_lengths
.
items
():
padding_lengths
[
f
"
{
indexer_name
}
___
{
key
}
"
]
=
length
return
padding_lengths
def
sequence_length
(
self
)
->
int
:
return
len
(
self
.
tokens
)
def
as_tensor
(
self
,
padding_lengths
:
Dict
[
str
,
int
])
->
TextFieldTensors
:
if
self
.
_indexed_tokens
is
None
:
raise
ConfigurationError
(
"
You must call .index(vocabulary) on a field before calling .as_tensor()
"
)
tensors
=
{}
indexer_lengths
:
Dict
[
str
,
Dict
[
str
,
int
]]
=
defaultdict
(
dict
)
for
key
,
value
in
padding_lengths
.
items
():
# We want this to crash if the split fails. Should never happen, so I'm not
# putting in a check, but if you fail on this line, open a github issue.
indexer_name
,
padding_key
=
key
.
split
(
"
___
"
)
indexer_lengths
[
indexer_name
][
padding_key
]
=
value
for
indexer_name
,
indexer
in
self
.
token_indexers
.
items
():
tensors
[
indexer_name
]
=
indexer
.
as_padded_tensor_dict
(
self
.
_indexed_tokens
[
indexer_name
],
indexer_lengths
[
indexer_name
]
)
return
tensors
def
empty_field
(
self
):
text_field
=
TextField
([],
self
.
_token_indexers
)
text_field
.
_indexed_tokens
=
{}
if
self
.
_token_indexers
is
not
None
:
for
indexer_name
,
indexer
in
self
.
token_indexers
.
items
():
text_field
.
_indexed_tokens
[
indexer_name
]
=
indexer
.
get_empty_token_list
()
return
text_field
def
batch_tensors
(
self
,
tensor_list
:
List
[
TextFieldTensors
])
->
TextFieldTensors
:
# This is creating a dict of {token_indexer_name: {token_indexer_outputs: batched_tensor}}
# for each token indexer used to index this field.
indexer_lists
:
Dict
[
str
,
List
[
Dict
[
str
,
torch
.
Tensor
]]]
=
defaultdict
(
list
)
for
tensor_dict
in
tensor_list
:
for
indexer_name
,
indexer_output
in
tensor_dict
.
items
():
indexer_lists
[
indexer_name
].
append
(
indexer_output
)
batched_tensors
=
{
# NOTE(mattg): if an indexer has its own nested structure, rather than one tensor per
# argument, then this will break. If that ever happens, we should move this to an
# `indexer.batch_tensors` method, with this logic as the default implementation in the
# base class.
indexer_name
:
batch_tensor_dicts
(
indexer_outputs
)
for
indexer_name
,
indexer_outputs
in
indexer_lists
.
items
()
}
return
batched_tensors
def
__str__
(
self
)
->
str
:
# Double tab to indent under the header.
formatted_text
=
""
.
join
(
"
\t\t
"
+
text
+
"
\n
"
for
text
in
textwrap
.
wrap
(
repr
(
self
.
tokens
),
100
)
)
if
self
.
_token_indexers
is
not
None
:
indexers
=
{
name
:
indexer
.
__class__
.
__name__
for
name
,
indexer
in
self
.
_token_indexers
.
items
()
}
return
(
f
"
TextField of length
{
self
.
sequence_length
()
}
with
"
f
"
text:
\n
{
formatted_text
}
\t\t
and TokenIndexers :
{
indexers
}
"
)
else
:
return
f
"
TextField of length
{
self
.
sequence_length
()
}
with text:
\n
{
formatted_text
}
"
# Sequence[Token] methods
def
__iter__
(
self
)
->
Iterator
[
TokenizerToken
]:
return
iter
(
self
.
tokens
)
def
__getitem__
(
self
,
idx
:
int
)
->
TokenizerToken
:
return
self
.
tokens
[
idx
]
def
__len__
(
self
)
->
int
:
return
len
(
self
.
tokens
)
def
duplicate
(
self
):
"""
Overrides the behavior of `duplicate` so that `self._token_indexers` won
'
t
actually be deep-copied.
Not only would it be extremely inefficient to deep-copy the token indexers,
but it also fails in many cases since some tokenizers (like those used in
the
'
transformers
'
lib) cannot actually be deep-copied.
"""
if
self
.
_token_indexers
is
not
None
:
new
=
TextField
(
deepcopy
(
self
.
tokens
),
{
k
:
v
for
k
,
v
in
self
.
_token_indexers
.
items
()})
else
:
new
=
TextField
(
deepcopy
(
self
.
tokens
))
new
.
_indexed_tokens
=
deepcopy
(
self
.
_indexed_tokens
)
return
new
def
human_readable_repr
(
self
)
->
List
[
str
]:
return
[
str
(
t
)
for
t
in
self
.
tokens
]
This diff is collapsed.
Click to expand it.
Preview
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment