Commit 8ff8923d authored by Alexander Fuchs's avatar Alexander Fuchs
Browse files

Fixed discriminator and wrong model assignment

parent db96e7cd
......@@ -23,8 +23,8 @@ class Discriminator(tf.keras.Model):
for i in range(self.n_layers):
self.model_layers.append(SpectralNormalization(tf.keras.layers.Conv2D(self.n_channels[i],
kernel_size=strides[i]+1,
strides=strides[i],
kernel_size=self.strides[i]+1,
strides=self.strides[i],
padding="same",
use_bias=False,
name=self.name_op+'_conv_' + str(i),
......@@ -32,15 +32,15 @@ class Discriminator(tf.keras.Model):
kernel_initializer = self.kernel_initializer,
activation='relu')))
self.dense = SpectralNormalization(Dense(1,
self.dense = SpectralNormalization(tf.keras.layers.Dense(1,
kernel_initializer = self.kernel_initializer,
use_bias = False,
name = self.name_op+"_dense",
activation = None))
activation = tf.nn.sigmoid))
def call(self,input,training=False):
x = input
x = tf.image.per_image_standardization(input)
for layer in self.model_layers:
x = layer(x)
......
......@@ -42,7 +42,7 @@ class Generator(tf.keras.Model):
self.model_layers.append((tf.keras.layers.BatchNormalization(axis=-1),1))
self.model_layers.append((tf.keras.layers.Activation("relu"),0))
def call(self,input,training=False):
x = input
......
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import numpy as np
import librosa
from absl import logging
from absl import app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from utils.trainer import ModelTrainer
from utils.data_loader import Dataset
from utils.data_loader import DataGenerator
......@@ -12,7 +12,7 @@ from models.classifier import Classifier
from models.res_block import ResBlockBasicLayer
from models.discriminator import Discriminator
from models.generator import Generator
#logging.set_verbosity(logging.WARNING)
import tensorflow as tf # pylint: disable=g-bad-import-order
BINS = 1025
N_FRAMES = 216
......@@ -190,10 +190,10 @@ def main(argv):
epochs,
classifier = classifier_model,
generator = generator_model,
discriminator = generator_model,
discriminator = discriminator_model,
base_learning_rate_classifier = lr,
base_learning_rate_generator = lr*0.001,
base_learning_rate_discriminator = lr*0.001,
base_learning_rate_discriminator = lr*0.0005,
learning_rate_fn_classifier = learning_rate_fn,
learning_rate_fn_generator = learning_rate_fn,
learning_rate_fn_discriminator = learning_rate_fn,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment