Generative Adversarial Networks (GANs) are powerful for generating realistic data, but training them can be unstable due to issues like vanishing gradients, mode collapse, and sensitivity to hyperparameters. Wasserstein GAN (WGAN) addresses these issues by introducing a new loss function based on the Earth Mover’s (Wasserstein) distance, leading to more stable training.
1. Challenges in Standard GAN Training
GANs consist of two networks:
- Generator (G): Generates fake samples.
- Discriminator (D): Distinguishes real samples from fake ones.
Key Challenges:
- Vanishing Gradients:
- When the discriminator becomes too strong, gradients for the generator diminish, slowing or halting training.
- Mode Collapse:
- The generator produces limited variations of data, reducing diversity.
- Oscillations:
- The adversarial nature of GANs often leads to unstable training dynamics.
- Poor Metric for Convergence:
- The Jensen-Shannon (JS) divergence used in standard GANs provides no meaningful feedback when the supports of the real and fake data distributions do not overlap.
2. What is Wasserstein GAN (WGAN)?
WGAN replaces the original GAN loss function with a new one based on the Wasserstein-1 distance (Earth Mover’s distance). This metric provides a meaningful measure of how two distributions differ and improves training stability.
Key Differences from Standard GANs:
- Loss Function:
- Standard GANs minimize the JS divergence.
- WGAN minimizes the Wasserstein distance, which is smoother and provides more informative gradients.
- Discriminator → Critic:
- WGAN uses a “critic” instead of a “discriminator” because it outputs a scalar score instead of a probability.
- Weight Clipping:
- Enforces the Lipschitz constraint (required for the Wasserstein distance) by clipping weights of the critic to a small range (e.g., [-0.01, 0.01]).
3. WGAN Loss Function
For WGAN:
- Critic Loss:
[
L_C = \mathbb{E}{\mathbf{x} \sim P_r} [C(\mathbf{x})] – \mathbb{E}{\mathbf{\tilde{x}} \sim P_g} [C(\mathbf{\tilde{x}})]
]
Where ( P_r ) is the real data distribution, ( P_g ) is the generator’s distribution, and ( C(\cdot) ) is the critic. - Generator Loss:
[
L_G = -\mathbb{E}_{\mathbf{\tilde{x}} \sim P_g} [C(\mathbf{\tilde{x}})]
]
Here:
- The critic aims to maximize the difference between real and fake data scores.
- The generator tries to minimize the critic’s score for fake data.
4. Implementation of WGAN in PyTorch
a. Model Definitions
import torch
import torch.nn as nn
import torch.optim as optim
# Generator
class Generator(nn.Module):
def __init__(self, noise_dim, output_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(noise_dim, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, output_dim),
nn.Tanh()
)
def forward(self, z):
return self.model(z)
# Critic
class Critic(nn.Module):
def __init__(self, input_dim):
super(Critic, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 1)
)
def forward(self, x):
return self.model(x)
b. Training Loop
# Hyperparameters
noise_dim = 100
data_dim = 28 * 28 # For MNIST
batch_size = 64
n_critic = 5 # Number of critic updates per generator update
weight_clip = 0.01 # Weight clipping range
# Models
generator = Generator(noise_dim, data_dim)
critic = Critic(data_dim)
# Optimizers
lr = 5e-5
gen_optimizer = optim.RMSprop(generator.parameters(), lr=lr)
critic_optimizer = optim.RMSprop(critic.parameters(), lr=lr)
# Training loop
for epoch in range(epochs):
for real_data in dataloader: # Assume `dataloader` provides batches of real data
real_data = real_data.view(batch_size, -1).to(device)
# Critic training
for _ in range(n_critic):
noise = torch.randn(batch_size, noise_dim).to(device)
fake_data = generator(noise).detach()
critic_loss = -(torch.mean(critic(real_data)) - torch.mean(critic(fake_data)))
critic_optimizer.zero_grad()
critic_loss.backward()
critic_optimizer.step()
# Weight clipping
for p in critic.parameters():
p.data.clamp_(-weight_clip, weight_clip)
# Generator training
noise = torch.randn(batch_size, noise_dim).to(device)
fake_data = generator(noise)
generator_loss = -torch.mean(critic(fake_data))
gen_optimizer.zero_grad()
generator_loss.backward()
gen_optimizer.step()
print(f"Epoch [{epoch+1}/{epochs}] | Critic Loss: {critic_loss.item():.4f} | Generator Loss: {generator_loss.item():.4f}")
5. Improvements with Wasserstein GAN
- Stability:
- Replacing the JS divergence with the Wasserstein distance ensures smoother gradients, reducing training oscillations.
- Mitigates Mode Collapse:
- The critic focuses on learning the Wasserstein distance, which indirectly encourages the generator to capture more diverse data.
- Better Convergence Metric:
- The Wasserstein distance provides a meaningful metric for convergence. A smaller critic loss indicates better alignment of real and generated distributions.
6. Extensions of WGAN
a. WGAN-GP (Wasserstein GAN with Gradient Penalty)
Instead of weight clipping, WGAN-GP enforces the Lipschitz constraint using a gradient penalty:
[
L_{GP} = \lambda \cdot \mathbb{E}{\hat{x} \sim P{\hat{x}}} \left[ (|\nabla_{\hat{x}} C(\hat{x})|_2 – 1)^2 \right]
]
Where ( \hat{x} ) is a linear interpolation of real and fake samples.
Advantages:
- Eliminates issues caused by weight clipping (e.g., reduced capacity of the critic).
7. Best Practices for WGAN Training
- Use RMSProp or Adam:
- RMSProp or Adam optimizers with a small learning rate (e.g., ( 5e-5 )) work well with WGAN.
- Critic Updates:
- Train the critic more frequently than the generator (( n_{\text{critic}} > 1 )).
- Avoid Over-Clipping:
- Excessive weight clipping reduces the capacity of the critic. Consider WGAN-GP for improved performance.
- Batch Size:
- Use a sufficiently large batch size to improve gradient estimates.
8. Summary
Feature | Standard GAN | Wasserstein GAN |
---|---|---|
Loss Function | Jensen-Shannon Divergence | Wasserstein Distance |
Critic Output | Probability (0 to 1) | Real-valued scalar |
Stability | Unstable | Stable |
Gradient Issues | Vanishing Gradients | Smoother Gradients |
Mode Collapse | Frequent | Rare |
Conclusion
Wasserstein GAN introduces fundamental improvements to GAN training by replacing the traditional loss function with the Wasserstein distance, stabilizing training and improving convergence. For more robust results, consider using WGAN-GP, which avoids the limitations of weight clipping.