This repository is the official implementation of Importance Weighting for Aligning Language Models under Deployment Distribution Shift.
Note: The code was tested on a computer with eight NVIDIA A100-SXM4-40GB GPUs.
See install.sh
We recommend preparing the dataset as a JSON file, where each example consists of the following data fields:
{
"prompt": [
{
"role": "user",
"content": "..."
}
],
"output_A": [
{
"role": "assistant",
"content": "..."
}
],
"output_B": [
{
"role": "assistant",
"content": "..."
}
],
"label": 1,
"reward_A": 1,
"reward_B": 0,
"reward_difference": 1,
"type": "pairwise_feedback",
"split": "train", /* Use "train" for training example, "test" for test examples and "validation" for validation examples. */
"origin": "test"
}Validation examples must be labeled as "train" if you intend to include the validation data during SFT/DPO training.
As an example, we provide a script to generate training, validation, and test datasets from SafeRLHF, located at process_datasets/saferlhf.py. You can run the script using the following command:
python -m process_datasets.saferlhf
Before proceeding, please
- provide your Hugging Face token or be logged in to Hugging Face,
- ensure that you have access to models available on Hugging Face, such as Llama-3.1-8B-Instruct, since some models require gaining access,
- edit fsdp.yaml according to your needs and hardware configuration,
- prepare your wandb account.
See the available models at config/model.
accelerate launch --config_file accelerate_config/fsdp.yaml --main_process_port 29500 launch.py \
n_epochs=1 loss=sft model=pythia datasets=[examples/saferlhf_combined-train-val_feedback.json] exp_name=safe_lm_SFT seed=1 ++cache_dir=.cache/data/models ++model.name_or_path=EleutherAI/pythia-2.8b ++lr=5e-6 ++loss.beta=0.1 model.batch_size=32 model.eval_batch_size=32 model.max_length=512 model.max_prompt_length=256
Note that SFT is optional.
accelerate launch --config_file accelerate_config/fsdp.yaml --main_process_port 29500 launch.py \
# General config
n_epochs=1 loss=dpo model=pythia exp_name=safe_lm_IW-DPO seed=1 ++cache_dir=.cache/data/models ++model.name_or_path=EleutherAI/pythia-2.8b ++lr=5e-6 ++loss.beta=0.1 model.batch_size=32 model.eval_batch_size=32 model.max_length=512 model.max_prompt_length=256 ++model.load_from=.cache/data/models/safe_lm_SFT/FINAL \
# IW-DPO-specific config
datasets=[examples/saferlhf_separated-train-val_feedback.json] model.val_batch_size=32 iw.enabled=true iw.t=reward iw.warmup_examples=1024 iw.kernel_width=null iw.lambda_reg=0.1 iw.normalize_w=true iw.method=kmm
python -m train.sample .cache/data/models/safe_lm_IW-DPO/FINAL --gpu_count 8 --output_file outputs/generations/safe_lm_IW-DPO.json --datasets examples/saferlhf_only-test_feedback.json --mode safe_lm --max_tokens 512 --max_prompt_length 256
Please set the OPENAI_API_KEY environment variable to your OpenAI API key before running.
python -m train.evaluate --input_file outputs/generations/safe_lm_IW-DPO.json --task safe_lm --evaluator gpt-4o-mini
Our code is based on ContextualAI's HALOs.
