Commit 76d8abc3 authored by Alexander Fuchs's avatar Alexander Fuchs
Browse files

Added pre-loading of false samples

parent a1536602
......@@ -10,6 +10,7 @@ import sys
import random
import copy
import tensorflow as tf
warnings.filterwarnings('ignore')
class Dataset(object):
......@@ -132,9 +133,10 @@ class Dataset(object):
class DataGenerator(object):
def __init__(self,dataset,augmentation,
shuffle=True,
is_training=True,
force_feature_recalc=False,
shuffle = True,
is_training = True,
force_feature_recalc = False,
preload_false_samples = True,
max_time = 5,
max_samples_per_audio = 6,
n_fft = 2048,
......@@ -149,11 +151,28 @@ class DataGenerator(object):
self.is_training = is_training
self.sampling_rate = sampling_rate
self.n_fft = n_fft
self.preload_false_samples = preload_false_samples
self.hop_length = hop_length
self.max_time = max_time
self.max_samples_per_audio = max_samples_per_audio
self.force_feature_recalc = force_feature_recalc
#Get paths of false samples
false_samples_mono = glob.glob(self.dataset.false_audio_path+ "/mono/*.npz",
recursive = True)
false_samples_stereo = glob.glob(self.dataset.false_audio_path+ "/stereo/*.npz",
recursive = True)
self.false_sample_paths = false_samples_mono + false_samples_stereo
#Pre load false samples
if self.preload_false_samples:
self.preloaded_false_samples = {}
for path in self.false_sample_paths:
with np.load(path,allow_pickle=True) as sample_file:
self.preloaded_false_samples[path] = copy.deepcopy(sample_file.f.arr_0)
print("Finished pre-loading")
def do_stft(self,y,channels):
spectra = []
#STFT for all channels
......@@ -302,63 +321,64 @@ class DataGenerator(object):
samples = self.dataset.train_samples
else:
samples = self.dataset.test_samples
#Get paths of false samples
false_samples_mono = glob.glob(self.dataset.false_audio_path+ "/mono/*.npz",
recursive = True)
false_samples_stereo = glob.glob(self.dataset.false_audio_path+ "/stereo/*.npz",
recursive = True)
false_samples = false_samples_mono + false_samples_stereo
stft_len = int(np.ceil(self.max_time*self.sampling_rate/self.hop_length))
for sample in samples:
try:
filename = sample['filename']
#If feature was already created load from file
if os.path.isfile(filename.replace("mp3","npz")) and not(self.force_feature_recalc):
spectra_npz = np.load(filename.replace("mp3","npz"),allow_pickle=True)
spec_keys = spectra_npz.f.arr_0.item().keys()
spec_keys = list(spec_keys)
rnd_key = spec_keys[np.random.randint(0,len(spec_keys))]
spectra = spectra_npz.f.arr_0.item()[rnd_key]
else:
#Create features via STFT if no file exists
spectra = self.create_feature(sample)
#Check for None type and shape
if np.any(spectra) == None or spectra.shape[-1] != stft_len:
filename = sample['filename']
#If feature was already created load from file
if os.path.isfile(filename.replace("mp3","npz")) and not(self.force_feature_recalc):
with np.load(filename.replace("mp3","npz"),allow_pickle=True) as sample_file:
spectra_npz = sample_file.f.arr_0
try:
spec_keys = spectra_npz.item().keys()
except:
continue
#Get false sample
rnd_false_sample = random.choice(false_samples)
false_spectra_npz = np.load(rnd_false_sample,allow_pickle=True)
false_spec_keys = false_spectra_npz.f.arr_0.item().keys()
false_spec_keys = list(false_spec_keys)
false_rnd_key = false_spec_keys[np.random.randint(0,len(false_spec_keys))]
false_spectra = false_spectra_npz.f.arr_0.item()[false_rnd_key]
#If only mono --> duplicate
if spectra.shape[0] == 1:
spectra = np.tile(spectra,[2,1,1])
spec_keys = list(spec_keys)
rnd_key = spec_keys[np.random.randint(0,len(spec_keys))]
spectra = spectra_npz.item()[rnd_key]
else:
#Create features via STFT if no file exists
spectra = self.create_feature(sample)
#Check for None type and shape
if np.any(spectra) == None or spectra.shape[-1] != stft_len:
continue
#Get false sample
rnd_false_sample = random.choice(self.false_sample_paths)
if self.preload_false_samples:
false_spectra_npz = self.preloaded_false_samples[rnd_false_sample]
else:
with np.load(rnd_false_sample,allow_pickle=True) as sample_file:
false_spectra_npz = sample_file.f.arr_0
#If false only mono --> duplicate
if false_spectra.shape[0] == 1:
false_spectra = np.tile(false_spectra,[2,1,1])
false_spec_keys = false_spectra_npz.item().keys()
false_spec_keys = list(false_spec_keys)
false_rnd_key = false_spec_keys[np.random.randint(0,len(false_spec_keys))]
false_spectra = false_spectra_npz.item()[false_rnd_key]
#If only mono --> duplicate
if spectra.shape[0] == 1:
spectra = np.tile(spectra,[2,1,1])
#Transpose spectrogramms for "channels_last"
spectra = tf.transpose(spectra,perm=[1,2,0])
false_spectra = tf.transpose(false_spectra,perm=[1,2,0])
yield {'input_features':spectra,
'labels':tf.one_hot(sample['bird_id'],self.dataset.n_classes+1),
'false_sample':false_spectra}
except:
continue
#If false only mono --> duplicate
if false_spectra.shape[0] == 1:
false_spectra = np.tile(false_spectra,[2,1,1])
#Transpose spectrogramms for "channels_last"
spectra = tf.transpose(spectra,perm=[1,2,0])
false_spectra = tf.transpose(false_spectra,perm=[1,2,0])
yield {'input_features':spectra,
'labels':tf.one_hot(sample['bird_id'],self.dataset.n_classes+1),
'false_sample':false_spectra}
# except:
# continue
if __name__ == "__main__":
ds = Dataset("/srv/TUG/datasets/cornell_birdcall_recognition")
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment