huggingface/trl
Fork: 1184 Star: 9410 (更新于 2024-09-28 21:10:33)
license: Apache-2.0
Language: Python .
Train transformer language models with reinforcement learning.
最后发布版本: v0.11.1 ( 2024-09-25 00:13:05)
TRL - Transformer Reinforcement Learning
Full stack library to post-train large language models.
What is it?
TRL is a library to post-train LLMs and diffusion models with methods such as Supervised Fine-tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO).
The library is built on top of 🤗 Transformers and is compatible with any model architecture available there.
Highlights
-
Efficient and scalable
:- 🤗 Accelerate is the backbone of TRL that model training to scale from a single GPU to a large scale multi-node cluster with methods such as DDP and DeepSpeed.
-
PEFT
is fully integrated and allows to train even the largest models on modest hardware with quantisation and methods such as LoRA or QLoRA. - Unsloth is also integrated and allows to significantly speed up training with dedicated kernels.
-
CLI
: With the CLI you can fine-tune and chat with LLMs without writing any code using a single command and a flexible config system. -
Trainers
: The trainer classes are an abstraction to apply many fine-tuning methods with ease such as theSFTTrainer
,DPOTrainer
,RewardTrainer
,PPOTrainer
, andORPOTrainer
. -
AutoModels
: TheAutoModelForCausalLMWithValueHead
&AutoModelForSeq2SeqLMWithValueHead
classes add an additional value head to the model which allows to train them with RL algorithms such as PPO. -
Examples
: Fine-tune Llama for chat applications or apply full RLHF using adapters etc, following the examples.
Installation
Python package
Install the library with pip
:
pip install trl
From source
If you want to use the latest features before an official release you can install from source:
pip install git+https://github.com/huggingface/trl.git
Repository
If you want to use the examples you can clone the repository with the following command:
git clone https://github.com/huggingface/trl.git
Command Line Interface (CLI)
You can use TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT) and Direct Preference Optimization (DPO), or vibe check your model with the chat CLI:
SFT:
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B --dataset_name trl-lib/Capybara --output_dir Qwen2.5-0.5B-SFT
DPO:
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct --dataset_name argilla/Capybara-Preferences --output_dir Qwen2.5-0.5B-DPO
Chat:
trl chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
Read more about CLI in the relevant documentation section or use --help
for more details.
How to use
For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.
SFTTrainer
Here is a basic example on how to use the SFTTrainer
:
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
dataset = load_dataset("trl-lib/Capybara", split="train")
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
trainer = SFTTrainer(
args=training_args,
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
)
trainer.train()
RewardTrainer
Here is a basic example on how to use the RewardTrainer
:
from trl import RewardConfig, RewardTrainer
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForSequenceClassification.from_pretrained(
"Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
)
model.config.pad_token_id = tokenizer.pad_token_id
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
trainer = RewardTrainer(
args=training_args,
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
)
trainer.train()
RLOOTrainer
RLOOTrainer
implements a REINFORCE-style optimization for RLHF that is more performant and memory-efficient than PPO. Here is a basic example of how to use the RLOOTrainer
:
from trl import RLOOConfig, RLOOTrainer, apply_chat_template
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
reward_model = AutoModelForSequenceClassification.from_pretrained(
"Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
)
ref_policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback-prompt")
dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
dataset = dataset.map(lambda x: tokenizer(x["prompt"]), remove_columns="prompt")
training_args = RLOOConfig(output_dir="Qwen2.5-0.5B-RL")
trainer = RLOOTrainer(
config=training_args,
tokenizer=tokenizer,
policy=policy,
ref_policy=ref_policy,
reward_model=reward_model,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
)
trainer.train()
DPOTrainer
DPOTrainer
implements the popular Direct Preference Optimization (DPO) algorithm that was used to post-train Llama 3 and many other models. Here is a basic example on how to use the DPOTrainer
:
from trl import DPOConfig, DPOTrainer, maybe_extract_prompt, maybe_apply_chat_template
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/Capybara-Preferences", split="train")
dataset = dataset.map(maybe_extract_prompt)
dataset = dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(
args=training_args,
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
)
trainer.train()
Development
If you want to contribute to trl
or customizing it to your needs make sure to read the contribution guide and make sure you make a dev install:
git clone https://github.com/huggingface/trl.git
cd trl/
make dev
Citation
@misc{vonwerra2022trl,
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
title = {TRL: Transformer Reinforcement Learning},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/huggingface/trl}}
}
最近版本更新:(数据更新于 2024-09-28 21:10:15)
2024-09-25 00:13:05 v0.11.1
2024-09-19 16:46:19 v0.11.0
2024-08-29 22:34:50 v0.10.1
2024-07-08 21:51:10 v0.9.6
2024-06-06 22:17:27 v0.9.4
2024-06-06 00:08:05 v0.9.3
2024-04-22 16:59:58 v0.8.6
2024-04-18 19:58:41 v0.8.5
2024-04-17 23:22:10 v0.8.4
2024-04-12 18:25:23 v0.8.3
huggingface/trl同语言 Python最近更新仓库
2024-10-22 19:20:59 PKU-YuanGroup/Open-Sora-Plan
2024-10-22 16:51:25 comfyanonymous/ComfyUI
2024-10-22 13:14:01 yt-dlp/yt-dlp
2024-10-22 07:36:06 ultralytics/ultralytics
2024-10-22 03:00:27 phidatahq/phidata
2024-10-22 00:02:52 getomni-ai/zerox