gan_01.py 3.27 KB
Newer Older
1
2
3
4
5
6
7
8
9
import matplotlib.pyplot as plt
import numpy as np

from keras.layers import Dense, Flatten, Reshape
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential
from keras.datasets import fashion_mnist
from keras.optimizers import Adam

Nicolas PERNOUD's avatar
Nicolas PERNOUD committed
10
11
12
13
14
15
16
17
z_dim = 100

img_lines = 28
img_columns = 28
img_channels = 1

img_shape = (img_lines, img_columns, img_channels)

18
19
20
21
22

def build_generator(img_shape, z_dim):
    model = Sequential()
    model.add(Dense(128, input_dim=z_dim))
    model.add(LeakyReLU(alpha=0.01))
Nicolas PERNOUD's avatar
Nicolas PERNOUD committed
23
    model.add(Dense(img_lines*img_columns*img_channels, activation='tanh'))
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    model.add(Reshape(img_shape))
    return model


def build_discriminator(img_shape):
    model = Sequential()
    model.add(Flatten(input_shape=img_shape))
    model.add(Dense(128))
    model.add(LeakyReLU(alpha=0.01))
    model.add(Dense(1, activation='sigmoid'))
    return model


def build_gan(generator, discriminator):
    model = Sequential()
    model.add(generator)
    model.add(discriminator)
    discriminator.trainable = False
Nicolas Pernoud's avatar
Nicolas Pernoud committed
42
    model.compile(loss='binary_crossentropy', optimizer=Adam())
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    return model


discriminator = build_discriminator(img_shape)
discriminator.compile(loss='binary_crossentropy',
                      optimizer=Adam(), metrics=['accuracy'])

generator = build_generator(img_shape, z_dim)
gan = build_gan(generator, discriminator)


def sample_images(generator, iter, img_per_l=4, img_per_c=4):
    z = np.random.normal(0, 1, (img_per_l*img_per_c, z_dim))
    img_gen = generator.predict(z)
    img_gen = 0.5*img_gen+0.5
    _, ax = plt.subplots(img_per_l, img_per_c, figsize=(
        4, 4), sharey=True, sharex=True)
    cpt = 0
    for i in range(img_per_l):
        for j in range(img_per_c):
            ax[i, j].imshow(img_gen[cpt, :, :, 0], cmap='gray')
            ax[i, j].axis('off')
            cpt += 1
    plt.savefig("test_"+f'{iter:05d}'+".png", dpi=150)


def train(iterations, batch_size, sample_interval):
    losses = []
    accuracies = []
    iteration_checkpoints = []
    ((X_train, _), (_, _)) = fashion_mnist.load_data()
    X_train = X_train/127.5-1
    X_train = np.expand_dims(X_train, axis=3)
    real = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    for iteration in range(iterations):
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs = X_train[idx]
        z = np.random.normal(0, 1, (batch_size, z_dim))
        gen_imgs = generator.predict(z)
        d_loss_real = discriminator.train_on_batch(imgs, real)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_loss, accuracy = 0.5*np.add(d_loss_real, d_loss_fake)
        z = np.random.normal(0, 1, (batch_size, z_dim))
        gen_imgs = generator.predict(z)
        g_loss = gan.train_on_batch(z, real)
        if (iteration+1) % sample_interval == 0 or iteration == 0:
            losses.append((d_loss, g_loss))
            accuracies.append((100*accuracy))
            iteration_checkpoints.append(iteration+1)
Nicolas PERNOUD's avatar
Nicolas PERNOUD committed
93
            status = 'iteration: {:} [D loss: {:}, acc.: {:2.2%}] [G loss: {:}]'.format(
94
95
96
97
98
99
100
101
102
                iteration+1, d_loss, accuracy, g_loss)
            print(status)
            sample_images(generator, iteration+1)


iterations = 20000
batch_size = 128
sample_interval = 1000
train(iterations, batch_size, sample_interval)