DL
[Generation] Generative Adversarial Network (GAN)
해서
2024. 8. 20. 19:41
0. 적대적 생성 신경망(Generative Adversarial Network, GAN)
- Generative(생성) , Adversarial(적대적), Network(신경망)
- 적대적으로 학습하는 신경망들로 구성되며, 생성 모델로써 활용한다.
- 생성모델 관점에서 VAE와 GAN의 차이
- VAE의 생성 방식: 데이터의 latent space를 가우시안 분포로 학습해 그 분포에서 샘플링하여 데이터를 생성하는 모델
- GAN의 생성 방식: 생성된 데이터와 실제 데이터를 판별하고 속이는 과정을 거치며 성성 모델을 개선
1. GAN의 구조
- 데이터를 생성하는 생성 모델 (Generator)와 데이터의 진위를 구별하는 판별 모델(Discriminator)로 구성된다.
- 생성 모델(Generator) : 랜덤 한 노이즈 벡터를 입력받아 가짜 데이터를 생성, 목표는 판별 모델이 진짜라고 믿을 만큼 실제 데이터와 유사한 데이터를 만드는 것
- 판별 모델(Discriminator) : 입력받은 데이터가 실제 데이터인지, 가짜 데이터인지 구별, 목표는 진짜 데이터와 가짜 데이터를 정확히 구별하는 것
- GAN 목적
- GAN은 생성모델의 분포를 통해 판별 모델이 예측을 반복적으로 업데이트하며 학습이 진행된다.
- 학습 과정에서 두 분포가 점점 더 겹쳐지면서, 판별자는 더 이상 두 분포를 완전히 구분할 수 없게 되고, 생성자는 실제 데이터와 유사한 데이터를 생성하게 된다.
- 검정색 점선: 실제 데이터 분포, 파란색 점선: 판별 모델의 결정경계, 초록색실선: 생성된 데이터의 분포
2. GAN의 목적 함수
- 판별 모델(Discriminator): 실제 데이터와 생성된 데이터를 정확히 구별해야 한다.
- logD(x) 부분: 실제 데이터를 입력받았을 때는 1을 리턴
- log(1-D(G(z))) 부분: 생성된 데이터를 입력 받았을 때 D(G(z))가 0 이 되도록 즉, 가짜 이미지를 받았을 때 가짜로 분류하도록 한다.
- 생성 모델(Generator): 실제와 유사한 데이터를 생성하여 판별자를 속여야 한다.
- log(1-D(G(z))) 부분: 생성 모델은 자신이 생성한 결과물이 판별모델로 하여금 1에 가깝게 높은 값을 할당하도록 한다. 이는 판별모델이 진짜 사진을 보고 1에 가깝게 최대로 만드는 것과 동치가 된다.
3. GAN의 학습 방식
- 판별 모델의 손실함수
- 생성 모델의 손실 함수
- GAN의 목적 함수는 실제로 잘 동작하지 않는다 -> D(G(z))의 기울기 문제
- 생성 모델이 더 훈련되어야 하는 구간에서 더 평평한 기울기를 가진다. -> 판별 모델의 기울기를 조정하여 학습이 잘되게 하자
4. 코드
- FashionMNIST 를 활용하여 DCGAN을 이용, 생성자와 판별자의 구현과 손실함수는 어떻게 구현됐는지 알아보자
- Generator
- DCGAN에서 latent vector를 받아 이미지를 생성하기 위한 생성자 구조이다.
- 본 코드에서는 latent 차원을 64로, kernel_size는 4를 이용(pytorch에서는 커널크기가 4일 때 크기를 2배로 만들 수 있다)
- 마지막에 tanh()사용이유: 사전에 데이터 범위를 [-1, 1]로 변환하였기에 일관된 범위를 유지하기 위해서이다.
class Generator(nn.Module):
def __init__(self,
latent_dim: int=64,
base_channels: int=64,
out_channels: int=1):
super().__init__()
self.layers = nn.Sequential(
# latent_dim x 1 x 1 -> base_channels*8 x 4 x 4
nn.ConvTranspose2d(latent_dim, base_channels*8, kernel_size=4, stride=1, bias=False),
nn.BatchNorm2d(base_channels*8), # 채널 단위로 정규화를 해줍니다
nn.ReLU(inplace=True),
# base_channels*8 x 4 x 4 -> base_channels*4 x 8 x 8
nn.ConvTranspose2d(base_channels*8, base_channels*4, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(base_channels*4),
nn.ReLU(inplace=True),
# base_channels*4 x 8 x 8 -> base_channels*2 x 16 x 16
nn.ConvTranspose2d(base_channels*4, base_channels*2, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(base_channels*2),
nn.ReLU(inplace=True),
# base_channels*2 x 16 x 16 -> base_channels x 32 x 32
nn.ConvTranspose2d(base_channels*2, base_channels, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(base_channels),
nn.ReLU(inplace=True),
# base_channels x 32 x 32 -> output_channels x 64 x 64
nn.ConvTranspose2d(base_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
nn.Tanh() # 최종 출력 값의 범위 [-1, 1]
)
def forward(self, z):
return self.layers(z)
- Discriminator
- 판별자는 이미지를 입력으로 받아 진짜인지 가짜인지 구별한다. Generator와 반대로 Conv를 이용하여 이미지 크기를 점차 줄인다.
- 마지막 layer에서는 fc_layer가 아닌 Conv2d를 이용하여 1x1x1 tensor를 만든다.
- 이를 입력으로 Sigmoide를 통과하여 0~1 사이 값으로 만든다.
- 1에 가까운 값: 진짜일 확률이 높음
- 0에 가까운 값: 가짜일 확률이 높음
class Discriminator(nn.Module):
def __init__(self, in_channels: int=1, base_channels: int=64):
super().__init__()
self.layers = nn.Sequential(
# input_channels x 64 x 64
nn.Conv2d(in_channels, base_channels, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# base_channels x 32 x 32
nn.Conv2d(base_channels, base_channels*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(base_channels*2),
nn.LeakyReLU(0.2, inplace=True),
# base_channels * 2 x 16 x 16
nn.Conv2d(base_channels*2, base_channels*4, 4, 2, 1, bias=False),
nn.BatchNorm2d(base_channels*4),
nn.LeakyReLU(0.2, inplace=True),
# base_channels * 4 x 8 x 8
nn.Conv2d(base_channels*4, base_channels*8, 4, 2, 1, bias=False),
nn.BatchNorm2d(base_channels*8),
nn.LeakyReLU(0.2, inplace=True),
# 크기: base_channels * 8 x 4 x 4
nn.Conv2d(base_channels*8, 1, 4, 1, 0, bias=False),
nn.Sigmoid() # 판별자 출력값을 [0, 1] 사이로 제한
)
def forward(self, x):
return layers.model(x)
- DCGAN 학습
- GAN의 손실 함수
- 이진 분류에서의 크로스 엔트로피 (Binary Cross Entropy) 손실함수 (데이터 x, 레이블 y에 대해)
- 이를 활용하여 GAN의 손실함수 구현을 할 수 있다.
- Discriminator의 손실 함수
- BCE에서 첫번째항 ylogx에서 x부분을 D(X)로 계산하고 이는 label이 1인 경우이다. 즉, real image를 받았을 땐 1로 label 분류한다.
- BCE에서 두번째 항 (1-y) log(1-x)에 대해 x부분은 fake image를 받아 D(G(z))를 계산하고 이는 레이블이 0인 경우로 분류한다.
- Generator의 손실 함수
- log(1-x)에 대해 기울기 소실문제로 이를 개선한 -logD(G(z))를 이용한다. 이는 fake image가 식별자로 하여금 real image처럼 판별되면 좋겠기 때문에 BCE loss에서 label이 1인 경우로 계산하는 것과 같다.
- torch.detach()
- 판별자(Discriminator) 학습과정에서 생성자(Generator)의 가중치 업데이트를 방지하기 위해 사용된다.
- 이유: 판별자는 가짜이미지를 0으로 만들어야 하는 기울기를 계산하는데, 생성자에 반영된다면 생성자입장에서 내가만든 이미지가 fake_image로 분류됐으면 좋겠어라는 결과를 가져오기 때문이다.
- GAN 학습과정
- 1. Generator: latent 벡터(random noise)를 입력받아 fake_image 생성
- 2. Discriminator: real_image와 fake_image를 입력받아 진짜인지 가짜인지 판별
- 3. Discriminator 학습: real_image를 진짜로, fake_image를 가짜로 구분하도록 학습
- 4. Generator 학습: fake_image를 Discriminator가 진짜라고 속일 수 있도록 학습
for epoch in range(num_epochs):
##
for idx, data in enumerate(train_loader):
real_image = data[0].to(device)
num_data = real_image.size(0)
# latent vector 샘플링, 이미지 생성
sampling_z = torch.randn((num_data, latent_dim, 1, 1), device=device)
fake_image = generator(sampling_z)
# BCELoss 레이블 생성 (real_image = 1, fake_image = 0)
real_label = torch.ones(num_data, device=device)
fake_label = torch.zeros(num_data, device=device)
# Discriminator 학습
real_prob = discriminator(real_image).view(-1) # D(x) 계산
fake_prob = discriminator(fake_image.detach()).view(-1) # D(G(z)) 계산
d_optim.zero_grad()
# Discriminator 손실 함수 계산
d_loss = criterion(real_prob, real_label) + criterion(fake_prob, fake_label)
d_loss.backward()
d_optim.step()
# Generator 학습
fake_prob = discriminator(fake_image).view(-1) # D(G(z)) 계산
g_optim.zero_grad()
# Generator 손실 함수 계산
g_loss = criterion(fake_prob, real_label)
g_loss.backward()
g_optim.step()
- 학습 결과
- 생성 모델 평가 : Inception Score
- Inception Score : 5.958016395568848
- Inception Score의 만점은 총 class 개수와 비슷하다고 하니 10점으로 봤을 때, 준수한 성적을 얻었다고 볼 수 있다.