输入“/”快速插入内容

P-Tuning v2实现过程

2023年12月7日修改
常见参数高效微调方法(Parameter-Efficient Fine-Tuning,PEFT)有哪些呢?主要是Prompt系列和LoRA系列。本文主要介绍P-Tuning v2微调方法。如下所示:
Prompt系列比如,Prefix Tuning(2021.01-Stanford)、Prompt Tuning(2021.09-Google)、P-Tuning(2021.03-Tsinghua)、P-Tuning v2(2022.03-Tsinghua);
LoRA系列比如,LoRA(2021.11-Microsoft)、AdaLoRA(2023.03-Microsoft)、QLoRA(2023.05-Washington)。
还有不知道如何分类的比如,BitFit、Adapter Tuning及其变体、MAM Adapter、UniPELT等。
一.P-Tuning v2工作原理
1.Hard/Soft Prompt-Tuning如何设计
  提示工程发展经过了从人工或半自动离散空间的hard prompt设计,到采用连续可微空间soft prompt设计的过程,这样的好处是可通过端到端优化学习不同任务对应的prompt参数。
2.P-Tuning工作原理和不足
  主要是将continuous prompt应用于预训练模型输入层,预训练模型后面的每一层都没有合并continuous prompt。
3.P-Tuning v2如何解决P-Tuning不足
  P-Tuning v2把continuous prompt应用于预训练模型的每一层,而不仅仅是输入层
二.P-Tuning v2实现过程
1.整体项目结构
  源码参考文献[4],源码结构如下所示:
参数解释如下所示:
(1)--model_name_or_path L:/20230713_HuggingFaceModel/20231004_BERT/bert-base-chinese:BERT模型路径
(2)--task_name qa:任务名字
(3)--dataset_name squad:数据集名字
(4)--do_train:训练过程
(5)--do_eval:验证过程
(6)--max_seq_length 128:最大序列长度
(7)--per_device_train_batch_size 2:每个设备训练批次大小
(8)--learning_rate 5e-3:学习率
(9)--num_train_epochs 10:训练epoch数量
(10)--pre_seq_len 128:前缀序列长度
(11)--output_dir checkpoints/SQuAD-bert:检查点输出目录
(12)--overwrite_output_dir:覆盖输出目录
(13)--hidden_dropout_prob 0.1:隐藏dropout概率
(14)--seed 11:种子
(15)--save_strategy no:保存策略
(16)--evaluation_strategy epoch:评估策略
(17)--prefix:P-Tuning v2方法
执行代码如下所示:
代码块
python3 run.py --model_name_or_path L:/20230713_HuggingFaceModel/20231004_BERT/bert-base-chinese --task_name qa --dataset_name squad --do_train --do_eval --max_seq_length 128 --per_device_train_batch_size 2 --learning_rate 5e-3 --num_train_epochs 10 --pre_seq_len 128 --output_dir checkpoints/SQuAD-bert --overwrite_output_dir --hidden_dropout_prob 0.1 --seed 11 --save_strategy no --evaluation_strategy epoch --prefix
2.代码执行流程
(1)P-tuning-v2/run.py
根据task_name=="qa"选择tasks.qa.get_trainer
根据get_trainer得到trainer,然后训练、评估和预测
(2)P-tuning-v2/tasks/qa/get_trainer.py
得到config、tokenizer、model、squad数据集、QuestionAnsweringTrainer对象trainer
重点关注model是如何得到的
代码块
# fix_bert表示不更新bert参数,model数据类型为BertPrefixForQuestionAnswering
model = get_model(model_args, TaskType.QUESTION_ANSWERING, config, fix_bert=True)
重点关注QuestionAnsweringTrainer具体实现
代码块
trainer = QuestionAnsweringTrainer( # 读取trainer
model=model, # 模型
args=training_args, # 训练参数
train_dataset=dataset.train_dataset if training_args.do_train else None, # 训练集
eval_dataset=dataset.eval_dataset if training_args.do_eval else None, # 验证集
eval_examples=dataset.eval_examples if training_args.do_eval else None, # 验证集
tokenizer=tokenizer, # tokenizer
data_collator=dataset.data_collator, # 用于将数据转换为batch
post_process_function=dataset.post_processing_function, # 用于将预测结果转换为最终结果
compute_metrics=dataset.compute_metrics, # 用于计算评价指标
)