Below is a simplified implementation of a GAN using TensorFlow and Keras to generate MNIST digits.
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define a basic transform
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.CIFAR10(root='./data',\
train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(train_dataset, \
batch_size=32, shuffle=True)
# Hyperparameters
latent_dim = 100
lr = 0.0002
beta1 = 0.5
beta2 = 0.999
num_epochs = 10
# Define the generator
class Generator(nn.Module):
def __init__(self, latent_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 128 * 8 * 8),
nn.ReLU(),
nn.Unflatten(1, (128, 8, 8)),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128, momentum=0.78),
nn.ReLU(),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64, momentum=0.78),
nn.ReLU(),
nn.Conv2d(64, 3, kernel_size=3, padding=1),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
return img
# Define the discriminator
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.Dropout(0.25),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ZeroPad2d((0, 1, 0, 1)),
nn.BatchNorm2d(64, momentum=0.82),
nn.LeakyReLU(0.25),
nn.Dropout(0.25),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128, momentum=0.82),
nn.LeakyReLU(0.2),
nn.Dropout(0.25),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256, momentum=0.8),
nn.LeakyReLU(0.25),
nn.Dropout(0.25),
nn.Flatten(),
nn.Linear(256 * 5 * 5, 1),
nn.Sigmoid()
)
def forward(self, img):
validity = self.model(img)
return validity
# Define the generator and discriminator
# Initialize generator and discriminator
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
# Loss function
adversarial_loss = nn.BCELoss()
# Optimizers
optimizer_G = optim.Adam(generator.parameters()\
, lr=lr, betas=(beta1, beta2))
optimizer_D = optim.Adam(discriminator.parameters()\
, lr=lr, betas=(beta1, beta2))
# Training loop
for epoch in range(num_epochs):
for i, batch in enumerate(dataloader):
# Convert list to tensor
real_images = batch[0].to(device)
# Adversarial ground truths
valid = torch.ones(real_images.size(0), 1, device=device)
fake = torch.zeros(real_images.size(0), 1, device=device)
# Configure input
real_images = real_images.to(device)
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Sample noise as generator input
z = torch.randn(real_images.size(0), latent_dim, device=device)
# Generate a batch of images
fake_images = generator(z)
# Measure discriminator's ability
# to classify real and fake images
real_loss = adversarial_loss(discriminator\
(real_images), valid)
fake_loss = adversarial_loss(discriminator\
(fake_images.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
# Backward pass and optimize
d_loss.backward()
optimizer_D.step()
# -----------------
# Train Generator
# -----------------
optimizer_G.zero_grad()
# Generate a batch of images
gen_images = generator(z)
# Adversarial loss
g_loss = adversarial_loss(discriminator(gen_images), valid)
# Backward pass and optimize
g_loss.backward()
optimizer_G.step()
# ---------------------
# Progress Monitoring
# ---------------------
if (i + 1) % 100 == 0:
print(
f"Epoch [{epoch+1}/{num_epochs}]\
Batch {i+1}/{len(dataloader)} "
f"Discriminator Loss: {d_loss.item():.4f} "
f"Generator Loss: {g_loss.item():.4f}"
)
# Save generated images for every epoch
if (epoch + 1) % 10 == 0:
with torch.no_grad():
z = torch.randn(16, latent_dim, device=device)
generated = generator(z).detach().cpu()
grid = torchvision.utils.make_grid(generated,\
nrow=4, normalize=True)
plt.imshow(np.transpose(grid, (1, 2, 0)))
plt.axis("off")
plt.show()
Explanation of the Code
Imports and Device Setup
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- torch and torchvision: Libraries for building and training neural networks.
- matplotlib.pyplot and numpy: Libraries for plotting and numerical operations.
- device: Sets the computation device to GPU if available, otherwise CPU.
Data Preparation
# Define a basic transform
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
- transform: Normalizes images to [-1, 1] range.
- train_dataset: Loads CIFAR-10 dataset.
- dataloader: Creates batches of training data.
Hyperparameters
latent_dim = 100
lr = 0.0002
beta1 = 0.5
beta2 = 0.999
num_epochs = 10
- latent_dim: Size of the random noise vector input to the generator.lr, beta1, beta2: Learning rate and beta values for the Adam optimizer.num_epochs: Number of training epochs.
Generator Model
class Generator(nn.Module):
def __init__(self, latent_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 128 * 8 * 8),
nn.ReLU(),
nn.Unflatten(1, (128, 8, 8)),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128, momentum=0.78),
nn.ReLU(),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64, momentum=0.78),
nn.ReLU(),
nn.Conv2d(64, 3, kernel_size=3, padding=1),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
return img
- Generator class: Defines the generator network.
- model: A sequence of layers that transform a random noise vector into an image.
- nn.Linear: Fully connected layer.
- nn.ReLU: Activation function.
- nn.Unflatten: Reshapes the tensor.
- nn.Upsample: Upsampling layers to increase image dimensions.
- nn.Conv2d: Convolutional layers.
- nn.BatchNorm2d: Batch normalization layers.
- nn.Tanh: Activation function for the output layer.
Discriminator Model
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.Dropout(0.25),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ZeroPad2d((0, 1, 0, 1)),
nn.BatchNorm2d(64, momentum=0.82),
nn.LeakyReLU(0.25),
nn.Dropout(0.25),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128, momentum=0.82),
nn.LeakyReLU(0.2),
nn.Dropout(0.25),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256, momentum=0.8),
nn.LeakyReLU(0.25),
nn.Dropout(0.25),
nn.Flatten(),
nn.Linear(256 * 5 * 5, 1),
nn.Sigmoid()
)
def forward(self, img):
validity = self.model(img)
return validity
- Discriminator class: Defines the discriminator network.
- model: A sequence of layers that transform an image into a single scalar value representing its validity.
- nn.Conv2d: Convolutional layers.
- nn.LeakyReLU: Activation function.
- nn.Dropout: Dropout layers for regularization.
- nn.ZeroPad2d: Padding layer.
- nn.BatchNorm2d: Batch normalization layers.
- nn.Flatten: Flattens the tensor.
- nn.Linear: Fully connected layer.
- nn.Sigmoid: Activation function for the output layer.
Initialize Models, Loss, and Optimizers
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))
- generator, discriminator: Instantiates the generator and discriminator models.
- adversarial_loss: Binary cross-entropy loss function.
- optimizer_G, optimizer_D: Adam optimizers for the generator and discriminator.
Training Loop
for epoch in range(num_epochs):
for i, batch in enumerate(dataloader):
real_images = batch[0].to(device)
valid = torch.ones(real_images.size(0), 1, device=device)
fake = torch.zeros(real_images.size(0), 1, device=device)
# Train Discriminator
optimizer_D.zero_grad()
z = torch.randn(real_images.size(0), latent_dim, device=device)
fake_images = generator(z)
real_loss = adversarial_loss(discriminator(real_images), valid)
fake_loss = adversarial_loss(discriminator(fake_images.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# Train Generator
optimizer_G.zero_grad()
gen_images = generator(z)
g_loss = adversarial_loss(discriminator(gen_images), valid)
g_loss.backward()
optimizer_G.step()
# Progress Monitoring
if (i + 1) % 100 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] Batch {i+1}/{len(dataloader)} "
f"Discriminator Loss: {d_loss.item():.4f} "
f"Generator Loss: {g_loss.item():.4f}")
# Save generated images for every epoch
if (epoch + 1) % 10 == 0:
with torch.no_grad():
z = torch.randn(16, latent_dim, device=device)
generated = generator(z).detach().cpu()
grid = torchvision.utils.make_grid(generated, nrow=4, normalize=True)
plt.imshow(np.transpose(grid, (1, 2, 0)))
plt.axis("off")
plt.show()
- Epoch loop: Loops through epochs.
- Batch loop: Loops through batches of real images.
- real_images: Real images from the dataset.
- valid, fake: Labels for real and fake images.
- Train Discriminator:
- z: Random noise vector.
- fake_images: Generated fake images.
- real_loss, fake_loss: Losses for real and fake images.
- d_loss: Total discriminator loss.
- optimizer_D.step(): Updates discriminator weights.
- Train Generator:
- gen_images: Generated fake images.
- g_loss: Generator loss.
- optimizer_G.step(): Updates generator weights.
- Progress Monitoring: Prints the losses every 100 batches.
- Save generated images: Saves and displays generated images every 10 epochs.