Commit f46cbe7c authored by Alexander Fuchs's avatar Alexander Fuchs
Browse files

Data loader optimizations

parent 8f319eed
......@@ -23,7 +23,7 @@ class Discriminator(tf.keras.Model):
for i in range(self.n_layers):
self.model_layers.append(SpectralNormalization(tf.keras.layers.Conv2D(self.n_channels[i],
kernel_size=(3, 3),
kernel_size=strides[i]+1,
strides=strides[i],
padding="same",
use_bias=False,
......@@ -37,7 +37,8 @@ class Discriminator(tf.keras.Model):
use_bias = False,
name = self.name_op+"_dense",
activation = None))
def call(self,input,training=False):
x = input
......
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import numpy as np
from absl import logging
from absl import app
......@@ -13,18 +14,54 @@ from models.discriminator import Discriminator
from models.generator import Generator
#logging.set_verbosity(logging.WARNING)
BINS = 1025
N_FRAMES = 216
N_CHANNELS = 2
def preprocess_input(sample,n_classes,is_training):
"""Preprocess a single image of layout [height, width, depth]."""
if tf.random.uniform([1])[0] > 1/n_classes or not(is_training):
input_features = tf.image.per_image_standardization(sample['input_features'])
labels = sample['labels']
else:
input_features = tf.image.per_image_standardization(sample['false_sample'])
labels = tf.cast(tf.one_hot(n_classes+1,n_classes+1),tf.int32)
return {'input_features':input_features,'labels':labels,'false_sample':sample['false_sample']}
input_features = sample['input_features']
labels = sample['labels']
false_sample = sample['false_sample']
batch_size = tf.shape(labels)[0]
if is_training:
rnd = tf.random.uniform([batch_size])
rnd_tiled_feat = tf.tile(tf.reshape(rnd,[batch_size,1,1,1]),[1,BINS,
N_FRAMES,
N_CHANNELS])
rnd_tiled_lbl = tf.tile(tf.reshape(rnd,[batch_size,1]),[1,n_classes+1])
false_lbl = tf.cast(tf.one_hot(n_classes+1,n_classes+1),tf.int32)
false_lbl = tf.tile(tf.reshape(false_lbl,[1,n_classes+1]),[batch_size,1])
input_features = tf.where(rnd_tiled_feat > 1/n_classes,input_features,false_sample)
labels = tf.where(rnd_tiled_lbl > 1/n_classes,labels,false_lbl)
return {'input_features':input_features,'labels':labels,'false_sample':false_sample}
def augment_input(sample,n_classes):
"""Preprocess a single image of layout [height, width, depth]."""
input_features = sample['input_features']
labels = sample['labels']
false_sample = sample['false_sample']
rnd = tf.random.uniform([1])
rnd_tiled_feat = tf.tile(tf.reshape(rnd,[1,1,1]),[BINS,
N_FRAMES,
N_CHANNELS])
rnd_tiled_lbl = tf.tile(tf.reshape(rnd,[1]),[n_classes+1])
false_lbl = tf.cast(tf.one_hot(n_classes+1,n_classes+1),labels.dtype)
input_features = tf.where(rnd_tiled_feat > 1/n_classes,input_features,false_sample)
labels = tf.where(rnd_tiled_lbl > 1/n_classes,labels,false_lbl)
return {'input_features':input_features,'labels':labels,'false_sample':false_sample}
def data_generator(data_generator,batch_size,is_training,
shuffle_buffer = 128,
......@@ -44,13 +81,10 @@ def data_generator(data_generator,batch_size,is_training,
dataset = dataset.take(take_n)
if is_training:
dataset = dataset.shuffle(shuffle_buffer)
dataset = dataset.map(lambda sample : preprocess_input(sample,n_classes,is_training))
dataset = dataset.batch(batch_size,drop_remainder=True)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
else:
dataset = dataset.map(lambda sample : preprocess_input(sample,n_classes,is_training))
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
......@@ -75,7 +109,8 @@ flags.DEFINE_integer('epochs', 300, 'number of epochs')
flags.DEFINE_integer('batch_size', 32, 'Mini-batch size')
flags.DEFINE_float('dropout_rate', 0.0, 'dropout rate for the dense blocks')
flags.DEFINE_float('weight_decay', 1e-4, 'weight decay parameter')
flags.DEFINE_float('learning_rate', 1e-3, 'learning rate')
flags.DEFINE_float('learning_rate', 1e-1, 'learning rate')
flags.DEFINE_boolean('preload_samples',False,'Preload samples (requires >140 GB RAM)')
flags.DEFINE_float('training_percentage', 90, 'Percentage of the training data used for training. (100-training_percentage is used as validation data.)')
flags.DEFINE_boolean('load_model', False, 'Bool indicating if the model should be loaded')
......@@ -106,12 +141,26 @@ def main(argv):
lr = FLAGS.learning_rate
load_model = FLAGS.load_model
training_percentage = FLAGS.training_percentage
preload_samples = FLAGS.preload_samples
model_save_dir+="_batch_size_"+str(batch_size)+"_dropout_rate_"+str(dropout_rate)+"_learning_rate_"+str(lr)+"_weight_decay_"+str(weight_decay)
ds_train = Dataset(data_dir,is_training_set = True)
n_total = ds_train.n_samples
dg_train = DataGenerator(ds_train,None)
def augment_fn(sample):
return augment_input(sample,ds_train.n_classes)
dg_train = DataGenerator(ds_train,augment_fn,
training_percentage = training_percentage,
preload_samples = preload_samples,
is_training=True)
dg_val = DataGenerator(ds_train,None,
training_percentage = training_percentage,
is_validation = True,
preload_samples = preload_samples,
is_training=False)
n_train = int(n_total*training_percentage/100)
n_val = n_total-n_train
......@@ -139,21 +188,19 @@ def main(argv):
#Discriminator for estimating the Wasserstein distance
discriminator_model = Discriminator(3,
[32,64,128],
[2,2,2],
[4,4,4],
name = "discriminator")
train_data_gen = data_generator(dg_train.generate,batch_size,
is_training=True,
shuffle_buffer = 256,
shuffle_buffer = 16*batch_size,
n_classes = ds_train.n_classes,
take_n=n_train)
val_data_gen = data_generator(dg_train.generate,10,
val_data_gen = data_generator(dg_val.generate,batch_size,
is_training=False,
is_validation = True,
n_classes = ds_train.n_classes,
skip_n=n_train,
take_n=n_val)
n_classes = ds_train.n_classes)
trainer = ModelTrainer(train_data_gen,
......@@ -163,15 +210,19 @@ def main(argv):
classifier = classifier_model,
generator = generator_model,
discriminator = generator_model,
learning_rate_fn = learning_rate_fn,
optimizer_classifier = tf.keras.optimizers.Adam,
base_learning_rate_classifier = lr,
base_learning_rate_generator = lr*0.001,
base_learning_rate_discriminator = lr*0.0001,
learning_rate_fn_classifier = learning_rate_fn,
learning_rate_fn_generator = learning_rate_fn,
learning_rate_fn_discriminator = learning_rate_fn,
optimizer_classifier = tf.keras.optimizers.SGD,
optimizer_generator = tf.keras.optimizers.Adam,
optimizer_discriminator = tf.keras.optimizers.Adam,
num_train_batches = int(n_train/batch_size),
base_learning_rate = lr,
load_model = load_model,
save_dir = model_save_dir,
init_data = tf.random.normal([batch_size,1025,216,2]),
init_data = tf.random.normal([batch_size,BINS,N_FRAMES,N_CHANNELS]),
start_epoch = 0)
trainer.train()
......
......@@ -9,6 +9,7 @@ import multiprocessing
import sys
import random
import copy
import time
import tensorflow as tf
warnings.filterwarnings('ignore')
......@@ -135,8 +136,11 @@ class DataGenerator(object):
def __init__(self,dataset,augmentation,
shuffle = True,
is_training = True,
is_validation = False,
force_feature_recalc = False,
preload_false_samples = True,
preload_samples = False,
training_percentage = 90,
max_time = 5,
max_samples_per_audio = 6,
n_fft = 2048,
......@@ -148,16 +152,25 @@ class DataGenerator(object):
random.seed(4)
random.shuffle(self.dataset.train_samples)
self.n_training_samples = int(dataset.n_samples*training_percentage/100)
self.n_validation_samples = dataset.n_samples-self.n_training_samples
self.augmentation = augmentation
self.is_training = is_training
self.is_validation = is_validation
self.sampling_rate = sampling_rate
self.n_fft = n_fft
self.preload_samples = preload_samples
self.preload_false_samples = preload_false_samples
self.hop_length = hop_length
self.max_time = max_time
self.max_samples_per_audio = max_samples_per_audio
self.force_feature_recalc = force_feature_recalc
if self.is_training:
self.first_sample = 0
self.last_sample = self.n_training_samples
elif self.is_validation:
self.first_sample = self.n_training_samples
self.last_sample = self.dataset.n_samples
#Get paths of false samples
false_samples_mono = glob.glob(self.dataset.false_audio_path+ "/mono/*.npz",
recursive = True)
......@@ -171,8 +184,25 @@ class DataGenerator(object):
self.preloaded_false_samples = {}
for path in self.false_sample_paths:
with np.load(path,allow_pickle=True) as sample_file:
self.preloaded_false_samples[path] = copy.deepcopy(sample_file.f.arr_0)
print("Finished pre-loading")
self.preloaded_false_samples[path] = sample_file.f.arr_0
print("Finished pre-loading false samples!")
if self.is_training or self.is_validation:
self.samples = self.dataset.train_samples[self.first_sample:self.last_sample]
else:
self.samples = self.dataset.test_samples
#Pre load samples (takes a lot of RAM ~130 GB)
try:
if self.preload_samples:
self.preloaded_samples = {}
for sample in self.samples:
path = sample["filename"].replace("mp3","npz")
with np.load(path,allow_pickle=True) as sample_file:
self.preloaded_samples[path] = sample_file.f.arr_0
print("Finished pre-loading samples")
except:
self.preload_samples = False
def do_stft(self,y,channels):
spectra = []
......@@ -318,20 +348,19 @@ class DataGenerator(object):
def generate(self):
if self.is_training:
samples = self.dataset.train_samples
else:
samples = self.dataset.test_samples
stft_len = int(np.ceil(self.max_time*self.sampling_rate/self.hop_length))
for sample in samples:
for sample in self.samples:
filename = sample['filename']
#If feature was already created load from file
if os.path.isfile(filename.replace("mp3","npz")) and not(self.force_feature_recalc):
with np.load(filename.replace("mp3","npz"),allow_pickle=True) as sample_file:
spectra_npz = sample_file.f.arr_0
if self.preload_samples:
spectra_npz = self.preloaded_samples[filename.replace("mp3","npz")]
else:
with np.load(filename.replace("mp3","npz"),allow_pickle=True) as sample_file:
spectra_npz = sample_file.f.arr_0
spec_keys = spectra_npz.item().keys()
spec_keys = list(spec_keys)
......@@ -340,7 +369,6 @@ class DataGenerator(object):
else:
#Create features via STFT if no file exists
spectra = self.create_feature(sample)
#Check for None type and shape
if np.any(spectra) == None or spectra.shape[-1] != stft_len:
......@@ -368,13 +396,18 @@ class DataGenerator(object):
#If false only mono --> duplicate
if false_spectra.shape[0] == 1:
false_spectra = np.tile(false_spectra,[2,1,1])
#Transpose spectrogramms for "channels_last"
spectra = tf.transpose(spectra,perm=[1,2,0])
false_spectra = tf.transpose(false_spectra,perm=[1,2,0])
yield {'input_features':spectra,
sample = {'input_features':spectra,
'labels':tf.one_hot(sample['bird_id'],self.dataset.n_classes+1),
'false_sample':false_spectra}
if self.augmentation != None:
yield self.augmentation(sample)
else:
yield sample
if __name__ == "__main__":
ds = Dataset("/srv/TUG/datasets/cornell_birdcall_recognition")
......
......@@ -7,6 +7,9 @@ def clip_by_value_10(grad):
grad = tf.where(tf.math.is_finite(grad), grad, tf.zeros_like(grad))
return tf.clip_by_value(grad, -10, 10)
def constant_lr(epoch):
return 1.0
class ModelTrainer():
def __init__(self,
training_data_generator,
......@@ -19,8 +22,12 @@ class ModelTrainer():
optimizer_classifier = tf.keras.optimizers.Adam,
optimizer_generator = tf.keras.optimizers.Adam,
optimizer_discriminator = tf.keras.optimizers.Adam,
learning_rate_fn = None,
base_learning_rate=1e-3,
base_learning_rate_classifier=1e-3,
base_learning_rate_generator=1e-4,
base_learning_rate_discriminator=5e-5,
learning_rate_fn_classifier = constant_lr,
learning_rate_fn_generator = constant_lr,
learning_rate_fn_discriminator = constant_lr,
init_data = None,
start_epoch = 0,
num_train_batches = None,
......@@ -33,12 +40,20 @@ class ModelTrainer():
self.classifier = classifier
self.generator = generator
self.discriminator = discriminator
self.learning_rate_fn = learning_rate_fn
self.save_dir = save_dir
self.base_learning_rate = base_learning_rate
self.learning_rate_classifier = tf.Variable(self.base_learning_rate*self.learning_rate_fn(start_epoch))
self.learning_rate_generator = tf.Variable(self.base_learning_rate*self.learning_rate_fn(start_epoch))
self.learning_rate_discriminator = tf.Variable(self.base_learning_rate*self.learning_rate_fn(start_epoch))
#Set up learning rates
self.learning_rate_fn_classifier = learning_rate_fn_classifier
self.learning_rate_fn_generator = learning_rate_fn_generator
self.learning_rate_fn_discriminator = learning_rate_fn_discriminator
self.base_learning_rate_classifier = base_learning_rate_classifier
self.base_learning_rate_generator = base_learning_rate_generator
self.base_learning_rate_discriminator = base_learning_rate_discriminator
self.learning_rate_classifier = tf.Variable(self.base_learning_rate_classifier*self.learning_rate_fn_classifier(start_epoch))
self.learning_rate_generator = tf.Variable(self.base_learning_rate_generator*self.learning_rate_fn_generator(start_epoch))
self.learning_rate_discriminator = tf.Variable(self.base_learning_rate_discriminator*self.learning_rate_fn_discriminator(start_epoch))
#Initialize models
self.classifier(init_data,training = True)
self.classifier.summary()
......@@ -155,12 +170,13 @@ class ModelTrainer():
logits_classifier = self.classifier(x, training=training)
return tf.nn.softmax(logits_classifier, axis=-1)
@tf.function
def compute_loss(self, x, y , training = True):
logits_classifier = self.classifier(x[0],training=training)
fake_input = self.generator(x[1],training)
fake_features = self.generator(x[1],training)
true = self.discriminator(x[0],training)
false = self.discriminator(fake_input,training)
false = self.discriminator(fake_features,training)
# Cross entropy losses
class_loss = tf.reduce_mean(
......@@ -171,13 +187,13 @@ class ModelTrainer():
))
#Wasserstein losses
gen_loss = -tf.reduce_mean(false)
discr_loss = tf.reduce_mean(false)-tf.reduce_mean(true)
gen_loss = tf.reduce_mean(false)
discr_loss = tf.reduce_mean(true)-tf.reduce_mean(false)
if len(self.classifier.losses) > 0:
weight_decay_loss = tf.add_n(self.classifier.losses)
else:
weight_decay_loss = 0.0
weight_decay_loss = 0.0
if len(self.generator.losses) > 0:
weight_decay_loss += tf.add_n(self.generator.losses)
......@@ -190,9 +206,9 @@ class ModelTrainer():
total_loss += discr_loss
total_loss += weight_decay_loss
return (class_loss,gen_loss,discr_loss,weight_decay_loss,total_loss),predictions
return (class_loss,gen_loss,discr_loss,weight_decay_loss,total_loss),(predictions,fake_features)
@tf.function
def compute_gradients(self, x, y):
# Pass through network
with tf.GradientTape() as tape:
......@@ -216,25 +232,28 @@ class ModelTrainer():
return out_grads,losses,predictions
@tf.function
def apply_gradients_classifier(self, gradients, variables):
self.optimizer_classifier.apply_gradients(
zip(gradients, variables)
)
@tf.function
def apply_gradients_generator(self, gradients, variables):
self.optimizer_generator.apply_gradients(
zip(gradients, variables)
)
@tf.function
def apply_gradients_discriminator(self, gradients, variables):
self.optimizer_discriminator.apply_gradients(
zip(gradients, variables)
)
@tf.function
def split_grads(self,grads):
classifier_grads = grads[:self.n_vars_classifier]
generator_grads = grads[self.n_vars_classifier:self.n_vars_classifier+self.n_vars_generator]
......@@ -306,9 +325,9 @@ class ModelTrainer():
prog = tf.keras.utils.Progbar(self.max_number_batches)
for epoch in range(self.start_epoch, self.epochs):
# Update learning rates
self.learning_rate_classifier.assign(self.base_learning_rate * self.learning_rate_fn(epoch))
self.learning_rate_generator.assign(self.base_learning_rate * self.learning_rate_fn(epoch))
self.learning_rate_discriminator.assign(self.base_learning_rate * self.learning_rate_fn(epoch))
self.learning_rate_classifier.assign(self.base_learning_rate_classifier * self.learning_rate_fn_classifier(epoch))
self.learning_rate_generator.assign(self.base_learning_rate_generator * self.learning_rate_fn_generator(epoch))
self.learning_rate_discriminator.assign(self.base_learning_rate_discriminator * self.learning_rate_fn_discriminator(epoch))
#Update epoch variable
self.epoch_variable.assign(epoch)
with self.summary_writer["train"].as_default():
......@@ -324,21 +343,27 @@ class ModelTrainer():
train_y = self.get_data_for_keys(train_xy,self.label_keys)
start_time = time()
losses,predictions = self.train_step(train_x, train_y)
losses,outputs = self.train_step(train_x, train_y)
predictions,fake_features = outputs
class_loss,gen_loss,discr_loss,weight_decay_loss,total_loss = losses
# Update summaries
self.update_summaries(class_loss,
gen_loss,
discr_loss,
weight_decay_loss,
total_loss,
predictions,
train_y,
train_y[0],
"train")
batch += 1
prog.update(batch, [("class_loss", self.scalar_summaries["train_" + "class_loss"].result()),
("g_loss", self.scalar_summaries["train_" + "generator_loss"].result()),
("d_loss", self.scalar_summaries["train_" + "discriminator_loss"].result()),
("wd_loss", self.scalar_summaries["train_" + "weight_decay_loss"].result()),
("accuracy", 100 * self.scalar_summaries["train_" + "accuracy"].result()),
("time / step", np.round(time() - start_time, 2))])
......@@ -353,12 +378,13 @@ class ModelTrainer():
for validation_xy in self.validation_data_generator:
validation_x = self.get_data_for_keys(validation_xy,self.input_keys)
validation_y = self.get_data_for_keys(validation_xy,self.label_keys)
losses,predictions = self.compute_loss(validation_x,
losses,outputs = self.compute_loss(validation_x,
validation_y,
training=False)
predictions,fake_features = outputs
class_loss,gen_loss,discr_loss,weight_decay_loss,total_loss = losses
# Update summaries
self.update_summaries(class_loss,
gen_loss,
......@@ -366,9 +392,20 @@ class ModelTrainer():
weight_decay_loss,
total_loss,
predictions,
validation_y,
validation_y[0],
"val")
#Save generated audio STFT samples
with self.summary_writer["val"].as_default():
tf.summary.image("original_features",
validation_xy["input_features"],
step = epoch,
max_outputs = 1)
tf.summary.image("fake_features",
fake_features,
step = epoch,
max_outputs = 1)
# Write validation summaries
self.write_summaries(epoch,"val")
......@@ -378,10 +415,10 @@ class ModelTrainer():
test_x = self.get_data_for_keys(test_xy,self.input_keys)
test_y = self.get_data_for_keys(test_xy,self.label_keys)
losses,predictions = self.compute_loss(test_x,
losses,outputs = self.compute_loss(test_x,
test_y,
training=False)
predictions,fake_features = outputs
class_loss,gen_loss,discr_loss,weight_decay_loss,total_loss = losses
# Update summaries
......@@ -391,7 +428,7 @@ class ModelTrainer():
weight_decay_loss,
total_loss,
predictions,
test_y,
test_y[0],
"test")
# Write test summaries
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment