En aquest tema, aprendrem a crear una Xarxa Generativa Adversarial (GAN) per a la generació d'imatges. Les GAN són una de les tècniques més emocionants i innovadores en el camp del Deep Learning, utilitzades per generar dades noves i realistes a partir de dades d'entrenament.

Objectius

  • Comprendre la teoria darrere de les GAN.
  • Aprendre a implementar una GAN bàsica utilitzant PyTorch.
  • Generar imatges sintètiques a partir d'un conjunt de dades d'entrenament.

Continguts

  1. Introducció a les GAN
  2. Arquitectura d'una GAN
  3. Implementació d'una GAN amb PyTorch
  4. Entrenament de la GAN
  5. Generació d'imatges
  6. Exercicis pràctics

  1. Introducció a les GAN

Què és una GAN?

Una Xarxa Generativa Adversarial (GAN) és un tipus de model de deep learning compost per dues xarxes neuronals que competeixen entre si: el generador i el discriminador.

  • Generador: Crea dades sintètiques que semblen reals.
  • Discriminador: Avalua si les dades són reals (provinents del conjunt de dades d'entrenament) o falses (generades pel generador).

Funcionament

El generador intenta enganyar el discriminador creant dades cada vegada més realistes, mentre que el discriminador millora la seva capacitat per distingir entre dades reals i falses. Aquest procés de competició millora ambdues xarxes.

  1. Arquitectura d'una GAN

Components

  • Generador: Una xarxa neuronal que pren un vector de soroll com a entrada i genera una imatge.
  • Discriminador: Una xarxa neuronal que pren una imatge com a entrada i classifica si és real o generada.

Diagrama de Flux

Vector de Soroll -> Generador -> Imatge Generada -> Discriminador -> Classificació (Real/Falsa)

  1. Implementació d'una GAN amb PyTorch

Instal·lació de PyTorch

Abans de començar, assegura't de tenir PyTorch instal·lat. Pots instal·lar-lo amb pip:

pip install torch torchvision

Codi d'Implementació

Importació de Llibreries

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

Definició del Generador

class Generador(nn.Module):
    def __init__(self):
        super(Generador, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

Definició del Discriminador

class Discriminador(nn.Module):
    def __init__(self):
        super(Discriminador, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(784, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)

Preparació del Conjunt de Dades

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

mnist = dsets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = DataLoader(dataset=mnist, batch_size=100, shuffle=True)

  1. Entrenament de la GAN

Definició de la Funció de Pèrdua i Optimitzadors

generador = Generador()
discriminador = Discriminador()

criterion = nn.BCELoss()
optimizer_g = optim.Adam(generador.parameters(), lr=0.0002)
optimizer_d = optim.Adam(discriminador.parameters(), lr=0.0002)

Bucle d'Entrenament

num_epochs = 50
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        # Preparació de les dades
        images = images.view(images.size(0), -1)
        real_labels = torch.ones(images.size(0), 1)
        fake_labels = torch.zeros(images.size(0), 1)

        # Entrenament del Discriminador
        outputs = discriminador(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs

        z = torch.randn(images.size(0), 100)
        fake_images = generador(z)
        outputs = discriminador(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs

        d_loss = d_loss_real + d_loss_fake
        optimizer_d.zero_grad()
        d_loss.backward()
        optimizer_d.step()

        # Entrenament del Generador
        z = torch.randn(images.size(0), 100)
        fake_images = generador(z)
        outputs = discriminador(fake_images)

        g_loss = criterion(outputs, real_labels)

        optimizer_g.zero_grad()
        g_loss.backward()
        optimizer_g.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}, D(x): {real_score.mean().item()}, D(G(z)): {fake_score.mean().item()}')

  1. Generació d'Imatges

Visualització de les Imatges Generades

import matplotlib.pyplot as plt

z = torch.randn(64, 100)
fake_images = generador(z)
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
fake_images = fake_images.data

grid = torchvision.utils.make_grid(fake_images, nrow=8, normalize=True)
plt.imshow(grid.permute(1, 2, 0))
plt.show()

  1. Exercicis Pràctics

Exercici 1: Modificar l'Arquitectura del Generador

  • Prova a afegir més capes al generador i observa com afecta la qualitat de les imatges generades.

Exercici 2: Entrenar la GAN amb un Conjunt de Dades Diferent

  • Utilitza un conjunt de dades diferent, com CIFAR-10, i entrena la GAN per generar imatges de diferents categories.

Exercici 3: Ajustar els Hiperparàmetres

  • Experimenta amb diferents valors per als hiperparàmetres com la taxa d'aprenentatge, el nombre d'epochs, i la mida del batch.

Conclusió

En aquesta secció, hem après a crear una GAN per a la generació d'imatges utilitzant PyTorch. Hem cobert la teoria darrere de les GAN, la seva arquitectura, i hem implementat una GAN des de zero. A més, hem explorat com entrenar la GAN i generar imatges sintètiques. Els exercicis pràctics proporcionats t'ajudaran a aprofundir en els conceptes i millorar les teves habilitats en la creació de GANs.

© Copyright 2024. Tots els drets reservats