4.4 模型剪枝¶
学习目标
- 理解什么是模型剪枝.
- 掌握模型剪枝的基本操作.

1.模型剪枝定义¶
- 基于深度神经网络的大型预训练模型拥有庞大的参数量, 才能达到SOTA的效果. 但是我们参考生物的神经网络, 发现却是依靠大量稀疏的连接来完成复杂的意识活动.
- 仿照生物的稀疏神经网络, 将大型网络中的稠密连接变成稀疏的连接, 并同样达到SOTA的效果, 就是模型剪枝的原动力.

- Pytorch中对模型剪枝的支持在torch.nn.utils.prune模块中, 分以下几种剪枝方式:
- 对特定网络模块的剪枝(Pruning Model).
- 多参数模块的剪枝(Pruning multiple parameters).
- 全局剪枝(GLobal pruning).
- 用户自定义剪枝(Custom pruning).
- 注意: 保证Pytorch的版本在1.4.0以上, 支持剪枝操作.
2.代码实现¶
2.1 配置文件Config¶
import torch
import os
import datetime
from transformers.models import BertModel,BertTokenizer,BertConfig
current_date=datetime.datetime.now().date().strftime("%Y%m%d")
class Config(object):
def __init__(self):
"""
配置类,包含模型和训练所需的各种参数。
"""
self.model_name = "bert" # 模型名称
self.data_path = "../../01-data" #数据集的根路径
self.train_path = self.data_path + "\\train.txt" # 训练集
self.dev_path = self.data_path + "\\dev3.txt" # 少量验证集,快速验证
self.test_path = self.data_path + "\\test.txt" # 测试集
self.class_path=self.data_path + "\\class.txt" #类别文件
self.class_list = [line.strip() for line in open(self.class_path, encoding="utf-8")] # 类别名单
# BERT模型训练结果保存路径
self.model_save_path = "./models_save/bert20250521.pt"
# 剪枝模型训练结果保存路径
self.prune_model_save_path = "./models_save/prune_bertclassifer_model.pt"
# 模型训练+预测的时候
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 训练设备,如果GPU可用,则为cuda,否则为cpu
self.num_classes = len(self.class_list) # 类别数
self.num_epochs = 2 # epoch数
self.batch_size = 256 # mini-batch大小
self.pad_size = 32 # 每句话处理成的长度(短填长切)
self.learning_rate = 5e-5 # 学习率
self.bert_path = "../../04-bert/bert-base-chinese" # 预训练BERT模型的路径
self.bert_model=BertModel.from_pretrained(self.bert_path)
self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) # BERT模型的分词器
self.bert_config = BertConfig.from_pretrained(self.bert_path) # BERT模型的配置
self.hidden_size = 768 # BERT模型的隐藏层大小
if __name__ == '__main__':
conf = Config()
print(conf.bert_config)
input_size=conf.tokenizer.convert_tokens_to_ids(["你","好","中国","人"])
print(input_size)
print(conf.class_list)
2.2 模型剪枝¶
(1)导入依赖包¶
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from transformers import BertModel
from bert_classifer_model import BertClassifier
from utils import build_dataloader
from train import model2dev
from tqdm import tqdm
from itertools import islice
(2)定义稀疏度计算函数¶
def compute_sparsity(model):
"""计算所有 encoder 层 query 权重的稀疏度"""
total_params = 0
zero_params = 0
for i in range(12):
weight = model.bert.encoder.layer[i].attention.self.query.weight
total_params += weight.numel()
zero_params += (weight == 0).sum().item()
return zero_params / total_params if total_params > 0 else 0
(3)定义权重打印函数¶
def print_weights(weight, name, rows=5, cols=5):
"""打印权重矩阵的前 rows x cols 部分"""
print(f"\n{name}(前 {rows}x{cols}):")
print(weight[:rows, :cols])
(4)主函数¶
BERT 全局非结构化剪枝:对所有 encoder 层注意力权重剪枝 30%,L1 范数。
def main():
train_dataloader, test_dataloader, dev_dataloader = build_dataloader()
# 加载模型
model = BertClassifier().to(conf.device)
model.load_state_dict(torch.load(conf.model_save_path), strict=False)
# 剪枝前
print("剪枝前模型:")
print(model.bert.encoder.layer[0].attention.self)
print_weights(model.bert.encoder.layer[0].attention.self.query.weight,
"layer[0].attention.self.query.weight 剪枝前")
report, f1score, accuracy, precision = model2dev(model, dev_dataloader, conf.device)
print(f"\n剪枝前准确率: {accuracy:.4f}, F1: {f1score:.4f}")
# 全局非结构化剪枝:所有 encoder 层 query 权重 30%
parameters_to_prune = [(model.bert.encoder.layer[i].attention.self.query, 'weight') for i in range(12)]
prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.3)
for module, param in parameters_to_prune:
prune.remove(module, param)
# 剪枝后
print("\n剪枝后模型:")
print(model.bert.encoder.layer[0].attention.self)
print_weights(model.bert.encoder.layer[0].attention.self.query.weight,
"layer[0].attention.self.query.weight 剪枝后")
report, f1score, accuracy, precision = model2dev(model, dev_dataloader, conf.device)
sparsity = compute_sparsity(model)
print(f"\n剪枝后准确率: {accuracy:.4f}, F1: {f1score:.4f}\n稀疏度: {sparsity:.4f}")
# 模型保存
torch.save(model.state_dict(), conf.prune_model_save_path)
if __name__ == '__main__':
# 1.加载配置文件
conf = Config()
# 2.调用主函数
main()
输出日志:
XXXXX/bin/python XXXXX/TMFCode/06-model-compression/bert_prune/prune_bert_attention.py
Loading data: 180000it [00:00, 711615.96it/s]
Loading data: 10000it [00:00, 409776.08it/s]
Loading data: 50it [00:00, 187245.71it/s]
剪枝前模型:
BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
layer[0].attention.self.query.weight 剪枝前(前 5x5):
tensor([[ 0.1152, -0.0104, 0.0063, 0.0414, -0.0410],
[ 0.0050, -0.0232, -0.0065, 0.0219, 0.0891],
[ 0.0138, 0.0019, 0.0359, -0.0140, -0.0088],
[ 0.0024, -0.0525, -0.0323, 0.0530, -0.0187],
[-0.0470, 0.0525, 0.0182, -0.0156, 0.0729]],
grad_fn=<SliceBackward0>)
Bert Classifer Evaluating ......: 100%|██████████| 1/1 [00:01<00:00, 1.46s/it]
剪枝前准确率: 0.9600, F1: 0.9407
剪枝后模型:
BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
layer[0].attention.self.query.weight 剪枝后(前 5x5):
tensor([[ 0.1152, -0.0000, 0.0000, 0.0414, -0.0410],
[ 0.0000, -0.0232, -0.0000, 0.0219, 0.0891],
[ 0.0000, 0.0000, 0.0359, -0.0000, -0.0000],
[ 0.0000, -0.0525, -0.0323, 0.0530, -0.0187],
[-0.0470, 0.0525, 0.0182, -0.0000, 0.0729]],
grad_fn=<SliceBackward0>)
Bert Classifer Evaluating ......: 100%|██████████| 1/1 [00:01<00:00, 1.30s/it]
剪枝后准确率: 0.9400, F1: 0.9187
稀疏度: 0.3000
Process finished with exit code 0
结论:经过全局剪枝操作后,模型的F1为91.87%,相较于最好的指标,下降2个百分点左右。
3.本节小结¶
- 本部分完成了全局非结构化剪枝:对所有 encoder 层注意力权重剪枝 30%,L1 范数。