Commit 8c3e0003 authored by Alexander Fuchs's avatar Alexander Fuchs
Browse files

Prepared models for more complicated architectures

parent 76d8abc3
from models.layers.dense_moe import DenseMoE
import tensorflow as tf
class Classifier(tf.keras.Model):
def __init__(self,
network_block,
n_blocks,
n_layers,
strides,
channel_base,
n_classes,
init_ch,
init_ksize,
init_stride,
use_max_pool = True,
kernel_regularizer = tf.keras.regularizers.l2(2e-4),
kernel_initializer = tf.keras.initializers.he_normal(),
dropout=0.2):
super(Classifier, self).__init__()
self.network_block = network_block
self.n_blocks = n_blocks
self.n_layers = n_layers
self.strides = strides
self.channel_base = channel_base
self.n_classes = n_classes
self.dropout = dropout
self.init_ch = init_ch
self.init_ksize = init_ksize
self.init_stride = init_stride
self.use_max_pool = use_max_pool
self.kernel_regularizer = kernel_regularizer
self.kernel_initializer = kernel_initializer
def build(self,input_shape):
self.init_conv = tf.keras.layers.Conv2D(self.init_ch,
self.init_ksize,
self.init_stride,
padding = "same",
use_bias = False,
name = 'initial_conv',
kernel_regularizer = self.kernel_regularizer,
kernel_initializer = self.kernel_initializer)
self.init_bn = tf.keras.layers.BatchNormalization(axis=-1)
self.init_relu = tf.keras.layers.Activation("relu")
if self.use_max_pool:
self.init_max_pool = tf.keras.layers.MaxPool2D(pool_size=(3, 3),
strides=(2, 2),
padding="same")
self.network_blocks = []
for i_block in range(self.n_blocks):
self.network_blocks.append(self.network_block(self.n_layers[i_block],
self.channel_base[i_block],
stride = self.strides[i_block],
kernel_regularizer = self.kernel_regularizer,
kernel_initializer = self.kernel_initializer))
self.last_bn = tf.keras.layers.BatchNormalization(axis=-1)
self.last_relu = tf.keras.layers.Activation("relu")
self.avg_pool = tf.keras.layers.GlobalAveragePooling2D()
self.dense = tf.keras.layers.Dense(self.n_classes,
name = 'dense_layer',
kernel_regularizer = self.kernel_regularizer,
kernel_initializer = self.kernel_initializer)
def call(self,input,training=False):
"""Returns logits"""
x = self.init_conv(input)
x = self.init_bn(x,training)
x = self.init_relu(x)
if self.use_max_pool:
x = self.init_max_pool(x)
for block in self.network_blocks:
x = block(x,training)
x = self.last_bn(x,training)
x = self.last_relu(x)
x = self.avg_pool(x)
x = self.dense(x)
return x
import tensorflow as tf
from models.layers.spectral_normalization import SpectralNormalization
class Discriminator(tf.keras.Model):
def __init__(self,
n_layers,
n_channels,
strides,
name = "",
kernel_regularizer = None,
kernel_initializer = tf.keras.initializers.he_normal()):
super(Discriminator, self).__init__()
self.n_layers = n_layers
self.strides = strides
self.n_channels = n_channels
self.name_op = name
self.kernel_regularizer = None
self.kernel_initializer = kernel_initializer
def build(self,input_shape):
self.model_layers = []
for i in range(self.n_layers):
self.model_layers.append(SpectralNormalization(tf.keras.layers.Conv2D(self.n_channels[i],
kernel_size=(3, 3),
strides=strides[i],
padding="same",
use_bias=False,
name=self.name_op+'_conv_' + str(i),
kernel_regularizer = self.kernel_regularizer,
kernel_initializer = self.kernel_initializer,
activation='relu')))
self.dense = SpectralNormalization(Dense(1,
kernel_initializer = self.kernel_initializer,
use_bias = False,
name = self.name_op+"_dense",
activation = None))
def call(self,input,training=False):
x = input
for layer in self.model_layers:
x = layer(x)
x = self.dense(x)
return x
import tensorflow as tf
class Generator(tf.keras.Model):
def __init__(self,
n_layers,
n_channels,
name = "",
kernel_regularizer = tf.keras.regularizers.l2(2e-4),
kernel_initializer = tf.keras.initializers.he_normal()):
super(Generator, self).__init__()
self.n_layers = n_layers
self.n_channels = n_channels
self.name_op = name
self.kernel_regularizer = None
self.kernel_initializer = kernel_initializer
def build(self,input_shape):
self.model_layers = []
for i in range(self.n_layers):
if i == self.n_layers-1:
self.model_layers.append((tf.keras.layers.Conv2D(input_shape[-1],
kernel_size=(1, 1),
strides = (1,1),
padding="same",
use_bias=False,
name=self.name_op+'_conv_' + str(i),
kernel_regularizer = None,
kernel_initializer = self.kernel_initializer,
activation=None),0))
else:
self.model_layers.append((tf.keras.layers.Conv2D(self.n_channels[i],
kernel_size=(3, 3),
strides=(1,1),
padding="same",
use_bias=False,
name=self.name_op+'_conv_' + str(i),
kernel_regularizer = self.kernel_regularizer,
kernel_initializer = self.kernel_initializer,
activation=None),0))
self.model_layers.append((tf.keras.layers.BatchNormalization(axis=-1),1))
self.model_layers.append((tf.keras.layers.Activation("relu"),0))
def call(self,input,training=False):
x = input
for layer in self.model_layers:
if layer[1]==0:
x = layer[0](x)
elif layer[1]==1:
x = layer[0](x,training)
return x
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import layers
from tensorflow.python.keras import initializers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
import tensorflow as tf
class SpectralNormalization(layers.Wrapper):
"""
Attributes:
layer: tensorflow keras layers (with kernel attribute)
"""
def __init__(self, layer, **kwargs):
super(SpectralNormalization, self).__init__(layer, **kwargs)
def build(self, input_shape):
"""Build `Layer`"""
if not self.layer.built:
self.layer.build(input_shape)
if not hasattr(self.layer, 'kernel'):
raise ValueError(
'`SpectralNormalization` must wrap a layer that'
' contains a `kernel` for weights')
self.w = self.layer.kernel
self.w_shape = self.w.shape.as_list()
self.u = self.add_weight(
shape=tuple([1, self.w_shape[-1]]),
initializer=initializers.TruncatedNormal(stddev=0.02),
name='sn_u',
aggregation = tf.VariableAggregation.MEAN,
trainable=False)
#super(SpectralNormalization, self).build(input_shape)
#@tf.function(experimental_relax_shapes=True)
def call(self, inputs, training=None):
"""Call `Layer`"""
#if training==None:
# training = K.learning_phase()
if training==True:
# Recompute weights for each forward pass
self._compute_weights()
output = self.layer(inputs)
return output
def _compute_weights(self):
"""Generate normalized weights.
This method will update the value of self.layer.kernel with the
normalized value, so that the layer is ready for call().
"""
w_reshaped = array_ops.reshape(self.w, [-1, self.w_shape[-1]])
eps = 1e-12
_u = array_ops.identity(self.u)
_v = math_ops.matmul(_u, array_ops.transpose(w_reshaped))
_v = _v / math_ops.maximum(math_ops.reduce_sum(_v**2)**0.5, eps)
_u = math_ops.matmul(_v, w_reshaped)
_u = _u / math_ops.maximum(math_ops.reduce_sum(_u**2)**0.5, eps)
self.u.assign(_u)
sigma = math_ops.matmul(math_ops.matmul(_v, w_reshaped), array_ops.transpose(_u))
self.layer.kernel.assign(self.w / sigma)
def compute_output_shape(self, input_shape):
return tensor_shape.TensorShape(
self.layer.compute_output_shape(input_shape).as_list())
......@@ -2,7 +2,7 @@ from models.layers.dense_moe import DenseMoE
import tensorflow as tf
class Network(tf.keras.Model):
class Classifier(tf.keras.Model):
def __init__(self,
network_block,
n_blocks,
......@@ -70,7 +70,6 @@ class Network(tf.keras.Model):
def call(self,input,training=False):
"""Returns logits"""
x = self.init_conv(input)
x = self.init_bn(x,training)
x = self.init_relu(x)
......
......@@ -331,11 +331,8 @@ class DataGenerator(object):
if os.path.isfile(filename.replace("mp3","npz")) and not(self.force_feature_recalc):
with np.load(filename.replace("mp3","npz"),allow_pickle=True) as sample_file:
spectra_npz = sample_file.f.arr_0
try:
spec_keys = spectra_npz.item().keys()
except:
continue
spec_keys = spectra_npz.item().keys()
spec_keys = list(spec_keys)
rnd_key = spec_keys[np.random.randint(0,len(spec_keys))]
spectra = spectra_npz.item()[rnd_key]
......@@ -377,8 +374,6 @@ class DataGenerator(object):
yield {'input_features':spectra,
'labels':tf.one_hot(sample['bird_id'],self.dataset.n_classes+1),
'false_sample':false_spectra}
# except:
# continue
if __name__ == "__main__":
ds = Dataset("/srv/TUG/datasets/cornell_birdcall_recognition")
......
......@@ -8,11 +8,14 @@ def clip_by_value_10(grad):
return tf.clip_by_value(grad, -10, 10)
class ModelTrainer():
def __init__(self,model,
def __init__(self,
training_data_generator,
validation_data_generator,
test_data_generator,
epochs,
classifier,
generator = None,
discriminator = None,
optimizer = tf.keras.optimizers.Adam,
learning_rate_fn = None,
base_learning_rate=1e-3,
......@@ -25,14 +28,20 @@ class ModelTrainer():
gradient_processesing_fn = clip_by_value_10,
save_dir = 'tmp'):
self.model = model
self.classifier = classifier
self.generator = generator
self.discriminator = discriminator
self.learning_rate_fn = learning_rate_fn
self.save_dir = save_dir
self.base_learning_rate = base_learning_rate
self.learning_rate = tf.Variable(self.base_learning_rate*self.learning_rate_fn(start_epoch))
#Initialize model
self.model(init_data,training = True)
self.model.summary()
#Initialize models
self.classifier(init_data,training = True)
self.classifier.summary()
self.generator(init_data,training = True)
self.generator.summary()
self.discriminator(init_data,training = True)
self.discriminator.summary()
#Set up optimizer
if optimizer == tf.keras.optimizers.SGD:
self.optimizer = optimizer(self.learning_rate,0.9,True)
......@@ -84,44 +93,47 @@ class ModelTrainer():
self.scalar_summaries['val_' + name] = tf.keras.metrics.Mean('val_' + name, dtype=tf.float32)
self.scalar_summaries['test_' + name] = tf.keras.metrics.Mean('test_' + name, dtype=tf.float32)
def get_latest_model_file(self,model_type):
files = [os.path.join(self.save_dir, f) for f in os.listdir(self.save_dir) if
os.path.isfile(os.path.join(self.save_dir, f)) and "h5" in f and model_type in f]
files.sort(key=lambda x: os.path.getmtime(x))
latest = files[-1]
return latest
def load_model(self):
def get_model(self,model_type,print_model_name=True):
"""If the self.start_epoch is 0 the loader loads the latest file
from the directory, otherwise the file from the specified epoch is loaded"""
latest = self.get_latest_model_file(model_type)
split = latest.split("_")
if self.start_epoch == 0:
files = [os.path.join(self.save_dir, f) for f in os.listdir(self.save_dir) if
os.path.isfile(os.path.join(self.save_dir, f))and "h5" in f]
files.sort(key=lambda x: os.path.getmtime(x))
latest = files[-1]
split = latest.split("_")
str_epoch = split[-2]
self.start_epoch = int(str_epoch)
print("____________________________________")
print("Loading model: "+latest)
print("____________________________________")
self.model.load_weights(latest)
model_to_load = latest
else:
files = [os.path.join(self.save_dir, f) for f in os.listdir(self.save_dir) if
os.path.isfile(os.path.join(self.save_dir, f)) and "h5" in f]
files.sort(key=lambda x: os.path.getmtime(x))
latest = files[-1]
split = latest.split("_")
split[-2] = str(start_epoch)
model_to_load = "_".join(split)
if print_model_name:
print("____________________________________")
print("Loading model: "+latest)
print("____________________________________")
self.model.load_weights(latest)
return model_to_load
def load_model(self):
"""Loads weights for all models"""
self.classifier.load_weights(self.get_model("classifier"))
self.generator.load_weights(self.get_model("generator"))
self.discriminator.load_weights(self.get_model("discriminator"))
def classify(self,x,training = False):
logits_classifier = self.model(x, training=training)
logits_classifier = self.classifier(x, training=training)
return tf.nn.softmax(logits_classifier, axis=-1)
def compute_loss(self, x, y , training = True):
logits_classifier = self.model(x,training=training)
logits_classifier = self.classifier(x,training=training)
# Cross entropy losses
class_loss = tf.reduce_mean(
......@@ -131,8 +143,13 @@ class ModelTrainer():
axis=-1,
))
if len(self.model.losses) > 0:
weight_decay_loss = tf.add_n(self.model.losses)
if len(self.classifier.losses) > 0:
weight_decay_loss = tf.add_n(self.classifier.losses)
else:
weight_decay_loss = 0.0
if len(self.generator.losses) > 0:
weight_decay_loss = tf.add_n(self.generator.losses)
else:
weight_decay_loss = 0.0
......@@ -151,7 +168,7 @@ class ModelTrainer():
x, y, training=True)
gradients = tape.gradient(total_loss, self.model.trainable_variables)
gradients = tape.gradient(total_loss, self.classifier.trainable_variables)
if self.gradient_processing_fn != None:
out_grads = []
......@@ -178,7 +195,7 @@ class ModelTrainer():
#Compute gradients
gradients,class_loss,weight_decay_loss,total_loss,predictions = self.compute_gradients(x,y)
#Apply gradeints
self.apply_gradients(gradients, self.model.trainable_variables)
self.apply_gradients(gradients, self.classifier.trainable_variables)
return (class_loss,weight_decay_loss,total_loss,predictions)
......@@ -210,7 +227,9 @@ class ModelTrainer():
def train(self):
if self.start_epoch == 0:
print("Saving...")
self.model.save_weights(os.path.join(self.save_dir, "capsule_network_" + str(0) + "_.h5"))
self.classifier.save_weights(os.path.join(self.save_dir, "classifier_" + str(0) + "_.h5"))
self.generator.save_weights(os.path.join(self.save_dir, "generator_" + str(0) + "_.h5"))
self.discriminator.save_weights(os.path.join(self.save_dir, "discriminator_" + str(0) + "_.h5"))
print("Starting training...")
......@@ -234,7 +253,12 @@ class ModelTrainer():
outputs = self.train_step(train_x, train_y)
class_loss,weight_decay_loss,total_loss,predictions = outputs
# Update summaries
self.update_summaries(class_loss,weight_decay_loss,total_loss,predictions,train_y,"train")
self.update_summaries(class_loss,
weight_decay_loss,
total_loss,
predictions,
train_y,
"train")
batch += 1
......@@ -259,7 +283,12 @@ class ModelTrainer():
training=False)
# Update summaries
self.update_summaries(class_loss,weight_decay_loss,total_loss,predictions,validation_y,"val")
self.update_summaries(class_loss,
weight_decay_loss,
total_loss,
predictions,
validation_y,
"val")
# Write validation summaries
self.write_summaries(epoch,"val")
......@@ -275,7 +304,12 @@ class ModelTrainer():
training=False)
# Update summaries
self.update_summaries(class_loss, weight_decay_loss, total_loss, predictions, test_y, "test")
self.update_summaries(class_loss,
weight_decay_loss,
total_loss,
predictions,
test_y,
"test")
# Write test summaries
self.write_summaries(epoch,"test")
......@@ -292,4 +326,6 @@ class ModelTrainer():
#Reset summaries after epoch
self.reset_summaries()
#Save the model weights
self.model.save_weights(os.path.join(self.save_dir, "capsule_network_" + str(epoch + 1) + "_.h5"))
self.classifier.save_weights(os.path.join(self.save_dir, "classifier_" + str(epoch+1) + "_.h5"))
self.generator.save_weights(os.path.join(self.save_dir, "generator_" + str(epoch+1) + "_.h5"))
self.discriminator.save_weights(os.path.join(self.save_dir, "discriminator_" + str(epoch+1) + "_.h5"))
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment