0. 前言

虽然自编码器 (AutoEncoder, AE) 在重建输入数据方面表现良好,但通常在生成训练集中不存在的新样本时表现不佳。更重要的是,自编码器在输入插值方面同样表现不佳,无法生成两个输入数据点之间的中间表示。这就引出了变分自编码器 (Variational Auto-Encoder, VAE),变分自编码器是一种生成模型,结合了深度学习和概率图模型的优点,通过学习数据的潜在概率分布来生成新的数据样本。本节将从零开始构建和训练一个 VAE,使用 cifar-10 数据集训练 VAE

1. 潜空间运算

使用变分自编码器 (Variational Auto-Encoder, VAE) 可以进行向量运算和输入插值。操作不同输入的编码表示(潜向量),以在解码时实现特定的结果(例如,图像中是否具有某些特征)。潜向量控制解码图像中的不同特征,如性别、图像中是否有眼镜等。例如,可以首先获得戴眼镜的男性的潜向量 (z1)、戴眼镜的女性的潜向量 (z2) 和不戴眼镜的女性的潜向量 (z3)。然后,计算一个新的潜向量 z4 = z1 – z2 + z3。由于 z1z2 解码后都会出现眼镜,z1 – z2 会在结果图像中去除眼镜特征。类似地,由于 z2z3 都会解码为女性面孔,z3 – z2 会去除结果图像中的女性特征。因此,如果使用训练好的 VAE 解码 z4 将得到一张没有不戴眼镜的男性图像。

2. 变分自编码器

虽然自编码器 (AutoEncoder, AE) 擅长重建原始图像,但它们在生成训练集中没有出现的新图像方面表现不佳。此外,自编码器通常无法将相似的输入映射到潜空间中的相邻点。因此,AE 的潜空间既不连续,也不容易解释。例如,无法通过插值两个输入数据点来生成有意义的中间表示。基于这些原因,我们将学习自编码器的改进模型,变分自编码器 (Variational Auto-Encoder, VAE)。

2.1 VAE 工作原理

VAE 使用深度学习构建概率模型,将输入数据映射到一个低维度的潜空间中,并通过解码器将潜空间中的分布转换回数据空间中,以生成与原始数据相似的数据。与传统的自编码器相比,VAE 更加稳定,生成样本的质量更高。
VAE 的核心思想是利用概率模型来描述高维的输入数据,将输入数据采样于一个低维度的潜变量分布中,并通过解码器生成与原始数据相似的输出。具体来说,VAE 同样是由编码器和解码器组成:

  • 编码器将数据 x x x 映射到一个潜在空间 z z z 中,该空间定义在低维正态分布中,即 z ∼ N ( 0 , I ) z∼N(0,I) zN(0,I),编码器由两个部分组成:一是将数据映射到均值和方差,即 z ∼ N ( μ , σ 2 ) z∼N(μ,σ^2) zN(μ,σ2);二是通过重参数化技巧,将均值和方差的采样过程分离出来,并引入随机变量 ϵ ∼ N ( 0 , I ) ϵ∼N(0,I) ϵN(0,I),使得 z = μ + ϵ σ z=μ+ϵσ z=μ+ϵσ
  • 解码器将潜在变量 z z z 映射回数据空间中,生成与原始数据 x x x 相似的数据 x ′ x′ x,为了使生成的数据 x ′ x′ x 能够与原始数据 x x x 较高的相似度,VAE 在损失函数中使用重构误差和正则化项,重构误差表示生成数据与原始数据之间的差异,正则化项用于约束潜在变量的分布,使其满足高斯正态分布,使得 VAE 从潜空间中生成的样本质量更高

VAE 具有广泛的应用场景,如图像生成、语音、自然语言处理等领域,它能够通过有限的数据样本学习到输入数据背后的潜在规律,生成与原始数据类似的新数据,具有很强的潜数据的可解释性。

2.2 VAE 构建策略

VAE 中,基于预定义分布获得的随机向量生成逼真图像,而在传统自编码器中并未指定在网络中生成图像的数据分布。可以通过以下策略,实现 VAE:

  1. 编码器的输出包括两个向量:
    • 输入图像平均值
    • 输入图像标准差
  2. 根据以上两个向量,通过在均值和标准差之和中引入随机变量 ( ϵ ∼ N ( 0 , I ) ϵ∼N(0,I) ϵN(0,I)) 获取随机向量 ( z = μ + ϵ σ z=μ+ϵσ z=μ+ϵσ)
  3. 将上一步得到的随机向量作为输入传递给解码器以重构图像
  4. 损失函数是均方误差和 KL 散度损失的组合:
    • KL 散度损失衡量由均值向量 μ \mu μ 和标准差向量 σ \sigma σ 构建的分布与 N ( 0 , I ) N(0,I) N(0,I) 分布的偏差
    • 均方损失用于优化重建(解码)图像

通过训练网络,指定输入数据满足由均值向量 μ \mu μ 和标准差向量 σ \sigma σ 构建的 N ( 0 , 1 ) N(0,1) N(0,1) 分布,当我们生成均值为 0 且标准差为 1 的随机噪声时,解码器将能够生成逼真的图像。
需要注意的是,如果只最小化 KL 散度,编码器将预测均值向量为 0,标准差为 1。因此,需要同时最小化 KL 散度损失和均方损失。在下一节中,让我们介绍 KL 散度,以便将其纳入模型的损失值计算中。

2.3 KL 散度

KL 散度(也称相对熵)可以用于衡量两个概率分布之间的差异:

K L ( P ∣ ∣ Q ) = ∑ x ∈ X P ( x ) l n ( P ( i ) Q ( i ) ) KL(P||Q) = \sum_{x∈X} P(x) ln(\frac {P(i)}{Q(i)}) KL(P∣∣Q)=xXP(x)ln(Q(i)P(i))

其中, P P P Q Q Q 为两个概率分布,KL 散度的值越小,两个分布的相似性就越高,当且仅当 P P P Q Q Q 两个概率分布完全相同时,KL 散度等于 0。在 VAE 中,我们希望瓶颈特征值遵循平均值为 0 和标准差为 1 的正态分布。因此,我们可以使用 KL 散度衡量变分自编码器中编码器输出的分布与标准高斯分布 N ( 0 , 1 ) N(0,1) N(0,1) 之间的差异。
可以通过以下公式计算 KL 散度损失:

∑ i = 1 n σ i 2 + μ i 2 − l o g ∗ ( σ i ) − 1 \sum_{i=1}^n\sigma_i^2+\mu_i^2-log*(\sigma_i)-1 i=1nσi2+μi2log(σi)1

在上式中, σ σ σ μ μ μ 表示每个输入图像的均值和标准差值:

  • 确保均值向量分布在 0 附近:
    • 最小化上式中的均方误差 ( μ i 2 \mu_i^2 μi2) 可确保 μ \mu μ 尽可能接近 0
  • 确保标准差向量分布在 1 附近:
    • 上式中其余部分(除了 μ i 2 \mu_i^2 μi2 )用于确保标准差 ( s i g m a sigma sigma) 分布在 1 附近

当均值 ( μ μ μ) 为 0 且标准差为 1 时,以上损失函数值达到最小,通过引入标准差的对数,确保 σ \sigma σ 值不为负。通过最小化以上损失可以确保编码器输出遵循预定义分布。

2.4 重参数化技巧

下图左侧显示了 VAE 网络。编码器获取输入 x x x,并估计潜矢量 z z z 的多元高斯分布的均值 μ μ μ 和标准差 σ σ σ,解码器从潜矢量 z z z 采样,以将输入重构为 x x x
VAE
但是反向传播梯度不会通过随机采样块。虽然可以为神经网络提供随机输入,但梯度不可能穿过随机层。解决此问题的方法是将“采样”过程作为输入,如图右侧所示。 采样计算为:

S a m p l e = μ + ε σ Sample=\mu + εσ Sample=μ+εσ

如果 ε ε ε σ σ σ 以矢量形式表示,则 ε σ εσ εσ 是逐元素乘法,使用上式,令采样好像直接来自于潜空间。 这种技术被称为重参数化技巧 (Reparameterization trick)。

3. 实现 VAE

在本节中,使用 PyTorch 实现 VAE 模型生成 cifar-10 图像。

3.1 数据加载

(1) 首先导入所需的库:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

(2) 定义数据预处理转换:

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 将像素值归一化到[-1,1]
])

(3) 加载 CIFAR-10 训练集和测试集:

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

(4) 创建数据加载器:

batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

3.2 模型构建

(1) 定义 VAE 模型,由编码器和解码器构成:

class VAE(nn.Module):
    def __init__(self, latent_dim=128):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        
        # 编码器
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),  # 32x16x16
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # 64x8x8
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 128x4x4
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Flatten(),  # 128*4*4=2048
            nn.Linear(2048, 1024),
            nn.ReLU()
        )
        
        # 潜在空间的均值和对数方差
        self.fc_mu = nn.Linear(1024, latent_dim)
        self.fc_logvar = nn.Linear(1024, latent_dim)
        
        # 解码器
        self.decoder_input = nn.Linear(latent_dim, 1024)
        
        self.decoder = nn.Sequential(
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),  # 128x4x4
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 64x8x8
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # 32x16x16
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),  # 3x32x32
            nn.Tanh()  # 输出在[-1,1]之间,与输入归一化一致
        )
    
    def encode(self, x):
        """
        编码输入图像x,返回潜在空间的均值和方差
        """
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        """
        重参数化技巧,从N(mu, var)采样
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        """
        从潜在变量z解码重构图像
        """
        h = self.decoder_input(z)
        x_recon = self.decoder(h)
        return x_recon
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar

(2) 定义损失函数,由重建损失和 KL 散度组成:

def vae_loss(recon_x, x, mu, logvar):
    """
    VAE损失函数 = 重构损失 + KL散度
    """
    # 重构损失
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    
    # KL散度:-0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return recon_loss + kl_loss

3.3 模型训练

(1) 定义模型训练和测试函数:

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        
        # 前向传播
        recon_batch, mu, logvar = model(data)
        
        # 计算损失
        loss = vae_loss(recon_batch, data, mu, logvar)
        
        # 反向传播和优化
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item() / len(data):.6f}')
    
    avg_loss = train_loss / len(train_loader.dataset)
    print(f'====> Epoch: {epoch} Average loss: {avg_loss:.4f}')
    return avg_loss

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += vae_loss(recon_batch, data, mu, logvar).item()
    
    test_loss /= len(test_loader.dataset)
    print(f'====> Test set loss: {test_loss:.4f}')
    return test_loss

(2) 定义可视化函数,用于可视化原始图像和重构图像:

def visualize_reconstruction(model, device, test_loader, num_images=8):
    model.eval()
    with torch.no_grad():
        # 获取一批测试图像
        data, _ = next(iter(test_loader))
        data = data[:num_images].to(device)
        
        # 重构图像
        recon_data, _, _ = model(data)
        
        # 将图像从[-1,1]转换回[0,1]以便显示
        data = data.cpu().numpy().transpose(0, 2, 3, 1)
        data = (data + 1) / 2  # 从[-1,1]到[0,1]
        
        recon_data = recon_data.cpu().numpy().transpose(0, 2, 3, 1)
        recon_data = (recon_data + 1) / 2  # 从[-1,1]到[0,1]
        
        # 绘制图像
        fig, axes = plt.subplots(2, num_images, figsize=(num_images * 2, 4))
        for i in range(num_images):
            axes[0, i].imshow(data[i])
            axes[0, i].axis('off')
            axes[1, i].imshow(recon_data[i])
            axes[1, i].axis('off')
        axes[0, 0].set_ylabel('Original')
        axes[1, 0].set_ylabel('Reconstructed')
        plt.show()

(3) 定义 generate_samples(),从潜空间随机采样生成新图像:

def generate_samples(model, device, latent_dim, num_samples=16):
    model.eval()
    with torch.no_grad():
        # 从标准正态分布采样
        z = torch.randn(num_samples, latent_dim).to(device)
        
        # 生成样本
        samples = model.decode(z).cpu()
        samples = samples.numpy().transpose(0, 2, 3, 1)
        samples = (samples + 1) / 2  # 从[-1,1]到[0,1]
        
        # 绘制生成的样本
        fig, axes = plt.subplots(4, 4, figsize=(8, 8))
        for i, ax in enumerate(axes.flat):
            ax.imshow(samples[i])
            ax.axis('off')
        plt.show()

(4) 训练模型 50epoch,训练完成后,可视化模型生成效果,并绘制训练和测试损失变化曲线:

def main():
    # 设置设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # 初始化模型
    latent_dim = 128
    model = VAE(latent_dim=latent_dim).to(device)
    
    # 定义优化器
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    
    # 训练参数
    epochs = 50
    train_losses = []
    test_losses = []
    
    # 训练循环
    for epoch in range(1, epochs + 1):
        train_loss = train(model, device, train_loader, optimizer, epoch)
        test_loss = test(model, device, test_loader)
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        
        # 每5个epoch可视化一次
        if epoch % 5 == 0:
            visualize_reconstruction(model, device, test_loader)
    
    # 训练完成后可视化
    generate_samples(model, device, latent_dim)
    
    # 绘制训练和测试损失曲线
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Test Loss')
    plt.show()

main()

重建效果:

重建效果
生成结果:

生成结果

Logo

助力广东及东莞地区开发者,代码托管、在线学习与竞赛、技术交流与分享、资源共享、职业发展,成为松山湖开发者首选的工作与学习平台

更多推荐