En aquest tema, aprendrem com desar i carregar models en PyTorch. Aquesta habilitat és essencial per a la reutilització de models entrenats, la compartició de models amb altres persones i la implementació de models en producció.

  1. Introducció

Desar i carregar models en PyTorch és un procés senzill però crucial. PyTorch proporciona dues maneres principals de desar models:

  1. Desar només els pesos del model (state_dict).
  2. Desar tot el model (estructura + pesos).

  1. Desament de Models

2.1 Desar només els pesos del model

Desar només els pesos del model és la manera més recomanada, ja que és més flexible i permet reconstruir el model amb diferents configuracions.

import torch
import torch.nn as nn

# Definim una xarxa neuronal simple
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(10, 50)
        self.fc2 = nn.Linear(50, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Creem una instància del model
model = SimpleNN()

# Desem els pesos del model
torch.save(model.state_dict(), 'model_weights.pth')

2.2 Desar tot el model

Desar tot el model inclou tant l'estructura com els pesos. Això pot ser útil per a la reproducció exacta del model, però és menys flexible.

# Desem tot el model
torch.save(model, 'model_complete.pth')

  1. Càrrega de Models

3.1 Càrrega només dels pesos del model

Per carregar només els pesos del model, primer hem de crear una instància del model amb la mateixa estructura i després carregar els pesos.

# Creem una nova instància del model
model = SimpleNN()

# Carreguem els pesos del model
model.load_state_dict(torch.load('model_weights.pth'))

# Posem el model en mode d'avaluació
model.eval()

3.2 Càrrega de tot el model

Per carregar tot el model, simplement utilitzem la funció torch.load.

# Carreguem tot el model
model = torch.load('model_complete.pth')

# Posem el model en mode d'avaluació
model.eval()

  1. Exercicis Pràctics

Exercici 1: Desar i Carregar un Model Simple

  1. Crea una xarxa neuronal simple amb PyTorch.
  2. Entrena el model amb un conjunt de dades petit.
  3. Desa els pesos del model.
  4. Carrega els pesos del model en una nova instància.
  5. Comprova que el model carregat produeix les mateixes prediccions que el model original.

Solució

import torch
import torch.nn as nn
import torch.optim as optim

# Definim una xarxa neuronal simple
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(10, 50)
        self.fc2 = nn.Linear(50, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Creem una instància del model
model = SimpleNN()

# Definim una pèrdua i un optimitzador
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Generem dades aleatòries per entrenar
inputs = torch.randn(100, 10)
targets = torch.randn(100, 1)

# Entrenem el model
for epoch in range(100):
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

# Desa els pesos del model
torch.save(model.state_dict(), 'simple_model_weights.pth')

# Creem una nova instància del model
new_model = SimpleNN()

# Carreguem els pesos del model
new_model.load_state_dict(torch.load('simple_model_weights.pth'))

# Posem el model en mode d'avaluació
new_model.eval()

# Comprovem que el model carregat produeix les mateixes prediccions
with torch.no_grad():
    original_outputs = model(inputs)
    loaded_outputs = new_model(inputs)
    print(torch.allclose(original_outputs, loaded_outputs))

  1. Resum

En aquesta secció, hem après com desar i carregar models en PyTorch. Hem vist dues maneres principals de desar models: desar només els pesos i desar tot el model. També hem practicat com carregar models desats i hem realitzat un exercici pràctic per reforçar aquests conceptes. Aquestes habilitats són fonamentals per a la reutilització i la implementació de models en aplicacions reals.

© Copyright 2024. Tots els drets reservats