import torch from torch import nn # Automatically use GPU if available, otherwise fall back to CPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using {device} device") # Define model class NeuralNetwork(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28*28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10) ) def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logits model = NeuralNetwork().to(device) # Move data to GPU print(model) # Save the model torch.save(model.state_dict(), "model.pth") print("Saved PyTorch model state to model.pth")