Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Alexander Fuchs
tensorflow2_trainer_template
Commits
03a7f25b
Commit
03a7f25b
authored
Aug 05, 2020
by
Alexander Fuchs
Browse files
Preparations for WGAN training
parent
8c3e0003
Changes
1
Hide whitespace changes
Inline
Side-by-side
src/utils/trainer.py
View file @
03a7f25b
...
...
@@ -23,8 +23,8 @@ class ModelTrainer():
start_epoch
=
0
,
num_train_batches
=
None
,
load_model
=
False
,
input_key
=
"input_features"
,
label_key
=
"labels"
,
input_key
s
=
[
"input_features"
,
"false_sample"
],
label_key
s
=
[
"labels"
]
,
gradient_processesing_fn
=
clip_by_value_10
,
save_dir
=
'tmp'
):
...
...
@@ -69,8 +69,8 @@ class ModelTrainer():
self
.
summary_writer
[
"train"
]
=
tf
.
summary
.
create_file_writer
(
self
.
train_log_dir
)
self
.
summary_writer
[
"val"
]
=
tf
.
summary
.
create_file_writer
(
self
.
validation_log_dir
)
self
.
summary_writer
[
"test"
]
=
tf
.
summary
.
create_file_writer
(
self
.
test_log_dir
)
self
.
input_key
=
input_key
self
.
label_key
=
label_key
self
.
input_key
s
=
input_key
s
self
.
label_key
s
=
label_key
s
#Get number of training batches
if
num_train_batches
==
None
:
...
...
@@ -133,16 +133,23 @@ class ModelTrainer():
def
compute_loss
(
self
,
x
,
y
,
training
=
True
):
logits_classifier
=
self
.
classifier
(
x
,
training
=
training
)
logits_classifier
=
self
.
classifier
(
x
[
0
],
training
=
training
)
fake_input
=
self
.
generator
(
x
[
1
],
training
)
true
=
self
.
discriminator
(
x
[
0
],
training
)
false
=
self
.
discriminator
(
fake_input
,
training
)
# Cross entropy losses
class_loss
=
tf
.
reduce_mean
(
tf
.
nn
.
softmax_cross_entropy_with_logits
(
y
,
y
[
0
]
,
logits_classifier
,
axis
=-
1
,
))
#Wasserstein losses
gen_loss
=
-
tf
.
reduce_mean
(
false
)
discr_loss
=
tf
.
reduce_mean
(
false
)
-
tf
.
reduce_mean
(
true
)
if
len
(
self
.
classifier
.
losses
)
>
0
:
weight_decay_loss
=
tf
.
add_n
(
self
.
classifier
.
losses
)
else
:
...
...
@@ -154,21 +161,21 @@ class ModelTrainer():
weight_decay_loss
=
0.0
predictions
=
tf
.
nn
.
softmax
(
logits_classifier
,
axis
=-
1
)
total_loss
=
class_loss
total_loss
+=
weight_decay_loss
return
class_loss
,
weight_decay_loss
,
total_loss
,
predictions
return
(
class_loss
,
gen_loss
,
discr_loss
,
weight_decay_loss
,
total_loss
)
,
predictions
def
compute_gradients
(
self
,
x
,
y
):
# Pass through network
with
tf
.
GradientTape
()
as
tape
:
class_loss
,
weight_decay_loss
,
total_
loss
,
predictions
=
self
.
compute_loss
(
los
se
s
,
predictions
=
self
.
compute_loss
(
x
,
y
,
training
=
True
)
gradients
=
tape
.
gradient
(
total_loss
,
self
.
classifier
.
trainable_variables
)
gradients
=
tape
.
gradient
(
losses
[
-
1
]
,
self
.
classifier
.
trainable_variables
)
if
self
.
gradient_processing_fn
!=
None
:
out_grads
=
[]
...
...
@@ -181,7 +188,7 @@ class ModelTrainer():
else
:
out_grads
=
gradients
return
out_grads
,
class_loss
,
weight_decay_loss
,
total_
loss
,
predictions
return
out_grads
,
los
se
s
,
predictions
def
apply_gradients
(
self
,
gradients
,
variables
):
...
...
@@ -193,11 +200,11 @@ class ModelTrainer():
@
tf
.
function
def
train_step
(
self
,
x
,
y
):
#Compute gradients
gradients
,
class_loss
,
weight_decay_loss
,
total_
loss
,
predictions
=
self
.
compute_gradients
(
x
,
y
)
gradients
,
los
se
s
,
predictions
=
self
.
compute_gradients
(
x
,
y
)
#Apply gradeints
self
.
apply_gradients
(
gradients
,
self
.
classifier
.
trainable_variables
)
return
(
class_loss
,
weight_decay_loss
,
total_
loss
,
predictions
)
return
los
se
s
,
predictions
def
update_summaries
(
self
,
class_loss
,
weight_decay_loss
,
total_loss
,
predictions
,
labels
,
mode
=
'train'
):
...
...
@@ -205,8 +212,8 @@ class ModelTrainer():
self
.
scalar_summaries
[
mode
+
"_"
+
"class_loss"
].
update_state
(
class_loss
)
self
.
scalar_summaries
[
mode
+
"_"
+
"weight_decay_loss"
].
update_state
(
weight_decay_loss
)
self
.
scalar_summaries
[
mode
+
"_"
+
"total_loss"
].
update_state
(
total_loss
)
self
.
scalar_summaries
[
mode
+
"_"
+
"accuracy"
].
update_state
(
tf
.
argmax
(
predictions
,
axis
=-
1
),
tf
.
argmax
(
labels
,
axis
=-
1
))
self
.
scalar_summaries
[
mode
+
"_"
+
"accuracy"
].
update_state
(
tf
.
argmax
(
tf
.
squeeze
(
predictions
)
,
axis
=-
1
),
tf
.
argmax
(
tf
.
squeeze
(
labels
)
,
axis
=-
1
))
def
write_summaries
(
self
,
epoch
,
mode
=
"train"
):
# Write summaries
...
...
@@ -224,6 +231,14 @@ class ModelTrainer():
for
key
in
self
.
scalar_summaries
.
keys
():
self
.
scalar_summaries
[
key
].
reset_states
()
def
get_data_for_keys
(
self
,
xy
,
keys
):
"""Returns list of input data"""
x
=
[]
for
key
in
keys
:
x
.
append
(
xy
[
key
])
return
x
def
train
(
self
):
if
self
.
start_epoch
==
0
:
print
(
"Saving..."
)
...
...
@@ -246,12 +261,12 @@ class ModelTrainer():
batch
=
0
for
train_xy
in
self
.
training_data_generator
:
train_x
=
train_xy
[
self
.
input_key
]
train_y
=
train_xy
[
self
.
label_key
]
train_x
=
self
.
get_data_for_keys
(
train_xy
,
self
.
input_key
s
)
train_y
=
self
.
get_data_for_keys
(
train_xy
,
self
.
label_key
s
)
start_time
=
time
()
output
s
=
self
.
train_step
(
train_x
,
train_y
)
class_loss
,
weight_decay_loss
,
total_loss
,
predictions
=
output
s
losses
,
prediction
s
=
self
.
train_step
(
train_x
,
train_y
)
class_loss
,
gen_loss
,
discr_loss
,
weight_decay_loss
,
total_loss
=
losse
s
# Update summaries
self
.
update_summaries
(
class_loss
,
weight_decay_loss
,
...
...
@@ -275,12 +290,13 @@ class ModelTrainer():
if
self
.
validation_data_generator
!=
None
:
for
validation_xy
in
self
.
validation_data_generator
:
validation_x
=
validation_xy
[
self
.
input_key
]
validation_y
=
validation_xy
[
self
.
label_key
]
validation_x
=
self
.
get_data_for_keys
(
validation_xy
,
self
.
input_key
s
)
validation_y
=
self
.
get_data_for_keys
(
validation_xy
,
self
.
label_key
s
)
class_loss
,
weight_decay_loss
,
total_loss
,
predictions
=
self
.
compute_loss
(
validation_x
,
validation_y
,
training
=
False
)
losses
,
predictions
=
self
.
compute_loss
(
validation_x
,
validation_y
,
training
=
False
)
class_loss
,
gen_loss
,
discr_loss
,
weight_decay_loss
,
total_loss
=
losses
# Update summaries
self
.
update_summaries
(
class_loss
,
...
...
@@ -296,13 +312,15 @@ class ModelTrainer():
if
self
.
test_data_generator
!=
None
:
for
test_xy
in
self
.
test_data_generator
:
test_x
=
test_xy
[
self
.
input_key
]
test_y
=
test_xy
[
self
.
label_key
]
test_x
=
self
.
get_data_for_keys
(
test_xy
,
self
.
input_keys
)
test_y
=
self
.
get_data_for_keys
(
test_xy
,
self
.
label_keys
)
losses
,
predictions
=
self
.
compute_loss
(
test_x
,
test_y
,
training
=
False
)
class_loss
,
gen_loss
,
discr_loss
,
weight_decay_loss
,
total_loss
=
losses
class_loss
,
weight_decay_loss
,
total_loss
,
predictions
=
self
.
compute_loss
(
test_x
,
test_y
,
training
=
False
)
# Update summaries
self
.
update_summaries
(
class_loss
,
weight_decay_loss
,
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment