Commit 04c52d23 authored by Alexander Fuchs's avatar Alexander Fuchs
Browse files

Implemented generic summaries and evaluation functions, added methods to evaluate whole dataset

parent 2eb852cc
......@@ -16,6 +16,7 @@ class Classifier(tf.keras.Model):
use_max_pool = True,
kernel_regularizer = tf.keras.regularizers.l2(2e-4),
kernel_initializer = tf.keras.initializers.he_normal(),
name = "classifier",
dropout=0.2):
super(Classifier, self).__init__()
self.network_block = network_block
......@@ -31,6 +32,7 @@ class Classifier(tf.keras.Model):
self.use_max_pool = use_max_pool
self.kernel_regularizer = kernel_regularizer
self.kernel_initializer = kernel_initializer
self.model_name = name
def build(self,input_shape):
......@@ -67,7 +69,8 @@ class Classifier(tf.keras.Model):
name = 'dense_layer',
kernel_regularizer = self.kernel_regularizer,
kernel_initializer = self.kernel_initializer)
def call(self,input,training=False):
"""Returns logits"""
x = self.init_conv(input)
......
......@@ -6,14 +6,14 @@ class Discriminator(tf.keras.Model):
n_layers,
n_channels,
strides,
name = "",
name = "discriminator",
kernel_regularizer = None,
kernel_initializer = tf.keras.initializers.he_normal()):
super(Discriminator, self).__init__()
self.n_layers = n_layers
self.strides = strides
self.n_channels = n_channels
self.name_op = name
self.model_name = name
self.kernel_regularizer = None
self.kernel_initializer = kernel_initializer
......@@ -27,7 +27,7 @@ class Discriminator(tf.keras.Model):
strides=self.strides[i],
padding="same",
use_bias=False,
name=self.name_op+'_conv_' + str(i),
name=self.model_name+'_conv_' + str(i),
kernel_regularizer = self.kernel_regularizer,
kernel_initializer = self.kernel_initializer,
activation='relu')))
......@@ -35,8 +35,8 @@ class Discriminator(tf.keras.Model):
self.dense = SpectralNormalization(tf.keras.layers.Dense(1,
kernel_initializer = self.kernel_initializer,
use_bias = False,
name = self.name_op+"_dense",
activation = tf.nn.sigmoid))
name = self.model_name +"_dense",
activation = tf.nn.relu))
def call(self,input,training=False):
......
import tensorflow as tf
class LossFunction(object):
def __init__(self,classifier,generator,discriminator):
self.classifier = classifier
self.generator = generator
self.discriminator = discriminator
class EvalFunctions(object):
"""This class implements specialized operation used in the training framework"""
def __init__(self,models):
self.classifier = models[0]
self.generator = models[1]
self.discriminator = models[2]
@tf.function
def compute_loss(self, x, y, training=True):
def predict(self, x,training=True):
"""Returns a dict containing predictions e.g.{'predictions':predictions}"""
logits_classifier = self.classifier(x[0], training=training)
return {'predictions':tf.nn.softmax(logits_classifier,axis=-1)}
@tf.function
def generate(self, x,training=True):
"""Returns a dict containing fake samples e.g.{'fake_features':fake_features}"""
fake_features = self.generator(x[1], training)
return {'fake_features':fake_features}
@tf.function
def discriminate(self, x,training=True):
"""Returns a dict containing scores e.g.{'scores':score}"""
scores = self.discriminator(x[0], training)
return {'scores':scores}
def accuracy(self,pred,y):
correct_predictions = tf.cast(tf.equal(tf.argmax(pred,axis=-1),
tf.argmax(y[0],axis=-1)),tf.float32)
return tf.reduce_mean(correct_predictions)
@tf.function
def compute_loss(self, x, y, training=True):
"""Has to at least return a dict containing the total loss and a prediction dict e.g.{'total_loss':total_loss},{'predictions':predictions}"""
logits_classifier = self.classifier(x[0], training=training)
fake_features = self.generator(x[1], training)
true = self.discriminator(x[0], training)
......@@ -42,4 +66,4 @@ class LossFunction(object):
total_loss += discr_loss
total_loss += weight_decay_loss
return (class_loss, gen_loss, discr_loss, weight_decay_loss, total_loss), (predictions, fake_features)
\ No newline at end of file
return {'class_loss':class_loss, 'generator_loss':gen_loss, 'discriminator_loss':discr_loss, 'weight_decay_loss':weight_decay_loss,'total_loss':total_loss}, {'predictions':predictions, 'fake_features':fake_features}
......@@ -4,44 +4,58 @@ class Generator(tf.keras.Model):
def __init__(self,
n_layers,
n_channels,
name = "",
name = "generator",
kernel_regularizer = tf.keras.regularizers.l2(2e-4),
kernel_initializer = tf.keras.initializers.he_normal()):
super(Generator, self).__init__()
self.n_layers = n_layers
self.n_channels = n_channels
self.name_op = name
self.model_name = name
self.kernel_regularizer = None
self.kernel_initializer = kernel_initializer
def build(self,input_shape):
self.model_layers = []
#Compression layers
for i in range(self.n_layers):
if i == self.n_layers-1:
self.model_layers.append((tf.keras.layers.Conv2D(input_shape[-1],
kernel_size=(1, 1),
strides = (1,1),
padding="same",
use_bias=False,
name=self.name_op+'_conv_' + str(i),
kernel_regularizer = None,
kernel_initializer = self.kernel_initializer,
activation=None),0))
else:
self.model_layers.append((tf.keras.layers.Conv2D(self.n_channels[i],
kernel_size=(3, 3),
strides=(1,1),
strides=(2,2),
padding="same",
use_bias=False,
name=self.name_op+'_conv_' + str(i),
name=self.model_name+'_conv_' + str(i),
kernel_regularizer = self.kernel_regularizer,
kernel_initializer = self.kernel_initializer,
activation=None),0))
self.model_layers.append((tf.keras.layers.BatchNormalization(axis=-1),1))
self.model_layers.append((tf.keras.layers.Activation("relu"),0))
#Decompression layers
for i in range(self.n_layers+1):
if i == self.n_layers:
self.model_layers.append((tf.keras.layers.Conv2DTranspose(input_shape[-1],
kernel_size=(1, 1),
strides=(1, 1),
padding="same",
use_bias=False,
name=self.model_name + '_conv_' + str(i),
kernel_regularizer=None,
kernel_initializer=self.kernel_initializer,
activation=None), 0))
else:
self.model_layers.append((tf.keras.layers.Conv2DTranspose(self.n_channels[self.n_layers-1-i],
kernel_size=(3, 3),
strides=(2, 2),
padding="same",
use_bias=False,
name=self.model_name + '_conv_tp_' + str(i),
kernel_regularizer=self.kernel_regularizer,
kernel_initializer=self.kernel_initializer,
activation=None), 0))
self.model_layers.append((tf.keras.layers.BatchNormalization(axis=-1), 1))
self.model_layers.append((tf.keras.layers.Activation("relu"), 0))
def call(self,input,training=False):
x = input
......@@ -51,5 +65,5 @@ class Generator(tf.keras.Model):
x = layer[0](x)
elif layer[1]==1:
x = layer[0](x,training)
x = x[:,:input.shape[1],:input.shape[2],:]
return x+input
......@@ -8,11 +8,12 @@ from absl import flags
from utils.trainer import ModelTrainer
from utils.data_loader import Dataset
from utils.data_loader import DataGenerator
from utils.summary_utils import ScalarSummaries
from models.classifier import Classifier
from models.res_block import ResBlockBasicLayer
from models.discriminator import Discriminator
from models.generator import Generator
from models.loss_functions.separate_classifier_wgan_loss_fn import LossFunction
from models.eval_functions.separate_classifier_wgan_eval_fns import EvalFunctions
import tensorflow as tf # pylint: disable=g-bad-import-order
BINS = 1025
......@@ -91,7 +92,7 @@ flags.DEFINE_integer('epochs', 300, 'number of epochs')
flags.DEFINE_integer('batch_size', 32, 'Mini-batch size')
flags.DEFINE_float('dropout_rate', 0.0, 'dropout rate for the dense blocks')
flags.DEFINE_float('weight_decay', 1e-4, 'weight decay parameter')
flags.DEFINE_float('learning_rate', 1e-1, 'learning rate')
flags.DEFINE_float('learning_rate', 5e-2, 'learning rate')
flags.DEFINE_boolean('preload_samples',False,'Preload samples (requires >140 GB RAM)')
flags.DEFINE_float('training_percentage', 90, 'Percentage of the training data used for training. (100-training_percentage is used as validation data.)')
......@@ -159,10 +160,11 @@ def main(argv):
use_max_pool = True,
kernel_regularizer = tf.keras.regularizers.l2(2e-4),
kernel_initializer = tf.keras.initializers.he_normal(),
name = "classifier",
dropout=dropout_rate)
#Generator model used to augment to false samples
generator_model = Generator(4,
[16,32,16,1],
generator_model = Generator(8,
[8,8,16,16,32,32,64,64],
kernel_regularizer = tf.keras.regularizers.l2(2e-4),
kernel_initializer = tf.keras.initializers.he_normal(),
name = "generator")
......@@ -184,27 +186,37 @@ def main(argv):
is_validation = True,
n_classes = ds_train.n_classes)
summaries = ScalarSummaries(scalar_summary_names=["class_loss",
"generator_loss",
"discriminator_loss",
"weight_decay_loss",
"total_loss",
"accuracy"],
learning_rate_names = ['learning_rate_'+str(classifier_model.model_name),
'learning_rate_'+str(generator_model.model_name),
'learning_rate_'+str(discriminator_model.model_name)],
save_dir = model_save_dir,
modes = ["train","val","test"],
summaries_to_print={'train':['class_loss','generator_loss','discriminator_loss','accuracy'],
'eval':['total_loss','accuracy']})
trainer = ModelTrainer(train_data_gen,
val_data_gen,
None,
epochs,
LossFunction,
classifier = classifier_model,
generator = generator_model,
discriminator = discriminator_model,
base_learning_rate_classifier = lr,
base_learning_rate_generator = lr*0.001,
base_learning_rate_discriminator = lr*0.0005,
learning_rate_fn_classifier = learning_rate_fn,
learning_rate_fn_generator = learning_rate_fn,
learning_rate_fn_discriminator = learning_rate_fn,
optimizer_classifier = tf.keras.optimizers.SGD,
optimizer_generator = tf.keras.optimizers.Adam,
optimizer_discriminator = tf.keras.optimizers.Adam,
EvalFunctions,
models = [classifier_model,generator_model,discriminator_model],
scalar_summaries = summaries,
base_learning_rates = [lr,lr*0.0001,lr*0.002],
learning_rate_fns = [learning_rate_fn],
optimizer_types = [tf.keras.optimizers.SGD,tf.keras.optimizers.Adam,tf.keras.optimizers.Adam],
num_train_batches = int(n_train/batch_size),
load_model = load_model,
save_dir = model_save_dir,
input_keys = ["input_features","false_sample"],
label_keys = ["labels"],
init_data = tf.random.normal([batch_size,BINS,N_FRAMES,N_CHANNELS]),
start_epoch = 0)
......
......@@ -12,7 +12,7 @@ from models.classifier import Classifier
from models.res_block import ResBlockBasicLayer
from models.discriminator import Discriminator
from models.generator import Generator
from models.loss_functions.separate_classifier_wgan_loss_fn import LossFunction
from models.eval_functions.separate_classifier_wgan_eval_fns import EvalFunctions
import tensorflow as tf # pylint: disable=g-bad-import-order
BINS = 1025
......@@ -140,6 +140,7 @@ def main(argv):
use_max_pool = True,
kernel_regularizer = tf.keras.regularizers.l2(2e-4),
kernel_initializer = tf.keras.initializers.he_normal(),
name = "classifier",
dropout=dropout_rate)
#Generator model used to augment to false samples
generator_model = Generator(8,
......@@ -162,26 +163,21 @@ def main(argv):
None,
None,
epochs,
LossFunction,
classifier = classifier_model,
generator = generator_model,
discriminator = discriminator_model,
base_learning_rate_classifier = lr,
base_learning_rate_generator = lr*0.0001,
base_learning_rate_discriminator = lr*0.002,
learning_rate_fn_classifier = learning_rate_fn,
learning_rate_fn_generator = learning_rate_fn,
learning_rate_fn_discriminator = learning_rate_fn,
optimizer_classifier = tf.keras.optimizers.SGD,
optimizer_generator = tf.keras.optimizers.Adam,
optimizer_discriminator = tf.keras.optimizers.Adam,
EvalFunctions,
models = [classifier_model,generator_model,discriminator_model],
scalar_summaries = None,
base_learning_rates = [lr,lr*0.0001,lr*0.002],
learning_rate_fns = [learning_rate_fn],
optimizer_types = [tf.keras.optimizers.SGD,tf.keras.optimizers.Adam,tf.keras.optimizers.Adam],
num_train_batches = int(n_train/batch_size),
load_model = load_model,
save_dir = model_save_dir,
input_keys = ["input_features","false_sample"],
label_keys = ["labels"],
init_data = tf.random.normal([batch_size,BINS,N_FRAMES,N_CHANNELS]),
start_epoch = 0)
all_predictions = trainer.classify_dataset(data_gen)
all_predictions = trainer.predict_dataset(data_gen)
np.save(os.path.join(data_dir,"train_set_predictions.npy"),all_predictions)
if __name__ == '__main__':
......
import os
import tensorflow as tf
class ScalarSummaries(object):
def __init__(self,scalar_summary_names,learning_rate_names,save_dir,modes = ["train","val","test"],summaries_to_print={}):
self.scalar_summary_names = scalar_summary_names
self.save_dir = save_dir
self.modes = modes
self.summaries_to_print = summaries_to_print
self.scalar_summaries = {}
self.lr_summaries = {}
self.learning_rate_names = learning_rate_names
self.summary_writers = {}
self.create_summary_writers()
self.define_scalar_summaries()
self.define_learning_rate_summaries()
def get_summary_list(self,mode = 'train'):
summary_list = []
if 'eval' in mode:
for tmp_mode in self.modes:
for key in self.summaries_to_print[mode]:
summary_list.append((key,self.scalar_summaries[tmp_mode+'_'+key].result()))
else:
for key in self.summaries_to_print[mode]:
summary_list.append((key,self.scalar_summaries[mode+'_'+key].result()))
return summary_list
def create_summary_writers(self):
for mode in self.modes:
log_dir = os.path.join(self.save_dir, 'logs',mode)
self.summary_writers[mode] = tf.summary.create_file_writer(log_dir)
def define_learning_rate_summaries(self):
for key in self.learning_rate_names:
self.lr_summaries[key] = tf.keras.metrics.Mean(key, dtype=tf.float32)
def define_scalar_summaries(self):
for mode in self.modes:
for key in self.scalar_summary_names:
self.scalar_summaries[mode+'_'+key] = tf.keras.metrics.Mean('mode_'+key, dtype=tf.float32)
def update(self,scalars,mode="train"):
for key in self.scalar_summaries.keys():
if mode in key.lower():
#Get scalar key
scalar_map_key = key.split(mode+"_")[-1]
scalar = scalars[scalar_map_key]
#Update summary
self.scalar_summaries[key].update_state(scalar)
def update_lr(self,lrs):
for key in self.lr_summaries.keys():
self.lr_summaries[key].update_state(lrs[key])
def write(self,epoch,mode="train"):
for key in self.scalar_summaries.keys():
if mode in key.lower():
scalar_map_key = key.split(mode+"_")[-1]
# Write summaries
with self.summary_writers[mode].as_default():
tf.summary.scalar(scalar_map_key,
self.scalar_summaries[key].result(), step=epoch)
def write_lr(self,epoch):
for key in self.lr_summaries.keys():
# Write summaries
with self.summary_writers['train'].as_default():
tf.summary.scalar(key,self.lr_summaries[key].result(), step=epoch)
def reset_summaries(self):
for key in self.scalar_summaries.keys():
self.scalar_summaries[key].reset_states()
for key in self.lr_summaries.keys():
self.lr_summaries[key].reset_states()
This diff is collapsed.
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment