Skip to content

ishida-lab/IW-DPO

Repository files navigation

IW-DPO

License Static Badge

This repository is the official implementation of Importance Weighting for Aligning Language Models under Deployment Distribution Shift.

DPO vs IW-DPO

Note: The code was tested on a computer with eight NVIDIA A100-SXM4-40GB GPUs.

Requirements

See install.sh

Data preparation

We recommend preparing the dataset as a JSON file, where each example consists of the following data fields:

{
    "prompt": [
        {
            "role": "user",
            "content": "..."
        }
    ],
    "output_A": [
        {
            "role": "assistant",
            "content": "..."
        }
    ],
    "output_B": [
        {
            "role": "assistant",
            "content": "..."
        }
    ],
    "label": 1,
    "reward_A": 1,
    "reward_B": 0,
    "reward_difference": 1,
    "type": "pairwise_feedback",
    "split": "train", /* Use "train" for training example, "test" for test examples and "validation" for validation examples. */
    "origin": "test"
}

Validation examples must be labeled as "train" if you intend to include the validation data during SFT/DPO training.

As an example, we provide a script to generate training, validation, and test datasets from SafeRLHF, located at process_datasets/saferlhf.py. You can run the script using the following command:

python -m process_datasets.saferlhf

Training

Before proceeding, please

See the available models at config/model.

SFT

accelerate launch --config_file accelerate_config/fsdp.yaml --main_process_port 29500 launch.py \
    n_epochs=1 loss=sft model=pythia datasets=[examples/saferlhf_combined-train-val_feedback.json] exp_name=safe_lm_SFT seed=1 ++cache_dir=.cache/data/models ++model.name_or_path=EleutherAI/pythia-2.8b ++lr=5e-6 ++loss.beta=0.1 model.batch_size=32 model.eval_batch_size=32 model.max_length=512 model.max_prompt_length=256

Note that SFT is optional.

IW-DPO

accelerate launch --config_file accelerate_config/fsdp.yaml --main_process_port 29500 launch.py \
    # General config
    n_epochs=1 loss=dpo model=pythia exp_name=safe_lm_IW-DPO seed=1 ++cache_dir=.cache/data/models ++model.name_or_path=EleutherAI/pythia-2.8b ++lr=5e-6 ++loss.beta=0.1 model.batch_size=32 model.eval_batch_size=32 model.max_length=512 model.max_prompt_length=256 ++model.load_from=.cache/data/models/safe_lm_SFT/FINAL \
    # IW-DPO-specific config
    datasets=[examples/saferlhf_separated-train-val_feedback.json] model.val_batch_size=32 iw.enabled=true iw.t=reward iw.warmup_examples=1024 iw.kernel_width=null iw.lambda_reg=0.1 iw.normalize_w=true iw.method=kmm

Evaluation

Sample

python -m train.sample .cache/data/models/safe_lm_IW-DPO/FINAL --gpu_count 8 --output_file outputs/generations/safe_lm_IW-DPO.json --datasets examples/saferlhf_only-test_feedback.json --mode safe_lm --max_tokens 512 --max_prompt_length 256

LLM Judge

Please set the OPENAI_API_KEY environment variable to your OpenAI API key before running.

python -m train.evaluate --input_file outputs/generations/safe_lm_IW-DPO.json  --task safe_lm --evaluator gpt-4o-mini

Acknowledgement

Our code is based on ContextualAI's HALOs.

About

[TMLR 2025] Importance Weighting for Aligning Language Models under Deployment Distribution Shift

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages