[AI] Variational Autoencoder Implementation
Simple Implementation for Variational Autoencoder
[AI] Variational Autoencoder Implementation
1. Variational Autoencoder
1) Import Libraries
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# AE.ipynb
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
import torch
from torchvision import datasets
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import os
#os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
2) CUDA Setting
1
2
3
4
5
6
# CUDA setting
if torch.cuda.is_available():
device = torch.device('cuda')
print(f"CUDA GPU : {torch.cuda.get_device_name(0)}")
else:
device = torch.device('cpu')
1
CUDA GPU : NVIDIA GeForce RTX 4070 Ti SUPER
3) Dataset Download & Parse
1
2
3
4
5
6
7
8
9
# Dataset download (MNIST)
BATCH_SIZE = 512
train_data = datasets.MNIST('./data', train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.MNIST('./data', train=False, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
4) Model Archtecture & Structure
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Model architecture
class VAE(nn.Module):
def __init__(self, latent_dimension = 16):
super(VAE, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=2, padding=1), # 1x28x28 -> 32x14x14
nn.ReLU(),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), # 32x14x14 -> 64x7x7
nn.ReLU(),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7), # 64x7x7 -> 64x1x1
nn.ReLU(),
)
self.latent_mean = nn.Sequential(
nn.Linear(in_features=64, out_features=latent_dimension, bias=True),
)
self.latent_logvar = nn.Sequential(
nn.Linear(in_features=64, out_features=latent_dimension, bias=True),
)
self.latent_dec = nn.Sequential(
nn.Linear(in_features=latent_dimension, out_features=64),
nn.ReLU()
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=7),
nn.ReLU(),
nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(in_channels=16, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.Sigmoid()
)
def reparameterize(self, mean, logvar):
eps = torch.randn_like(logvar).to(device)
std = torch.exp(logvar * 0.5)
return mean + std * eps
def forward(self, x):
x = self.encoder(x)
x = x.view(-1, 64) # reshape
means = self.latent_mean(x) # BATCH, latent_dimension
logvars = self.latent_logvar(x)
latent_vector = self.reparameterize(means, logvars)
recon_x = self.latent_dec(latent_vector)
recon_x = recon_x.view(-1, 64, 1, 1)
recon_x = self.decoder(recon_x)
return recon_x, latent_vector, means, logvars
5) Training defined model (dimension = 16)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Model Training
from IPython import display
model = VAE().to(device)
epoch = 300
learning_rate = 1E-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
losses = list()
def loss_fn(x, pred, mu, logvar):
recon_loss = F.mse_loss(pred, x, reduction='sum')
kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + kld
for epoch_idx in trange(epoch):
model.train()
running_loss = 0.0
total_batches = 0
for (x_train, _) in train_loader:
x_train = x_train.to(device)
output, __, mean, logvar = model(x_train)
loss = loss_fn(x_train, output, mean, logvar)
running_loss += loss.item()
total_batches += 1
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(running_loss / total_batches)
display.clear_output(wait=True)
plt.figure(figsize=(8, 5))
xrange = [i for i in range(1, epoch_idx+2)]
if len(xrange) > 30:
xrange = xrange[-30:]
yrange = losses[xrange[0]-1:xrange[-1]]
else:
yrange = losses[:]
plt.plot(xrange,yrange, linestyle='--', linewidth=2, c='r')
plt.scatter(xrange, yrange, c='red', s=40)
plt.xlabel('Batch number')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()
1
100%|█████████████████████████████████████████████████████████████████████████████████| 300/300 [08:51<00:00, 1.77s/it]
6) Visualization of my model
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import matplotlib.pyplot as plt
import torch
model = torch.load('./VAE_MNIST.pt', map_location=device)
start_idx = 200
num_pairs = 100
pairs_per_row = 10
plt.figure(figsize=(20, 10))
for i in range(num_pairs):
row = i // pairs_per_row
col = i % pairs_per_row
# Original Image
plt.subplot(10, 20, 2 * col + 1 + 20 * row)
data = test_data[start_idx + i][0]
img = data.reshape(28, 28)
plt.imshow(img, cmap='gray')
plt.axis('off')
# Reconstructed Image
plt.subplot(10, 20, 2 * col + 2 + 20 * row)
with torch.no_grad():
reconstructed_images, _, __, ___ = model(data.unsqueeze(0).to(device))
reconstructed_images = reconstructed_images.cpu()
plt.imshow(reconstructed_images.squeeze().reshape(28, 28), cmap='gray')
plt.axis('off')
plt.subplots_adjust(wspace=0.1, hspace=0.1)
plt.show()
7) Saving Model Weights
1
2
print(f"Result loss : {losses[-1]}")
torch.save(model, "./VAE_MNIST.pt")
1
Result loss : 3218.4372475834216
2. 2D Modeling (decomposition)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.patches as mpatches
model = torch.load("./VAE_MNIST.pt", map_location=device)
model.eval()
latent_vectors = []
label = []
colors = ['#fe7c73', '#2471A3', '#3498DB', '#27AE60', '#82E0AA', '#D35400', '#5D6D7E', '#E74C3C', '#21618C', '#B7950B']
for (X, y) in tqdm(test_data):
with torch.no_grad():
_, __, mean, ___ = model(X.to(device))
latent_vectors.append(mean.reshape(16).cpu().numpy())
label.append(y)
color_labeled = [colors[label[i]] for i in range(len(label))]
all_latent_vectors = np.vstack(latent_vectors)
# decomposition_model = PCA(n_components=2) # Not that good performance..
decomposition_model = TSNE(n_components=2)
latent_2d = decomposition_model.fit_transform(all_latent_vectors)
patches = [mpatches.Patch(color=colors[i], label=f'{i}') for i in range(len(colors))]
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.scatter(latent_2d[:, 0], latent_2d[:, 1], s=1, c = color_labeled)
plt.title('2D decomposition of Latent Vectors')
plt.show()
1
100%|███████████████████████████████████████████████████████████████████████████| 10000/10000 [00:07<00:00, 1302.40it/s]
3. Can we use VAE as generative model ??
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
model = torch.load("./VAE_MNIST.pt", map_location=device)
model.eval()
label_latent_means = [torch.zeros(16).to(device) for _ in range(10)]
label_latent_logvars = [torch.zeros(16).to(device) for _ in range(10)]
label_num = [0 for _ in range(10)]
for (X, y) in tqdm(test_data):
with torch.no_grad():
_, __, mean, logvar = model(X.to(device))
label_latent_means[y] += mean.view(16)
label_latent_logvars[y] += logvar.view(16)
label_num[y] += 1
label_latent_vectors = [model.reparameterize(label_latent_means[i] / label_num[i], label_latent_logvars[i] / label_num[i]) for i in range(10)]
label_mean_img = []
with torch.no_grad():
for i in range(10):
recon_x = model.latent_dec(label_latent_vectors[i])
recon_x = recon_x.view(64, 1, 1)
label_mean_img.append(model.decoder(recon_x).cpu())
fig, axes = plt.subplots(1, 10, figsize=(15, 2))
for i in range(10):
axes[i].imshow(label_mean_img[i].view(28, 28), cmap='gray')
axes[i].axis('off')
axes[i].set_title(str(i))
plt.show()
1
100%|███████████████████████████████████████████████████████████████████████████| 10000/10000 [00:06<00:00, 1442.23it/s]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
assert len(all_latent_vectors) == 10000
reshaped_label_latent_vectors = np.array([llv.view(16).cpu() for llv in label_latent_vectors])
new_latent_vectors = np.vstack([all_latent_vectors, reshaped_label_latent_vectors])
decomposition_model = TSNE(n_components=2)
latent_2d = decomposition_model.fit_transform(new_latent_vectors)
patches = [mpatches.Patch(color=colors[i], label=f'{i}') for i in range(len(colors))]
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.scatter(latent_2d[:, 0][:-10], latent_2d[:, 1][:-10], s=1, c = color_labeled)
plt.scatter(latent_2d[:, 0][-10:], latent_2d[:, 1][-10:], s = 10, c = 'black')
plt.title('2D decomposition of Latent Vectors')
plt.show()
4. And how about…?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
model = torch.load('./VAE_MNIST.pt', map_location=device)
model.eval()
# can define any label
sample1 = label_latent_vectors[2]
sample2 = label_latent_vectors[4]
sample3 = label_latent_vectors[5]
sample4 = label_latent_vectors[6]
fig, axes = plt.subplots(11, 11, figsize=(10, 10))
recon_imgs = list()
for alpha1 in np.arange(0.0, 1.1, 0.1):
for alpha2 in np.arange(0.0, 1.1, 0.1):
new_vector = (alpha1 * sample1 + (1 - alpha1) * sample2) * alpha2 + (1-alpha2) * (alpha1 * sample3 + (1-alpha1) * sample4)
with torch.no_grad():
recon_x = model.latent_dec(new_vector)
recon_x = recon_x.view(64, 1, 1)
reconstructed_img = model.decoder(recon_x).cpu()
recon_imgs.append(reconstructed_img)
index = 0
for i in range(11):
for j in range(11):
axes[i, j].imshow(recon_imgs[index].view(28,28), cmap='gray')
axes[i, j].axis('off')
index += 1
plt.tight_layout()
plt.show()
AMAZING, ISN’t IT?
5. What if latent dimension is 2d?
1) Model Training
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Model Training
from IPython import display
model = VAE(latent_dimension=2).to(device)
epoch = 300
learning_rate = 3E-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
losses = list()
def loss_fn(x, pred, mu, logvar):
recon_loss = F.mse_loss(pred, x, reduction='sum')
kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + kld
for epoch_idx in trange(epoch):
model.train()
running_loss = 0.0
total_batches = 0
for (x_train, _) in train_loader:
x_train = x_train.to(device)
output, __, mean, logvar = model(x_train)
loss = loss_fn(x_train, output, mean, logvar)
running_loss += loss.item()
total_batches += 1
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(running_loss / total_batches)
display.clear_output(wait=True)
plt.figure(figsize=(8, 5))
xrange = [i for i in range(1, epoch_idx+2)]
if len(xrange) > 30:
xrange = xrange[-30:]
yrange = losses[xrange[0]-1:xrange[-1]]
else:
yrange = losses[:]
plt.plot(xrange,yrange, linestyle='--', linewidth=2, c='r')
plt.scatter(xrange, yrange, c='red', s=40)
plt.xlabel('Batch number')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()
1
100%|█████████████████████████████████████████████████████████████████████████████████| 300/300 [09:08<00:00, 1.83s/it]
2) Visualization
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import matplotlib.pyplot as plt
import torch
start_idx = 200
num_pairs = 100
pairs_per_row = 10
plt.figure(figsize=(20, 10))
for i in range(num_pairs):
row = i // pairs_per_row
col = i % pairs_per_row
# Original Image
plt.subplot(10, 20, 2 * col + 1 + 20 * row)
data = test_data[start_idx + i][0]
img = data.reshape(28, 28)
plt.imshow(img, cmap='gray')
plt.axis('off')
# Reconstructed Image
plt.subplot(10, 20, 2 * col + 2 + 20 * row)
with torch.no_grad():
reconstructed_images, _, __, ___ = model(data.unsqueeze(0).to(device))
reconstructed_images = reconstructed_images.cpu()
plt.imshow(reconstructed_images.squeeze().reshape(28, 28), cmap='gray')
plt.axis('off')
plt.subplots_adjust(wspace=0.1, hspace=0.1)
plt.show()
torch.save(model, './VAE_MNIST_2d.pt')
3) Visualization in 2d latent space
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.patches as mpatches
model = torch.load("./VAE_MNIST_2d.pt", map_location=device)
model.eval()
latent_vectors = []
label = []
colors = ['#fe7c73', '#2471A3', '#3498DB', '#27AE60', '#82E0AA', '#D35400', '#5D6D7E', '#E74C3C', '#21618C', '#B7950B']
for (X, y) in tqdm(test_data):
with torch.no_grad():
_, __, latent_x, ___ = model(X.to(device))
latent_vectors.append(latent_x.reshape(2).cpu().numpy())
label.append(y)
color_labeled = [colors[label[i]] for i in range(len(label))]
all_latent_vectors = np.vstack(latent_vectors)
plt.scatter(all_latent_vectors[:, 0], all_latent_vectors[:, 1], s=1, c = color_labeled)
patches = [mpatches.Patch(color=colors[i], label=f'{i}') for i in range(len(colors))]
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc='upper left')\
plt.title('2D decomposition of Latent Vectors')
plt.show()
1
100%|████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:24<00:00, 408.61it/s]
The Result of VAE is not that good in 2d space, but when the dimension is enough, It overcomes the AE itself.
Moreover, the VAE model is based in decoder, so we could make another n labeled pictures using decoder.
This post is licensed under CC BY 4.0 by the author.