Commit 1a47e306 authored by Alexander Fuchs's avatar Alexander Fuchs
Browse files

Trainer now allows for generic loss functions

parent 8ff8923d
import tensorflow as tf
class LossFunction(object):
def __init__(self,classifier,generator,discriminator):
self.classifier = classifier
self.generator = generator
self.discriminator = discriminator
@tf.function
def compute_loss(self, x, y, training=True):
logits_classifier = self.classifier(x[0], training=training)
fake_features = self.generator(x[1], training)
true = self.discriminator(x[0], training)
false = self.discriminator(fake_features, training)
# Cross entropy losses
class_loss = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(
y[0],
logits_classifier,
axis=-1,
))
gen_loss = tf.reduce_mean((fake_features - x[1]) ** 2)
# Wasserstein losses
gen_loss += tf.reduce_mean(false)
discr_loss = tf.reduce_mean(true) - tf.reduce_mean(false)
if len(self.classifier.losses) > 0:
weight_decay_loss = tf.add_n(self.classifier.losses)
else:
weight_decay_loss = 0.0
if len(self.generator.losses) > 0:
weight_decay_loss += tf.add_n(self.generator.losses)
predictions = tf.nn.softmax(logits_classifier, axis=-1)
total_loss = class_loss
total_loss += gen_loss
total_loss += discr_loss
total_loss += weight_decay_loss
return (class_loss, gen_loss, discr_loss, weight_decay_loss, total_loss), (predictions, fake_features)
\ No newline at end of file
......@@ -12,6 +12,7 @@ from models.classifier import Classifier
from models.res_block import ResBlockBasicLayer
from models.discriminator import Discriminator
from models.generator import Generator
from models.loss_functions.separate_classifier_wgan_loss_fn import LossFunction
import tensorflow as tf # pylint: disable=g-bad-import-order
BINS = 1025
......@@ -188,6 +189,7 @@ def main(argv):
val_data_gen,
None,
epochs,
LossFunction,
classifier = classifier_model,
generator = generator_model,
discriminator = discriminator_model,
......
......@@ -16,6 +16,7 @@ class ModelTrainer():
validation_data_generator,
test_data_generator,
epochs,
loss_fn,
classifier,
generator = None,
discriminator = None,
......@@ -64,6 +65,8 @@ class ModelTrainer():
self.n_vars_classifier = len(self.classifier.trainable_variables)
self.n_vars_generator = len(self.generator.trainable_variables)
self.n_vars_discriminator = len(self.discriminator.trainable_variables)
#Set up loss function
self.loss_fn = loss_fn(self.classifier,self.generator,self.discriminator)
#Set up optimizers
if optimizer_classifier == tf.keras.optimizers.SGD:
self.optimizer_classifier = optimizer_classifier(self.learning_rate_classifier,0.9,True)
......@@ -170,44 +173,8 @@ class ModelTrainer():
logits_classifier = self.classifier(x, training=training)
return tf.nn.softmax(logits_classifier, axis=-1)
@tf.function
def compute_loss(self, x, y , training = True):
logits_classifier = self.classifier(x[0],training=training)
fake_features = self.generator(x[1],training)
true = self.discriminator(x[0],training)
false = self.discriminator(fake_features,training)
# Cross entropy losses
class_loss = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(
y[0],
logits_classifier,
axis=-1,
))
gen_loss = tf.reduce_mean((fake_features - x[1])**2)
#Wasserstein losses
gen_loss += tf.reduce_mean(false)
discr_loss = tf.reduce_mean(true)-tf.reduce_mean(false)
if len(self.classifier.losses) > 0:
weight_decay_loss = tf.add_n(self.classifier.losses)
else:
weight_decay_loss = 0.0
if len(self.generator.losses) > 0:
weight_decay_loss += tf.add_n(self.generator.losses)
predictions = tf.nn.softmax(logits_classifier, axis=-1)
total_loss = class_loss
total_loss += gen_loss
total_loss += discr_loss
total_loss += weight_decay_loss
return (class_loss,gen_loss,discr_loss,weight_decay_loss,total_loss),(predictions,fake_features)
return self.loss_fn.compute_loss(x,y,training)
@tf.function
def compute_gradients(self, x, y):
......@@ -399,7 +366,7 @@ class ModelTrainer():
#Save generated audio STFT samples
with self.summary_writer["val"].as_default():
tf.summary.image("original_features",
validation_xy["input_features"],
validation_xy['false_sample'],
step = epoch,
max_outputs = 1)
tf.summary.image("fake_features",
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment