跳转至

4.4 模型剪枝

学习目标

  • 理解什么是模型剪枝.
  • 掌握模型剪枝的基本操作.

img


1.模型剪枝定义

  • 基于深度神经网络的大型预训练模型拥有庞大的参数量, 才能达到SOTA的效果. 但是我们参考生物的神经网络, 发现却是依靠大量稀疏的连接来完成复杂的意识活动.

  • 仿照生物的稀疏神经网络, 将大型网络中的稠密连接变成稀疏的连接, 并同样达到SOTA的效果, 就是模型剪枝的原动力.


  • Pytorch中对模型剪枝的支持在torch.nn.utils.prune模块中, 分以下几种剪枝方式:
  • 对特定网络模块的剪枝(Pruning Model).
  • 多参数模块的剪枝(Pruning multiple parameters).
  • 全局剪枝(GLobal pruning).
  • 用户自定义剪枝(Custom pruning).

  • 注意: 保证Pytorch的版本在1.4.0以上, 支持剪枝操作.

2.代码实现

2.1 配置文件Config

import torch
import os
import datetime
from transformers.models import BertModel,BertTokenizer,BertConfig
current_date=datetime.datetime.now().date().strftime("%Y%m%d")

class Config(object):
    def __init__(self):
        """
        配置类,包含模型和训练所需的各种参数。
        """
        self.model_name = "bert" # 模型名称
        self.data_path = "../../01-data"  #数据集的根路径
        self.train_path = self.data_path + "\\train.txt"  # 训练集
        self.dev_path = self.data_path + "\\dev3.txt"  # 少量验证集,快速验证
        self.test_path = self.data_path + "\\test.txt"  # 测试集

        self.class_path=self.data_path + "\\class.txt" #类别文件

        self.class_list = [line.strip() for line in open(self.class_path, encoding="utf-8")]  # 类别名单

        # BERT模型训练结果保存路径
        self.model_save_path = "./models_save/bert20250521.pt"
        # 剪枝模型训练结果保存路径
        self.prune_model_save_path = "./models_save/prune_bertclassifer_model.pt"

        # 模型训练+预测的时候
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 训练设备,如果GPU可用,则为cuda,否则为cpu

        self.num_classes = len(self.class_list)  # 类别数
        self.num_epochs = 2  # epoch数
        self.batch_size = 256  # mini-batch大小
        self.pad_size = 32  # 每句话处理成的长度(短填长切)
        self.learning_rate = 5e-5  # 学习率
        self.bert_path = "../../04-bert/bert-base-chinese"  # 预训练BERT模型的路径
        self.bert_model=BertModel.from_pretrained(self.bert_path)
        self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) # BERT模型的分词器
        self.bert_config = BertConfig.from_pretrained(self.bert_path) # BERT模型的配置
        self.hidden_size = 768 # BERT模型的隐藏层大小

if __name__ == '__main__':
    conf = Config()
    print(conf.bert_config)
    input_size=conf.tokenizer.convert_tokens_to_ids(["你","好","中国","人"])
    print(input_size)
    print(conf.class_list)

2.2 模型剪枝

(1)导入依赖包

import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from transformers import BertModel
from bert_classifer_model import BertClassifier
from utils import build_dataloader
from train import model2dev
from tqdm import tqdm
from itertools import islice

(2)定义稀疏度计算函数

def compute_sparsity(model):
    """计算所有 encoder 层 query 权重的稀疏度"""
    total_params = 0
    zero_params = 0
    for i in range(12):
        weight = model.bert.encoder.layer[i].attention.self.query.weight
        total_params += weight.numel()
        zero_params += (weight == 0).sum().item()
    return zero_params / total_params if total_params > 0 else 0

(3)定义权重打印函数

def print_weights(weight, name, rows=5, cols=5):
    """打印权重矩阵的前 rows x cols 部分"""
    print(f"\n{name}(前 {rows}x{cols}):")
    print(weight[:rows, :cols])

(4)主函数

BERT 全局非结构化剪枝:对所有 encoder 层注意力权重剪枝 30%,L1 范数。

def main():
    train_dataloader, test_dataloader, dev_dataloader = build_dataloader()

    # 加载模型
    model = BertClassifier().to(conf.device)
    model.load_state_dict(torch.load(conf.model_save_path), strict=False)

    # 剪枝前
    print("剪枝前模型:")
    print(model.bert.encoder.layer[0].attention.self)
    print_weights(model.bert.encoder.layer[0].attention.self.query.weight,
                  "layer[0].attention.self.query.weight 剪枝前")
    report, f1score, accuracy, precision = model2dev(model, dev_dataloader, conf.device)
    print(f"\n剪枝前准确率: {accuracy:.4f}, F1: {f1score:.4f}")

    # 全局非结构化剪枝:所有 encoder 层 query 权重 30%
    parameters_to_prune = [(model.bert.encoder.layer[i].attention.self.query, 'weight') for i in range(12)]
    prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.3)

    for module, param in parameters_to_prune:
        prune.remove(module, param)

    # 剪枝后
    print("\n剪枝后模型:")
    print(model.bert.encoder.layer[0].attention.self)
    print_weights(model.bert.encoder.layer[0].attention.self.query.weight,
                  "layer[0].attention.self.query.weight 剪枝后")
    report, f1score, accuracy, precision = model2dev(model, dev_dataloader, conf.device)
    sparsity = compute_sparsity(model)
    print(f"\n剪枝后准确率: {accuracy:.4f}, F1: {f1score:.4f}\n稀疏度: {sparsity:.4f}")

    # 模型保存
    torch.save(model.state_dict(), conf.prune_model_save_path)


if __name__ == '__main__':
    # 1.加载配置文件
    conf = Config()
    # 2.调用主函数
    main()

输出日志:

XXXXX/bin/python XXXXX/TMFCode/06-model-compression/bert_prune/prune_bert_attention.py 
Loading data: 180000it [00:00, 711615.96it/s]
Loading data: 10000it [00:00, 409776.08it/s]
Loading data: 50it [00:00, 187245.71it/s]

剪枝前模型:
BertSelfAttention(
  (query): Linear(in_features=768, out_features=768, bias=True)
  (key): Linear(in_features=768, out_features=768, bias=True)
  (value): Linear(in_features=768, out_features=768, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

layer[0].attention.self.query.weight 剪枝前(前 5x5):
tensor([[ 0.1152, -0.0104,  0.0063,  0.0414, -0.0410],
        [ 0.0050, -0.0232, -0.0065,  0.0219,  0.0891],
        [ 0.0138,  0.0019,  0.0359, -0.0140, -0.0088],
        [ 0.0024, -0.0525, -0.0323,  0.0530, -0.0187],
        [-0.0470,  0.0525,  0.0182, -0.0156,  0.0729]],
       grad_fn=<SliceBackward0>)
Bert Classifer Evaluating ......: 100%|██████████| 1/1 [00:01<00:00,  1.46s/it]

剪枝前准确率: 0.9600, F1: 0.9407

剪枝后模型:
BertSelfAttention(
  (query): Linear(in_features=768, out_features=768, bias=True)
  (key): Linear(in_features=768, out_features=768, bias=True)
  (value): Linear(in_features=768, out_features=768, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

layer[0].attention.self.query.weight 剪枝后(前 5x5):
tensor([[ 0.1152, -0.0000,  0.0000,  0.0414, -0.0410],
        [ 0.0000, -0.0232, -0.0000,  0.0219,  0.0891],
        [ 0.0000,  0.0000,  0.0359, -0.0000, -0.0000],
        [ 0.0000, -0.0525, -0.0323,  0.0530, -0.0187],
        [-0.0470,  0.0525,  0.0182, -0.0000,  0.0729]],
       grad_fn=<SliceBackward0>)
Bert Classifer Evaluating ......: 100%|██████████| 1/1 [00:01<00:00,  1.30s/it]

剪枝后准确率: 0.9400, F1: 0.9187
稀疏度: 0.3000

Process finished with exit code 0

结论:经过全局剪枝操作后,模型的F1为91.87%,相较于最好的指标,下降2个百分点左右。

3.本节小结

  • 本部分完成了全局非结构化剪枝:对所有 encoder 层注意力权重剪枝 30%,L1 范数。