Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Alexander Fuchs
tensorflow2_trainer_template
Commits
1a47e306
Commit
1a47e306
authored
Aug 07, 2020
by
Alexander Fuchs
Browse files
Trainer now allows for generic loss functions
parent
8ff8923d
Changes
3
Hide whitespace changes
Inline
Side-by-side
src/models/loss_functions/separate_classifier_wgan_loss_fn.py
0 → 100644
View file @
1a47e306
import
tensorflow
as
tf
class
LossFunction
(
object
):
def
__init__
(
self
,
classifier
,
generator
,
discriminator
):
self
.
classifier
=
classifier
self
.
generator
=
generator
self
.
discriminator
=
discriminator
@
tf
.
function
def
compute_loss
(
self
,
x
,
y
,
training
=
True
):
logits_classifier
=
self
.
classifier
(
x
[
0
],
training
=
training
)
fake_features
=
self
.
generator
(
x
[
1
],
training
)
true
=
self
.
discriminator
(
x
[
0
],
training
)
false
=
self
.
discriminator
(
fake_features
,
training
)
# Cross entropy losses
class_loss
=
tf
.
reduce_mean
(
tf
.
nn
.
softmax_cross_entropy_with_logits
(
y
[
0
],
logits_classifier
,
axis
=-
1
,
))
gen_loss
=
tf
.
reduce_mean
((
fake_features
-
x
[
1
])
**
2
)
# Wasserstein losses
gen_loss
+=
tf
.
reduce_mean
(
false
)
discr_loss
=
tf
.
reduce_mean
(
true
)
-
tf
.
reduce_mean
(
false
)
if
len
(
self
.
classifier
.
losses
)
>
0
:
weight_decay_loss
=
tf
.
add_n
(
self
.
classifier
.
losses
)
else
:
weight_decay_loss
=
0.0
if
len
(
self
.
generator
.
losses
)
>
0
:
weight_decay_loss
+=
tf
.
add_n
(
self
.
generator
.
losses
)
predictions
=
tf
.
nn
.
softmax
(
logits_classifier
,
axis
=-
1
)
total_loss
=
class_loss
total_loss
+=
gen_loss
total_loss
+=
discr_loss
total_loss
+=
weight_decay_loss
return
(
class_loss
,
gen_loss
,
discr_loss
,
weight_decay_loss
,
total_loss
),
(
predictions
,
fake_features
)
\ No newline at end of file
src/scripts/birdsong_simple_main.py
View file @
1a47e306
...
...
@@ -12,6 +12,7 @@ from models.classifier import Classifier
from
models.res_block
import
ResBlockBasicLayer
from
models.discriminator
import
Discriminator
from
models.generator
import
Generator
from
models.loss_functions.separate_classifier_wgan_loss_fn
import
LossFunction
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
BINS
=
1025
...
...
@@ -188,6 +189,7 @@ def main(argv):
val_data_gen
,
None
,
epochs
,
LossFunction
,
classifier
=
classifier_model
,
generator
=
generator_model
,
discriminator
=
discriminator_model
,
...
...
src/utils/trainer.py
View file @
1a47e306
...
...
@@ -16,6 +16,7 @@ class ModelTrainer():
validation_data_generator
,
test_data_generator
,
epochs
,
loss_fn
,
classifier
,
generator
=
None
,
discriminator
=
None
,
...
...
@@ -64,6 +65,8 @@ class ModelTrainer():
self
.
n_vars_classifier
=
len
(
self
.
classifier
.
trainable_variables
)
self
.
n_vars_generator
=
len
(
self
.
generator
.
trainable_variables
)
self
.
n_vars_discriminator
=
len
(
self
.
discriminator
.
trainable_variables
)
#Set up loss function
self
.
loss_fn
=
loss_fn
(
self
.
classifier
,
self
.
generator
,
self
.
discriminator
)
#Set up optimizers
if
optimizer_classifier
==
tf
.
keras
.
optimizers
.
SGD
:
self
.
optimizer_classifier
=
optimizer_classifier
(
self
.
learning_rate_classifier
,
0.9
,
True
)
...
...
@@ -170,44 +173,8 @@ class ModelTrainer():
logits_classifier
=
self
.
classifier
(
x
,
training
=
training
)
return
tf
.
nn
.
softmax
(
logits_classifier
,
axis
=-
1
)
@
tf
.
function
def
compute_loss
(
self
,
x
,
y
,
training
=
True
):
logits_classifier
=
self
.
classifier
(
x
[
0
],
training
=
training
)
fake_features
=
self
.
generator
(
x
[
1
],
training
)
true
=
self
.
discriminator
(
x
[
0
],
training
)
false
=
self
.
discriminator
(
fake_features
,
training
)
# Cross entropy losses
class_loss
=
tf
.
reduce_mean
(
tf
.
nn
.
softmax_cross_entropy_with_logits
(
y
[
0
],
logits_classifier
,
axis
=-
1
,
))
gen_loss
=
tf
.
reduce_mean
((
fake_features
-
x
[
1
])
**
2
)
#Wasserstein losses
gen_loss
+=
tf
.
reduce_mean
(
false
)
discr_loss
=
tf
.
reduce_mean
(
true
)
-
tf
.
reduce_mean
(
false
)
if
len
(
self
.
classifier
.
losses
)
>
0
:
weight_decay_loss
=
tf
.
add_n
(
self
.
classifier
.
losses
)
else
:
weight_decay_loss
=
0.0
if
len
(
self
.
generator
.
losses
)
>
0
:
weight_decay_loss
+=
tf
.
add_n
(
self
.
generator
.
losses
)
predictions
=
tf
.
nn
.
softmax
(
logits_classifier
,
axis
=-
1
)
total_loss
=
class_loss
total_loss
+=
gen_loss
total_loss
+=
discr_loss
total_loss
+=
weight_decay_loss
return
(
class_loss
,
gen_loss
,
discr_loss
,
weight_decay_loss
,
total_loss
),(
predictions
,
fake_features
)
return
self
.
loss_fn
.
compute_loss
(
x
,
y
,
training
)
@
tf
.
function
def
compute_gradients
(
self
,
x
,
y
):
...
...
@@ -399,7 +366,7 @@ class ModelTrainer():
#Save generated audio STFT samples
with
self
.
summary_writer
[
"val"
].
as_default
():
tf
.
summary
.
image
(
"original_features"
,
validation_xy
[
"input_features"
],
validation_xy
[
'false_sample'
],
step
=
epoch
,
max_outputs
=
1
)
tf
.
summary
.
image
(
"fake_features"
,
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment