Commit 2eb852cc authored by Alexander Fuchs's avatar Alexander Fuchs
Browse files

Added methods to evaluate the whole dataset

parent 1a47e306
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import numpy as np
import librosa
from absl import logging
from absl import app
from absl import flags
from utils.trainer import ModelTrainer
from utils.data_loader import Dataset
from utils.data_loader import DataGenerator
from models.classifier import Classifier
from models.res_block import ResBlockBasicLayer
from models.discriminator import Discriminator
from models.generator import Generator
from models.loss_functions.separate_classifier_wgan_loss_fn import LossFunction
import tensorflow as tf # pylint: disable=g-bad-import-order
BINS = 1025
N_FRAMES = 216
N_CHANNELS = 2
def augment_input(sample,n_classes,training):
"""Preprocess a single image of layout [height, width, depth]."""
input_features = sample['input_features']
labels = sample['labels']
false_sample = sample['false_sample']
input_features = tf.image.per_image_standardization(input_features)
false_sample = tf.image.per_image_standardization(false_sample)
return {'input_features':input_features,'labels':labels,'false_sample':false_sample}
def data_generator(data_generator,batch_size,is_training,
shuffle_buffer = 128,
is_validation=False,
n_classes = 10,
take_n=None,
skip_n=None):
dataset = tf.data.Dataset.from_generator(data_generator,
output_types = {'input_features':tf.float32,
'labels':tf.int32,
'false_sample':tf.float32})
if skip_n != None:
dataset = dataset.skip(skip_n)
if take_n != None:
dataset = dataset.take(take_n)
if is_training:
dataset = dataset.batch(batch_size,drop_remainder=False)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
else:
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
def learning_rate_fn(epoch):
if epoch >= 150 and epoch <200:
return 0.1
elif epoch >=200 and epoch <250:
return 0.01
elif epoch >=250:
return 0.001
else:
return 1.0
FLAGS = flags.FLAGS
flags.DEFINE_string('model_dir', '/tmp', 'save directory name')
flags.DEFINE_string('data_dir', '/tmp', 'data directory name')
flags.DEFINE_integer('epochs', 300, 'number of epochs')
flags.DEFINE_integer('batch_size', 32, 'Mini-batch size')
flags.DEFINE_float('dropout_rate', 0.0, 'dropout rate for the dense blocks')
flags.DEFINE_float('weight_decay', 1e-4, 'weight decay parameter')
flags.DEFINE_float('learning_rate', 5e-2, 'learning rate')
flags.DEFINE_boolean('preload_samples',False,'Preload samples (requires >140 GB RAM)')
flags.DEFINE_float('training_percentage', 100, 'Percentage of the training data used for training. (100-training_percentage is used as validation data.)')
flags.DEFINE_boolean('load_model', True, 'Bool indicating if the model should be loaded')
def main(argv):
try:
task_id = int(os.environ['SLURM_ARRAY_TASK_ID'])
except KeyError:
task_id = 0
model_save_dir = FLAGS.model_dir
data_dir = FLAGS.data_dir
print("Saving model to : " + str(model_save_dir))
print("Loading data from : " + str(data_dir))
test_data_dir = data_dir
train_data_dir = data_dir
epochs = FLAGS.epochs
batch_size = FLAGS.batch_size
dropout_rate = FLAGS.dropout_rate
weight_decay = FLAGS.weight_decay
lr = FLAGS.learning_rate
load_model = FLAGS.load_model
training_percentage = FLAGS.training_percentage
preload_samples = FLAGS.preload_samples
ds = Dataset(data_dir,is_training_set = True)
n_total = ds.n_samples
def augment_fn(sample,training):
return augment_input(sample,ds.n_classes,training)
dg = DataGenerator(ds,augment_fn,
training_percentage = training_percentage,
preload_samples = preload_samples,
save_created_features = False,
max_samples_per_audio = 99,
is_training=True)
n_train = int(n_total*training_percentage/100)
n_val = n_total-n_train
#ResNet 18
classifier_model = Classifier(ResBlockBasicLayer,
n_blocks = 4,
n_layers = [2,2,2,2],
strides = [2,2,2,2],
channel_base = [64,128,256,512],
n_classes = ds.n_classes+1,
init_ch = 64,
init_ksize = 7,
init_stride = 2,
use_max_pool = True,
kernel_regularizer = tf.keras.regularizers.l2(2e-4),
kernel_initializer = tf.keras.initializers.he_normal(),
dropout=dropout_rate)
#Generator model used to augment to false samples
generator_model = Generator(8,
[8,8,16,16,32,32,64,64],
kernel_regularizer = tf.keras.regularizers.l2(2e-4),
kernel_initializer = tf.keras.initializers.he_normal(),
name = "generator")
#Discriminator for estimating the Wasserstein distance
discriminator_model = Discriminator(3,
[32,64,128],
[4,4,4],
name = "discriminator")
data_gen = data_generator(dg.generate_all_samples_from_scratch,batch_size,
is_training=True,
n_classes = ds.n_classes)
trainer = ModelTrainer(data_gen,
None,
None,
epochs,
LossFunction,
classifier = classifier_model,
generator = generator_model,
discriminator = discriminator_model,
base_learning_rate_classifier = lr,
base_learning_rate_generator = lr*0.0001,
base_learning_rate_discriminator = lr*0.002,
learning_rate_fn_classifier = learning_rate_fn,
learning_rate_fn_generator = learning_rate_fn,
learning_rate_fn_discriminator = learning_rate_fn,
optimizer_classifier = tf.keras.optimizers.SGD,
optimizer_generator = tf.keras.optimizers.Adam,
optimizer_discriminator = tf.keras.optimizers.Adam,
num_train_batches = int(n_train/batch_size),
load_model = load_model,
save_dir = model_save_dir,
init_data = tf.random.normal([batch_size,BINS,N_FRAMES,N_CHANNELS]),
start_epoch = 0)
all_predictions = trainer.classify_dataset(data_gen)
np.save(os.path.join(data_dir,"train_set_predictions.npy"),all_predictions)
if __name__ == '__main__':
app.run(main)
......@@ -141,6 +141,7 @@ class DataGenerator(object):
preload_false_samples = True,
preload_samples = False,
training_percentage = 90,
save_created_features = True,
max_time = 5,
max_samples_per_audio = 6,
n_fft = 2048,
......@@ -165,6 +166,7 @@ class DataGenerator(object):
self.max_time = max_time
self.max_samples_per_audio = max_samples_per_audio
self.force_feature_recalc = force_feature_recalc
self.save_created_features = save_created_features
if self.is_training:
self.first_sample = 0
self.last_sample = self.n_training_samples
......@@ -238,7 +240,7 @@ class DataGenerator(object):
def create_feature(self,sample):
"""Creates the features by doing a STFT"""
#try:
filename = sample['filename']
channels_str = sample['channels']
channels = int(channels_str.split(" ")[0])
......@@ -269,13 +271,12 @@ class DataGenerator(object):
spectrum = self.pad_sample(spectrum,
x_size=np.ceil(self.max_time*self.sampling_rate/self.hop_length))
spectra[str(i_sample)] = spectrum
if "mp3" in filename:
np.savez(filename.replace("mp3","npz"),spectra)
else:
np.savez(filename.replace("wav","npz"),spectra)
#except:
# spectra = None
# print(sample['filename']+" failed at feature extraction!")
if self.save_created_features:
if "mp3" in filename:
np.savez(filename.replace("mp3","npz"),spectra)
else:
np.savez(filename.replace("wav","npz"),spectra)
return spectra
......@@ -345,6 +346,45 @@ class DataGenerator(object):
pool = multiprocessing.Pool(os.cpu_count())
for i, _ in enumerate(pool.imap_unordered(self.create_feature, samples), 1):
sys.stderr.write('\rdone {0:%}'.format(max(0,i/n)))
def generate_all_samples_from_scratch(self):
stft_len = int(np.ceil(self.max_time*self.sampling_rate/self.hop_length))
for sample in self.samples:
filename = sample['filename']
#Create features via STFT if no file exists
spectra = self.create_feature(sample)
for spec_key in spectra.keys():
#Check for None type
spectrum = spectra[spec_key]
if np.any(spectrum) == None or spectrum.shape[-1] != stft_len:
continue
#If only mono --> duplicate
if spectrum.shape[0] == 1:
spectrum = np.tile(spectra[spec_key],[2,1,1])
#Transpose spectrogramms for "channels_last"
spectrum = tf.transpose(spectrum,perm=[1,2,0])
#Fill false spectra with zero
false_spectrum = tf.zeros_like(spectrum)
if self.is_training or self.is_validation:
label = tf.one_hot(sample['bird_id'],self.dataset.n_classes+1)
else:
label = None
sub_sample = {'input_features':spectrum,
'labels':label,
'false_sample':false_spectrum}
if self.augmentation != None:
yield self.augmentation(sub_sample,self.is_training)
else:
yield sample
def generate(self):
......@@ -369,8 +409,12 @@ class DataGenerator(object):
else:
#Create features via STFT if no file exists
spectra = self.create_feature(sample)
spec_keys = spectra.keys()
spec_keys = list(spec_keys)
rnd_key = spec_keys[np.random.randint(0,len(spec_keys))]
spectra = spectra[rnd_key]
#Check for None type and shape
#Check for None type
if np.any(spectra) == None or spectra.shape[-1] != stft_len:
continue
......
......@@ -164,11 +164,20 @@ class ModelTrainer():
def load_model(self):
"""Loads weights for all models"""
self.classifier.load_weights(self.get_model("classifier"))
self.generator.load_weights(self.get_model("generator"))
self.discriminator.load_weights(self.get_model("discriminator"))
try:
self.classifier.load_weights(self.get_model("classifier"))
except:
print("Failed to load classifier")
try:
self.generator.load_weights(self.get_model("generator"))
except:
print("Failed to load generator")
try:
self.discriminator.load_weights(self.get_model("discriminator"))
except:
print("Failed to load discriminator")
@tf.function
def classify(self,x,training = False):
logits_classifier = self.classifier(x, training=training)
return tf.nn.softmax(logits_classifier, axis=-1)
......@@ -279,6 +288,19 @@ class ModelTrainer():
x.append(xy[key])
return x
def classify_dataset(self,data_generator):
prog = tf.keras.utils.Progbar(self.max_number_batches)
all_predictions = []
batch = 0
for xy in data_generator:
start_time = time()
x = self.get_data_for_keys(xy,self.input_keys)
predictions = self.classify(x[0],False)
all_predictions.append(predictions)
prog.update(batch, [("time / step", np.round(time() - start_time, 2))])
batch += 1
return all_predictions
def train(self):
if self.start_epoch == 0:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment