Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Alexander Fuchs
tensorflow2_trainer_template
Commits
db96e7cd
Commit
db96e7cd
authored
Aug 06, 2020
by
Alexander Fuchs
Browse files
Fixed normalization using augment_fn
parent
c09d16bf
Changes
3
Hide whitespace changes
Inline
Side-by-side
src/scripts/birdsong_simple_main.py
View file @
db96e7cd
...
...
@@ -18,48 +18,29 @@ BINS = 1025
N_FRAMES
=
216
N_CHANNELS
=
2
def
preprocess_input
(
sample
,
n_classes
,
is_training
):
def
augment_input
(
sample
,
n_classes
,
training
):
"""Preprocess a single image of layout [height, width, depth]."""
input_features
=
sample
[
'input_features'
]
labels
=
sample
[
'labels'
]
false_sample
=
sample
[
'false_sample'
]
batch_size
=
tf
.
shape
(
labels
)[
0
]
if
is_training
:
rnd
=
tf
.
random
.
uniform
([
batch_size
])
rnd_tiled_feat
=
tf
.
tile
(
tf
.
reshape
(
rnd
,[
batch_size
,
1
,
1
,
1
]),[
1
,
BINS
,
if
training
:
rnd
=
tf
.
random
.
uniform
([
1
])
rnd_tiled_feat
=
tf
.
tile
(
tf
.
reshape
(
rnd
,[
1
,
1
,
1
]),[
BINS
,
N_FRAMES
,
N_CHANNELS
])
rnd_tiled_lbl
=
tf
.
tile
(
tf
.
reshape
(
rnd
,[
batch_size
,
1
]),[
1
,
n_classes
+
1
])
false_lbl
=
tf
.
cast
(
tf
.
one_hot
(
n_classes
+
1
,
n_classes
+
1
),
tf
.
int32
)
false_lbl
=
tf
.
tile
(
tf
.
reshape
(
false_lbl
,[
1
,
n_classes
+
1
]),[
batch_size
,
1
])
input_features
=
tf
.
where
(
rnd_tiled_feat
>
1
/
n_classes
,
input_features
,
false_sample
)
labels
=
tf
.
where
(
rnd_tiled_lbl
>
1
/
n_classes
,
labels
,
false_lbl
)
return
{
'input_features'
:
input_features
,
'labels'
:
labels
,
'false_sample'
:
false_sample
}
rnd_tiled_lbl
=
tf
.
tile
(
tf
.
reshape
(
rnd
,[
1
]),[
n_classes
+
1
])
def
augment_input
(
sample
,
n_classes
):
"""Preprocess a single image of layout [height, width, depth]."""
input_features
=
sample
[
'input_features'
]
labels
=
sample
[
'labels'
]
false_sample
=
sample
[
'false_sample'
]
false_lbl
=
tf
.
cast
(
tf
.
one_hot
(
n_classes
+
1
,
n_classes
+
1
),
labels
.
dtype
)
rnd
=
tf
.
random
.
uniform
([
1
])
rnd_tiled_feat
=
tf
.
tile
(
tf
.
reshape
(
rnd
,[
1
,
1
,
1
]),[
BINS
,
N_FRAMES
,
N_CHANNELS
])
rnd_tiled_lbl
=
tf
.
tile
(
tf
.
reshape
(
rnd
,[
1
]),[
n_classes
+
1
])
false_lbl
=
tf
.
cast
(
tf
.
one_hot
(
n_classes
+
1
,
n_classes
+
1
),
labels
.
dtype
)
input_features
=
tf
.
where
(
rnd_tiled_feat
>
1
/
n_classes
,
input_features
,
false_sample
)
labels
=
tf
.
where
(
rnd_tiled_lbl
>
1
/
n_classes
,
labels
,
false_lbl
)
input_features
=
tf
.
where
(
rnd_tiled_feat
>
1
/
n_classes
,
input_features
,
false_sample
)
labels
=
tf
.
where
(
rnd_tiled_lbl
>
1
/
n_classes
,
labels
,
false_lbl
)
input_features
=
tf
.
image
.
per_image_standardization
(
input_features
)
false_sample
=
tf
.
image
.
per_image_standardization
(
false_sample
)
return
{
'input_features'
:
input_features
,
'labels'
:
labels
,
'false_sample'
:
false_sample
}
...
...
@@ -148,15 +129,15 @@ def main(argv):
ds_train
=
Dataset
(
data_dir
,
is_training_set
=
True
)
n_total
=
ds_train
.
n_samples
def
augment_fn
(
sample
):
return
augment_input
(
sample
,
ds_train
.
n_classes
)
def
augment_fn
(
sample
,
training
):
return
augment_input
(
sample
,
ds_train
.
n_classes
,
training
)
dg_train
=
DataGenerator
(
ds_train
,
augment_fn
,
training_percentage
=
training_percentage
,
preload_samples
=
preload_samples
,
is_training
=
True
)
dg_val
=
DataGenerator
(
ds_train
,
None
,
dg_val
=
DataGenerator
(
ds_train
,
augment_fn
,
training_percentage
=
training_percentage
,
is_validation
=
True
,
preload_samples
=
preload_samples
,
...
...
@@ -212,7 +193,7 @@ def main(argv):
discriminator
=
generator_model
,
base_learning_rate_classifier
=
lr
,
base_learning_rate_generator
=
lr
*
0.001
,
base_learning_rate_discriminator
=
lr
*
0.00
0
1
,
base_learning_rate_discriminator
=
lr
*
0.001
,
learning_rate_fn_classifier
=
learning_rate_fn
,
learning_rate_fn_generator
=
learning_rate_fn
,
learning_rate_fn_discriminator
=
learning_rate_fn
,
...
...
src/utils/data_loader.py
View file @
db96e7cd
...
...
@@ -405,7 +405,7 @@ class DataGenerator(object):
'labels'
:
tf
.
one_hot
(
sample
[
'bird_id'
],
self
.
dataset
.
n_classes
+
1
),
'false_sample'
:
false_spectra
}
if
self
.
augmentation
!=
None
:
yield
self
.
augmentation
(
sample
)
yield
self
.
augmentation
(
sample
,
self
.
is_training
)
else
:
yield
sample
...
...
src/utils/trainer.py
View file @
db96e7cd
...
...
@@ -186,8 +186,9 @@ class ModelTrainer():
axis
=-
1
,
))
gen_loss
=
tf
.
reduce_mean
((
fake_features
-
x
[
1
])
**
2
)
#Wasserstein losses
gen_loss
=
tf
.
reduce_mean
(
false
)
gen_loss
+
=
tf
.
reduce_mean
(
false
)
discr_loss
=
tf
.
reduce_mean
(
true
)
-
tf
.
reduce_mean
(
false
)
if
len
(
self
.
classifier
.
losses
)
>
0
:
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment