Commit 45b3e48d authored by Alexander Fuchs's avatar Alexander Fuchs
Browse files

Started implementing a training script for the birdcall challenge

parent 92d4ce96
import os
import numpy as np
from absl import logging
from absl import app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from utils.trainer import ModelTrainer
from models import basic_dmoe_cnn
#logging.set_verbosity(logging.WARNING)
HEIGHT = 28
WIDTH = 28
NUM_CHANNELS = 1
BATCH_SIZE = 16
NUM_CLASSES = 10
def preprocess_feature(sample):
"""Preprocess a single image of layout [height, width, depth]."""
return sample
def data_generator(data_generator,batch_size,is_training,shuffle_buffer = 128,is_validation=False,take_n=None,skip_n=None):
dataset = tf.data.Dataset.from_generator(data_generator)
if skip_n != None:
dataset = dataset.skip(skip_n)
if take_n != None:
dataset = dataset.take(take_n)
if is_training:
dataset = dataset.shuffle(shuffle_buffer)
dataset = dataset.map(lambda feat, lbl: (preprocess_feature(feat), lbl))
dataset = dataset.map(lambda feat, lbl: (feat, tf.one_hot(lbl,NUM_CLASSES)))
dataset = dataset.batch(BATCH_SIZE,drop_remainder=True)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
else:
dataset = dataset.map(lambda feat, lbl: (preprocess_feature(feat), lbl))
dataset = dataset.map(lambda feat, lbl: (feat, tf.one_hot(lbl,NUM_CLASSES)))
dataset = dataset.batch(10)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
def learning_rate_fn(epoch):
if epoch >= 20 and epoch <30:
return 0.01
elif epoch >=30 and epoch <40:
return 0.001
elif epoch >=40:
return 0.001
else:
return 1.0
FLAGS = flags.FLAGS
flags.DEFINE_string('model_dir', '/tmp', 'save directory name')
flags.DEFINE_string('data_dir', '/tmp', 'data directory name')
flags.DEFINE_integer('epochs', 40, 'number of epochs')
flags.DEFINE_float('dropout_rate', 0.0, 'dropout rate for the dense blocks')
flags.DEFINE_float('weight_decay', 1e-4, 'weight decay parameter')
flags.DEFINE_float('learning_rate', 1e-3, 'learning rate')
flags.DEFINE_float('training_percentage', 80, 'Percentage of the training data used for training. (100-training_percentage is used as validation data.)')
flags.DEFINE_boolean('load_model', False, 'Bool indicating if the model should be loaded')
def main(argv):
try:
task_id = int(os.environ['SLURM_ARRAY_TASK_ID'])
except KeyError:
task_id = 0
model_save_dir = FLAGS.model_dir
data_dir = FLAGS.data_dir
print("Saving model to : " + str(model_save_dir))
print("Loading data from : " + str(data_dir))
test_data_dir = data_dir
train_data_dir = data_dir
epochs = FLAGS.epochs
dropout_rate = FLAGS.dropout_rate
weight_decay = FLAGS.weight_decay
lr = FLAGS.learning_rate
load_model = FLAGS.load_model
training_percentage = FLAGS.training_percentage
model_save_dir+="_dropout_rate_"+str(dropout_rate)+"_learning_rate_"+str(lr)+"_weight_decay_"+str(weight_decay)
model = basic_dmoe_cnn.basic_dmoe_cnn_mnist()
ds_train = Dataset(data_dir,is_training_set = True)
n_total = ds_train.n_samples
dg_train = DataGenerator(ds_train,None)
n_train = int(n_total*training_percentage/100)
n_val = n_total-n_train
train_data_gen = data_generator(dg_train,BATCH_SIZE,is_training=True,take_n=n_train)
val_data_gen = data_generator(train_data,100,is_training=False,is_validation = True,skip_n=n_train,take_n=n_val)
trainer = ModelTrainer(model,
train_data_gen,
val_data_gen,
None,
epochs,
learning_rate_fn = learning_rate_fn,
optimizer = tf.keras.optimizers.Adam,
num_train_batches = int(n_train/BATCH_SIZE),
base_learning_rate = lr,
load_model = load_model,
save_dir = model_save_dir,
init_data = tf.random.normal([BATCH_SIZE,1025,432,2]),
start_epoch = 0)
trainer.train()
if __name__ == '__main__':
app.run(main)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment