En aquest tema, aprendrem com ajustar finament les Xarxes Neuronals Convolucionals (CNNs) per millorar el seu rendiment en tasques específiques. L'ajust fi és una tècnica que implica prendre un model preentrenat i adaptar-lo a una nova tasca. Aquesta tècnica és especialment útil quan tenim una quantitat limitada de dades per entrenar un model des de zero.
Objectius
- Entendre el concepte d'ajust fi.
- Aprendre a carregar un model preentrenat.
- Modificar l'arquitectura del model per adaptar-lo a una nova tasca.
- Entrenar el model ajustat finament amb un nou conjunt de dades.
- Què és l'Ajust Fi?
L'ajust fi (fine-tuning) és el procés de prendre un model que ha estat preentrenat en una gran base de dades (com ImageNet) i adaptar-lo a una nova tasca amb un conjunt de dades més petit. Els avantatges d'aquesta tècnica inclouen:
- Reducció del temps d'entrenament: Com que el model ja ha après característiques generals de les imatges, només cal ajustar-lo a les especificitats de la nova tasca.
- Millor rendiment amb menys dades: Els models preentrenats ja tenen una bona comprensió de les característiques visuals, la qual cosa permet obtenir bons resultats fins i tot amb menys dades.
- Carregar un Model Preentrenat
PyTorch proporciona una varietat de models preentrenats a través del mòdul torchvision.models
. A continuació, carregarem un model ResNet preentrenat.
import torch import torchvision.models as models # Carregar un model ResNet preentrenat model = models.resnet18(pretrained=True)
- Modificar l'Arquitectura del Model
Per adaptar el model a una nova tasca, hem de modificar l'última capa de la xarxa per ajustar-la al nombre de classes de la nostra nova tasca. Suposem que volem classificar imatges en 10 categories.
import torch.nn as nn # Modificar l'última capa de la xarxa num_classes = 10 model.fc = nn.Linear(model.fc.in_features, num_classes)
- Entrenar el Model Ajustat Finament
Ara que hem modificat el model, podem entrenar-lo amb el nostre nou conjunt de dades. A continuació, es mostra un exemple de com fer-ho.
4.1 Preparar les Dades
from torchvision import datasets, transforms # Definir les transformacions per a les dades d'entrenament i validació transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Carregar les dades d'entrenament i validació train_dataset = datasets.ImageFolder(root='path/to/train_data', transform=transform) val_dataset = datasets.ImageFolder(root='path/to/val_data', transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)
4.2 Definir la Funció de Pèrdua i l'Optimitzador
criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
4.3 Entrenar el Model
num_epochs = 10 for epoch in range(num_epochs): model.train() running_loss = 0.0 for inputs, labels in train_loader: optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader)}') # Validació model.eval() correct = 0 total = 0 with torch.no_grad(): for inputs, labels in val_loader: outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f'Validation Accuracy: {100 * correct / total}%')
- Resum
En aquesta secció, hem après com ajustar finament una CNN preentrenada per a una nova tasca. Hem cobert els passos per carregar un model preentrenat, modificar la seva arquitectura, preparar les dades i entrenar el model. L'ajust fi és una tècnica poderosa que pot millorar significativament el rendiment dels models amb menys dades i menys temps d'entrenament.
En el següent mòdul, explorarem les Xarxes Neuronals Recurrents (RNNs) i com utilitzar-les per a tasques seqüencials.
PyTorch: De Principiant a Avançat
Mòdul 1: Introducció a PyTorch
- Què és PyTorch?
- Configuració de l'Entorn
- Operacions Bàsiques amb Tensor
- Autograd: Diferenciació Automàtica
Mòdul 2: Construcció de Xarxes Neuronals
- Introducció a les Xarxes Neuronals
- Creació d'una Xarxa Neuronal Simple
- Funcions d'Activació
- Funcions de Pèrdua i Optimització
Mòdul 3: Entrenament de Xarxes Neuronals
- Càrrega i Preprocessament de Dades
- Bucle d'Entrenament
- Validació i Prova
- Desament i Càrrega de Models
Mòdul 4: Xarxes Neuronals Convolucionals (CNNs)
- Introducció a les CNNs
- Construcció d'una CNN des de Zero
- Aprenentatge per Transferència amb Models Preentrenats
- Ajust Fi de les CNNs
Mòdul 5: Xarxes Neuronals Recurrents (RNNs)
- Introducció a les RNNs
- Construcció d'una RNN des de Zero
- Xarxes de Memòria a Llarg i Curt Termini (LSTM)
- Unitats Recurrents Gated (GRUs)
Mòdul 6: Temes Avançats
- Xarxes Generatives Adversàries (GANs)
- Aprenentatge per Reforç amb PyTorch
- Desplegament de Models PyTorch
- Optimització del Rendiment