Commit db96e7cd authored by Alexander Fuchs's avatar Alexander Fuchs
Browse files

Fixed normalization using augment_fn

parent c09d16bf
......@@ -18,48 +18,29 @@ BINS = 1025
N_FRAMES = 216
N_CHANNELS = 2
def preprocess_input(sample,n_classes,is_training):
def augment_input(sample,n_classes,training):
"""Preprocess a single image of layout [height, width, depth]."""
input_features = sample['input_features']
labels = sample['labels']
false_sample = sample['false_sample']
batch_size = tf.shape(labels)[0]
if is_training:
rnd = tf.random.uniform([batch_size])
rnd_tiled_feat = tf.tile(tf.reshape(rnd,[batch_size,1,1,1]),[1,BINS,
if training:
rnd = tf.random.uniform([1])
rnd_tiled_feat = tf.tile(tf.reshape(rnd,[1,1,1]),[BINS,
N_FRAMES,
N_CHANNELS])
rnd_tiled_lbl = tf.tile(tf.reshape(rnd,[batch_size,1]),[1,n_classes+1])
false_lbl = tf.cast(tf.one_hot(n_classes+1,n_classes+1),tf.int32)
false_lbl = tf.tile(tf.reshape(false_lbl,[1,n_classes+1]),[batch_size,1])
input_features = tf.where(rnd_tiled_feat > 1/n_classes,input_features,false_sample)
labels = tf.where(rnd_tiled_lbl > 1/n_classes,labels,false_lbl)
return {'input_features':input_features,'labels':labels,'false_sample':false_sample}
rnd_tiled_lbl = tf.tile(tf.reshape(rnd,[1]),[n_classes+1])
def augment_input(sample,n_classes):
"""Preprocess a single image of layout [height, width, depth]."""
input_features = sample['input_features']
labels = sample['labels']
false_sample = sample['false_sample']
false_lbl = tf.cast(tf.one_hot(n_classes+1,n_classes+1),labels.dtype)
rnd = tf.random.uniform([1])
rnd_tiled_feat = tf.tile(tf.reshape(rnd,[1,1,1]),[BINS,
N_FRAMES,
N_CHANNELS])
rnd_tiled_lbl = tf.tile(tf.reshape(rnd,[1]),[n_classes+1])
false_lbl = tf.cast(tf.one_hot(n_classes+1,n_classes+1),labels.dtype)
input_features = tf.where(rnd_tiled_feat > 1/n_classes,input_features,false_sample)
labels = tf.where(rnd_tiled_lbl > 1/n_classes,labels,false_lbl)
input_features = tf.where(rnd_tiled_feat > 1/n_classes,input_features,false_sample)
labels = tf.where(rnd_tiled_lbl > 1/n_classes,labels,false_lbl)
input_features = tf.image.per_image_standardization(input_features)
false_sample = tf.image.per_image_standardization(false_sample)
return {'input_features':input_features,'labels':labels,'false_sample':false_sample}
......@@ -148,15 +129,15 @@ def main(argv):
ds_train = Dataset(data_dir,is_training_set = True)
n_total = ds_train.n_samples
def augment_fn(sample):
return augment_input(sample,ds_train.n_classes)
def augment_fn(sample,training):
return augment_input(sample,ds_train.n_classes,training)
dg_train = DataGenerator(ds_train,augment_fn,
training_percentage = training_percentage,
preload_samples = preload_samples,
is_training=True)
dg_val = DataGenerator(ds_train,None,
dg_val = DataGenerator(ds_train,augment_fn,
training_percentage = training_percentage,
is_validation = True,
preload_samples = preload_samples,
......@@ -212,7 +193,7 @@ def main(argv):
discriminator = generator_model,
base_learning_rate_classifier = lr,
base_learning_rate_generator = lr*0.001,
base_learning_rate_discriminator = lr*0.0001,
base_learning_rate_discriminator = lr*0.001,
learning_rate_fn_classifier = learning_rate_fn,
learning_rate_fn_generator = learning_rate_fn,
learning_rate_fn_discriminator = learning_rate_fn,
......
......@@ -405,7 +405,7 @@ class DataGenerator(object):
'labels':tf.one_hot(sample['bird_id'],self.dataset.n_classes+1),
'false_sample':false_spectra}
if self.augmentation != None:
yield self.augmentation(sample)
yield self.augmentation(sample,self.is_training)
else:
yield sample
......
......@@ -186,8 +186,9 @@ class ModelTrainer():
axis=-1,
))
gen_loss = tf.reduce_mean((fake_features - x[1])**2)
#Wasserstein losses
gen_loss = tf.reduce_mean(false)
gen_loss += tf.reduce_mean(false)
discr_loss = tf.reduce_mean(true)-tf.reduce_mean(false)
if len(self.classifier.losses) > 0:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment