Commit 33712b68 authored by Nicolas PERNOUD's avatar Nicolas PERNOUD
Browse files

chore: refactored gan_01.py

parent 58b55ea5
...@@ -7,12 +7,20 @@ from keras.models import Sequential ...@@ -7,12 +7,20 @@ from keras.models import Sequential
from keras.datasets import fashion_mnist from keras.datasets import fashion_mnist
from keras.optimizers import Adam from keras.optimizers import Adam
z_dim = 100
img_lines = 28
img_columns = 28
img_channels = 1
img_shape = (img_lines, img_columns, img_channels)
def build_generator(img_shape, z_dim): def build_generator(img_shape, z_dim):
model = Sequential() model = Sequential()
model.add(Dense(128, input_dim=z_dim)) model.add(Dense(128, input_dim=z_dim))
model.add(LeakyReLU(alpha=0.01)) model.add(LeakyReLU(alpha=0.01))
model.add(Dense(28*28*1, activation='tanh')) model.add(Dense(img_lines*img_columns*img_channels, activation='tanh'))
model.add(Reshape(img_shape)) model.add(Reshape(img_shape))
return model return model
...@@ -34,14 +42,6 @@ def build_gan(generator, discriminator): ...@@ -34,14 +42,6 @@ def build_gan(generator, discriminator):
return model return model
z_dim = 100
img_lines = 28
img_columns = 28
img_channels = 1
img_shape = (img_lines, img_columns, img_channels)
discriminator = build_discriminator(img_shape) discriminator = build_discriminator(img_shape)
discriminator.compile(loss='binary_crossentropy', discriminator.compile(loss='binary_crossentropy',
optimizer=Adam(), metrics=['accuracy']) optimizer=Adam(), metrics=['accuracy'])
...@@ -91,7 +91,7 @@ def train(iterations, batch_size, sample_interval): ...@@ -91,7 +91,7 @@ def train(iterations, batch_size, sample_interval):
losses.append((d_loss, g_loss)) losses.append((d_loss, g_loss))
accuracies.append((100*accuracy)) accuracies.append((100*accuracy))
iteration_checkpoints.append(iteration+1) iteration_checkpoints.append(iteration+1)
status = 'iteration: {:} [D loss:{:}, acc.:{:2.2%}] [G loss: {:}]'.format( status = 'iteration: {:} [D loss: {:}, acc.: {:2.2%}] [G loss: {:}]'.format(
iteration+1, d_loss, accuracy, g_loss) iteration+1, d_loss, accuracy, g_loss)
print(status) print(status)
sample_images(generator, iteration+1) sample_images(generator, iteration+1)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment