Commit 03a7f25b authored by Alexander Fuchs's avatar Alexander Fuchs
Browse files

Preparations for WGAN training

parent 8c3e0003
......@@ -23,8 +23,8 @@ class ModelTrainer():
start_epoch = 0,
num_train_batches = None,
load_model = False,
input_key = "input_features",
label_key = "labels",
input_keys = ["input_features","false_sample"],
label_keys = ["labels"],
gradient_processesing_fn = clip_by_value_10,
save_dir = 'tmp'):
......@@ -69,8 +69,8 @@ class ModelTrainer():
self.summary_writer["train"] = tf.summary.create_file_writer(self.train_log_dir)
self.summary_writer["val"] = tf.summary.create_file_writer(self.validation_log_dir)
self.summary_writer["test"] = tf.summary.create_file_writer(self.test_log_dir)
self.input_key = input_key
self.label_key = label_key
self.input_keys = input_keys
self.label_keys = label_keys
#Get number of training batches
if num_train_batches == None:
......@@ -133,16 +133,23 @@ class ModelTrainer():
def compute_loss(self, x, y , training = True):
logits_classifier = self.classifier(x,training=training)
logits_classifier = self.classifier(x[0],training=training)
fake_input = self.generator(x[1],training)
true = self.discriminator(x[0],training)
false = self.discriminator(fake_input,training)
# Cross entropy losses
class_loss = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(
y,
y[0],
logits_classifier,
axis=-1,
))
#Wasserstein losses
gen_loss = -tf.reduce_mean(false)
discr_loss = tf.reduce_mean(false)-tf.reduce_mean(true)
if len(self.classifier.losses) > 0:
weight_decay_loss = tf.add_n(self.classifier.losses)
else:
......@@ -154,21 +161,21 @@ class ModelTrainer():
weight_decay_loss = 0.0
predictions = tf.nn.softmax(logits_classifier, axis=-1)
total_loss = class_loss
total_loss += weight_decay_loss
return class_loss,weight_decay_loss,total_loss,predictions
return (class_loss,gen_loss,discr_loss,weight_decay_loss,total_loss),predictions
def compute_gradients(self, x, y):
# Pass through network
with tf.GradientTape() as tape:
class_loss, weight_decay_loss,total_loss,predictions = self.compute_loss(
losses,predictions = self.compute_loss(
x, y, training=True)
gradients = tape.gradient(total_loss, self.classifier.trainable_variables)
gradients = tape.gradient(losses[-1], self.classifier.trainable_variables)
if self.gradient_processing_fn != None:
out_grads = []
......@@ -181,7 +188,7 @@ class ModelTrainer():
else:
out_grads = gradients
return out_grads, class_loss, weight_decay_loss,total_loss,predictions
return out_grads,losses,predictions
def apply_gradients(self, gradients, variables):
......@@ -193,11 +200,11 @@ class ModelTrainer():
@tf.function
def train_step(self,x,y):
#Compute gradients
gradients,class_loss,weight_decay_loss,total_loss,predictions = self.compute_gradients(x,y)
gradients,losses,predictions = self.compute_gradients(x,y)
#Apply gradeints
self.apply_gradients(gradients, self.classifier.trainable_variables)
return (class_loss,weight_decay_loss,total_loss,predictions)
return losses,predictions
def update_summaries(self,class_loss,weight_decay_loss,total_loss,predictions,labels,mode='train'):
......@@ -205,8 +212,8 @@ class ModelTrainer():
self.scalar_summaries[mode + "_" + "class_loss"].update_state(class_loss)
self.scalar_summaries[mode + "_" + "weight_decay_loss"].update_state(weight_decay_loss)
self.scalar_summaries[mode + "_" + "total_loss"].update_state(total_loss)
self.scalar_summaries[mode + "_" + "accuracy"].update_state(tf.argmax(predictions, axis=-1),
tf.argmax(labels, axis=-1))
self.scalar_summaries[mode + "_" + "accuracy"].update_state(tf.argmax(tf.squeeze(predictions), axis=-1),
tf.argmax(tf.squeeze(labels), axis=-1))
def write_summaries(self,epoch,mode="train"):
# Write summaries
......@@ -224,6 +231,14 @@ class ModelTrainer():
for key in self.scalar_summaries.keys():
self.scalar_summaries[key].reset_states()
def get_data_for_keys(self,xy,keys):
"""Returns list of input data"""
x = []
for key in keys:
x.append(xy[key])
return x
def train(self):
if self.start_epoch == 0:
print("Saving...")
......@@ -246,12 +261,12 @@ class ModelTrainer():
batch = 0
for train_xy in self.training_data_generator:
train_x = train_xy[self.input_key]
train_y = train_xy[self.label_key]
train_x = self.get_data_for_keys(train_xy,self.input_keys)
train_y = self.get_data_for_keys(train_xy,self.label_keys)
start_time = time()
outputs = self.train_step(train_x, train_y)
class_loss,weight_decay_loss,total_loss,predictions = outputs
losses,predictions = self.train_step(train_x, train_y)
class_loss,gen_loss,discr_loss,weight_decay_loss,total_loss = losses
# Update summaries
self.update_summaries(class_loss,
weight_decay_loss,
......@@ -275,12 +290,13 @@ class ModelTrainer():
if self.validation_data_generator != None:
for validation_xy in self.validation_data_generator:
validation_x = validation_xy[self.input_key]
validation_y = validation_xy[self.label_key]
validation_x = self.get_data_for_keys(validation_xy,self.input_keys)
validation_y = self.get_data_for_keys(validation_xy,self.label_keys)
class_loss, weight_decay_loss,total_loss,predictions = self.compute_loss(validation_x,
validation_y,
training=False)
losses,predictions = self.compute_loss(validation_x,
validation_y,
training=False)
class_loss,gen_loss,discr_loss,weight_decay_loss,total_loss = losses
# Update summaries
self.update_summaries(class_loss,
......@@ -296,13 +312,15 @@ class ModelTrainer():
if self.test_data_generator != None:
for test_xy in self.test_data_generator:
test_x = test_xy[self.input_key]
test_y = test_xy[self.label_key]
test_x = self.get_data_for_keys(test_xy,self.input_keys)
test_y = self.get_data_for_keys(test_xy,self.label_keys)
losses,predictions = self.compute_loss(test_x,
test_y,
training=False)
class_loss,gen_loss,discr_loss,weight_decay_loss,total_loss = losses
class_loss, weight_decay_loss, total_loss, predictions = self.compute_loss(test_x,
test_y,
training=False)
# Update summaries
self.update_summaries(class_loss,
weight_decay_loss,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment