Wasserstein GAN w\ Gradient Penalty
이번 페이지에서는 기존 Wasserstein GAN 에서 Exploding and vanishing gradients 문제를 해결한 WGAN-GP 대해서 구현합니다.
Paper
논문에서 제시하는 수정된 최적화 알고리즘은 다음과 같습니다.
기존 WGAN 에서 사용한 학습 코드를 수도코드로 작성하면 다음과 같습니다.
for epoch in range(n_epochs):
for step, batch in enumerate(loader):
#
# 1. get discriminator loss from real data
#
...
#
# 2. get discriminator loss from fake data
#
...
#
# 3. get discriminator loss and update discriminator
#
...
#
# 4. clip weight
#
...
if step % n_critic == 0: # in paper suggested n_critic is 5
#
# 5. get generator loss and update generator
#
...
WGAN-GP 알고리즘을 Python 수도 코드로 작성하면 다음과 같습니다.
for epoch in range(n_epochs):
for step, batch in enumerate(loader):
#
# 1. get discriminator loss from real data
#
...
#
# 2. get discriminator loss from fake data
#
...
#
# 3. calculate penalty
#
...
#
# 4. get discriminator loss and update discriminator
#
...
if step % n_critic == 0: # in paper suggested n_critic is 5
#
# 5. get generator loss and update generator
#
...
cliping 을 해주던 WGAN 과 다르게 gradient 의 패널티를 계산하는 부분이 추가되었습니다.
for epoch in range(n_epochs):
for batch in loader:
for n in range(n_critic): # in paper suggested n_critic is 5
...
#
# 3. calculate penalty
#
...
Dataset
우선 튜토리얼에 들어가기에 앞서 사용할 데이터셋을 선언합니다. 데이터셋에 대한 자세한 설명은 CelebA 페이지에서 확인할 수 있습니다.
우선 배치 사이즈를 논문에서 제시하는 mini batch 숫자와 맞추기 위해 64로 설정합니다.
또한 배치 사이즈가 다를 경우 의도한 대로 학습이 안 될 수 있기 때문에 64가 안될 수 있는 마지막 배치는 사용하지 않도록 drop_last=True
로 선언합니다.
import torchvision.transforms as T
import numpy as np
import matplotlib.pyplot as plt
from torchvision.datasets.celeba import CelebA
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
transform = T.Compose(
[
T.Resize(64),
T.CenterCrop(64),
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
dataset = CelebA(
"./datasets", download=True, transform=transform
)
loader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)
Model
Generator
WGAN 에서는 생성기는 DCGAN 에서 사용한 모델을 그대로 사용합니다. 생성기에 대한 자세한 설명은 DCGAN 페이지에서 확인할 수 있습니다.
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, num_channel=3, latent_dim=100, feature_dim=64):
super().__init__()
self.layer_1 = nn.Sequential(
nn.ConvTranspose2d(latent_dim, feature_dim * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(feature_dim * 8),
nn.ReLU(True),
)
self.layer_2 = nn.Sequential(
nn.ConvTranspose2d(feature_dim * 8, feature_dim * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_dim * 4),
nn.ReLU(True),
)
self.layer_3 = nn.Sequential(
nn.ConvTranspose2d(feature_dim * 4, feature_dim * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_dim * 2),
nn.ReLU(True),
)
self.layer_4 = nn.Sequential(
nn.ConvTranspose2d(feature_dim * 2, feature_dim, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_dim),
nn.ReLU(True),
)
self.last_layer = nn.Sequential(
nn.ConvTranspose2d(feature_dim, num_channel, 4, 2, 1, bias=False),
nn.Tanh(),
)
def forward(self, z):
# decoding
layer_1_out = self.layer_1(z) # (N, 512, 4, 4)
layer_2_out = self.layer_2(layer_1_out) # (N, 256, 8, 8)
layer_3_out = self.layer_3(layer_2_out) # (N, 128, 16, 16)
layer_4_out = self.layer_4(layer_3_out) # (N, 64, 32, 32)
# transform to rgb
out = self.last_layer(layer_4_out) # (N, 3, 64, 64)
return out
Discriminator
분류기에 대한 부분을 작성합니다.
class Discriminator(nn.Module):
def __init__(self, num_channel=3, feature_dim=64):
super().__init__()
self.layer_1 = nn.Sequential(
nn.Conv2d(num_channel, feature_dim, 4, 2, 1, bias=False),
nn.InstanceNorm2d(feature_dim, affine=True),
nn.LeakyReLU(0.2, inplace=True),
)
self.layer_2 = nn.Sequential(
nn.Conv2d(feature_dim, feature_dim * 2, 4, 2, 1, bias=False),
nn.InstanceNorm2d(feature_dim * 2, affine=True),
nn.LeakyReLU(0.2, inplace=True),
)
self.layer_3 = nn.Sequential(
nn.Conv2d(feature_dim * 2, feature_dim * 4, 4, 2, 1, bias=False),
nn.InstanceNorm2d(feature_dim * 4, affine=True),
nn.LeakyReLU(0.2, inplace=True),
)
self.layer_4 = nn.Sequential(
nn.Conv2d(feature_dim * 4, feature_dim * 8, 4, 2, 1, bias=False),
nn.InstanceNorm2d(feature_dim * 8, affine=True),
nn.LeakyReLU(0.2, inplace=True),
)
self.last_layer = nn.Sequential(
nn.Conv2d(feature_dim * 8, 1, 4, 1, 0, bias=False)
)
def forward(self, x):
# encoding
layer_1_out = self.layer_1(x) # (N, 64, 32, 32)
layer_2_out = self.layer_2(layer_1_out) # (N, 128, 16, 16)
layer_3_out = self.layer_3(layer_2_out) # (N, 256, 8, 8)
layer_4_out = self.layer_4(layer_3_out) # (N, 512, 4, 4)
# classify
out = self.last_layer(layer_4_out).squeeze() # (N)
return out
분류기에서는 더 이상 Bathcnorm 을 사용하지 않습니다.
No critic batch normalization
Most prior GAN implementations [22, 23, 2] use batch normalization in both the generator and the discriminator to help stabilize training, but batch normalization changes the form of the discriminator’s problem from mapping a single input to a single output to mapping from an entire batch of inputs to a batch of outputs [23]. Our penalized training objective is no longer valid in this setting, since we penalize the norm of the critic’s gradient with respect to each input independently, and not the entire batch.
- To resolve this, we simply omit batch normalization in the critic in our models, finding that they perform well without it. Our method works with normalization schemes which don’t introduce correlations between examples.
- In particular, werecommend layer normalization [3] as a drop-in replacement for batch normalization.
- 추후 설명할 학습 과정의 gradient penalty 부분이 이미 정규화 부분을 처리함으로 BatchnNorm 을 사용하지 않습니다.
- 대신 Layer normalization 을 제시하지만, pytorch 에서 적절한 부분이 없어서
InstanceNorm2d
로 대체합니다.
또한 WGAN 과 마찬가지로 마지막 layer 에서 sigmoid 를 사용하지 않습니다.
WGAN-GP Train
이제 위에서 작성한 수도 코드의 내용을 채워서 학습을 진행해 보겠습니다.
Weight Initialization
DCGAN 에서 적용했던 내용을 같이 설정합니다.
파라미터 초기화를 위한 함수를 작성합니다.
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
각 모델에 적용합니다.
_ = generator.apply(weights_init)
_ = discriminator.apply(weights_init)
Convolution Network 는 원활한 학습을 위해서는 gpu 가 필요합니다. GPU 가 없는 경우 학습에 다소 시간이 소요될 수 있습니다.
아래 코드를 이용해 device 를 선언합니다.
만약 gpu 가 사용 가능한 경우 device(type='cuda')
메세지가 나옵니다.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')
위에서 선언한 모델을 gpu 메모리로 옮기겠습니다.
_ = discriminator.to(device)
_ = generator.to(device)
Logger
학습 과정을 tensorboard 에 저장하기 위한 writer 입니다.
from torch.utils.tensorboard import SummaryWriter
# tensorboard logger
writer = SummaryWriter()
Loss
위의 Discriminator 에서 설명한 것 과 같이 WGAN 에서는 loss term 을 분류기의 거리의 평균을 이용하기 때문에 따로 선언하지 않습니다.
Optimizer
논문에서는 제시하는 Adam 을 사용하며 learning_rate는 제시된 0.0001, beta는 0.5, 0.9 를 사용합니다.
import torch.optim as optim
# optimizer
discriminator_opt = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.9))
generator_opt = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.9))
Gradient Penalty
이번에는 WGAN-GP 에서 GP 를 뜻하는 gradient penalty 를 계산하는 부분을 작성해야 합니다.
논문에서는 해당 term 을 계산하기 위해서 다음과 같이 제시합니다.