Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
combo
Manage
Activity
Members
Labels
Plan
Issues
20
Issue boards
Milestones
Wiki
Redmine
Code
Merge requests
2
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Syntactic Tools
combo
Commits
557b09e2
Commit
557b09e2
authored
1 year ago
by
Maja Jablonska
Browse files
Options
Downloads
Patches
Plain Diff
Remove debug prints
parent
3c877323
1 merge request
!46
Merge COMBO 3.0 into master
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
combo/nn/utils.py
+1
-6
1 addition, 6 deletions
combo/nn/utils.py
with
1 addition
and
6 deletions
combo/nn/utils.py
+
1
−
6
View file @
557b09e2
...
@@ -14,12 +14,7 @@ StateDictType = Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"]
...
@@ -14,12 +14,7 @@ StateDictType = Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"]
def
masked_cross_entropy
(
pred
:
torch
.
Tensor
,
true
:
torch
.
Tensor
,
mask
:
torch
.
BoolTensor
)
->
torch
.
Tensor
:
def
masked_cross_entropy
(
pred
:
torch
.
Tensor
,
true
:
torch
.
Tensor
,
mask
:
torch
.
BoolTensor
)
->
torch
.
Tensor
:
pred
=
pred
+
(
mask
.
float
().
unsqueeze
(
-
1
)
+
1e-45
).
log
()
pred
=
pred
+
(
mask
.
float
().
unsqueeze
(
-
1
)
+
1e-45
).
log
()
try
:
return
F
.
cross_entropy
(
pred
,
true
,
reduction
=
"
none
"
)
*
mask
return
F
.
cross_entropy
(
pred
,
true
,
reduction
=
"
none
"
)
*
mask
except
Exception
as
e
:
print
(
"
pred shape
"
,
pred
.
shape
,
"
true shape
"
,
true
.
shape
,
"
mask shape
"
,
mask
.
shape
)
print
(
F
.
cross_entropy
(
pred
,
true
,
reduction
=
"
none
"
).
shape
)
raise
e
def
tiny_value_of_dtype
(
dtype
:
torch
.
dtype
):
def
tiny_value_of_dtype
(
dtype
:
torch
.
dtype
):
"""
"""
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment