LLM - LoRA 模型合并与保存_merge_and_unload-程序员宅基地

技术标签: LoRA  merge  LLM  

目录

一.引言

二.LoRA 

1.LoRA 简介

2.LoRA 参数

3.LoRA 合并

4.LoRA 保存

三.总结


一.引言

LLM 使用过程中最常用方法之一就是通过 LoRA 基于自己的数据对大模型进行微调,本文简单介绍 LoRA 原理以及如何合并多个 LoRA 模型并保存。

peft==0.4.0
transformers==4.29.1

二.LoRA 

1.LoRA 简介

LoRA 是一种消耗较少内存同时加速大型模型微调的技术。

为了使得微调更有效,LoRA 的方法是通过低秩分解两个较小的矩阵 [称为更新矩阵] 用于参数更新。在维持较少更新参数的基础上,训练并适应新的微调数据。原始参数 W 保持冻结,不会进行更新训练。该方法有如下优点:

    ◆  LoRA 通过大幅减少可训练参数数量,使微调更加效率。

    ◆  基于 LoRA 的轻便型,可以构建多个轻量级 LoRA 模型用于不同下游任务 

    ◆  LoRA 与许多其他参数有效方法正交,并且可以与其结合。

      使用 LoRA 进行微调的模型的性能与完全微调模型的性能相当。

    ◆  LoRA 几乎不添加任何推理延迟,因为适配器权重可以与基本模型合并。

2.LoRA 参数

原则上,LoRA 可以应用于神经网络中权重矩阵的任何子集,以减少可训练参数的数量。然而,为了简化和进一步提高参数效率,在 Transformer 模型中,LoRA 通常仅应用于注意力块。LoRA 模型中可训练参数的结果数量取决于低秩更新矩阵的大小,其主要由秩 r 和原始权重矩阵的形状确定。实际使用过程中,通过选择不同的 lora_target 决定训练的参数量,以 LLama 为例:

--lora_target q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj \

可以通过 numel 方法获取张量中参数的数量和 requires_grad 方法获取模型张量是否进行梯度计算来计算可训练参数的比例:

# 打印可训练参数
def print_trainable_params(model: torch.nn.Module) -> None:
    trainable_params, all_param = 0, 0
    for param in model.parameters():
        num_params = param.numel()
        # if using DS Zero 3 and the weights are initialized empty
        if num_params == 0 and hasattr(param, "ds_numel"):
            num_params = param.ds_numel
        all_param += num_params
        if param.requires_grad:
            trainable_params += num_params
    print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
        trainable_params, all_param, 100 * trainable_params / all_param))

3.LoRA 合并

在训练过程中,较小的权重矩阵 A、B 是分开的。但一旦训练完成,权重实际上可以合并到一个相同的新权重矩阵中。虽然 LoRA 明显更小,训练速度更快,但由于分别加载基本模型和 LoRA 模型,可能会在推理过程中遇到延迟问题。为了消除延迟,可以使用 merge_and_unload 函数将适配器权重与基本模型合并,这样可以有效地将新合并的模型用作独立模型。同时也可以合并多个 LoRA 模型,使得 Base Model 同时具备多个任务处理能力。

训练时原始参数 W 保持不动,更新矩阵 A/B。合并时 W-merged = W + BA 代替原有的 W。 

from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, GenerationConfig
from peft import PeftModel

    # 载入预训练模型
    tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True, padding_side="left", **config_kwargs)
    print("Tokenizer Load Success!")
    config = AutoConfig.from_pretrained(base_model, **config_kwargs)
    # Load and prepare pretrained models (without valuehead).
    model = AutoModelForCausalLM.from_pretrained(
        base_model,
        config=config,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
        revision='main'
    )
    print('origin config =', model.config)
    # 模型合并
    ckpt_list = ["checkpoint-1000", "checkpoint-2000", "checkpoint-3000"]
    for checkpoint in ckpt_list:
        print('Merge checkpoint: {}'.format(checkpoint))
        model = PeftModel.from_pretrained(model, os.path.join(lora_model, checkpoint))
        model = model.merge_and_unload()
    print('merge config =', model.config)

通过 merge_and_unlaod 方法可以合并多个 Lora 模型,这里博主尝试将同一个模型的 3 个  CKPT 合并至原始模型中。通过合并不同类型任务的 CKPT,原始模型可以同时具备多种下游任务的能力且推理效率不会受影响。

4.LoRA 保存

import torch
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer


def apply_lora(model_name_or_path, output_path, lora_path):
    print(f"Loading the base model from {model_name_or_path}")
    base = AutoModelForCausalLM.from_pretrained(
        model_name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
    )
    base_tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path)

    print(f"Loading the LoRA adapter from {lora_path}")

    lora_model = PeftModel.from_pretrained(
        base,
        lora_path,
        torch_dtype=torch.float16,
    )

    print("Applying the LoRA")
    model = lora_model.merge_and_unload()

    print(f"Saving the target model to {output_path}")
    model.save_pretrained(output_path)
    base_tokenizer.save_pretrained(output_path)

使用 merge_and_unload 方法进行参数合并,再调用 save_pretrained 方法进行模型保存。可以看到保存前后模型大小几乎没有变化。 

 

三.总结

合并多个 LoRA 模型也可能对模型的原始能力造成影响,大家可以根据需求测试与尝试。

参考:

- LoRA https://huggingface.co/docs/peft/conceptual_guides/lora

- PEFT Models 

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/BIT_666/article/details/132065177

智能推荐

c# 调用c++ lib静态库_c#调用lib-程序员宅基地

文章浏览阅读2w次,点赞7次,收藏51次。四个步骤1.创建C++ Win32项目动态库dll 2.在Win32项目动态库中添加 外部依赖项 lib头文件和lib库3.导出C接口4.c#调用c++动态库开始你的表演...①创建一个空白的解决方案,在解决方案中添加 Visual C++ , Win32 项目空白解决方案的创建:添加Visual C++ , Win32 项目这......_c#调用lib

deepin/ubuntu安装苹方字体-程序员宅基地

文章浏览阅读4.6k次。苹方字体是苹果系统上的黑体,挺好看的。注重颜值的网站都会使用,例如知乎:font-family: -apple-system, BlinkMacSystemFont, Helvetica Neue, PingFang SC, Microsoft YaHei, Source Han Sans SC, Noto Sans CJK SC, W..._ubuntu pingfang

html表单常见操作汇总_html表单的处理程序有那些-程序员宅基地

文章浏览阅读159次。表单表单概述表单标签表单域按钮控件demo表单标签表单标签基本语法结构<form action="处理数据程序的url地址“ method=”get|post“ name="表单名称”></form><!--action,当提交表单时,向何处发送表单中的数据,地址可以是相对地址也可以是绝对地址--><!--method将表单中的数据传送给服务器处理,get方式直接显示在url地址中,数据可以被缓存,且长度有限制;而post方式数据隐藏传输,_html表单的处理程序有那些

PHP设置谷歌验证器(Google Authenticator)实现操作二步验证_php otp 验证器-程序员宅基地

文章浏览阅读1.2k次。使用说明:开启Google的登陆二步验证(即Google Authenticator服务)后用户登陆时需要输入额外由手机客户端生成的一次性密码。实现Google Authenticator功能需要服务器端和客户端的支持。服务器端负责密钥的生成、验证一次性密码是否正确。客户端记录密钥后生成一次性密码。下载谷歌验证类库文件放到项目合适位置(我这边放在项目Vender下面)https://github.com/PHPGangsta/GoogleAuthenticatorPHP代码示例://引入谷_php otp 验证器

【Python】matplotlib.plot画图横坐标混乱及间隔处理_matplotlib更改横轴间距-程序员宅基地

文章浏览阅读4.3k次,点赞5次,收藏11次。matplotlib.plot画图横坐标混乱及间隔处理_matplotlib更改横轴间距

docker — 容器存储_docker 保存容器-程序员宅基地

文章浏览阅读2.2k次。①Storage driver 处理各镜像层及容器层的处理细节,实现了多层数据的堆叠,为用户 提供了多层数据合并后的统一视图②所有 Storage driver 都使用可堆叠图像层和写时复制(CoW)策略③docker info 命令可查看当系统上的 storage driver主要用于测试目的,不建议用于生成环境。_docker 保存容器

随便推点

网络拓扑结构_网络拓扑csdn-程序员宅基地

文章浏览阅读834次,点赞27次,收藏13次。网络拓扑结构是指计算机网络中各组件(如计算机、服务器、打印机、路由器、交换机等设备)及其连接线路在物理布局或逻辑构型上的排列形式。这种布局不仅描述了设备间的实际物理连接方式,也决定了数据在网络中流动的路径和方式。不同的网络拓扑结构影响着网络的性能、可靠性、可扩展性及管理维护的难易程度。_网络拓扑csdn

JS重写Date函数,兼容IOS系统_date.prototype 将所有 ios-程序员宅基地

文章浏览阅读1.8k次,点赞5次,收藏8次。IOS系统Date的坑要创建一个指定时间的new Date对象时,通常的做法是:new Date("2020-09-21 11:11:00")这行代码在 PC 端和安卓端都是正常的,而在 iOS 端则会提示 Invalid Date 无效日期。在IOS年月日中间的横岗许换成斜杠,也就是new Date("2020/09/21 11:11:00")通常为了兼容IOS的这个坑,需要做一些额外的特殊处理,笔者在开发的时候经常会忘了兼容IOS系统。所以就想试着重写Date函数,一劳永逸,避免每次ne_date.prototype 将所有 ios

如何将EXCEL表导入plsql数据库中-程序员宅基地

文章浏览阅读5.3k次。方法一:用PLSQL Developer工具。 1 在PLSQL Developer的sql window里输入select * from test for update; 2 按F8执行 3 打开锁, 再按一下加号. 鼠标点到第一列的列头,使全列成选中状态,然后粘贴,最后commit提交即可。(前提..._excel导入pl/sql

Git常用命令速查手册-程序员宅基地

文章浏览阅读83次。Git常用命令速查手册1、初始化仓库git init2、将文件添加到仓库git add 文件名 # 将工作区的某个文件添加到暂存区 git add -u # 添加所有被tracked文件中被修改或删除的文件信息到暂存区,不处理untracked的文件git add -A # 添加所有被tracked文件中被修改或删除的文件信息到暂存区,包括untracked的文件...

分享119个ASP.NET源码总有一个是你想要的_千博二手车源码v2023 build 1120-程序员宅基地

文章浏览阅读202次。分享119个ASP.NET源码总有一个是你想要的_千博二手车源码v2023 build 1120

【C++缺省函数】 空类默认产生的6个类成员函数_空类默认产生哪些类成员函数-程序员宅基地

文章浏览阅读1.8k次。版权声明:转载请注明出处 http://blog.csdn.net/irean_lau。目录(?)[+]1、缺省构造函数。2、缺省拷贝构造函数。3、 缺省析构函数。4、缺省赋值运算符。5、缺省取址运算符。6、 缺省取址运算符 const。[cpp] view plain copy_空类默认产生哪些类成员函数

推荐文章

热门文章

相关标签