Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Alexander Fuchs
tensorflow2_trainer_template
Commits
76d8abc3
Commit
76d8abc3
authored
Aug 05, 2020
by
Alexander Fuchs
Browse files
Added pre-loading of false samples
parent
a1536602
Changes
1
Hide whitespace changes
Inline
Side-by-side
src/utils/data_loader.py
View file @
76d8abc3
...
...
@@ -10,6 +10,7 @@ import sys
import
random
import
copy
import
tensorflow
as
tf
warnings
.
filterwarnings
(
'ignore'
)
class
Dataset
(
object
):
...
...
@@ -132,9 +133,10 @@ class Dataset(object):
class
DataGenerator
(
object
):
def
__init__
(
self
,
dataset
,
augmentation
,
shuffle
=
True
,
is_training
=
True
,
force_feature_recalc
=
False
,
shuffle
=
True
,
is_training
=
True
,
force_feature_recalc
=
False
,
preload_false_samples
=
True
,
max_time
=
5
,
max_samples_per_audio
=
6
,
n_fft
=
2048
,
...
...
@@ -149,11 +151,28 @@ class DataGenerator(object):
self
.
is_training
=
is_training
self
.
sampling_rate
=
sampling_rate
self
.
n_fft
=
n_fft
self
.
preload_false_samples
=
preload_false_samples
self
.
hop_length
=
hop_length
self
.
max_time
=
max_time
self
.
max_samples_per_audio
=
max_samples_per_audio
self
.
force_feature_recalc
=
force_feature_recalc
#Get paths of false samples
false_samples_mono
=
glob
.
glob
(
self
.
dataset
.
false_audio_path
+
"/mono/*.npz"
,
recursive
=
True
)
false_samples_stereo
=
glob
.
glob
(
self
.
dataset
.
false_audio_path
+
"/stereo/*.npz"
,
recursive
=
True
)
self
.
false_sample_paths
=
false_samples_mono
+
false_samples_stereo
#Pre load false samples
if
self
.
preload_false_samples
:
self
.
preloaded_false_samples
=
{}
for
path
in
self
.
false_sample_paths
:
with
np
.
load
(
path
,
allow_pickle
=
True
)
as
sample_file
:
self
.
preloaded_false_samples
[
path
]
=
copy
.
deepcopy
(
sample_file
.
f
.
arr_0
)
print
(
"Finished pre-loading"
)
def
do_stft
(
self
,
y
,
channels
):
spectra
=
[]
#STFT for all channels
...
...
@@ -302,63 +321,64 @@ class DataGenerator(object):
samples
=
self
.
dataset
.
train_samples
else
:
samples
=
self
.
dataset
.
test_samples
#Get paths of false samples
false_samples_mono
=
glob
.
glob
(
self
.
dataset
.
false_audio_path
+
"/mono/*.npz"
,
recursive
=
True
)
false_samples_stereo
=
glob
.
glob
(
self
.
dataset
.
false_audio_path
+
"/stereo/*.npz"
,
recursive
=
True
)
false_samples
=
false_samples_mono
+
false_samples_stereo
stft_len
=
int
(
np
.
ceil
(
self
.
max_time
*
self
.
sampling_rate
/
self
.
hop_length
))
for
sample
in
samples
:
try
:
filename
=
sample
[
'filename'
]
#If feature was already created load from file
if
os
.
path
.
isfile
(
filename
.
replace
(
"mp3"
,
"npz"
))
and
not
(
self
.
force_feature_recalc
):
spectra_npz
=
np
.
load
(
filename
.
replace
(
"mp3"
,
"npz"
),
allow_pickle
=
True
)
spec_keys
=
spectra_npz
.
f
.
arr_0
.
item
().
keys
()
spec_keys
=
list
(
spec_keys
)
rnd_key
=
spec_keys
[
np
.
random
.
randint
(
0
,
len
(
spec_keys
))]
spectra
=
spectra_npz
.
f
.
arr_0
.
item
()[
rnd_key
]
else
:
#Create features via STFT if no file exists
spectra
=
self
.
create_feature
(
sample
)
#Check for None type and shape
if
np
.
any
(
spectra
)
==
None
or
spectra
.
shape
[
-
1
]
!=
stft_len
:
filename
=
sample
[
'filename'
]
#If feature was already created load from file
if
os
.
path
.
isfile
(
filename
.
replace
(
"mp3"
,
"npz"
))
and
not
(
self
.
force_feature_recalc
):
with
np
.
load
(
filename
.
replace
(
"mp3"
,
"npz"
),
allow_pickle
=
True
)
as
sample_file
:
spectra_npz
=
sample_file
.
f
.
arr_0
try
:
spec_keys
=
spectra_npz
.
item
().
keys
()
except
:
continue
#Get false sample
rnd_false_sample
=
random
.
choice
(
false_samples
)
false_spectra_npz
=
np
.
load
(
rnd_false_sample
,
allow_pickle
=
True
)
false_spec_keys
=
false_spectra_npz
.
f
.
arr_0
.
item
().
keys
()
false_spec_keys
=
list
(
false_spec_keys
)
false_rnd_key
=
false_spec_keys
[
np
.
random
.
randint
(
0
,
len
(
false_spec_keys
))]
false_spectra
=
false_spectra_npz
.
f
.
arr_0
.
item
()[
false_rnd_key
]
#If only mono --> duplicate
if
spectra
.
shape
[
0
]
==
1
:
spectra
=
np
.
tile
(
spectra
,[
2
,
1
,
1
])
spec_keys
=
list
(
spec_keys
)
rnd_key
=
spec_keys
[
np
.
random
.
randint
(
0
,
len
(
spec_keys
))]
spectra
=
spectra_npz
.
item
()[
rnd_key
]
else
:
#Create features via STFT if no file exists
spectra
=
self
.
create_feature
(
sample
)
#Check for None type and shape
if
np
.
any
(
spectra
)
==
None
or
spectra
.
shape
[
-
1
]
!=
stft_len
:
continue
#Get false sample
rnd_false_sample
=
random
.
choice
(
self
.
false_sample_paths
)
if
self
.
preload_false_samples
:
false_spectra_npz
=
self
.
preloaded_false_samples
[
rnd_false_sample
]
else
:
with
np
.
load
(
rnd_false_sample
,
allow_pickle
=
True
)
as
sample_file
:
false_spectra_npz
=
sample_file
.
f
.
arr_0
#If false only mono --> duplicate
if
false_spectra
.
shape
[
0
]
==
1
:
false_spectra
=
np
.
tile
(
false_spectra
,[
2
,
1
,
1
])
false_spec_keys
=
false_spectra_npz
.
item
().
keys
()
false_spec_keys
=
list
(
false_spec_keys
)
false_rnd_key
=
false_spec_keys
[
np
.
random
.
randint
(
0
,
len
(
false_spec_keys
))]
false_spectra
=
false_spectra_npz
.
item
()[
false_rnd_key
]
#If only mono --> duplicate
if
spectra
.
shape
[
0
]
==
1
:
spectra
=
np
.
tile
(
spectra
,[
2
,
1
,
1
])
#Transpose spectrogramms for "channels_last"
spectra
=
tf
.
transpose
(
spectra
,
perm
=
[
1
,
2
,
0
])
false_spectra
=
tf
.
transpose
(
false_spectra
,
perm
=
[
1
,
2
,
0
])
yield
{
'input_features'
:
spectra
,
'labels'
:
tf
.
one_hot
(
sample
[
'bird_id'
],
self
.
dataset
.
n_classes
+
1
),
'false_sample'
:
false_spectra
}
except
:
continue
#If false only mono --> duplicate
if
false_spectra
.
shape
[
0
]
==
1
:
false_spectra
=
np
.
tile
(
false_spectra
,[
2
,
1
,
1
])
#Transpose spectrogramms for "channels_last"
spectra
=
tf
.
transpose
(
spectra
,
perm
=
[
1
,
2
,
0
])
false_spectra
=
tf
.
transpose
(
false_spectra
,
perm
=
[
1
,
2
,
0
])
yield
{
'input_features'
:
spectra
,
'labels'
:
tf
.
one_hot
(
sample
[
'bird_id'
],
self
.
dataset
.
n_classes
+
1
),
'false_sample'
:
false_spectra
}
# except:
# continue
if
__name__
==
"__main__"
:
ds
=
Dataset
(
"/srv/TUG/datasets/cornell_birdcall_recognition"
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment