Commit 8f319eed authored by Alexander Fuchs's avatar Alexander Fuchs
Browse files

Implemented Wasserstein GAN for augmentation

parent 03a7f25b
......@@ -52,4 +52,4 @@ class Generator(tf.keras.Model):
elif layer[1]==1:
x = layer[0](x,training)
return x
return x+input
......@@ -7,9 +7,10 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from utils.trainer import ModelTrainer
from utils.data_loader import Dataset
from utils.data_loader import DataGenerator
from models.network import Network
from models.classifier import Classifier
from models.res_block import ResBlockBasicLayer
from models.discriminator import Discriminator
from models.generator import Generator
#logging.set_verbosity(logging.WARNING)
......@@ -115,7 +116,7 @@ def main(argv):
n_val = n_total-n_train
#ResNet 18
model = Network(ResBlockBasicLayer,
classifier_model = Classifier(ResBlockBasicLayer,
n_blocks = 4,
n_layers = [2,2,2,2],
strides = [2,2,2,2],
......@@ -128,6 +129,18 @@ def main(argv):
kernel_regularizer = tf.keras.regularizers.l2(2e-4),
kernel_initializer = tf.keras.initializers.he_normal(),
dropout=dropout_rate)
#Generator model used to augment to false samples
generator_model = Generator(4,
[16,32,16,1],
kernel_regularizer = tf.keras.regularizers.l2(2e-4),
kernel_initializer = tf.keras.initializers.he_normal(),
name = "generator")
#Discriminator for estimating the Wasserstein distance
discriminator_model = Discriminator(3,
[32,64,128],
[2,2,2],
name = "discriminator")
train_data_gen = data_generator(dg_train.generate,batch_size,
is_training=True,
......@@ -143,13 +156,17 @@ def main(argv):
take_n=n_val)
trainer = ModelTrainer(model,
train_data_gen,
trainer = ModelTrainer(train_data_gen,
val_data_gen,
None,
epochs,
classifier = classifier_model,
generator = generator_model,
discriminator = generator_model,
learning_rate_fn = learning_rate_fn,
optimizer = tf.keras.optimizers.Adam,
optimizer_classifier = tf.keras.optimizers.Adam,
optimizer_generator = tf.keras.optimizers.Adam,
optimizer_discriminator = tf.keras.optimizers.Adam,
num_train_batches = int(n_train/batch_size),
base_learning_rate = lr,
load_model = load_model,
......
......@@ -145,6 +145,7 @@ class DataGenerator(object):
self.dataset = dataset
#Shuffle files before loading since dataset is ordered by class
if shuffle:
random.seed(4)
random.shuffle(self.dataset.train_samples)
self.augmentation = augmentation
......@@ -207,44 +208,44 @@ class DataGenerator(object):
def create_feature(self,sample):
"""Creates the features by doing a STFT"""
try:
filename = sample['filename']
channels_str = sample['channels']
channels = int(channels_str.split(" ")[0])
if channels == 1:
mono = True
else:
mono = False
y, sr = librosa.core.load(filename,mono=mono,sr=self.sampling_rate)
y,_ = librosa.effects.trim(y)
if mono == True:
y = np.expand_dims(y,0)
#try:
filename = sample['filename']
channels_str = sample['channels']
channels = int(channels_str.split(" ")[0])
if channels == 1:
mono = True
else:
mono = False
y, sr = librosa.core.load(filename,mono=mono,sr=self.sampling_rate)
y,_ = librosa.effects.trim(y)
if mono == True:
y = np.expand_dims(y,0)
duration = y.shape[-1]/self.sampling_rate
n_samples = int(np.ceil(duration/self.max_time))
n_samples = min(n_samples,self.max_samples_per_audio)
spectra = {}
for i_sample in range(n_samples):
start = i_sample*int(self.sampling_rate*self.max_time)
end = (i_sample+1)*int(self.sampling_rate*self.max_time)
end = min(end,y.shape[-1])
y_sample = y[:,start:end]
if y_sample.shape[-1] == 1:
break
#Transform audio
spectrum = self.do_stft(y_sample,channels)
#Pad spectrum
spectrum = self.pad_sample(spectrum,
x_size=np.ceil(self.max_time*self.sampling_rate/self.hop_length))
spectra[str(i_sample)] = spectrum
if "mp3" in filename:
np.savez(filename.replace("mp3","npz"),spectra)
else:
np.savez(filename.replace("wav","npz"),spectra)
except:
spectra = None
print(sample['filename']+" failed at feature extraction!")
duration = y.shape[-1]/self.sampling_rate
n_samples = int(np.ceil(duration/self.max_time))
n_samples = min(n_samples,self.max_samples_per_audio)
spectra = {}
for i_sample in range(n_samples):
start = i_sample*int(self.sampling_rate*self.max_time)
end = (i_sample+1)*int(self.sampling_rate*self.max_time)
end = min(end,y.shape[-1])
y_sample = y[:,start:end]
if y_sample.shape[-1] == 1:
break
#Transform audio
spectrum = self.do_stft(y_sample,channels)
#Pad spectrum
spectrum = self.pad_sample(spectrum,
x_size=np.ceil(self.max_time*self.sampling_rate/self.hop_length))
spectra[str(i_sample)] = spectrum
if "mp3" in filename:
np.savez(filename.replace("mp3","npz"),spectra)
else:
np.savez(filename.replace("wav","npz"),spectra)
#except:
# spectra = None
# print(sample['filename']+" failed at feature extraction!")
return spectra
......
......@@ -16,7 +16,9 @@ class ModelTrainer():
classifier,
generator = None,
discriminator = None,
optimizer = tf.keras.optimizers.Adam,
optimizer_classifier = tf.keras.optimizers.Adam,
optimizer_generator = tf.keras.optimizers.Adam,
optimizer_discriminator = tf.keras.optimizers.Adam,
learning_rate_fn = None,
base_learning_rate=1e-3,
init_data = None,
......@@ -34,7 +36,9 @@ class ModelTrainer():
self.learning_rate_fn = learning_rate_fn
self.save_dir = save_dir
self.base_learning_rate = base_learning_rate
self.learning_rate = tf.Variable(self.base_learning_rate*self.learning_rate_fn(start_epoch))
self.learning_rate_classifier = tf.Variable(self.base_learning_rate*self.learning_rate_fn(start_epoch))
self.learning_rate_generator = tf.Variable(self.base_learning_rate*self.learning_rate_fn(start_epoch))
self.learning_rate_discriminator = tf.Variable(self.base_learning_rate*self.learning_rate_fn(start_epoch))
#Initialize models
self.classifier(init_data,training = True)
self.classifier.summary()
......@@ -42,17 +46,32 @@ class ModelTrainer():
self.generator.summary()
self.discriminator(init_data,training = True)
self.discriminator.summary()
#Set up optimizer
if optimizer == tf.keras.optimizers.SGD:
self.optimizer = optimizer(self.learning_rate,0.9,True)
self.n_vars_classifier = len(self.classifier.trainable_variables)
self.n_vars_generator = len(self.generator.trainable_variables)
self.n_vars_discriminator = len(self.discriminator.trainable_variables)
#Set up optimizers
if optimizer_classifier == tf.keras.optimizers.SGD:
self.optimizer_classifier = optimizer_classifier(self.learning_rate_classifier,0.9,True)
else:
self.optimizer = optimizer(self.learning_rate)
self.optimizer_classifier = optimizer_classifier(self.learning_rate_classifier)
if optimizer_generator == tf.keras.optimizers.SGD:
self.optimizer_generator = optimizer_generator(self.learning_rate_generator,0.9,True)
else:
self.optimizer_generator = optimizer_generator(self.learning_rate_generator)
if optimizer_discriminator == tf.keras.optimizers.SGD:
self.optimizer_discriminator = optimizer_discriminator(self.learning_rate_discriminator,0.9,True)
else:
self.optimizer_discriminator = optimizer_discriminator(self.learning_rate_discriminator)
#Set up function to process the gradients
self.gradient_processing_fn = gradient_processesing_fn
self.epochs = epochs
self.epoch_variable = tf.Variable(start_epoch)
self.start_epoch = start_epoch
#Load model
if load_model:
self.load_model()
......@@ -78,7 +97,12 @@ class ModelTrainer():
else:
self.max_number_batches = num_train_batches
self.define_scalar_summaries(["accuracy","class_loss","weight_decay_loss","total_loss"])
self.define_scalar_summaries(["accuracy",
"class_loss",
"generator_loss",
"discriminator_loss",
"weight_decay_loss",
"total_loss"])
def define_scalar_summaries(self,names_list):
......@@ -110,7 +134,7 @@ class ModelTrainer():
self.start_epoch = int(str_epoch)
model_to_load = latest
else:
split[-2] = str(start_epoch)
split[-2] = str(self.start_epoch)
model_to_load = "_".join(split)
if print_model_name:
......@@ -156,13 +180,14 @@ class ModelTrainer():
weight_decay_loss = 0.0
if len(self.generator.losses) > 0:
weight_decay_loss = tf.add_n(self.generator.losses)
else:
weight_decay_loss = 0.0
weight_decay_loss += tf.add_n(self.generator.losses)
predictions = tf.nn.softmax(logits_classifier, axis=-1)
total_loss = class_loss
total_loss += gen_loss
total_loss += discr_loss
total_loss += weight_decay_loss
return (class_loss,gen_loss,discr_loss,weight_decay_loss,total_loss),predictions
......@@ -174,8 +199,9 @@ class ModelTrainer():
losses,predictions = self.compute_loss(
x, y, training=True)
gradients = tape.gradient(losses[-1], self.classifier.trainable_variables)
all_vars = self.classifier.trainable_variables +self.generator.trainable_variables +self.discriminator.trainable_variables
gradients = tape.gradient(losses[-1],all_vars)
if self.gradient_processing_fn != None:
out_grads = []
......@@ -191,25 +217,49 @@ class ModelTrainer():
return out_grads,losses,predictions
def apply_gradients(self, gradients, variables):
def apply_gradients_classifier(self, gradients, variables):
self.optimizer_classifier.apply_gradients(
zip(gradients, variables)
)
def apply_gradients_generator(self, gradients, variables):
self.optimizer_generator.apply_gradients(
zip(gradients, variables)
)
def apply_gradients_discriminator(self, gradients, variables):
self.optimizer.apply_gradients(
self.optimizer_discriminator.apply_gradients(
zip(gradients, variables)
)
def split_grads(self,grads):
classifier_grads = grads[:self.n_vars_classifier]
generator_grads = grads[self.n_vars_classifier:self.n_vars_classifier+self.n_vars_generator]
discriminator_grads = grads[-self.n_vars_discriminator:]
return classifier_grads,generator_grads,discriminator_grads
@tf.function
def train_step(self,x,y):
#Compute gradients
gradients,losses,predictions = self.compute_gradients(x,y)
#Apply gradeints
self.apply_gradients(gradients, self.classifier.trainable_variables)
classifier_grads,generator_grads,discriminator_grads = self.split_grads(gradients)
#Apply gradients
self.apply_gradients_classifier(classifier_grads, self.classifier.trainable_variables)
self.apply_gradients_generator(generator_grads, self.generator.trainable_variables)
self.apply_gradients_discriminator(discriminator_grads, self.discriminator.trainable_variables)
return losses,predictions
def update_summaries(self,class_loss,weight_decay_loss,total_loss,predictions,labels,mode='train'):
def update_summaries(self,class_loss,gen_loss,discr_loss,weight_decay_loss,
total_loss,predictions,labels,mode='train'):
#Update summaries
self.scalar_summaries[mode + "_" + "class_loss"].update_state(class_loss)
self.scalar_summaries[mode + "_" + "generator_loss"].update_state(gen_loss)
self.scalar_summaries[mode + "_" + "discriminator_loss"].update_state(discr_loss)
self.scalar_summaries[mode + "_" + "weight_decay_loss"].update_state(weight_decay_loss)
self.scalar_summaries[mode + "_" + "total_loss"].update_state(total_loss)
self.scalar_summaries[mode + "_" + "accuracy"].update_state(tf.argmax(tf.squeeze(predictions), axis=-1),
......@@ -220,6 +270,10 @@ class ModelTrainer():
with self.summary_writer[mode].as_default():
tf.summary.scalar('class_loss',
self.scalar_summaries[mode + "_" + "class_loss"].result(), step=epoch)
tf.summary.scalar('generator_loss',
self.scalar_summaries[mode + "_" + "generator_loss"].result(), step=epoch)
tf.summary.scalar('discriminator_loss',
self.scalar_summaries[mode + "_" + "discriminator_loss"].result(), step=epoch)
tf.summary.scalar('weight_decay_loss',
self.scalar_summaries[mode + "_" + "weight_decay_loss"].result(), step=epoch)
tf.summary.scalar('total_loss',
......@@ -251,11 +305,16 @@ class ModelTrainer():
# Progress bar
prog = tf.keras.utils.Progbar(self.max_number_batches)
for epoch in range(self.start_epoch, self.epochs):
# Update learning rate
self.learning_rate.assign(self.base_learning_rate * self.learning_rate_fn(epoch))
# Update learning rates
self.learning_rate_classifier.assign(self.base_learning_rate * self.learning_rate_fn(epoch))
self.learning_rate_generator.assign(self.base_learning_rate * self.learning_rate_fn(epoch))
self.learning_rate_discriminator.assign(self.base_learning_rate * self.learning_rate_fn(epoch))
#Update epoch variable
self.epoch_variable.assign(epoch)
with self.summary_writer["train"].as_default():
tf.summary.scalar('learning_rate', self.learning_rate, step=epoch)
tf.summary.scalar('learning_rate_classifier', self.learning_rate_classifier, step=epoch)
tf.summary.scalar('learning_rate_generator', self.learning_rate_generator, step=epoch)
tf.summary.scalar('learning_rate_discriminator', self.learning_rate_discriminator, step=epoch)
# Start epoch
print("Starting epoch " + str(epoch + 1))
batch = 0
......@@ -269,6 +328,8 @@ class ModelTrainer():
class_loss,gen_loss,discr_loss,weight_decay_loss,total_loss = losses
# Update summaries
self.update_summaries(class_loss,
gen_loss,
discr_loss,
weight_decay_loss,
total_loss,
predictions,
......@@ -300,6 +361,8 @@ class ModelTrainer():
# Update summaries
self.update_summaries(class_loss,
gen_loss,
discr_loss,
weight_decay_loss,
total_loss,
predictions,
......@@ -323,6 +386,8 @@ class ModelTrainer():
# Update summaries
self.update_summaries(class_loss,
gen_loss,
discr_loss,
weight_decay_loss,
total_loss,
predictions,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment