Skip to content

Quantized model behave differently on multi GPU with device_map="auto" compared to single GPU #2220

@DarkenStar

Description

@DarkenStar

I quantized llma2-70b into Int8 format using vllm expample.

But I found that if I load the model with device_map="auto" on 2 GPUs, the output attention hidden states of second half layers (on cuda:1) is different compared to 1 GPU case. here is my script, modified from spec-bench

from typing import Optional, Callable
import torch
import argparse
from evaluation.eval import run_eval, reorg_answer_file
import pdb
from fastchat.utils import str_to_torch_dtype
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationMixin
from evaluation.decoding_tp import _assisted_decoding
from gptqmodel import GPTQModel, QuantizeConfig


def sps_forward(inputs, model, tokenizer, max_new_tokens, do_sample=False, temperature=0.0, drafter=None):
    input_ids = inputs.input_ids
    model.generation_config.max_new_tokens = max_new_tokens
    model.generation_config.output_hidden_states = True
    output_ids, idx, accept_length_list = model.generate(
        **inputs,
        generation_config=model.generation_config,
        assistant_model=drafter,
        do_sample=do_sample,
        temperature=temperature,
        return_dict_in_generate=True,
        output_hidden_states=True,
    )
    new_token = len(output_ids[0][len(input_ids[0]):])
    return output_ids, new_token, idx + 1, accept_length_list


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-path",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--drafter-path",
        type=str,
        required=True,
    )
    parser.add_argument("--model-id", type=str, required=True)
    parser.add_argument(
        "--bench-name",
        type=str,
        default="mt_bench",
        help="The name of the benchmark question set.",
    )
    parser.add_argument(
        "--question-begin",
        type=int,
        help="A debug option. The begin index of questions.",
    )
    parser.add_argument(
        "--question-end",
        type=int,
        help="A debug option. The end index of questions."
    )
    parser.add_argument("--answer-file", type=str, help="The output answer file.")
    parser.add_argument(
        "--max-new-tokens",
        type=int,
        default=1024,
        help="The maximum number of new generated tokens.",
    )
    parser.add_argument(
        "--num-choices",
        type=int,
        default=1,
        help="How many completion choices to generate.",
    )
    parser.add_argument(
        "--num-gpus-per-model",
        type=int,
        default=1,
        help="The number of GPUs per model.",
    )
    parser.add_argument(
        "--num-gpus-total", type=int, default=1, help="The total number of GPUs."
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
        help="The temperature for medusa sampling.",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="float16",
        choices=["float32", "float64", "float16", "bfloat16"],
        help="Override the default dtype. If not set, it will use float16 on GPU.",
    )
    parser.add_argument(
        "--drafter-dtype",
        type=str,
        default="float16",
        choices=["float32", "float64", "float16", "bfloat16"],
        help="Override the default dtype. If not set, it will use float16 on GPU.",
    )
    args = parser.parse_args()
    
    #_set_backend_determinism()
    
    GenerationMixin._assisted_decoding = _assisted_decoding
    print("[INFO] Patched GenerationMixin._assisted_decoding -> evaluation.decoding_tp._assisted_decoding")

    question_file = f"data/{args.bench_name}/question.jsonl"
    if args.answer_file:
        answer_file = args.answer_file
    else:
        answer_file = f"data/{args.bench_name}/model_answer/{args.model_id}.jsonl"

    print(f"Output to {answer_file}")
 
    model = GPTQModel.load(
        model_id_or_path=args.model_path,
        quantize_config=QuantizeConfig(bits=8, ),
        device_map="auto",
        low_cpu_mem_usage=True,
        trust_remote_code=True,
    )
    drafter = AutoModelForCausalLM.from_pretrained(
        args.drafter_path,
        torch_dtype=str_to_torch_dtype(args.drafter_dtype),
        low_cpu_mem_usage=True,
        device_map="auto",
    )
    
    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    model.eval()
    drafter.eval()

    do_sample = args.temperature > 0.0
        
    run_eval(
        model=model,
        tokenizer=tokenizer,
        forward_func=sps_forward,
        model_id=args.model_id,
        question_file=question_file,
        question_begin=args.question_begin,
        question_end=args.question_end,
        answer_file=answer_file,
        max_new_tokens=args.max_new_tokens,
        num_choices=args.num_choices,
        num_gpus_per_model=args.num_gpus_per_model,
        num_gpus_total=args.num_gpus_total,
        drafter=drafter,
        temperature=args.temperature,
        do_sample=do_sample,
    )

    reorg_answer_file(answer_file)

I modified _assistant_decoding function to save one step infomation

SAVE_DEBUG_FILE = os.environ.get("SAVE_SPECULATIVE_DEBUG", None)
# after 2.3. 
if SAVE_DEBUG_FILE :
            saved_hidden_states = [
                        hs[:, -candidate_length - 1 :].cpu().clone() for hs in outputs.hidden_states
                    ]
            step_data = {
                "step": step,
                "cur_len": cur_len,
                "candidate_length": candidate_length,
                "input_ids": input_ids.cpu().clone(),
                "candidate_input_ids": candidate_input_ids.cpu().clone(),
                "candidate_logits": candidate_logits.cpu().clone() if candidate_logits is not None else None,
                "target_logits": new_logits.cpu().clone(),
                "attention_mask": model_inputs.get("attention_mask").cpu().clone() if "attention_mask" in model_inputs else None,
                "position_ids": model_inputs.get("position_ids").cpu().clone() if "position_ids" in model_inputs else None,
                "drafter_device": str(drafter_device),
                "target_device": str(self.device),
                "hidden_states": saved_hidden_states,
            }
            debug_data["steps"].append(step_data)
            debug_data["final_output_ids"] = None  
            debug_data["accept_length_list"] = accept_length_list

            Path(SAVE_DEBUG_FILE).parent.mkdir(parents=True, exist_ok=True)
            with open(SAVE_DEBUG_FILE, "wb") as f:
                pickle.dump(debug_data, f)
            print(f"    drafter token: {candidate_input_ids[0, cur_len:].tolist()[:candidate_length]}")
            print(f"    target  argmax    : {new_logits[0, :-1].argmax(-1).tolist()}")
            import sys; sys.exit(0)  

and compared results between single GPU and multi GPUs using following code

# compare_spec_debug.py
# compare_final_all.py
import pickle
import torch
import torch.nn.functional as F
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--single", type=str, required=True)
parser.add_argument("--multi",  type=str, required=True)
args = parser.parse_args()

def load_first_step(p):
    with open(p, "rb") as f:
        data = pickle.load(f)
    return [s for s in data["steps"] if s["step"] == 1][0]

single = load_first_step(args.single)
multi  = load_first_step(args.multi)

# 提取验证部分(去掉最后一个用于采样的位置)
s_tgt = single["target_logits"][:, :-1].squeeze(0)      # [L, V]
m_tgt = multi["target_logits"][:, :-1].squeeze(0)

s_cand = single["candidate_logits"].squeeze(0) if single["candidate_logits"] is not None else None
m_cand = multi["candidate_logits"].squeeze(0)  if multi["candidate_logits"]  is not None else None

drafter_tokens = single["candidate_input_ids"][0, single["cur_len"]: single["cur_len"]+single["candidate_length"]]

print("=" * 80)
print("完整第一步对比报告".center(80))
print("=" * 80)

print(f"candidate_length : {single['candidate_length']}")
print(f"drafter 猜测 token : {drafter_tokens.tolist()}")
print(f"input_ids 一致    : {torch.equal(single['input_ids'], multi['input_ids'])}")
print(f"candidate_input_ids 一致 : {torch.equal(single['candidate_input_ids'], multi['candidate_input_ids'])}")
print()

# 1. target_logits 单卡 vs 多卡
print("1. target_logits 单卡 ↔ 多卡 差异".center(70))
diff_tgt = torch.abs(s_tgt - m_tgt)
print(f"   最大绝对误差 : {diff_tgt.max().item():12.6f}")
print(f"   平均绝对误差 : {diff_tgt.mean().item():12.6f}")
print()

# 2. candidate_logits 单卡 vs 多卡(你特别要的)
print("2. candidate_logits 单卡 ↔ 多卡 差异".center(70))
if s_cand is not None and m_cand is not None:
    diff_cand = torch.abs(s_cand - m_cand)
    print(f"   最大绝对误差 : {diff_cand.max().item():12.6f}")
    print(f"   平均绝对误差 : {diff_cand.mean().item():12.6f}")
    print(f"   标准差       : {diff_cand.std().item():12.6f}")
print()

# 3. (candidate_logits → target_logits)
print("3. drafter 质量(candidate_logits → target_logits)".center(70))
print(f"{'':20} {'最大误差':>12} {'平均误差':>12} {'接受数':>8}")
if s_cand is not None:
    err_s = torch.abs(s_cand - s_tgt)
    accept_s = (s_tgt.argmax(-1) == drafter_tokens).sum().item()
    print(f"{'单卡':20} {err_s.max().item():12.6f} {err_s.mean().item():12.6f} {accept_s:5}/{len(drafter_tokens)}")
else:
    accept_s = (s_tgt.argmax(-1) == drafter_tokens).sum().item()
    print(f"{'单卡':20} {'(无cand)':>12} {'(无cand)':>12} {accept_s:5}/{len(drafter_tokens)}")

if m_cand is not None:
    err_m = torch.abs(m_cand - m_tgt)
    accept_m = (m_tgt.argmax(-1) == drafter_tokens).sum().item()
    print(f"{'多卡':20} {err_m.max().item():12.6f} {err_m.mean().item():12.6f} {accept_m:5}/{len(drafter_tokens)}")
else:
    accept_m = (m_tgt.argmax(-1) == drafter_tokens).sum().item()
    print(f"{'多卡':20} {'(无cand)':>12} {'(无cand)':>12} {accept_m:5}/{len(drafter_tokens)}")
print()

# 4. argmax 
print("位置级 argmax 对比".center(70))
s_top1 = s_tgt.argmax(-1)
m_top1 = m_tgt.argmax(-1)
for i, token in enumerate(drafter_tokens.tolist()):
    print(f"Pos {i}: drafter {token:5d} | single {s_top1[i].item():5d} {'✓' if s_top1[i].item() == token else '✗'} | "
          f"multi  {m_top1[i].item():5d} {'✓' if m_top1[i].item() == token else '✗'}")

print("5. hidden_states 分层差异".center(70))
s_hs = single.get("hidden_states", None)
m_hs = multi.get("hidden_states", None)

if s_hs is None or m_hs is None:
    print("   单卡或多卡缺少 hidden_states 字段,跳过。")
else:
    if len(s_hs) != len(m_hs):
        print(f"   层数不一致: single={len(s_hs)}, multi={len(m_hs)}")
    else:
        print(f"{'层':>4} {'shape一致':>10} {'max_abs':>12} {'mean_abs':>12}")
        first_diff_layer = None
        for li, (sh, mh) in enumerate(zip(s_hs, m_hs)):
            same_shape = tuple(sh.shape) == tuple(mh.shape)
            if not same_shape:
                print(f"{li:4d} {'False':>10} {'-':>12} {'-':>12}")
                continue
            diff = (sh - mh).abs()
            max_abs = diff.max().item()
            mean_abs = diff.mean().item()
            # first different layer
            if first_diff_layer is None and (max_abs > 0.0 or mean_abs > 0.0):
                first_diff_layer = li
            print(f"{li:4d} {str(same_shape):>10} {max_abs:12.6f} {mean_abs:12.6f}")

and the result shows that the llam2-7b drafter logits is same in 2 cases but quantized llama2-70b target's hidden states is dramatically different on the second GPU

candidate_length : 1
drafter 猜测 token : [26901]
input_ids 一致    : True
candidate_input_ids 一致 : True

                     1. target_logits 单卡 ↔ 多卡 差异                      
   最大绝对误差 :    18.126953
   平均绝对误差 :     2.402046

                    2. candidate_logits 单卡 ↔ 多卡 差异                    
   最大绝对误差 :     0.000000
   平均绝对误差 :     0.000000
   标准差       :     0.000000

           3. drafter 质量(candidate_logits → target_logits)            
                             最大误差         平均误差      接受数
单卡                       9.171875     1.171219     0/1
多卡                      17.580078     2.810862     0/1

                            位置级 argmax 对比                             
Pos 0: drafter 26901 | single 18585 ✗ | multi  19259 ✗
                        5. hidden_states 分层差异                         
   层    shape一致      max_abs     mean_abs
   0       True     0.000000     0.000000
   1       True     0.000000     0.000000
   2       True     0.000000     0.000000
   3       True     0.000000     0.000000
   4       True     0.000000     0.000000
   5       True     0.000000     0.000000
   6       True     0.000000     0.000000
   7       True     0.000000     0.000000
   8       True     0.000000     0.000000
   9       True     0.000000     0.000000
  10       True     0.000000     0.000000
  11       True     0.000000     0.000000
  12       True     0.000000     0.000000
  13       True     0.000000     0.000000
  14       True     0.000000     0.000000
  15       True     0.000000     0.000000
  16       True     0.000000     0.000000
  17       True     0.000000     0.000000
  18       True     0.000000     0.000000
  19       True     0.000000     0.000000
  20       True     0.000000     0.000000
  21       True     0.000000     0.000000
  22       True     0.000000     0.000000
  23       True     0.000000     0.000000
  24       True     0.000000     0.000000
  25       True     0.000000     0.000000
  26       True     0.000000     0.000000
  27       True     0.000000     0.000000
  28       True     0.000000     0.000000
  29       True     0.000000     0.000000
  30       True     0.000000     0.000000
  31       True     0.000000     0.000000
  32       True     0.000000     0.000000
  33       True     0.000000     0.000000
  34       True     0.000000     0.000000
  35       True     0.000000     0.000000
  36       True     0.000000     0.000000
  37       True     0.000000     0.000000
  38       True     0.000000     0.000000
  39       True     0.000000     0.000000
  40       True     0.000000     0.000000
  41       True     1.765625     0.326904
  42       True     2.902344     0.520020
  43       True     5.023438     0.672363
  44       True     7.066406     0.794434
  45       True     5.703125     0.947754
  46       True     8.804688     1.071289
  47       True    11.265625     1.173828
  48       True    15.304688     1.296875
  49       True    16.218750     1.413086
  50       True    20.765625     1.499023
  51       True    25.500000     1.581055
  52       True    28.000000     1.673828
  53       True    33.781250     1.760742
  54       True    38.500000     1.835938
  55       True    40.375000     1.911133
  56       True    43.531250     1.989258
  57       True    54.968750     2.070312
  58       True    55.031250     2.150391
  59       True    60.500000     2.222656
  60       True    65.062500     2.294922
  61       True    67.750000     2.388672
  62       True    69.562500     2.455078
  63       True    75.125000     2.533203
  64       True    78.750000     2.607422
  65       True    84.062500     2.699219
  66       True    85.375000     2.775391
  67       True    90.625000     2.857422
  68       True    96.937500     2.931641
  69       True   102.562500     3.019531
  70       True   101.187500     3.117188
  71       True   107.500000     3.218750
  72       True   105.187500     3.328125
  73       True   109.562500     3.462891
  74       True   112.937500     3.609375
  75       True   118.437500     3.736328
  76       True   109.250000     3.857422
  77       True   123.187500     4.031250
  78       True   141.875000     4.218750
  79       True   133.000000     4.457031
  80       True    39.062500     1.160156

What cause the situation and how can I solve it, please!!!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions