目录

1.冗余性与显著性筛选

2.量化参数的“贡献度”与剪枝阈值

固定阈值法

分层阈值法

统计分布阈值法

3.稀疏模型重构

参数级重构——构建稀疏权重矩阵

通道级重构——删除冗余通道与卷积核

微调优化——修复精度损失

4.网络剪枝的MATLAB实现


        在深度学习模型性能持续提升的同时,模型参数量与计算复杂度呈指数级增长,给边缘设备部署(如手机、嵌入式芯片)带来巨大挑战。神经网络稀疏化通过去除模型中 “冗余” 的参数或计算单元,在保证精度损失可控的前提下,实现模型压缩、加速与能耗降低。网络剪枝是稀疏化最核心、最成熟的技术之一,其本质是从 “过参数化” 模型中筛选出对任务贡献显著的核心结构,剔除无效或低效的组件,构建 “精简高效” 的稀疏网络。

1.冗余性与显著性筛选

       深度学习模型的“过参数化”特性是剪枝的前提 —— 训练完成的模型中,大量参数(如卷积核权重、全连接层权重)的绝对值接近0,对模型输出的贡献微乎其微,这些参数被称为“冗余参数”。剪枝的核心逻辑可概括为两点:

冗余性识别:通过量化指标(如权重绝对值、梯度贡献、信息熵)评估每个参数 / 结构对模型性能的 “重要性”;

显著性筛选:保留重要性高于阈值的核心组件,删除冗余组件,同时通过 “微调(Fine-tuning)” 修复剪枝导致的精度损失,确保稀疏模型与原模型性能对齐。

从模型结构维度,剪枝可作用于不同粒度的单元,按“从细到粗”可分为:

参数级剪枝:直接删除单个冗余权重(如全连接层中接近0的权重);

通道级剪枝:删除卷积层中贡献微弱的特征通道(及对应卷积核);

层级剪枝:删除整个对任务无必要的网络层(如浅层冗余卷积层、过渡层);

结构级剪枝:删除更大粒度的模块(如ResNet的残差块、Transformer的注意力头)。

       剪掉神经元节点之间的不重要的连接。相当于把权重矩阵中的单个权重值设置为0,一般的,会对权重矩阵中所有的数值按照大小排序,把排在后面的一定比例的值设为0即可。

无论何种粒度,剪枝的核心目标均是在“稀疏度(Sparsity)”与“精度损失(Accuracy Drop)” 之间寻找最优平衡。稀疏度定义为“被剪枝的参数数量占原模型总参数数量的比例”,公式如下:

       例如,一个100M参数的模型若剪枝后剩余10M参数,其稀疏度为90%。工业界常用的稀疏度范围为50%~99%,极端情况下(如边缘设备)可实现99.9%稀疏度(如MobileNet系列通过通道剪枝将稀疏度控制在80%左右,精度损失<1%)。

2.量化参数的“贡献度”与剪枝阈值

        重要性度量是剪枝的“标尺”,决定了哪些参数应被保留,不同剪枝粒度对应不同的度量指标。重要性度量完成后,需设定剪枝阈值(Pruning Threshold) ,将重要性低于阈值的参数 / 结构标记为“待剪枝”。阈值的计算需结合目标稀疏度,常用方法包括:

固定阈值法

       根据目标稀疏度S,对所有参数的重要性评分排序后,取“前(1−S)%参数的最小重要性”作为阈值。例如,目标稀疏度为90%时,取重要性排名前10%的参数中最小的那个值作为阈值T,公式为:

分层阈值法

       不同网络层的参数分布差异较大(如浅层卷积层权重方差大,深层权重方差小),固定全局阈值会导致部分层过度剪枝(精度损失严重)或剪枝不足(压缩效果差)。分层阈值法为每个层单独计算阈值,公式为:

统计分布阈值法

       假设参数重要性服从某一统计分布(如正态分布、拉普拉斯分布),通过分布参数确定阈值。例如,若重要性服从正态分布N(μ,σ2),可设置阈值为μ−kσ(k为超参数,通常取1~2),公式为:

3.稀疏模型重构

       剪枝阈值确定后,需对原模型进行“结构重构”,删除冗余参数并调整网络连接,确保稀疏模型的前向传播逻辑正确。不同剪枝粒度的重构方式不同,其数学表达如下:

参数级重构——构建稀疏权重矩阵

通道级重构——删除冗余通道与卷积核

微调优化——修复精度损失

       剪枝会破坏原模型的参数分布与特征表达能力,导致精度下降。微调(Fine-tuning) 通过在训练集或验证集上继续优化稀疏模型的参数,修复精度损失,其数学本质是 “在稀疏约束下的损失最小化”。

       微调通常采用小学习率(如原训练学习率的1/10~1/100) 和少量迭代次数(如原训练轮数的 1/5) ,避免参数剧烈波动导致精度进一步下降。常用的优化器为SGD、Adam,损失函数与原模型一致(如分类任务用交叉熵损失,回归任务用MSE损失)。

4.网络剪枝的MATLAB实现

% 神经网络剪枝示例程序:基于权重绝对值的参数级剪枝
% 以LeNet-5网络在MNIST数据集上的应用为例

clear; clc; close all;

%% 1. 准备数据
% 加载MNIST数据集
[XTrain, YTrain, XTest, YTest] = loadMNISTData();

% 数据预处理:归一化到[0,1]
XTrain = double(XTrain) / 255;
XTest = double(XTest) / 255;

% 转换标签为独热编码
YTrain = onehotencode(YTrain, 1);
YTest = onehotencode(YTest, 1);

%% 2. 定义并训练原始LeNet-5网络
% 定义LeNet-5网络架构
lenet = [
    imageInputLayer([28 28 1], 'Name', 'input')
    
    convolution2dLayer(5, 20, 'Padding', 0, 'Name', 'conv1')
    reluLayer('Name', 'relu1')
    maxPooling2dLayer(2, 'Stride', 2, 'Name', 'pool1')
    
    convolution2dLayer(5, 50, 'Padding', 0, 'Name', 'conv2')
    reluLayer('Name', 'relu2')
    maxPooling2dLayer(2, 'Stride', 2, 'Name', 'pool2')
    
    fullyConnectedLayer(500, 'Name', 'fc1')
    reluLayer('Name', 'relu3')
    
    fullyConnectedLayer(10, 'Name', 'fc2')
    softmaxLayer('Name', 'softmax')
    classificationLayer('Name', 'classification')
];

% 训练选项
options = trainingOptions('sgdm', ...
    'MaxEpochs', 10, ...
    'InitialLearnRate', 0.01, ...
    'MiniBatchSize', 256, ...
    'ValidationData', {XTest, YTest}, ...
    'ValidationFrequency', 30, ...
    'Verbose', true, ...
    'Plots', 'training-progress');

% 训练原始网络
disp('训练原始网络...');
originalNet = trainNetwork(XTrain, YTrain, lenet, options);

% 评估原始网络性能
[originalPreds, originalScores] = classify(originalNet, XTest);
originalAccuracy = mean(originalPreds == vec2ind(YTest'));
fprintf('原始网络测试准确率: %.4f\n', originalAccuracy);

% 计算原始网络参数数量
originalParams = countParameters(originalNet);
fprintf('原始网络参数数量: %d\n', originalParams);

%% 3. 网络剪枝
% 剪枝参数设置
pruningRatio = 0.5;  % 剪枝比例(保留50%的参数)
layersToPrune = {'conv1', 'conv2', 'fc1', 'fc2'};  % 需要剪枝的层

% 创建剪枝后的网络
prunedNet = originalNet;

% 记录剪枝前后的参数信息
paramInfo = struct();

% 对每一层进行剪枝
for i = 1:length(layersToPrune)
    layerName = layersToPrune{i};
    layerIdx = find(strcmp({prunedNet.Layers.Name}, layerName));
    
    % 获取层权重
    weights = prunedNet.Layers(layerIdx).Weights;
    biases = prunedNet.Layers(layerIdx).Bias;
    
    % 记录原始参数数量
    [rows, cols, channels, filters] = size(weights);
    originalLayerParams = numel(weights);
    paramInfo(i).layer = layerName;
    paramInfo(i).original = originalLayerParams;
    
    % 将权重转换为向量以便排序
    weightsVec = weights(:);
    
    % 计算剪枝阈值(基于权重绝对值)
    absWeights = abs(weightsVec);
    sortedWeights = sort(absWeights);
    thresholdIdx = floor((1 - pruningRatio) * length(sortedWeights));
    threshold = sortedWeights(thresholdIdx);
    
    % 执行剪枝:将低于阈值的权重设为0
    mask = absWeights >= threshold;
    prunedWeightsVec = weightsVec .* mask;
    
    % 恢复为原始权重形状
    prunedWeights = reshape(prunedWeightsVec, size(weights));
    
    % 更新网络权重
    prunedNet.Layers(layerIdx).Weights = prunedWeights;
    
    % 记录剪枝后的参数数量
    prunedLayerParams = sum(mask);
    paramInfo(i).pruned = prunedLayerParams;
    paramInfo(i).ratio = 1 - prunedLayerParams / originalLayerParams;
    
    fprintf('剪枝层 %s: 原始参数=%d, 剪枝后参数=%d, 剪枝比例=%.2f%%\n', ...
        layerName, originalLayerParams, prunedLayerParams, paramInfo(i).ratio*100);
end

% 计算剪枝后网络的总参数数量
prunedParams = 0;
for i = 1:length(paramInfo)
    prunedParams = prunedParams + paramInfo(i).pruned;
end
% 加上偏置参数(这里不剪枝偏置)
prunedParams = prunedParams + sum(cellfun(@(x) numel(x.Bias), {prunedNet.Layers(strcmp({prunedNet.Layers.Name}, layersToPrune))}));

fprintf('剪枝后网络总参数数量: %d (减少了 %.2f%%)\n', ...
    prunedParams, (1 - prunedParams / originalParams) * 100);

% 评估剪枝后未微调的网络性能
[prunedPreds, ~] = classify(prunedNet, XTest);
prunedAccuracyBeforeFT = mean(prunedPreds == vec2ind(YTest'));
fprintf('剪枝后未微调的测试准确率: %.4f\n', prunedAccuracyBeforeFT);

%% 4. 剪枝后网络的微调
% 微调选项:使用较小的学习率
fineTuneOptions = trainingOptions('sgdm', ...
    'MaxEpochs', 5, ...
    'InitialLearnRate', 0.001,  % 学习率比原始训练小
    'MiniBatchSize', 256, ...
    'ValidationData', {XTest, YTest}, ...
    'ValidationFrequency', 30, ...
    'Verbose', true, ...
    'Plots', 'training-progress');

% 执行微调
disp('剪枝后网络微调...');
fineTunedNet = trainNetwork(XTrain, YTrain, prunedNet, fineTuneOptions);

% 评估微调后的网络性能
[ftPreds, ~] = classify(fineTunedNet, XTest);
fineTunedAccuracy = mean(ftPreds == vec2ind(YTest'));
fprintf('剪枝后微调的测试准确率: %.4f\n', fineTunedAccuracy);

%% 5. 结果可视化
% 绘制不同阶段的准确率对比
figure;
bar([originalAccuracy, prunedAccuracyBeforeFT, fineTunedAccuracy]);
xticklabels({'原始网络', '剪枝后未微调', '剪枝后微调'});
ylabel('测试准确率');
title('不同阶段网络准确率对比');
ylim([0.9 1.0]);
grid on;

% 绘制参数数量对比
figure;
bar([originalParams, prunedParams]);
xticklabels({'原始网络', '剪枝后网络'});
ylabel('参数数量');
title('网络参数数量对比');
grid on;

% 可视化部分层的权重稀疏性
layerToVisualize = 'fc1';
layerIdx = find(strcmp({originalNet.Layers.Name}, layerToVisualize));

% 原始权重
originalWeights = originalNet.Layers(layerIdx).Weights;
% 剪枝后权重
prunedWeights = prunedNet.Layers(layerIdx).Weights;

figure;
subplot(1,2,1);
imagesc(originalWeights(:, 1:100));  % 显示部分权重
title(['原始网络 ', layerToVisualize, ' 层权重']);
colorbar;

subplot(1,2,2);
imagesc(prunedWeights(:, 1:100));  % 显示部分权重
title(['剪枝后 ', layerToVisualize, ' 层权重(黑色为0)']);
colorbar;

%% 辅助函数:加载MNIST数据集
function [XTrain, YTrain, XTest, YTest] = loadMNISTData()
    % 从MAT文件加载MNIST数据(假设数据已下载并保存为mnist.mat)
    % mnist.mat应包含XTrain, YTrain, XTest, YTest
    if exist('mnist.mat', 'file')
        load('mnist.mat');
    else
        % 如果没有本地文件,从网络下载(需要网络连接)
        disp('从网络下载MNIST数据集...');
        url = 'https://www.cs.toronto.edu/~kriz/mnist.mat';
        filename = 'mnist.mat';
        websave(filename, url);
        load(filename);
        
        % 转换数据格式
        XTrain = permute(train_images, [2, 1, 3]);
        XTrain = reshape(XTrain, 28, 28, 1, []);
        YTrain = train_labels + 1;  % 标签从1开始
        
        XTest = permute(test_images, [2, 1, 3]);
        XTest = reshape(XTest, 28, 28, 1, []);
        YTest = test_labels + 1;    % 标签从1开始
        
        % 保存数据供下次使用
        save('mnist.mat', 'XTrain', 'YTrain', 'XTest', 'YTest');
    end
end

Logo

助力广东及东莞地区开发者,代码托管、在线学习与竞赛、技术交流与分享、资源共享、职业发展,成为松山湖开发者首选的工作与学习平台

更多推荐