ChatGLM-6B的P-Tuning微调详细步骤及结果验证

05-14 阅读 0评论

文章目录

    • 1. ChatGLM-6B
      • 1.1 P-Tuning v2简介
      • 2. 运行环境
        • 2.1 项目准备
        • 3.数据准备
        • 4.使用P-Tuning v2对ChatGLM-6B微调
        • 5. 模型评估
        • 6. 利用微调后的模型进行验证
          • 6.1 微调后的模型
          • 6.2 原始ChatGLM-6B模型
          • 6.3 结果对比

            1. ChatGLM-6B

            ChatGLM-6B仓库地址:https://github.com/THUDM/ChatGLM-6B

            ChatGLM-6B/P-Tuning仓库地址:https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning

            1.1 P-Tuning v2简介

            P-Tuning是一种较新的模型微调方法,它采用了参数剪枝的技术,可以将微调的参数量减少到原来的0.1%。具体来说,P-Tuning v2是基于P-Tuning v1的升级版,主要的改进在于采用了更加高效的剪枝方法,可以进一步减少模型微调的参数量。

            P-Tuning v2的原理是通过对已训练好的大型语言模型进行参数剪枝,得到一个更加小巧、效率更高的轻量级模型。具体地,P-Tuning v2首先使用一种自适应的剪枝策略,对大型语言模型中的参数进行裁剪,去除其中不必要的冗余参数。然后,对于被剪枝的参数,P-Tuning v2使用了一种特殊的压缩方法,能够更加有效地压缩参数大小,并显著减少模型微调的总参数量。

            总的来说,P-Tuning v2的核心思想是让模型变得更加轻便、更加高效,同时尽可能地保持模型的性能不受影响。这不仅可以加快模型的训练和推理速度,还可以减少模型在使用过程中的内存和计算资源消耗,让模型更适用于各种实际应用场景中。

            2. 运行环境

            本项目租借autoDL GPU机器,具体配置如下:

            ChatGLM-6B的P-Tuning微调详细步骤及结果验证

            ChatGLM-6B的P-Tuning微调详细步骤及结果验证

            2.1 项目准备

            1.创建conda环境

            conda create -n tuning-chatglm python=3.8
            conda activate tuning-chatglm
            

            2.拉取ChatGLM-6B项目代码

            # 拉取代码
            git clone https://github.com/THUDM/ChatGLM-6B.git
            # 安装依赖库
            pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple/
            

            3.进入ptuning目录

            运行微调需要4.27.1版本的transformers。除 ChatGLM-6B 的依赖之外,还需要安装以下依赖

            cd ptuning
            # 再次安装依赖,ptuning文档里有说明
            pip install rouge_chinese nltk jieba datasets  -i https://pypi.tuna.tsinghua.edu.cn/simple/
            

            4.补充

            对于需要pip安装失败的依赖,可以采用源码安装的方式,具体步骤如下

            git clone https://github.com/huggingface/peft.git
            cd peft
            pip install -e .
            

            3.数据准备

            官方微调样例是以 ADGEN (广告生成) 数据集为例来介绍微调的具体使用。

            ADGEN 数据集为根据输入(content)生成一段广告词(summary),具体格式如下所示:

            {
                "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
                "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
            }
            

            请从官网下载 ADGEN 数据集,放到ptuning目录下并将其解压到 AdvertiseGen 目录。

            下载地址:https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view

            tar -zxvf AdvertiseGen.tar.gz
            

            ChatGLM-6B的P-Tuning微调详细步骤及结果验证

            查看数据集大小:

            > wc -l AdvertiseGen/*
            > 1070 AdvertiseGen/dev.json
            > 114599 AdvertiseGen/train.json
            > 115669 total
            

            4.使用P-Tuning v2对ChatGLM-6B微调

            对于 ChatGLM-6B 模型基于 P-Tuning v2 进行微调。可将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。

            进入到ptuning目录,首先,修改train.sh脚本,主要是修改其中的train_file、validation_file、model_name_or_path、output_dir参数:

            • train_file:训练数据文件位置
            • validation_file:验证数据文件位置
            • model_name_or_path:原始ChatGLM-6B模型文件路径
            • output_dir:输出模型文件路径
              PRE_SEQ_LEN=128
              LR=2e-2
              CUDA_VISIBLE_DEVICES=0 python3 main.py \
                  --do_train \
                  --train_file AdvertiseGen/train.json \
                  --validation_file AdvertiseGen/dev.json \
                  --prompt_column content \
                  --response_column summary \
                  --overwrite_cache \
                  --model_name_or_path model/chatglm-6b \
                  --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
                  --overwrite_output_dir \
                  --max_source_length 64 \
                  --max_target_length 64 \
                  --per_device_train_batch_size 1 \
                  --per_device_eval_batch_size 1 \
                  --gradient_accumulation_steps 16 \
                  --predict_with_generate \
                  --max_steps 3000 \
                  --logging_steps 10 \
                  --save_steps 1000 \
                  --learning_rate $LR \
                  --pre_seq_len $PRE_SEQ_LEN \
                  --quantization_bit 4
              

              执行bash train.sh脚本,运行过程如下:

                0%|                  | 0/3000 [00:00

免责声明
本网站所收集的部分公开资料来源于AI生成和互联网,转载的目的在于传递更多信息及用于网络分享,并不代表本站赞同其观点和对其真实性负责,也不构成任何其他建议。
文章版权声明:除非注明,否则均为主机测评原创文章,转载或复制请以超链接形式并注明出处。

发表评论

快捷回复: 表情:
评论列表 (暂无评论,人围观)

还没有评论,来说两句吧...

目录[+]