Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Alexander Fuchs
tensorflow2_trainer_template
Commits
5441bc86
Commit
5441bc86
authored
Aug 11, 2020
by
Alexander Fuchs
Browse files
Added support for image summaries
parent
59ce73cd
Changes
3
Hide whitespace changes
Inline
Side-by-side
src/scripts/birdsong_simple_main.py
View file @
5441bc86
...
...
@@ -8,7 +8,7 @@ from absl import flags
from
utils.trainer
import
ModelTrainer
from
utils.data_loader
import
Dataset
from
utils.data_loader
import
DataGenerator
from
utils.summary_utils
import
Scalar
Summaries
from
utils.summary_utils
import
Summaries
from
models.classifier
import
Classifier
from
models.res_block
import
ResBlockBasicLayer
from
models.discriminator
import
Discriminator
...
...
@@ -187,12 +187,13 @@ def main(argv):
n_classes
=
ds_train
.
n_classes
)
summaries
=
Scalar
Summaries
(
scalar_summary_names
=
[
"class_loss"
,
summaries
=
Summaries
(
scalar_summary_names
=
[
"class_loss"
,
"generator_loss"
,
"discriminator_loss"
,
"weight_decay_loss"
,
"total_loss"
,
"accuracy"
],
image_summary_settings
=
{
'train'
:[
'fake_features'
],
'n_images'
:
1
},
learning_rate_names
=
[
'learning_rate_'
+
str
(
classifier_model
.
model_name
),
'learning_rate_'
+
str
(
generator_model
.
model_name
),
'learning_rate_'
+
str
(
discriminator_model
.
model_name
)],
...
...
@@ -214,15 +215,15 @@ def main(argv):
'init_data'
:
tf
.
random
.
normal
([
batch_size
,
BINS
,
N_FRAMES
,
N_CHANNELS
])},
{
'model'
:
generator_model
,
'optimizer_type'
:
tf
.
keras
.
optimizers
.
Adam
,
'base_learning_rate'
:
lr
*
0.
00
01
,
'base_learning_rate'
:
lr
*
0.01
,
'learning_rate_fn'
:
learning_rate_fn
,
'init_data'
:
tf
.
random
.
normal
([
batch_size
,
BINS
,
N_FRAMES
,
N_CHANNELS
])},
{
'model'
:
discriminator_model
,
'optimizer_type'
:
tf
.
keras
.
optimizers
.
Adam
,
'base_learning_rate'
:
lr
*
0.
0
02
,
'base_learning_rate'
:
lr
*
0.02
,
'learning_rate_fn'
:
learning_rate_fn
,
'init_data'
:
tf
.
random
.
normal
([
batch_size
,
BINS
,
N_FRAMES
,
N_CHANNELS
])}],
scalar_
summaries
=
summaries
,
summaries
=
summaries
,
num_train_batches
=
int
(
n_train
/
batch_size
),
load_model
=
load_model
,
save_dir
=
model_save_dir
,
...
...
src/utils/summary_utils.py
View file @
5441bc86
import
os
import
tensorflow
as
tf
class
Scalar
Summaries
(
object
):
def
__init__
(
self
,
scalar_summary_names
,
learning_rate_names
,
save_dir
,
modes
=
[
"train"
,
"val"
,
"test"
],
summaries_to_print
=
{}):
class
Summaries
(
object
):
def
__init__
(
self
,
scalar_summary_names
,
learning_rate_names
,
image_summary_settings
=
{},
save_dir
=
"/tmp"
,
modes
=
[
"train"
,
"val"
,
"test"
],
summaries_to_print
=
{}):
self
.
scalar_summary_names
=
scalar_summary_names
self
.
image_summary_settings
=
image_summary_settings
self
.
save_dir
=
save_dir
self
.
modes
=
modes
self
.
summaries_to_print
=
summaries_to_print
self
.
scalar_summaries
=
{}
self
.
image_data
=
{}
self
.
lr_summaries
=
{}
self
.
learning_rate_names
=
learning_rate_names
self
.
summary_writers
=
{}
...
...
@@ -24,8 +26,9 @@ class ScalarSummaries(object):
for
key
in
self
.
summaries_to_print
[
mode
]:
summary_list
.
append
((
key
,
self
.
scalar_summaries
[
tmp_mode
+
'_'
+
key
].
result
()))
else
:
for
key
in
self
.
summaries_to_print
[
mode
]:
summary_list
.
append
((
key
,
self
.
scalar_summaries
[
mode
+
'_'
+
key
].
result
()))
if
mode
in
self
.
summaries_to_print
.
keys
():
for
key
in
self
.
summaries_to_print
[
mode
]:
summary_list
.
append
((
key
,
self
.
scalar_summaries
[
mode
+
'_'
+
key
].
result
()))
return
summary_list
def
create_summary_writers
(
self
):
...
...
@@ -68,6 +71,26 @@ class ScalarSummaries(object):
tf
.
summary
.
scalar
(
scalar_map_key
,
self
.
scalar_summaries
[
key
].
result
(),
step
=
epoch
)
def
update_image_data
(
self
,
outputs
,
mode
=
"train"
):
self
.
image_data
=
{}
for
key
in
outputs
.
keys
():
if
mode
in
self
.
image_summary_settings
.
keys
():
if
key
in
self
.
image_summary_settings
[
mode
]:
self
.
image_data
[
key
]
=
outputs
[
key
]
def
write_image_summaries
(
self
,
epoch
,
mode
=
"train"
):
with
self
.
summary_writers
[
mode
].
as_default
():
if
mode
in
self
.
image_summary_settings
.
keys
():
try
:
n_images
=
self
.
image_summary_settings
[
"n_images"
]
except
:
n_images
=
1
for
image_summary
in
self
.
image_summary_settings
[
mode
]:
tf
.
summary
.
image
(
image_summary
,
self
.
image_data
[
image_summary
],
step
=
epoch
,
max_outputs
=
n_images
)
def
write_lr
(
self
,
epoch
):
for
key
in
self
.
lr_summaries
.
keys
():
...
...
src/utils/trainer.py
View file @
5441bc86
...
...
@@ -22,7 +22,7 @@ class ModelTrainer():
base_learning_rates
=
[],
learning_rate_fns
=
[],
init_data
=
[],
scalar_
summaries
=
None
,
summaries
=
None
,
start_epoch
=
0
,
num_train_batches
=
None
,
load_model
=
False
,
...
...
@@ -53,7 +53,7 @@ class ModelTrainer():
self
.
initialize_models
(
init_data
)
self
.
n_vars_models
=
[
len
(
model
.
trainable_variables
)
for
model
in
self
.
models
]
#Set up summaries
self
.
scalar_
summaries
=
scalar_
summaries
self
.
summaries
=
summaries
#Set up loss function
self
.
eval_fns
=
eval_fns
(
self
.
models
)
#Set up optimizers
...
...
@@ -170,7 +170,7 @@ class ModelTrainer():
self
.
learning_rates
[
i
].
assign
(
self
.
base_learning_rates
[
i
]
*
self
.
learning_rate_fns
[
0
](
epoch
))
else
:
self
.
learning_rates
[
i
].
assign
(
self
.
base_learning_rates
[
i
]
*
self
.
learning_rate_fns
[
i
](
epoch
))
if
self
.
scalar_
summaries
!=
None
:
if
self
.
summaries
!=
None
:
self
.
update_learning_rate_summaries
()
self
.
write_learning_rate_summaries
(
epoch
)
...
...
@@ -230,15 +230,15 @@ class ModelTrainer():
def
update_learning_rate_summaries
(
self
):
lr_dict
=
{}
for
lr_name
,
lr
in
zip
(
self
.
scalar_
summaries
.
learning_rate_names
,
self
.
learning_rates
):
for
lr_name
,
lr
in
zip
(
self
.
summaries
.
learning_rate_names
,
self
.
learning_rates
):
lr_dict
[
lr_name
]
=
lr
self
.
scalar_
summaries
.
update_lr
(
lr_dict
)
self
.
summaries
.
update_lr
(
lr_dict
)
def
update_summaries
(
self
,
losses
,
outputs
,
y
=
None
,
mode
=
'train'
):
if
self
.
scalar_
summaries
!=
None
:
if
'accuracy'
in
self
.
scalar_
summaries
.
scalar_summary_names
:
if
self
.
summaries
!=
None
:
if
'accuracy'
in
self
.
summaries
.
scalar_summary_names
:
pred
=
outputs
[
'predictions'
]
accuracy
=
self
.
eval_fns
.
accuracy
(
pred
,
y
)
scalars
=
losses
...
...
@@ -246,25 +246,29 @@ class ModelTrainer():
else
:
scalars
=
losses
#Update summaries
self
.
scalar_summaries
.
update
(
scalars
,
mode
)
#Update scalar summaries
self
.
summaries
.
update
(
scalars
,
mode
)
#Update image summaries
self
.
summaries
.
update_image_data
(
outputs
,
mode
)
def
write_summaries
(
self
,
epoch
,
mode
=
"train"
):
if
self
.
scalar_summaries
!=
None
:
# Write summaries
self
.
scalar_summaries
.
write
(
epoch
,
mode
)
if
self
.
summaries
!=
None
:
# Write scalar summaries
self
.
summaries
.
write
(
epoch
,
mode
)
#Write image summaries
self
.
summaries
.
write_image_summaries
(
epoch
,
mode
)
def
write_learning_rate_summaries
(
self
,
epoch
):
if
self
.
scalar_
summaries
!=
None
:
if
self
.
summaries
!=
None
:
# Write summaries
self
.
scalar_
summaries
.
write_lr
(
epoch
)
self
.
summaries
.
write_lr
(
epoch
)
def
reset_summaries
(
self
):
if
self
.
scalar_
summaries
!=
None
:
if
self
.
summaries
!=
None
:
# Write summaries
self
.
scalar_
summaries
.
reset_summaries
()
self
.
summaries
.
reset_summaries
()
def
get_data_for_keys
(
self
,
xy
,
keys
):
"""Returns list of input data"""
...
...
@@ -317,8 +321,8 @@ class ModelTrainer():
self
.
update_summaries
(
losses
,
outputs
,
train_y
,
'train'
)
batch
+=
1
if
self
.
scalar_
summaries
!=
None
:
summary_list
=
self
.
scalar_
summaries
.
get_summary_list
(
'train'
)
if
self
.
summaries
!=
None
:
summary_list
=
self
.
summaries
.
get_summary_list
(
'train'
)
else
:
summary_list
=
[]
summary_list
+=
[(
"time / step"
,
np
.
round
(
time
()
-
start_time
,
2
))]
...
...
@@ -330,6 +334,7 @@ class ModelTrainer():
# Write train summaries
self
.
write_summaries
(
epoch
+
1
,
'train'
)
if
self
.
validation_data_generator
!=
None
:
for
validation_xy
in
self
.
validation_data_generator
:
validation_x
=
self
.
get_data_for_keys
(
validation_xy
,
self
.
input_keys
)
...
...
@@ -364,14 +369,14 @@ class ModelTrainer():
self
.
write_summaries
(
epoch
+
1
,
"test"
)
if
self
.
scalar_
summaries
!=
None
:
if
self
.
summaries
!=
None
:
template
=
'Epoch {}, Loss: {}, Accuracy: {},Val Loss: {}, Val Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
summary_list
=
self
.
scalar_
summaries
.
get_summary_list
(
'eval'
)
summary_list
=
self
.
summaries
.
get_summary_list
(
'eval'
)
scalars
=
[
x
[
1
]
for
x
in
summary_list
]
print
(
template
.
format
(
epoch
+
1
,
*
scalars
))
#Reset summaries after epoch
if
self
.
scalar_
summaries
!=
None
:
self
.
scalar_
summaries
.
reset_summaries
()
if
self
.
summaries
!=
None
:
self
.
summaries
.
reset_summaries
()
#Save the model weights
self
.
save_model
(
epoch
+
1
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment