Commit 21449e46 authored by Johanna Rock's avatar Johanna Rock
Browse files

Merge branch 'master' of git.spsc.tugraz.at:fuchs/kaggle_birdcall

parents 946c249c 5441bc86
......@@ -8,7 +8,7 @@ from absl import flags
from utils.trainer import ModelTrainer
from utils.data_loader import Dataset
from utils.data_loader import DataGenerator
from utils.summary_utils import ScalarSummaries
from utils.summary_utils import Summaries
from models.classifier import Classifier
from models.res_block import ResBlockBasicLayer
from models.discriminator import Discriminator
......@@ -182,12 +182,13 @@ def main(argv):
n_classes = ds_train.n_classes)
summaries = ScalarSummaries(scalar_summary_names=["class_loss",
summaries = Summaries(scalar_summary_names=["class_loss",
"generator_loss",
"discriminator_loss",
"weight_decay_loss",
"total_loss",
"accuracy"],
image_summary_settings = {'train':['fake_features'],'n_images':1},
learning_rate_names = ['learning_rate_'+str(classifier_model.model_name),
'learning_rate_'+str(generator_model.model_name),
'learning_rate_'+str(discriminator_model.model_name)],
......@@ -209,15 +210,15 @@ def main(argv):
'init_data':tf.random.normal([batch_size,BINS,N_FRAMES,N_CHANNELS])},
{'model':generator_model,
'optimizer_type':tf.keras.optimizers.Adam,
'base_learning_rate':lr*0.0001,
'base_learning_rate':lr*0.01,
'learning_rate_fn':learning_rate_fn,
'init_data':tf.random.normal([batch_size,BINS,N_FRAMES,N_CHANNELS])},
{'model':discriminator_model,
'optimizer_type':tf.keras.optimizers.Adam,
'base_learning_rate':lr*0.002,
'base_learning_rate':lr*0.02,
'learning_rate_fn':learning_rate_fn,
'init_data':tf.random.normal([batch_size,BINS,N_FRAMES,N_CHANNELS])}],
scalar_summaries = summaries,
summaries = summaries,
num_train_batches = int(n_train/batch_size),
load_model = load_model,
save_dir = model_save_dir,
......
import os
import tensorflow as tf
class ScalarSummaries(object):
def __init__(self,scalar_summary_names,learning_rate_names,save_dir,modes = ["train","val","test"],summaries_to_print={}):
class Summaries(object):
def __init__(self,scalar_summary_names,learning_rate_names,image_summary_settings={},save_dir="/tmp",modes = ["train","val","test"],summaries_to_print={}):
self.scalar_summary_names = scalar_summary_names
self.image_summary_settings = image_summary_settings
self.save_dir = save_dir
self.modes = modes
self.summaries_to_print = summaries_to_print
self.scalar_summaries = {}
self.image_data = {}
self.lr_summaries = {}
self.learning_rate_names = learning_rate_names
self.summary_writers = {}
......@@ -24,8 +26,9 @@ class ScalarSummaries(object):
for key in self.summaries_to_print[mode]:
summary_list.append((key,self.scalar_summaries[tmp_mode+'_'+key].result()))
else:
for key in self.summaries_to_print[mode]:
summary_list.append((key,self.scalar_summaries[mode+'_'+key].result()))
if mode in self.summaries_to_print.keys():
for key in self.summaries_to_print[mode]:
summary_list.append((key,self.scalar_summaries[mode+'_'+key].result()))
return summary_list
def create_summary_writers(self):
......@@ -68,6 +71,26 @@ class ScalarSummaries(object):
tf.summary.scalar(scalar_map_key,
self.scalar_summaries[key].result(), step=epoch)
def update_image_data(self,outputs,mode="train"):
self.image_data = {}
for key in outputs.keys():
if mode in self.image_summary_settings.keys():
if key in self.image_summary_settings[mode]:
self.image_data[key] = outputs[key]
def write_image_summaries(self,epoch,mode="train"):
with self.summary_writers[mode].as_default():
if mode in self.image_summary_settings.keys():
try:
n_images = self.image_summary_settings["n_images"]
except:
n_images = 1
for image_summary in self.image_summary_settings[mode]:
tf.summary.image(image_summary, self.image_data[image_summary],
step=epoch,
max_outputs=n_images)
def write_lr(self,epoch):
for key in self.lr_summaries.keys():
......
......@@ -22,7 +22,7 @@ class ModelTrainer():
base_learning_rates = [],
learning_rate_fns = [],
init_data = [],
scalar_summaries = None,
summaries = None,
start_epoch = 0,
num_train_batches = None,
load_model = False,
......@@ -53,7 +53,7 @@ class ModelTrainer():
self.initialize_models(init_data)
self.n_vars_models = [len(model.trainable_variables) for model in self.models]
#Set up summaries
self.scalar_summaries = scalar_summaries
self.summaries = summaries
#Set up loss function
self.eval_fns = eval_fns(self.models)
#Set up optimizers
......@@ -170,7 +170,7 @@ class ModelTrainer():
self.learning_rates[i].assign(self.base_learning_rates[i]*self.learning_rate_fns[0](epoch))
else:
self.learning_rates[i].assign(self.base_learning_rates[i]*self.learning_rate_fns[i](epoch))
if self.scalar_summaries != None:
if self.summaries != None:
self.update_learning_rate_summaries()
self.write_learning_rate_summaries(epoch)
......@@ -230,15 +230,15 @@ class ModelTrainer():
def update_learning_rate_summaries(self):
lr_dict = {}
for lr_name,lr in zip(self.scalar_summaries.learning_rate_names,self.learning_rates):
for lr_name,lr in zip(self.summaries.learning_rate_names,self.learning_rates):
lr_dict[lr_name] = lr
self.scalar_summaries.update_lr(lr_dict)
self.summaries.update_lr(lr_dict)
def update_summaries(self,losses,outputs,y=None,mode='train'):
if self.scalar_summaries != None:
if 'accuracy' in self.scalar_summaries.scalar_summary_names:
if self.summaries != None:
if 'accuracy' in self.summaries.scalar_summary_names:
pred = outputs['predictions']
accuracy = self.eval_fns.accuracy(pred,y)
scalars = losses
......@@ -246,25 +246,29 @@ class ModelTrainer():
else:
scalars = losses
#Update summaries
self.scalar_summaries.update(scalars,mode)
#Update scalar summaries
self.summaries.update(scalars,mode)
#Update image summaries
self.summaries.update_image_data(outputs,mode)
def write_summaries(self,epoch,mode="train"):
if self.scalar_summaries != None:
# Write summaries
self.scalar_summaries.write(epoch,mode)
if self.summaries != None:
# Write scalar summaries
self.summaries.write(epoch,mode)
#Write image summaries
self.summaries.write_image_summaries(epoch,mode)
def write_learning_rate_summaries(self,epoch):
if self.scalar_summaries != None:
if self.summaries != None:
# Write summaries
self.scalar_summaries.write_lr(epoch)
self.summaries.write_lr(epoch)
def reset_summaries(self):
if self.scalar_summaries != None:
if self.summaries != None:
# Write summaries
self.scalar_summaries.reset_summaries()
self.summaries.reset_summaries()
def get_data_for_keys(self,xy,keys):
"""Returns list of input data"""
......@@ -317,8 +321,8 @@ class ModelTrainer():
self.update_summaries(losses,outputs,train_y,'train')
batch += 1
if self.scalar_summaries != None:
summary_list = self.scalar_summaries.get_summary_list('train')
if self.summaries != None:
summary_list = self.summaries.get_summary_list('train')
else:
summary_list = []
summary_list += [("time / step", np.round(time() - start_time, 2))]
......@@ -330,6 +334,7 @@ class ModelTrainer():
# Write train summaries
self.write_summaries(epoch+1, 'train')
if self.validation_data_generator != None:
for validation_xy in self.validation_data_generator:
validation_x = self.get_data_for_keys(validation_xy,self.input_keys)
......@@ -364,14 +369,14 @@ class ModelTrainer():
self.write_summaries(epoch+1,"test")
if self.scalar_summaries != None:
if self.summaries != None:
template = 'Epoch {}, Loss: {}, Accuracy: {},Val Loss: {}, Val Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
summary_list = self.scalar_summaries.get_summary_list('eval')
summary_list = self.summaries.get_summary_list('eval')
scalars = [x[1] for x in summary_list]
print(template.format(epoch + 1,*scalars))
#Reset summaries after epoch
if self.scalar_summaries != None:
self.scalar_summaries.reset_summaries()
if self.summaries != None:
self.summaries.reset_summaries()
#Save the model weights
self.save_model(epoch+1)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment