生成式人工智能实战 | 变分自编码器(Variational Auto-Encoder, VAE)
本文介绍了变分自编码器 (VAE) 的原理与实现,阐述了其相较于传统自编码器的改进。VAE通过编码器将输入数据映射为潜在空间的概率分布(均值和方差),利用重参数化技巧实现可微采样,使解码器能生成新样本。通过引入 KL 散度损失,强制潜在变量服从标准正态分布,从而确保潜空间的连续性和可解释性,支持向量运算和样本插值。本文详细介绍了 VAE 的 PyTorch 实现,并使用 CIFAR-10 数据集组
生成式人工智能实战 | 变分自编码器
0. 前言
虽然自编码器 (AutoEncoder, AE) 在重建输入数据方面表现良好,但通常在生成训练集中不存在的新样本时表现不佳。更重要的是,自编码器在输入插值方面同样表现不佳,无法生成两个输入数据点之间的中间表示。这就引出了变分自编码器 (Variational Auto-Encoder
, VAE
),变分自编码器是一种生成模型,结合了深度学习和概率图模型的优点,通过学习数据的潜在概率分布来生成新的数据样本。本节将从零开始构建和训练一个 VAE
,使用 cifar-10
数据集训练 VAE
。
1. 潜空间运算
使用变分自编码器 (Variational Auto-Encoder
, VAE
) 可以进行向量运算和输入插值。操作不同输入的编码表示(潜向量),以在解码时实现特定的结果(例如,图像中是否具有某些特征)。潜向量控制解码图像中的不同特征,如性别、图像中是否有眼镜等。例如,可以首先获得戴眼镜的男性的潜向量 (z1
)、戴眼镜的女性的潜向量 (z2
) 和不戴眼镜的女性的潜向量 (z3
)。然后,计算一个新的潜向量 z4 = z1 – z2 + z3
。由于 z1
和 z2
解码后都会出现眼镜,z1 – z2
会在结果图像中去除眼镜特征。类似地,由于 z2
和 z3
都会解码为女性面孔,z3 – z2
会去除结果图像中的女性特征。因此,如果使用训练好的 VAE
解码 z4
将得到一张没有不戴眼镜的男性图像。
2. 变分自编码器
虽然自编码器 (AutoEncoder
, AE
) 擅长重建原始图像,但它们在生成训练集中没有出现的新图像方面表现不佳。此外,自编码器通常无法将相似的输入映射到潜空间中的相邻点。因此,AE
的潜空间既不连续,也不容易解释。例如,无法通过插值两个输入数据点来生成有意义的中间表示。基于这些原因,我们将学习自编码器的改进模型,变分自编码器 (Variational Auto-Encoder
, VAE
)。
2.1 VAE 工作原理
VAE
使用深度学习构建概率模型,将输入数据映射到一个低维度的潜空间中,并通过解码器将潜空间中的分布转换回数据空间中,以生成与原始数据相似的数据。与传统的自编码器相比,VAE
更加稳定,生成样本的质量更高。VAE
的核心思想是利用概率模型来描述高维的输入数据,将输入数据采样于一个低维度的潜变量分布中,并通过解码器生成与原始数据相似的输出。具体来说,VAE
同样是由编码器和解码器组成:
- 编码器将数据 x x x 映射到一个潜在空间 z z z 中,该空间定义在低维正态分布中,即 z ∼ N ( 0 , I ) z∼N(0,I) z∼N(0,I),编码器由两个部分组成:一是将数据映射到均值和方差,即 z ∼ N ( μ , σ 2 ) z∼N(μ,σ^2) z∼N(μ,σ2);二是通过重参数化技巧,将均值和方差的采样过程分离出来,并引入随机变量 ϵ ∼ N ( 0 , I ) ϵ∼N(0,I) ϵ∼N(0,I),使得 z = μ + ϵ σ z=μ+ϵσ z=μ+ϵσ
- 解码器将潜在变量 z z z 映射回数据空间中,生成与原始数据 x x x 相似的数据 x ′ x′ x′,为了使生成的数据 x ′ x′ x′ 能够与原始数据 x x x 较高的相似度,
VAE
在损失函数中使用重构误差和正则化项,重构误差表示生成数据与原始数据之间的差异,正则化项用于约束潜在变量的分布,使其满足高斯正态分布,使得VAE
从潜空间中生成的样本质量更高
VAE
具有广泛的应用场景,如图像生成、语音、自然语言处理等领域,它能够通过有限的数据样本学习到输入数据背后的潜在规律,生成与原始数据类似的新数据,具有很强的潜数据的可解释性。
2.2 VAE 构建策略
在 VAE
中,基于预定义分布获得的随机向量生成逼真图像,而在传统自编码器中并未指定在网络中生成图像的数据分布。可以通过以下策略,实现 VAE:
- 编码器的输出包括两个向量:
- 输入图像平均值
- 输入图像标准差
- 根据以上两个向量,通过在均值和标准差之和中引入随机变量 ( ϵ ∼ N ( 0 , I ) ϵ∼N(0,I) ϵ∼N(0,I)) 获取随机向量 ( z = μ + ϵ σ z=μ+ϵσ z=μ+ϵσ)
- 将上一步得到的随机向量作为输入传递给解码器以重构图像
- 损失函数是均方误差和 KL 散度损失的组合:
KL
散度损失衡量由均值向量 μ \mu μ 和标准差向量 σ \sigma σ 构建的分布与 N ( 0 , I ) N(0,I) N(0,I) 分布的偏差- 均方损失用于优化重建(解码)图像
通过训练网络,指定输入数据满足由均值向量 μ \mu μ 和标准差向量 σ \sigma σ 构建的 N ( 0 , 1 ) N(0,1) N(0,1) 分布,当我们生成均值为 0
且标准差为 1
的随机噪声时,解码器将能够生成逼真的图像。
需要注意的是,如果只最小化 KL
散度,编码器将预测均值向量为 0
,标准差为 1
。因此,需要同时最小化 KL
散度损失和均方损失。在下一节中,让我们介绍 KL
散度,以便将其纳入模型的损失值计算中。
2.3 KL 散度
KL
散度(也称相对熵)可以用于衡量两个概率分布之间的差异:
K L ( P ∣ ∣ Q ) = ∑ x ∈ X P ( x ) l n ( P ( i ) Q ( i ) ) KL(P||Q) = \sum_{x∈X} P(x) ln(\frac {P(i)}{Q(i)}) KL(P∣∣Q)=x∈X∑P(x)ln(Q(i)P(i))
其中, P P P 和 Q Q Q 为两个概率分布,KL
散度的值越小,两个分布的相似性就越高,当且仅当 P P P 和 Q Q Q 两个概率分布完全相同时,KL
散度等于 0
。在 VAE
中,我们希望瓶颈特征值遵循平均值为 0
和标准差为 1
的正态分布。因此,我们可以使用 KL
散度衡量变分自编码器中编码器输出的分布与标准高斯分布 N ( 0 , 1 ) N(0,1) N(0,1) 之间的差异。
可以通过以下公式计算 KL
散度损失:
∑ i = 1 n σ i 2 + μ i 2 − l o g ∗ ( σ i ) − 1 \sum_{i=1}^n\sigma_i^2+\mu_i^2-log*(\sigma_i)-1 i=1∑nσi2+μi2−log∗(σi)−1
在上式中, σ σ σ 和 μ μ μ 表示每个输入图像的均值和标准差值:
- 确保均值向量分布在
0
附近:- 最小化上式中的均方误差 ( μ i 2 \mu_i^2 μi2) 可确保 μ \mu μ 尽可能接近
0
- 最小化上式中的均方误差 ( μ i 2 \mu_i^2 μi2) 可确保 μ \mu μ 尽可能接近
- 确保标准差向量分布在
1
附近:- 上式中其余部分(除了 μ i 2 \mu_i^2 μi2 )用于确保标准差 ( s i g m a sigma sigma) 分布在
1
附近
- 上式中其余部分(除了 μ i 2 \mu_i^2 μi2 )用于确保标准差 ( s i g m a sigma sigma) 分布在
当均值 ( μ μ μ) 为 0
且标准差为 1
时,以上损失函数值达到最小,通过引入标准差的对数,确保 σ \sigma σ 值不为负。通过最小化以上损失可以确保编码器输出遵循预定义分布。
2.4 重参数化技巧
下图左侧显示了 VAE
网络。编码器获取输入 x x x,并估计潜矢量 z z z 的多元高斯分布的均值 μ μ μ 和标准差 σ σ σ,解码器从潜矢量 z z z 采样,以将输入重构为 x x x:
但是反向传播梯度不会通过随机采样块。虽然可以为神经网络提供随机输入,但梯度不可能穿过随机层。解决此问题的方法是将“采样”过程作为输入,如图右侧所示。 采样计算为:
S a m p l e = μ + ε σ Sample=\mu + εσ Sample=μ+εσ
如果 ε ε ε 和 σ σ σ 以矢量形式表示,则 ε σ εσ εσ 是逐元素乘法,使用上式,令采样好像直接来自于潜空间。 这种技术被称为重参数化技巧 (Reparameterization trick
)。
3. 实现 VAE
在本节中,使用 PyTorch
实现 VAE
模型生成 cifar-10
图像。
3.1 数据加载
(1) 首先导入所需的库:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
(2) 定义数据预处理转换:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 将像素值归一化到[-1,1]
])
(3) 加载 CIFAR-10
训练集和测试集:
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
(4) 创建数据加载器:
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
3.2 模型构建
(1) 定义 VAE
模型,由编码器和解码器构成:
class VAE(nn.Module):
def __init__(self, latent_dim=128):
super(VAE, self).__init__()
self.latent_dim = latent_dim
# 编码器
self.encoder = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1), # 32x16x16
nn.ReLU(),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), # 64x8x8
nn.ReLU(),
nn.BatchNorm2d(64),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # 128x4x4
nn.ReLU(),
nn.BatchNorm2d(128),
nn.Flatten(), # 128*4*4=2048
nn.Linear(2048, 1024),
nn.ReLU()
)
# 潜在空间的均值和对数方差
self.fc_mu = nn.Linear(1024, latent_dim)
self.fc_logvar = nn.Linear(1024, latent_dim)
# 解码器
self.decoder_input = nn.Linear(latent_dim, 1024)
self.decoder = nn.Sequential(
nn.Linear(1024, 2048),
nn.ReLU(),
nn.Unflatten(1, (128, 4, 4)), # 128x4x4
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # 64x8x8
nn.ReLU(),
nn.BatchNorm2d(64),
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # 32x16x16
nn.ReLU(),
nn.BatchNorm2d(32),
nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1), # 3x32x32
nn.Tanh() # 输出在[-1,1]之间,与输入归一化一致
)
def encode(self, x):
"""
编码输入图像x,返回潜在空间的均值和方差
"""
h = self.encoder(x)
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
def reparameterize(self, mu, logvar):
"""
重参数化技巧,从N(mu, var)采样
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
"""
从潜在变量z解码重构图像
"""
h = self.decoder_input(z)
x_recon = self.decoder(h)
return x_recon
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
x_recon = self.decode(z)
return x_recon, mu, logvar
(2) 定义损失函数,由重建损失和 KL 散度组成:
def vae_loss(recon_x, x, mu, logvar):
"""
VAE损失函数 = 重构损失 + KL散度
"""
# 重构损失
recon_loss = F.mse_loss(recon_x, x, reduction='sum')
# KL散度:-0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + kl_loss
3.3 模型训练
(1) 定义模型训练和测试函数:
def train(model, device, train_loader, optimizer, epoch):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
# 前向传播
recon_batch, mu, logvar = model(data)
# 计算损失
loss = vae_loss(recon_batch, data, mu, logvar)
# 反向传播和优化
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item() / len(data):.6f}')
avg_loss = train_loss / len(train_loader.dataset)
print(f'====> Epoch: {epoch} Average loss: {avg_loss:.4f}')
return avg_loss
def test(model, device, test_loader):
model.eval()
test_loss = 0
with torch.no_grad():
for data, _ in test_loader:
data = data.to(device)
recon_batch, mu, logvar = model(data)
test_loss += vae_loss(recon_batch, data, mu, logvar).item()
test_loss /= len(test_loader.dataset)
print(f'====> Test set loss: {test_loss:.4f}')
return test_loss
(2) 定义可视化函数,用于可视化原始图像和重构图像:
def visualize_reconstruction(model, device, test_loader, num_images=8):
model.eval()
with torch.no_grad():
# 获取一批测试图像
data, _ = next(iter(test_loader))
data = data[:num_images].to(device)
# 重构图像
recon_data, _, _ = model(data)
# 将图像从[-1,1]转换回[0,1]以便显示
data = data.cpu().numpy().transpose(0, 2, 3, 1)
data = (data + 1) / 2 # 从[-1,1]到[0,1]
recon_data = recon_data.cpu().numpy().transpose(0, 2, 3, 1)
recon_data = (recon_data + 1) / 2 # 从[-1,1]到[0,1]
# 绘制图像
fig, axes = plt.subplots(2, num_images, figsize=(num_images * 2, 4))
for i in range(num_images):
axes[0, i].imshow(data[i])
axes[0, i].axis('off')
axes[1, i].imshow(recon_data[i])
axes[1, i].axis('off')
axes[0, 0].set_ylabel('Original')
axes[1, 0].set_ylabel('Reconstructed')
plt.show()
(3) 定义 generate_samples()
,从潜空间随机采样生成新图像:
def generate_samples(model, device, latent_dim, num_samples=16):
model.eval()
with torch.no_grad():
# 从标准正态分布采样
z = torch.randn(num_samples, latent_dim).to(device)
# 生成样本
samples = model.decode(z).cpu()
samples = samples.numpy().transpose(0, 2, 3, 1)
samples = (samples + 1) / 2 # 从[-1,1]到[0,1]
# 绘制生成的样本
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for i, ax in enumerate(axes.flat):
ax.imshow(samples[i])
ax.axis('off')
plt.show()
(4) 训练模型 50
个 epoch
,训练完成后,可视化模型生成效果,并绘制训练和测试损失变化曲线:
def main():
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 初始化模型
latent_dim = 128
model = VAE(latent_dim=latent_dim).to(device)
# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# 训练参数
epochs = 50
train_losses = []
test_losses = []
# 训练循环
for epoch in range(1, epochs + 1):
train_loss = train(model, device, train_loader, optimizer, epoch)
test_loss = test(model, device, test_loader)
train_losses.append(train_loss)
test_losses.append(test_loss)
# 每5个epoch可视化一次
if epoch % 5 == 0:
visualize_reconstruction(model, device, test_loader)
# 训练完成后可视化
generate_samples(model, device, latent_dim)
# 绘制训练和测试损失曲线
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Test Loss')
plt.show()
main()
重建效果:
生成结果:
更多推荐
所有评论(0)