Commit 283ce360 authored by Alexander Fuchs's avatar Alexander Fuchs
Browse files

Modified trainer, now takes models and optimizer settings as dict

parent fdeb9f59
......@@ -207,11 +207,22 @@ def main(argv):
None,
epochs,
EvalFunctions,
models = [classifier_model,generator_model,discriminator_model],
model_settings = [{'model':classifier_model,
'optimizer_type':tf.keras.optimizers.SGD,
'base_learning_rate':lr,
'learning_rate_fn':learning_rate_fn,
'init_data':tf.random.normal([batch_size,BINS,N_FRAMES,N_CHANNELS])},
{'model':generator_model,
'optimizer_type':tf.keras.optimizers.Adam,
'base_learning_rate':lr*0.0001,
'learning_rate_fn':learning_rate_fn,
'init_data':tf.random.normal([batch_size,BINS,N_FRAMES,N_CHANNELS])},
{'model':discriminator_model,
'optimizer_type':tf.keras.optimizers.Adam,
'base_learning_rate':lr*0.002,
'learning_rate_fn':learning_rate_fn,
'init_data':tf.random.normal([batch_size,BINS,N_FRAMES,N_CHANNELS])}],
scalar_summaries = summaries,
base_learning_rates = [lr,lr*0.0001,lr*0.002],
learning_rate_fns = [learning_rate_fn],
optimizer_types = [tf.keras.optimizers.SGD,tf.keras.optimizers.Adam,tf.keras.optimizers.Adam],
num_train_batches = int(n_train/batch_size),
load_model = load_model,
save_dir = model_save_dir,
......
......@@ -164,11 +164,22 @@ def main(argv):
None,
epochs,
EvalFunctions,
models = [classifier_model,generator_model,discriminator_model],
model_settings = [{'model':classifier_model,
'optimizer_type':tf.keras.optimizers.SGD,
'base_learning_rate':lr,
'learning_rate_fn':learning_rate_fn,
'init_data':tf.random.normal([batch_size,BINS,N_FRAMES,N_CHANNELS])},
{'model':generator_model,
'optimizer_type':tf.keras.optimizers.Adam,
'base_learning_rate':lr*0.0001,
'learning_rate_fn':learning_rate_fn,
'init_data':tf.random.normal([batch_size,BINS,N_FRAMES,N_CHANNELS])},
{'model':discriminator_model,
'optimizer_type':tf.keras.optimizers.Adam,
'base_learning_rate':lr*0.002,
'learning_rate_fn':learning_rate_fn,
'init_data':tf.random.normal([batch_size,BINS,N_FRAMES,N_CHANNELS])}],
scalar_summaries = None,
base_learning_rates = [lr,lr*0.0001,lr*0.002],
learning_rate_fns = [learning_rate_fn],
optimizer_types = [tf.keras.optimizers.SGD,tf.keras.optimizers.Adam,tf.keras.optimizers.Adam],
num_train_batches = int(n_train/batch_size),
load_model = load_model,
save_dir = model_save_dir,
......
......@@ -17,7 +17,7 @@ class ModelTrainer():
test_data_generator,
epochs,
eval_fns,
models = [],
model_settings = [],
optimizer_types = [],
base_learning_rates = [],
learning_rate_fns = [],
......@@ -36,28 +36,20 @@ class ModelTrainer():
self.epoch_variable = tf.Variable(start_epoch)
self.start_epoch = start_epoch
if type(models) == list:
self.models = models
if type(model_settings) == list:
self.model_settings = model_settings
else:
self.models = [models]
self.model_settings = [model_settings]
self.models = [setting['model'] for setting in self.model_settings]
self.save_dir = save_dir
#Set up learning rates
if type(learning_rate_fns) == list:
self.learning_rate_fns = learning_rate_fns
else:
self.learning_rate_fns = [learning_rate_fns]
if type(base_learning_rates) == list:
self.base_learning_rates = base_learning_rates
else:
self.base_learning_rates = [base_learning_rates]
self.learning_rate_fns = [setting['learning_rate_fn'] for setting in self.model_settings]
self.base_learning_rates = [setting['base_learning_rate'] for setting in self.model_settings]
self.set_up_learning_rates()
#Initialize models
if type(init_data) != list:
init_data = [init_data]
init_data = [setting['init_data'] for setting in self.model_settings]
self.initialize_models(init_data)
self.n_vars_models = [len(model.trainable_variables) for model in self.models]
#Set up summaries
......@@ -65,10 +57,7 @@ class ModelTrainer():
#Set up loss function
self.eval_fns = eval_fns(self.models)
#Set up optimizers
if type(optimizer_types) != list:
self.optimizer_types = [optimizer_types]
else:
self.optimizer_types = optimizer_types
self.optimizer_types = [setting['optimizer_type'] for setting in self.model_settings]
self.set_up_optimizers()
#Set up function to process the gradients
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment