Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
combo
Manage
Activity
Members
Labels
Plan
Issues
20
Issue boards
Milestones
Wiki
Redmine
Code
Merge requests
2
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Syntactic Tools
combo
Commits
1380bf60
There was an error fetching the commit references. Please try again later.
Commit
1380bf60
authored
2 years ago
by
Maja Jabłońska
Browse files
Options
Downloads
Patches
Plain Diff
Add LemmatizerModel from COMBO
parent
a9e564c1
1 merge request
!46
Merge COMBO 3.0 into master
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
combo/models/lemma.py
+104
-2
104 additions, 2 deletions
combo/models/lemma.py
with
104 additions
and
2 deletions
combo/models/lemma.py
+
104
−
2
View file @
1380bf60
from
combo.models.base
import
Predictor
from
typing
import
Optional
,
Dict
,
List
,
Union
import
torch
import
torch.nn
as
nn
from
combo
import
data
from
combo.models
import
dilated_cnn
,
base
,
utils
from
combo.models.base
import
Predictor
,
TimeDistributed
from
combo.models.combo_nn
import
Activation
from
combo.utils
import
ConfigurationError
class
LemmatizerModel
(
Predictor
):
pass
"""
Lemmatizer model.
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
dilated_cnn_encoder
:
dilated_cnn
.
DilatedCnnEncoder
,
input_projection_layer
:
base
.
Linear
):
super
().
__init__
()
self
.
char_embed
=
nn
.
Embedding
(
num_embeddings
=
num_embeddings
,
embedding_dim
=
embedding_dim
,
)
self
.
dilated_cnn_encoder
=
TimeDistributed
(
dilated_cnn_encoder
)
self
.
input_projection_layer
=
input_projection_layer
def
forward
(
self
,
x
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]],
mask
:
Optional
[
torch
.
BoolTensor
]
=
None
,
labels
:
Optional
[
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]]
=
None
,
sample_weights
:
Optional
[
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]]
=
None
)
->
Dict
[
str
,
torch
.
Tensor
]:
encoder_emb
,
chars
=
x
encoder_emb
=
self
.
input_projection_layer
(
encoder_emb
)
char_embeddings
=
self
.
char_embed
(
chars
)
BATCH_SIZE
,
_
,
MAX_WORD_LENGTH
,
CHAR_EMB
=
char_embeddings
.
size
()
encoder_emb
=
encoder_emb
.
unsqueeze
(
2
).
repeat
(
1
,
1
,
MAX_WORD_LENGTH
,
1
)
x
=
torch
.
cat
((
char_embeddings
,
encoder_emb
),
dim
=-
1
).
transpose
(
2
,
3
)
x
=
self
.
dilated_cnn_encoder
(
x
).
transpose
(
2
,
3
)
output
=
{
"
prediction
"
:
x
.
argmax
(
-
1
),
"
probability
"
:
x
}
if
labels
is
not
None
:
if
mask
is
None
:
mask
=
encoder_emb
.
new_ones
(
encoder_emb
.
size
()[:
-
2
])
if
sample_weights
is
None
:
sample_weights
=
labels
.
new_ones
(
BATCH_SIZE
)
mask
=
mask
.
unsqueeze
(
2
).
repeat
(
1
,
1
,
MAX_WORD_LENGTH
).
bool
()
output
[
"
loss
"
]
=
self
.
_loss
(
x
,
labels
,
mask
,
sample_weights
)
return
output
@staticmethod
def
_loss
(
pred
:
torch
.
Tensor
,
true
:
torch
.
Tensor
,
mask
:
torch
.
BoolTensor
,
sample_weights
:
torch
.
Tensor
)
->
torch
.
Tensor
:
BATCH_SIZE
,
SENTENCE_LENGTH
,
MAX_WORD_LENGTH
,
CHAR_CLASSES
=
pred
.
size
()
pred
=
pred
.
reshape
(
-
1
,
CHAR_CLASSES
)
true
=
true
.
reshape
(
-
1
)
mask
=
true
.
gt
(
0
)
loss
=
utils
.
masked_cross_entropy
(
pred
,
true
,
mask
)
loss
=
loss
.
reshape
(
BATCH_SIZE
,
-
1
)
*
sample_weights
.
unsqueeze
(
-
1
)
valid_positions
=
mask
.
sum
()
return
loss
.
sum
()
/
valid_positions
@classmethod
def
from_vocab
(
cls
,
vocab
:
data
.
Vocabulary
,
char_vocab_namespace
:
str
,
lemma_vocab_namespace
:
str
,
embedding_dim
:
int
,
input_projection_layer
:
base
.
Linear
,
filters
:
List
[
int
],
kernel_size
:
List
[
int
],
stride
:
List
[
int
],
padding
:
List
[
int
],
dilation
:
List
[
int
],
activations
:
List
[
Activation
],
):
assert
char_vocab_namespace
in
vocab
.
get_namespaces
()
assert
lemma_vocab_namespace
in
vocab
.
get_namespaces
()
if
len
(
filters
)
+
1
!=
len
(
kernel_size
):
raise
ConfigurationError
(
f
"
len(filters) (
{
len
(
filters
)
:
d
}
) + 1 != kernel_size (
{
len
(
kernel_size
)
:
d
}
)
"
)
filters
=
filters
+
[
vocab
.
get_vocab_size
(
lemma_vocab_namespace
)]
dilated_cnn_encoder
=
dilated_cnn
.
DilatedCnnEncoder
(
input_dim
=
embedding_dim
+
input_projection_layer
.
get_output_dim
(),
filters
=
filters
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
activations
=
activations
,
)
return
cls
(
num_embeddings
=
vocab
.
get_vocab_size
(
char_vocab_namespace
),
embedding_dim
=
embedding_dim
,
dilated_cnn_encoder
=
dilated_cnn_encoder
,
input_projection_layer
=
input_projection_layer
)
This diff is collapsed.
Click to expand it.
Preview
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment