Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Alexander Fuchs
tensorflow2_trainer_template
Commits
f46cbe7c
Commit
f46cbe7c
authored
Aug 06, 2020
by
Alexander Fuchs
Browse files
Data loader optimizations
parent
8f319eed
Changes
4
Hide whitespace changes
Inline
Side-by-side
src/models/discriminator.py
View file @
f46cbe7c
...
...
@@ -23,7 +23,7 @@ class Discriminator(tf.keras.Model):
for
i
in
range
(
self
.
n_layers
):
self
.
model_layers
.
append
(
SpectralNormalization
(
tf
.
keras
.
layers
.
Conv2D
(
self
.
n_channels
[
i
],
kernel_size
=
(
3
,
3
)
,
kernel_size
=
strides
[
i
]
+
1
,
strides
=
strides
[
i
],
padding
=
"same"
,
use_bias
=
False
,
...
...
@@ -37,7 +37,8 @@ class Discriminator(tf.keras.Model):
use_bias
=
False
,
name
=
self
.
name_op
+
"_dense"
,
activation
=
None
))
def
call
(
self
,
input
,
training
=
False
):
x
=
input
...
...
src/scripts/birdsong_simple_main.py
View file @
f46cbe7c
import
os
os
.
environ
[
'TF_CPP_MIN_LOG_LEVEL'
]
=
'3'
import
numpy
as
np
from
absl
import
logging
from
absl
import
app
...
...
@@ -13,18 +14,54 @@ from models.discriminator import Discriminator
from
models.generator
import
Generator
#logging.set_verbosity(logging.WARNING)
BINS
=
1025
N_FRAMES
=
216
N_CHANNELS
=
2
def
preprocess_input
(
sample
,
n_classes
,
is_training
):
"""Preprocess a single image of layout [height, width, depth]."""
if
tf
.
random
.
uniform
([
1
])[
0
]
>
1
/
n_classes
or
not
(
is_training
):
input_features
=
tf
.
image
.
per_image_standardization
(
sample
[
'input_features'
])
labels
=
sample
[
'labels'
]
else
:
input_features
=
tf
.
image
.
per_image_standardization
(
sample
[
'false_sample'
])
labels
=
tf
.
cast
(
tf
.
one_hot
(
n_classes
+
1
,
n_classes
+
1
),
tf
.
int32
)
return
{
'input_features'
:
input_features
,
'labels'
:
labels
,
'false_sample'
:
sample
[
'false_sample'
]}
input_features
=
sample
[
'input_features'
]
labels
=
sample
[
'labels'
]
false_sample
=
sample
[
'false_sample'
]
batch_size
=
tf
.
shape
(
labels
)[
0
]
if
is_training
:
rnd
=
tf
.
random
.
uniform
([
batch_size
])
rnd_tiled_feat
=
tf
.
tile
(
tf
.
reshape
(
rnd
,[
batch_size
,
1
,
1
,
1
]),[
1
,
BINS
,
N_FRAMES
,
N_CHANNELS
])
rnd_tiled_lbl
=
tf
.
tile
(
tf
.
reshape
(
rnd
,[
batch_size
,
1
]),[
1
,
n_classes
+
1
])
false_lbl
=
tf
.
cast
(
tf
.
one_hot
(
n_classes
+
1
,
n_classes
+
1
),
tf
.
int32
)
false_lbl
=
tf
.
tile
(
tf
.
reshape
(
false_lbl
,[
1
,
n_classes
+
1
]),[
batch_size
,
1
])
input_features
=
tf
.
where
(
rnd_tiled_feat
>
1
/
n_classes
,
input_features
,
false_sample
)
labels
=
tf
.
where
(
rnd_tiled_lbl
>
1
/
n_classes
,
labels
,
false_lbl
)
return
{
'input_features'
:
input_features
,
'labels'
:
labels
,
'false_sample'
:
false_sample
}
def
augment_input
(
sample
,
n_classes
):
"""Preprocess a single image of layout [height, width, depth]."""
input_features
=
sample
[
'input_features'
]
labels
=
sample
[
'labels'
]
false_sample
=
sample
[
'false_sample'
]
rnd
=
tf
.
random
.
uniform
([
1
])
rnd_tiled_feat
=
tf
.
tile
(
tf
.
reshape
(
rnd
,[
1
,
1
,
1
]),[
BINS
,
N_FRAMES
,
N_CHANNELS
])
rnd_tiled_lbl
=
tf
.
tile
(
tf
.
reshape
(
rnd
,[
1
]),[
n_classes
+
1
])
false_lbl
=
tf
.
cast
(
tf
.
one_hot
(
n_classes
+
1
,
n_classes
+
1
),
labels
.
dtype
)
input_features
=
tf
.
where
(
rnd_tiled_feat
>
1
/
n_classes
,
input_features
,
false_sample
)
labels
=
tf
.
where
(
rnd_tiled_lbl
>
1
/
n_classes
,
labels
,
false_lbl
)
return
{
'input_features'
:
input_features
,
'labels'
:
labels
,
'false_sample'
:
false_sample
}
def
data_generator
(
data_generator
,
batch_size
,
is_training
,
shuffle_buffer
=
128
,
...
...
@@ -44,13 +81,10 @@ def data_generator(data_generator,batch_size,is_training,
dataset
=
dataset
.
take
(
take_n
)
if
is_training
:
dataset
=
dataset
.
shuffle
(
shuffle_buffer
)
dataset
=
dataset
.
map
(
lambda
sample
:
preprocess_input
(
sample
,
n_classes
,
is_training
))
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
True
)
dataset
=
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
else
:
dataset
=
dataset
.
map
(
lambda
sample
:
preprocess_input
(
sample
,
n_classes
,
is_training
))
dataset
=
dataset
.
batch
(
batch_size
)
dataset
=
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
...
...
@@ -75,7 +109,8 @@ flags.DEFINE_integer('epochs', 300, 'number of epochs')
flags
.
DEFINE_integer
(
'batch_size'
,
32
,
'Mini-batch size'
)
flags
.
DEFINE_float
(
'dropout_rate'
,
0.0
,
'dropout rate for the dense blocks'
)
flags
.
DEFINE_float
(
'weight_decay'
,
1e-4
,
'weight decay parameter'
)
flags
.
DEFINE_float
(
'learning_rate'
,
1e-3
,
'learning rate'
)
flags
.
DEFINE_float
(
'learning_rate'
,
1e-1
,
'learning rate'
)
flags
.
DEFINE_boolean
(
'preload_samples'
,
False
,
'Preload samples (requires >140 GB RAM)'
)
flags
.
DEFINE_float
(
'training_percentage'
,
90
,
'Percentage of the training data used for training. (100-training_percentage is used as validation data.)'
)
flags
.
DEFINE_boolean
(
'load_model'
,
False
,
'Bool indicating if the model should be loaded'
)
...
...
@@ -106,12 +141,26 @@ def main(argv):
lr
=
FLAGS
.
learning_rate
load_model
=
FLAGS
.
load_model
training_percentage
=
FLAGS
.
training_percentage
preload_samples
=
FLAGS
.
preload_samples
model_save_dir
+=
"_batch_size_"
+
str
(
batch_size
)
+
"_dropout_rate_"
+
str
(
dropout_rate
)
+
"_learning_rate_"
+
str
(
lr
)
+
"_weight_decay_"
+
str
(
weight_decay
)
ds_train
=
Dataset
(
data_dir
,
is_training_set
=
True
)
n_total
=
ds_train
.
n_samples
dg_train
=
DataGenerator
(
ds_train
,
None
)
def
augment_fn
(
sample
):
return
augment_input
(
sample
,
ds_train
.
n_classes
)
dg_train
=
DataGenerator
(
ds_train
,
augment_fn
,
training_percentage
=
training_percentage
,
preload_samples
=
preload_samples
,
is_training
=
True
)
dg_val
=
DataGenerator
(
ds_train
,
None
,
training_percentage
=
training_percentage
,
is_validation
=
True
,
preload_samples
=
preload_samples
,
is_training
=
False
)
n_train
=
int
(
n_total
*
training_percentage
/
100
)
n_val
=
n_total
-
n_train
...
...
@@ -139,21 +188,19 @@ def main(argv):
#Discriminator for estimating the Wasserstein distance
discriminator_model
=
Discriminator
(
3
,
[
32
,
64
,
128
],
[
2
,
2
,
2
],
[
4
,
4
,
4
],
name
=
"discriminator"
)
train_data_gen
=
data_generator
(
dg_train
.
generate
,
batch_size
,
is_training
=
True
,
shuffle_buffer
=
256
,
shuffle_buffer
=
16
*
batch_size
,
n_classes
=
ds_train
.
n_classes
,
take_n
=
n_train
)
val_data_gen
=
data_generator
(
dg_
train
.
generate
,
10
,
val_data_gen
=
data_generator
(
dg_
val
.
generate
,
batch_size
,
is_training
=
False
,
is_validation
=
True
,
n_classes
=
ds_train
.
n_classes
,
skip_n
=
n_train
,
take_n
=
n_val
)
n_classes
=
ds_train
.
n_classes
)
trainer
=
ModelTrainer
(
train_data_gen
,
...
...
@@ -163,15 +210,19 @@ def main(argv):
classifier
=
classifier_model
,
generator
=
generator_model
,
discriminator
=
generator_model
,
learning_rate_fn
=
learning_rate_fn
,
optimizer_classifier
=
tf
.
keras
.
optimizers
.
Adam
,
base_learning_rate_classifier
=
lr
,
base_learning_rate_generator
=
lr
*
0.001
,
base_learning_rate_discriminator
=
lr
*
0.0001
,
learning_rate_fn_classifier
=
learning_rate_fn
,
learning_rate_fn_generator
=
learning_rate_fn
,
learning_rate_fn_discriminator
=
learning_rate_fn
,
optimizer_classifier
=
tf
.
keras
.
optimizers
.
SGD
,
optimizer_generator
=
tf
.
keras
.
optimizers
.
Adam
,
optimizer_discriminator
=
tf
.
keras
.
optimizers
.
Adam
,
num_train_batches
=
int
(
n_train
/
batch_size
),
base_learning_rate
=
lr
,
load_model
=
load_model
,
save_dir
=
model_save_dir
,
init_data
=
tf
.
random
.
normal
([
batch_size
,
1025
,
216
,
2
]),
init_data
=
tf
.
random
.
normal
([
batch_size
,
BINS
,
N_FRAMES
,
N_CHANNELS
]),
start_epoch
=
0
)
trainer
.
train
()
...
...
src/utils/data_loader.py
View file @
f46cbe7c
...
...
@@ -9,6 +9,7 @@ import multiprocessing
import
sys
import
random
import
copy
import
time
import
tensorflow
as
tf
warnings
.
filterwarnings
(
'ignore'
)
...
...
@@ -135,8 +136,11 @@ class DataGenerator(object):
def
__init__
(
self
,
dataset
,
augmentation
,
shuffle
=
True
,
is_training
=
True
,
is_validation
=
False
,
force_feature_recalc
=
False
,
preload_false_samples
=
True
,
preload_samples
=
False
,
training_percentage
=
90
,
max_time
=
5
,
max_samples_per_audio
=
6
,
n_fft
=
2048
,
...
...
@@ -148,16 +152,25 @@ class DataGenerator(object):
random
.
seed
(
4
)
random
.
shuffle
(
self
.
dataset
.
train_samples
)
self
.
n_training_samples
=
int
(
dataset
.
n_samples
*
training_percentage
/
100
)
self
.
n_validation_samples
=
dataset
.
n_samples
-
self
.
n_training_samples
self
.
augmentation
=
augmentation
self
.
is_training
=
is_training
self
.
is_validation
=
is_validation
self
.
sampling_rate
=
sampling_rate
self
.
n_fft
=
n_fft
self
.
preload_samples
=
preload_samples
self
.
preload_false_samples
=
preload_false_samples
self
.
hop_length
=
hop_length
self
.
max_time
=
max_time
self
.
max_samples_per_audio
=
max_samples_per_audio
self
.
force_feature_recalc
=
force_feature_recalc
if
self
.
is_training
:
self
.
first_sample
=
0
self
.
last_sample
=
self
.
n_training_samples
elif
self
.
is_validation
:
self
.
first_sample
=
self
.
n_training_samples
self
.
last_sample
=
self
.
dataset
.
n_samples
#Get paths of false samples
false_samples_mono
=
glob
.
glob
(
self
.
dataset
.
false_audio_path
+
"/mono/*.npz"
,
recursive
=
True
)
...
...
@@ -171,8 +184,25 @@ class DataGenerator(object):
self
.
preloaded_false_samples
=
{}
for
path
in
self
.
false_sample_paths
:
with
np
.
load
(
path
,
allow_pickle
=
True
)
as
sample_file
:
self
.
preloaded_false_samples
[
path
]
=
copy
.
deepcopy
(
sample_file
.
f
.
arr_0
)
print
(
"Finished pre-loading"
)
self
.
preloaded_false_samples
[
path
]
=
sample_file
.
f
.
arr_0
print
(
"Finished pre-loading false samples!"
)
if
self
.
is_training
or
self
.
is_validation
:
self
.
samples
=
self
.
dataset
.
train_samples
[
self
.
first_sample
:
self
.
last_sample
]
else
:
self
.
samples
=
self
.
dataset
.
test_samples
#Pre load samples (takes a lot of RAM ~130 GB)
try
:
if
self
.
preload_samples
:
self
.
preloaded_samples
=
{}
for
sample
in
self
.
samples
:
path
=
sample
[
"filename"
].
replace
(
"mp3"
,
"npz"
)
with
np
.
load
(
path
,
allow_pickle
=
True
)
as
sample_file
:
self
.
preloaded_samples
[
path
]
=
sample_file
.
f
.
arr_0
print
(
"Finished pre-loading samples"
)
except
:
self
.
preload_samples
=
False
def
do_stft
(
self
,
y
,
channels
):
spectra
=
[]
...
...
@@ -318,20 +348,19 @@ class DataGenerator(object):
def
generate
(
self
):
if
self
.
is_training
:
samples
=
self
.
dataset
.
train_samples
else
:
samples
=
self
.
dataset
.
test_samples
stft_len
=
int
(
np
.
ceil
(
self
.
max_time
*
self
.
sampling_rate
/
self
.
hop_length
))
for
sample
in
samples
:
for
sample
in
self
.
samples
:
filename
=
sample
[
'filename'
]
#If feature was already created load from file
if
os
.
path
.
isfile
(
filename
.
replace
(
"mp3"
,
"npz"
))
and
not
(
self
.
force_feature_recalc
):
with
np
.
load
(
filename
.
replace
(
"mp3"
,
"npz"
),
allow_pickle
=
True
)
as
sample_file
:
spectra_npz
=
sample_file
.
f
.
arr_0
if
self
.
preload_samples
:
spectra_npz
=
self
.
preloaded_samples
[
filename
.
replace
(
"mp3"
,
"npz"
)]
else
:
with
np
.
load
(
filename
.
replace
(
"mp3"
,
"npz"
),
allow_pickle
=
True
)
as
sample_file
:
spectra_npz
=
sample_file
.
f
.
arr_0
spec_keys
=
spectra_npz
.
item
().
keys
()
spec_keys
=
list
(
spec_keys
)
...
...
@@ -340,7 +369,6 @@ class DataGenerator(object):
else
:
#Create features via STFT if no file exists
spectra
=
self
.
create_feature
(
sample
)
#Check for None type and shape
if
np
.
any
(
spectra
)
==
None
or
spectra
.
shape
[
-
1
]
!=
stft_len
:
...
...
@@ -368,13 +396,18 @@ class DataGenerator(object):
#If false only mono --> duplicate
if
false_spectra
.
shape
[
0
]
==
1
:
false_spectra
=
np
.
tile
(
false_spectra
,[
2
,
1
,
1
])
#Transpose spectrogramms for "channels_last"
spectra
=
tf
.
transpose
(
spectra
,
perm
=
[
1
,
2
,
0
])
false_spectra
=
tf
.
transpose
(
false_spectra
,
perm
=
[
1
,
2
,
0
])
yield
{
'input_features'
:
spectra
,
sample
=
{
'input_features'
:
spectra
,
'labels'
:
tf
.
one_hot
(
sample
[
'bird_id'
],
self
.
dataset
.
n_classes
+
1
),
'false_sample'
:
false_spectra
}
if
self
.
augmentation
!=
None
:
yield
self
.
augmentation
(
sample
)
else
:
yield
sample
if
__name__
==
"__main__"
:
ds
=
Dataset
(
"/srv/TUG/datasets/cornell_birdcall_recognition"
)
...
...
src/utils/trainer.py
View file @
f46cbe7c
...
...
@@ -7,6 +7,9 @@ def clip_by_value_10(grad):
grad
=
tf
.
where
(
tf
.
math
.
is_finite
(
grad
),
grad
,
tf
.
zeros_like
(
grad
))
return
tf
.
clip_by_value
(
grad
,
-
10
,
10
)
def
constant_lr
(
epoch
):
return
1.0
class
ModelTrainer
():
def
__init__
(
self
,
training_data_generator
,
...
...
@@ -19,8 +22,12 @@ class ModelTrainer():
optimizer_classifier
=
tf
.
keras
.
optimizers
.
Adam
,
optimizer_generator
=
tf
.
keras
.
optimizers
.
Adam
,
optimizer_discriminator
=
tf
.
keras
.
optimizers
.
Adam
,
learning_rate_fn
=
None
,
base_learning_rate
=
1e-3
,
base_learning_rate_classifier
=
1e-3
,
base_learning_rate_generator
=
1e-4
,
base_learning_rate_discriminator
=
5e-5
,
learning_rate_fn_classifier
=
constant_lr
,
learning_rate_fn_generator
=
constant_lr
,
learning_rate_fn_discriminator
=
constant_lr
,
init_data
=
None
,
start_epoch
=
0
,
num_train_batches
=
None
,
...
...
@@ -33,12 +40,20 @@ class ModelTrainer():
self
.
classifier
=
classifier
self
.
generator
=
generator
self
.
discriminator
=
discriminator
self
.
learning_rate_fn
=
learning_rate_fn
self
.
save_dir
=
save_dir
self
.
base_learning_rate
=
base_learning_rate
self
.
learning_rate_classifier
=
tf
.
Variable
(
self
.
base_learning_rate
*
self
.
learning_rate_fn
(
start_epoch
))
self
.
learning_rate_generator
=
tf
.
Variable
(
self
.
base_learning_rate
*
self
.
learning_rate_fn
(
start_epoch
))
self
.
learning_rate_discriminator
=
tf
.
Variable
(
self
.
base_learning_rate
*
self
.
learning_rate_fn
(
start_epoch
))
#Set up learning rates
self
.
learning_rate_fn_classifier
=
learning_rate_fn_classifier
self
.
learning_rate_fn_generator
=
learning_rate_fn_generator
self
.
learning_rate_fn_discriminator
=
learning_rate_fn_discriminator
self
.
base_learning_rate_classifier
=
base_learning_rate_classifier
self
.
base_learning_rate_generator
=
base_learning_rate_generator
self
.
base_learning_rate_discriminator
=
base_learning_rate_discriminator
self
.
learning_rate_classifier
=
tf
.
Variable
(
self
.
base_learning_rate_classifier
*
self
.
learning_rate_fn_classifier
(
start_epoch
))
self
.
learning_rate_generator
=
tf
.
Variable
(
self
.
base_learning_rate_generator
*
self
.
learning_rate_fn_generator
(
start_epoch
))
self
.
learning_rate_discriminator
=
tf
.
Variable
(
self
.
base_learning_rate_discriminator
*
self
.
learning_rate_fn_discriminator
(
start_epoch
))
#Initialize models
self
.
classifier
(
init_data
,
training
=
True
)
self
.
classifier
.
summary
()
...
...
@@ -155,12 +170,13 @@ class ModelTrainer():
logits_classifier
=
self
.
classifier
(
x
,
training
=
training
)
return
tf
.
nn
.
softmax
(
logits_classifier
,
axis
=-
1
)
@
tf
.
function
def
compute_loss
(
self
,
x
,
y
,
training
=
True
):
logits_classifier
=
self
.
classifier
(
x
[
0
],
training
=
training
)
fake_
input
=
self
.
generator
(
x
[
1
],
training
)
fake_
features
=
self
.
generator
(
x
[
1
],
training
)
true
=
self
.
discriminator
(
x
[
0
],
training
)
false
=
self
.
discriminator
(
fake_
input
,
training
)
false
=
self
.
discriminator
(
fake_
features
,
training
)
# Cross entropy losses
class_loss
=
tf
.
reduce_mean
(
...
...
@@ -171,13 +187,13 @@ class ModelTrainer():
))
#Wasserstein losses
gen_loss
=
-
tf
.
reduce_mean
(
false
)
discr_loss
=
tf
.
reduce_mean
(
fals
e
)
-
tf
.
reduce_mean
(
tru
e
)
gen_loss
=
tf
.
reduce_mean
(
false
)
discr_loss
=
tf
.
reduce_mean
(
tru
e
)
-
tf
.
reduce_mean
(
fals
e
)
if
len
(
self
.
classifier
.
losses
)
>
0
:
weight_decay_loss
=
tf
.
add_n
(
self
.
classifier
.
losses
)
else
:
weight_decay_loss
=
0.0
weight_decay_loss
=
0.0
if
len
(
self
.
generator
.
losses
)
>
0
:
weight_decay_loss
+=
tf
.
add_n
(
self
.
generator
.
losses
)
...
...
@@ -190,9 +206,9 @@ class ModelTrainer():
total_loss
+=
discr_loss
total_loss
+=
weight_decay_loss
return
(
class_loss
,
gen_loss
,
discr_loss
,
weight_decay_loss
,
total_loss
),
predictions
return
(
class_loss
,
gen_loss
,
discr_loss
,
weight_decay_loss
,
total_loss
),(
predictions
,
fake_features
)
@
tf
.
function
def
compute_gradients
(
self
,
x
,
y
):
# Pass through network
with
tf
.
GradientTape
()
as
tape
:
...
...
@@ -216,25 +232,28 @@ class ModelTrainer():
return
out_grads
,
losses
,
predictions
@
tf
.
function
def
apply_gradients_classifier
(
self
,
gradients
,
variables
):
self
.
optimizer_classifier
.
apply_gradients
(
zip
(
gradients
,
variables
)
)
@
tf
.
function
def
apply_gradients_generator
(
self
,
gradients
,
variables
):
self
.
optimizer_generator
.
apply_gradients
(
zip
(
gradients
,
variables
)
)
@
tf
.
function
def
apply_gradients_discriminator
(
self
,
gradients
,
variables
):
self
.
optimizer_discriminator
.
apply_gradients
(
zip
(
gradients
,
variables
)
)
@
tf
.
function
def
split_grads
(
self
,
grads
):
classifier_grads
=
grads
[:
self
.
n_vars_classifier
]
generator_grads
=
grads
[
self
.
n_vars_classifier
:
self
.
n_vars_classifier
+
self
.
n_vars_generator
]
...
...
@@ -306,9 +325,9 @@ class ModelTrainer():
prog
=
tf
.
keras
.
utils
.
Progbar
(
self
.
max_number_batches
)
for
epoch
in
range
(
self
.
start_epoch
,
self
.
epochs
):
# Update learning rates
self
.
learning_rate_classifier
.
assign
(
self
.
base_learning_rate
*
self
.
learning_rate_fn
(
epoch
))
self
.
learning_rate_generator
.
assign
(
self
.
base_learning_rate
*
self
.
learning_rate_fn
(
epoch
))
self
.
learning_rate_discriminator
.
assign
(
self
.
base_learning_rate
*
self
.
learning_rate_fn
(
epoch
))
self
.
learning_rate_classifier
.
assign
(
self
.
base_learning_rate
_classifier
*
self
.
learning_rate_fn
_classifier
(
epoch
))
self
.
learning_rate_generator
.
assign
(
self
.
base_learning_rate
_generator
*
self
.
learning_rate_fn
_generator
(
epoch
))
self
.
learning_rate_discriminator
.
assign
(
self
.
base_learning_rate
_discriminator
*
self
.
learning_rate_fn
_discriminator
(
epoch
))
#Update epoch variable
self
.
epoch_variable
.
assign
(
epoch
)
with
self
.
summary_writer
[
"train"
].
as_default
():
...
...
@@ -324,21 +343,27 @@ class ModelTrainer():
train_y
=
self
.
get_data_for_keys
(
train_xy
,
self
.
label_keys
)
start_time
=
time
()
losses
,
predictions
=
self
.
train_step
(
train_x
,
train_y
)
losses
,
outputs
=
self
.
train_step
(
train_x
,
train_y
)
predictions
,
fake_features
=
outputs
class_loss
,
gen_loss
,
discr_loss
,
weight_decay_loss
,
total_loss
=
losses
# Update summaries
self
.
update_summaries
(
class_loss
,
gen_loss
,
discr_loss
,
weight_decay_loss
,
total_loss
,
predictions
,
train_y
,
train_y
[
0
]
,
"train"
)
batch
+=
1
prog
.
update
(
batch
,
[(
"class_loss"
,
self
.
scalar_summaries
[
"train_"
+
"class_loss"
].
result
()),
(
"g_loss"
,
self
.
scalar_summaries
[
"train_"
+
"generator_loss"
].
result
()),
(
"d_loss"
,
self
.
scalar_summaries
[
"train_"
+
"discriminator_loss"
].
result
()),
(
"wd_loss"
,
self
.
scalar_summaries
[
"train_"
+
"weight_decay_loss"
].
result
()),
(
"accuracy"
,
100
*
self
.
scalar_summaries
[
"train_"
+
"accuracy"
].
result
()),
(
"time / step"
,
np
.
round
(
time
()
-
start_time
,
2
))])
...
...
@@ -353,12 +378,13 @@ class ModelTrainer():
for
validation_xy
in
self
.
validation_data_generator
:
validation_x
=
self
.
get_data_for_keys
(
validation_xy
,
self
.
input_keys
)
validation_y
=
self
.
get_data_for_keys
(
validation_xy
,
self
.
label_keys
)
losses
,
prediction
s
=
self
.
compute_loss
(
validation_x
,
losses
,
output
s
=
self
.
compute_loss
(
validation_x
,
validation_y
,
training
=
False
)
predictions
,
fake_features
=
outputs
class_loss
,
gen_loss
,
discr_loss
,
weight_decay_loss
,
total_loss
=
losses
# Update summaries
self
.
update_summaries
(
class_loss
,
gen_loss
,
...
...
@@ -366,9 +392,20 @@ class ModelTrainer():
weight_decay_loss
,
total_loss
,
predictions
,
validation_y
,
validation_y
[
0
]
,
"val"
)
#Save generated audio STFT samples
with
self
.
summary_writer
[
"val"
].
as_default
():
tf
.
summary
.
image
(
"original_features"
,
validation_xy
[
"input_features"
],
step
=
epoch
,
max_outputs
=
1
)
tf
.
summary
.
image
(
"fake_features"
,
fake_features
,
step
=
epoch
,
max_outputs
=
1
)
# Write validation summaries
self
.
write_summaries
(
epoch
,
"val"
)
...
...
@@ -378,10 +415,10 @@ class ModelTrainer():
test_x
=
self
.
get_data_for_keys
(
test_xy
,
self
.
input_keys
)
test_y
=
self
.
get_data_for_keys
(
test_xy
,
self
.
label_keys
)
losses
,
prediction
s
=
self
.
compute_loss
(
test_x
,
losses
,
output
s
=
self
.
compute_loss
(
test_x
,
test_y
,
training
=
False
)
predictions
,
fake_features
=
outputs
class_loss
,
gen_loss
,
discr_loss
,
weight_decay_loss
,
total_loss
=
losses
# Update summaries
...
...
@@ -391,7 +428,7 @@ class ModelTrainer():
weight_decay_loss
,
total_loss
,
predictions
,
test_y
,
test_y
[
0
]
,
"test"
)
# Write test summaries
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment