huggan/pokemon
Viewer • Updated • 7.36k • 526 • 23
A Generative Adversarial Network (GAN) trained to synthesize 64x64 pixel-art style Pokemon sprites. This model was trained on the huggan/pokemon dataset using Optuna for hyperparameter optimization.
This model utilizes a custom DCGAN-style framework with several modern stability improvements:
| Component | Details |
|---|---|
| Discriminator | A convolutional neural network that processes 64x64x3 RGB images down to a 1x1 scalar. It applies Spectral Normalization to every Conv2d layer to enforce Lipschitz continuity and stabilize training, paired with LeakyReLU activations. |
| Generator | Takes a latent noise vector (1x1) and projects it using a ConvTranspose2d layer. To prevent checkerboard artifacts, the upsampling path uses Upsampling followed by standard Conv2d layers, BatchNorm2d, and ReLU. The final output uses a Tanh activation to scale pixels to [-1, 1]. |
ReLU(1.0 - D(real)) + ReLU(1.0 + D(fake))-D(fake)import torch
from torch import nn
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
class PokemonGenerator(nn.Module):
def __init__(self, noise_dim=100, features_g=64, channels=3):
super().__init__()
self.initial_block = nn.Sequential(
nn.ConvTranspose2d(noise_dim, features_g * 8, kernel_size=4, stride=1, padding=0),
nn.BatchNorm2d(features_g * 8),
nn.ReLU(True)
)
self.upsample_blocks = nn.Sequential(
# 4x4 -> 8x8
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(features_g * 8, features_g * 4, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(features_g * 4),
nn.ReLU(True),
# 8x8 -> 16x16
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(features_g * 4, features_g * 2, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(features_g * 2),
nn.ReLU(True),
# 16x16 -> 32x32
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(features_g * 2, features_g, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(features_g),
nn.ReLU(True),
# 32x32 -> 64x64 RGB
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(features_g, channels, kernel_size=3, stride=1, padding=1),
nn.Tanh()
)
def forward(self, x):
out = self.initial_block(x)
return self.upsample_blocks(out)
model = PokemonGenerator(noise_dim=128)
weights_path = hf_hub_download(repo_id="VioletaR/pokemon-gan", filename="pokemon_generator.pth")
model.load_state_dict(torch.load(weights_path))
model.eval()
# Note: The generator expects a 4D tensor (Batch, Channels, Height, Width)
eval_noise = torch.randn(1, noise_dim, 1, 1)
with torch.no_grad():
generated_img = model(eval_noise)
generated_img = (generated_img + 1) / 2.0 # Scale from [-1, 1] to [0, 1]
img_numpy = generated_img.squeeze().permute(1, 2, 0).cpu().numpy()
plt.imshow(img_numpy)
plt.title("Generated Pokemon")
plt.axis('off')
plt.show()