大模型之BloomLLAMA----SFT(模型微调)

news/2024/8/24 8:20:33 标签: llama

0. 简介

随着chatgpt的爆火,最近也有很多大模型在不断地出现,比如说Bloom系列以及以LLAMA为基础的ziya和baichuan。这些模型相较于chatglm来说,更加具有发展前景,因为其是完全可商用,并可以不断迭代更新的。最近作者在跟着hiyouga大佬的LLaMA-Efficient-Tuning进行学习,相较于其他的项目来说,该项目是非常适合跟着学习并入门的。

1. 什么是SFT

SFT(Scalable Fine-Tuning)是一种用于自然语言处理的技术,它通过对预训练的语言模型进行微调,使其适应特定任务。在大模型SFT中,使用的是大型的预训练语言模型,例如LLAMA、GPT等,这些模型具有数十亿甚至数百亿个参数,可以处理大量的文本数据。

SFT的主要思想是在一个大型的预训练模型的基础上,针对特定的任务对模型进行微调。在微调过程中,模型会根据任务的特点调整模型的参数和结构,以提高模型在该任务上的表现。在微调过程中,可以使用不同的技术,例如数据增强、正则化、优化算法等。

SFT的优点是可以快速地针对不同的任务进行微调,而无需重新训练整个模型。此外,由于使用的是大型的预训练模型,可以利用海量的文本数据进行训练,从而获得更好的性能。不过,SFT也有一些缺点,例如需要大量的计算资源和时间进行微调,以及可能会出现过拟合等问题。

目前常用的SFT方法有P-Tuning v2、LORA、QLoRA、冻结(Freeze)、全参数(full-parameter)等方法。我们先来看一看在LLaMA-Efficient-Tuning中是如何写SFT的


2. 代码阅读–train_sft.py

下面是sft对应大模型的脚本,主要包括模型和数据的准备,数据集的划分,训练和评估等步骤。

首先,代码导入了一些必要的模块和函数。这包括一些用于数据处理、训练、加载预训练模型和绘制损失图的工具函数。(这部分和pt中一样)

    # Prepare pretrained model and dataset
    model_args, data_args, training_args, finetuning_args = prepare_args(stage="sft")# 用于准备各种参数,包括模型参数、数据参数、训练参数和微调参数。
    dataset = prepare_data(model_args, data_args)# 用于准备数据集
    model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")# 用于加载sft微调的模型和分词器。
    dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")# 用于预处理数据,例如将文本转换为模型可以理解的格式。
    data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)# 动态地对数据进行填充,使得每个batch中的数据长度一致。

下面的代码是用于Seq2SeqTrainer的解码参数进行覆盖

   # Override the decoding parameters of Seq2SeqTrainer
    training_args.generation_max_length = training_args.generation_max_length if \
                training_args.generation_max_length is not None else data_args.max_target_length# 设置训练参数(training_args)中的生成最大长度
    training_args.generation_num_beams = data_args.eval_num_beams if \
                data_args.eval_num_beams is not None else training_args.generation_num_beams # 设置训练参数中的生成束搜索数(generation_num_beams)

然后,根据是否进行训练,对数据集进行划分。如果进行训练,且开发集的比例大于0,那么数据集会被划分为训练集和开发集;否则,全部数据用于训练。如果不进行训练,那么全部数据用于评估或预测。

    # Split the dataset
    if training_args.do_train:
        if data_args.dev_ratio > 1e-6:
            dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
            trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
        else:
            trainer_kwargs = {"train_dataset": dataset}
    else: # do_eval or do_predict
        trainer_kwargs = {"eval_dataset": dataset}

接着,初始化Seq2SeqPeftTrainer对象,传入微调参数、模型、训练参数、分词器、数据处理器、回调函数和计算度量等参数(都是继承自Seq2SeqTrainer),以及前面划分的数据集。这个我们下一节将会仔细阅读里面的操作

…详情请参照古月居


http://www.niftyadmin.cn/n/4950735.html

相关文章

算法|Day46 动态规划14

LeetCode 1143- 最长公共子序列 题目链接:力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台 题目描述:给定两个字符串 text1 和 text2,返回这两个字符串的最长 公共子序列 的长度。如果不存在 公共子序列 &#xff…

HCIP---路由策略

文章目录 前言一、pandas是什么?二、使用步骤 1.引入库2.读入数据总结 前言 前文我们初步了解了重发布技术的技术的工作流程及配置方法,在解决路由回馈问题的同时,路由回馈,选路不佳问题仍然没有得到有效解决,接下来通…

第二章:25+ Python 数据操作教程(第十三节NUMPY 教程与练习)

NumPy(“Numerical Python”或“Numeric Python”的缩写)是 Python 中对数组和矩阵进行快速数学计算的最基本的软件包之一。在处理多维数据时它也非常有用。集成C、C++和FORTRAN工具是一件幸事。它还提供了许多傅里叶变换 (FT) 和线性代数函数。 为什么使用 NumPy 而不是列…

【stable-diffusion使用扩展+插件和模型资源(上】

个人网站: 文章目录 前言一、插件推荐1.qrcode-monster2.sd-webui-openpose-editor3.sd-webui-depth-lib4.roop(换脸插件)5.sd-webui-qrcode-toolkit(艺术二维码)5.光源控制6.二次元转真人7.动态视频转场(l…

前端-Sass和Less区别

Less和Sass都是CSS预处理器,它们提供了更强大、更灵活的方式来编写CSS样式。以下是Less和Sass之间的一些区别: 语法:Less使用类似于CSS的语法,而Sass使用类似于Ruby的语法。Less使用大括号 {} 和分号 ; 来表示代码块和语句&#x…

星际争霸之小霸王之小蜜蜂(三)--重构模块

目录 前言 一、为什么要重构模块 二、创建game_functions 三、创建update_screen() 四、修改alien_invasion模块 五、课后思考 总结 前言 前两天我们已经成功创建了窗口,并将小蜜蜂放在窗口的最下方中间位置,本来以为今天将学习控制小蜜蜂,结…

GM65二维码识别模块+命令控制

简介 MG65 条码识读模块,一款性能优良的扫描引擎,不仅能够轻松读取各类一维条码,而且可以高速读取二维条码,对线性条形码具有非常高的扫描速率,针对纸质条码及显示屏上的条码,也都能轻松扫描。 一、模块参…

通讯协议044——全网独有的OPC HDA知识一之聚合(十二)持续坏值时间

本文简单介绍OPC HDA规范的基本概念,更多通信资源请登录网信智汇(wangxinzhihui.com)。 本节旨在详细说明HDA聚合的要求和性能。其目的是使HDA聚合标准化,以便HDA客户端能够可靠地预测聚合计算的结果并理解其含义。如果用户需要聚合中的自定义功能&…