-
Notifications
You must be signed in to change notification settings - Fork 136
Open
Description
I quantized llma2-70b into Int8 format using vllm expample.
But I found that if I load the model with device_map="auto" on 2 GPUs, the output attention hidden states of second half layers (on cuda:1) is different compared to 1 GPU case. here is my script, modified from spec-bench
from typing import Optional, Callable
import torch
import argparse
from evaluation.eval import run_eval, reorg_answer_file
import pdb
from fastchat.utils import str_to_torch_dtype
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationMixin
from evaluation.decoding_tp import _assisted_decoding
from gptqmodel import GPTQModel, QuantizeConfig
def sps_forward(inputs, model, tokenizer, max_new_tokens, do_sample=False, temperature=0.0, drafter=None):
input_ids = inputs.input_ids
model.generation_config.max_new_tokens = max_new_tokens
model.generation_config.output_hidden_states = True
output_ids, idx, accept_length_list = model.generate(
**inputs,
generation_config=model.generation_config,
assistant_model=drafter,
do_sample=do_sample,
temperature=temperature,
return_dict_in_generate=True,
output_hidden_states=True,
)
new_token = len(output_ids[0][len(input_ids[0]):])
return output_ids, new_token, idx + 1, accept_length_list
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-path",
type=str,
required=True,
)
parser.add_argument(
"--drafter-path",
type=str,
required=True,
)
parser.add_argument("--model-id", type=str, required=True)
parser.add_argument(
"--bench-name",
type=str,
default="mt_bench",
help="The name of the benchmark question set.",
)
parser.add_argument(
"--question-begin",
type=int,
help="A debug option. The begin index of questions.",
)
parser.add_argument(
"--question-end",
type=int,
help="A debug option. The end index of questions."
)
parser.add_argument("--answer-file", type=str, help="The output answer file.")
parser.add_argument(
"--max-new-tokens",
type=int,
default=1024,
help="The maximum number of new generated tokens.",
)
parser.add_argument(
"--num-choices",
type=int,
default=1,
help="How many completion choices to generate.",
)
parser.add_argument(
"--num-gpus-per-model",
type=int,
default=1,
help="The number of GPUs per model.",
)
parser.add_argument(
"--num-gpus-total", type=int, default=1, help="The total number of GPUs."
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="The temperature for medusa sampling.",
)
parser.add_argument(
"--dtype",
type=str,
default="float16",
choices=["float32", "float64", "float16", "bfloat16"],
help="Override the default dtype. If not set, it will use float16 on GPU.",
)
parser.add_argument(
"--drafter-dtype",
type=str,
default="float16",
choices=["float32", "float64", "float16", "bfloat16"],
help="Override the default dtype. If not set, it will use float16 on GPU.",
)
args = parser.parse_args()
#_set_backend_determinism()
GenerationMixin._assisted_decoding = _assisted_decoding
print("[INFO] Patched GenerationMixin._assisted_decoding -> evaluation.decoding_tp._assisted_decoding")
question_file = f"data/{args.bench_name}/question.jsonl"
if args.answer_file:
answer_file = args.answer_file
else:
answer_file = f"data/{args.bench_name}/model_answer/{args.model_id}.jsonl"
print(f"Output to {answer_file}")
model = GPTQModel.load(
model_id_or_path=args.model_path,
quantize_config=QuantizeConfig(bits=8, ),
device_map="auto",
low_cpu_mem_usage=True,
trust_remote_code=True,
)
drafter = AutoModelForCausalLM.from_pretrained(
args.drafter_path,
torch_dtype=str_to_torch_dtype(args.drafter_dtype),
low_cpu_mem_usage=True,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
model.eval()
drafter.eval()
do_sample = args.temperature > 0.0
run_eval(
model=model,
tokenizer=tokenizer,
forward_func=sps_forward,
model_id=args.model_id,
question_file=question_file,
question_begin=args.question_begin,
question_end=args.question_end,
answer_file=answer_file,
max_new_tokens=args.max_new_tokens,
num_choices=args.num_choices,
num_gpus_per_model=args.num_gpus_per_model,
num_gpus_total=args.num_gpus_total,
drafter=drafter,
temperature=args.temperature,
do_sample=do_sample,
)
reorg_answer_file(answer_file)I modified _assistant_decoding function to save one step infomation
SAVE_DEBUG_FILE = os.environ.get("SAVE_SPECULATIVE_DEBUG", None)
# after 2.3.
if SAVE_DEBUG_FILE :
saved_hidden_states = [
hs[:, -candidate_length - 1 :].cpu().clone() for hs in outputs.hidden_states
]
step_data = {
"step": step,
"cur_len": cur_len,
"candidate_length": candidate_length,
"input_ids": input_ids.cpu().clone(),
"candidate_input_ids": candidate_input_ids.cpu().clone(),
"candidate_logits": candidate_logits.cpu().clone() if candidate_logits is not None else None,
"target_logits": new_logits.cpu().clone(),
"attention_mask": model_inputs.get("attention_mask").cpu().clone() if "attention_mask" in model_inputs else None,
"position_ids": model_inputs.get("position_ids").cpu().clone() if "position_ids" in model_inputs else None,
"drafter_device": str(drafter_device),
"target_device": str(self.device),
"hidden_states": saved_hidden_states,
}
debug_data["steps"].append(step_data)
debug_data["final_output_ids"] = None
debug_data["accept_length_list"] = accept_length_list
Path(SAVE_DEBUG_FILE).parent.mkdir(parents=True, exist_ok=True)
with open(SAVE_DEBUG_FILE, "wb") as f:
pickle.dump(debug_data, f)
print(f" drafter token: {candidate_input_ids[0, cur_len:].tolist()[:candidate_length]}")
print(f" target argmax : {new_logits[0, :-1].argmax(-1).tolist()}")
import sys; sys.exit(0) and compared results between single GPU and multi GPUs using following code
# compare_spec_debug.py
# compare_final_all.py
import pickle
import torch
import torch.nn.functional as F
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--single", type=str, required=True)
parser.add_argument("--multi", type=str, required=True)
args = parser.parse_args()
def load_first_step(p):
with open(p, "rb") as f:
data = pickle.load(f)
return [s for s in data["steps"] if s["step"] == 1][0]
single = load_first_step(args.single)
multi = load_first_step(args.multi)
# 提取验证部分(去掉最后一个用于采样的位置)
s_tgt = single["target_logits"][:, :-1].squeeze(0) # [L, V]
m_tgt = multi["target_logits"][:, :-1].squeeze(0)
s_cand = single["candidate_logits"].squeeze(0) if single["candidate_logits"] is not None else None
m_cand = multi["candidate_logits"].squeeze(0) if multi["candidate_logits"] is not None else None
drafter_tokens = single["candidate_input_ids"][0, single["cur_len"]: single["cur_len"]+single["candidate_length"]]
print("=" * 80)
print("完整第一步对比报告".center(80))
print("=" * 80)
print(f"candidate_length : {single['candidate_length']}")
print(f"drafter 猜测 token : {drafter_tokens.tolist()}")
print(f"input_ids 一致 : {torch.equal(single['input_ids'], multi['input_ids'])}")
print(f"candidate_input_ids 一致 : {torch.equal(single['candidate_input_ids'], multi['candidate_input_ids'])}")
print()
# 1. target_logits 单卡 vs 多卡
print("1. target_logits 单卡 ↔ 多卡 差异".center(70))
diff_tgt = torch.abs(s_tgt - m_tgt)
print(f" 最大绝对误差 : {diff_tgt.max().item():12.6f}")
print(f" 平均绝对误差 : {diff_tgt.mean().item():12.6f}")
print()
# 2. candidate_logits 单卡 vs 多卡(你特别要的)
print("2. candidate_logits 单卡 ↔ 多卡 差异".center(70))
if s_cand is not None and m_cand is not None:
diff_cand = torch.abs(s_cand - m_cand)
print(f" 最大绝对误差 : {diff_cand.max().item():12.6f}")
print(f" 平均绝对误差 : {diff_cand.mean().item():12.6f}")
print(f" 标准差 : {diff_cand.std().item():12.6f}")
print()
# 3. (candidate_logits → target_logits)
print("3. drafter 质量(candidate_logits → target_logits)".center(70))
print(f"{'':20} {'最大误差':>12} {'平均误差':>12} {'接受数':>8}")
if s_cand is not None:
err_s = torch.abs(s_cand - s_tgt)
accept_s = (s_tgt.argmax(-1) == drafter_tokens).sum().item()
print(f"{'单卡':20} {err_s.max().item():12.6f} {err_s.mean().item():12.6f} {accept_s:5}/{len(drafter_tokens)}")
else:
accept_s = (s_tgt.argmax(-1) == drafter_tokens).sum().item()
print(f"{'单卡':20} {'(无cand)':>12} {'(无cand)':>12} {accept_s:5}/{len(drafter_tokens)}")
if m_cand is not None:
err_m = torch.abs(m_cand - m_tgt)
accept_m = (m_tgt.argmax(-1) == drafter_tokens).sum().item()
print(f"{'多卡':20} {err_m.max().item():12.6f} {err_m.mean().item():12.6f} {accept_m:5}/{len(drafter_tokens)}")
else:
accept_m = (m_tgt.argmax(-1) == drafter_tokens).sum().item()
print(f"{'多卡':20} {'(无cand)':>12} {'(无cand)':>12} {accept_m:5}/{len(drafter_tokens)}")
print()
# 4. argmax
print("位置级 argmax 对比".center(70))
s_top1 = s_tgt.argmax(-1)
m_top1 = m_tgt.argmax(-1)
for i, token in enumerate(drafter_tokens.tolist()):
print(f"Pos {i}: drafter {token:5d} | single {s_top1[i].item():5d} {'✓' if s_top1[i].item() == token else '✗'} | "
f"multi {m_top1[i].item():5d} {'✓' if m_top1[i].item() == token else '✗'}")
print("5. hidden_states 分层差异".center(70))
s_hs = single.get("hidden_states", None)
m_hs = multi.get("hidden_states", None)
if s_hs is None or m_hs is None:
print(" 单卡或多卡缺少 hidden_states 字段,跳过。")
else:
if len(s_hs) != len(m_hs):
print(f" 层数不一致: single={len(s_hs)}, multi={len(m_hs)}")
else:
print(f"{'层':>4} {'shape一致':>10} {'max_abs':>12} {'mean_abs':>12}")
first_diff_layer = None
for li, (sh, mh) in enumerate(zip(s_hs, m_hs)):
same_shape = tuple(sh.shape) == tuple(mh.shape)
if not same_shape:
print(f"{li:4d} {'False':>10} {'-':>12} {'-':>12}")
continue
diff = (sh - mh).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
# first different layer
if first_diff_layer is None and (max_abs > 0.0 or mean_abs > 0.0):
first_diff_layer = li
print(f"{li:4d} {str(same_shape):>10} {max_abs:12.6f} {mean_abs:12.6f}")
and the result shows that the llam2-7b drafter logits is same in 2 cases but quantized llama2-70b target's hidden states is dramatically different on the second GPU
candidate_length : 1
drafter 猜测 token : [26901]
input_ids 一致 : True
candidate_input_ids 一致 : True
1. target_logits 单卡 ↔ 多卡 差异
最大绝对误差 : 18.126953
平均绝对误差 : 2.402046
2. candidate_logits 单卡 ↔ 多卡 差异
最大绝对误差 : 0.000000
平均绝对误差 : 0.000000
标准差 : 0.000000
3. drafter 质量(candidate_logits → target_logits)
最大误差 平均误差 接受数
单卡 9.171875 1.171219 0/1
多卡 17.580078 2.810862 0/1
位置级 argmax 对比
Pos 0: drafter 26901 | single 18585 ✗ | multi 19259 ✗
5. hidden_states 分层差异
层 shape一致 max_abs mean_abs
0 True 0.000000 0.000000
1 True 0.000000 0.000000
2 True 0.000000 0.000000
3 True 0.000000 0.000000
4 True 0.000000 0.000000
5 True 0.000000 0.000000
6 True 0.000000 0.000000
7 True 0.000000 0.000000
8 True 0.000000 0.000000
9 True 0.000000 0.000000
10 True 0.000000 0.000000
11 True 0.000000 0.000000
12 True 0.000000 0.000000
13 True 0.000000 0.000000
14 True 0.000000 0.000000
15 True 0.000000 0.000000
16 True 0.000000 0.000000
17 True 0.000000 0.000000
18 True 0.000000 0.000000
19 True 0.000000 0.000000
20 True 0.000000 0.000000
21 True 0.000000 0.000000
22 True 0.000000 0.000000
23 True 0.000000 0.000000
24 True 0.000000 0.000000
25 True 0.000000 0.000000
26 True 0.000000 0.000000
27 True 0.000000 0.000000
28 True 0.000000 0.000000
29 True 0.000000 0.000000
30 True 0.000000 0.000000
31 True 0.000000 0.000000
32 True 0.000000 0.000000
33 True 0.000000 0.000000
34 True 0.000000 0.000000
35 True 0.000000 0.000000
36 True 0.000000 0.000000
37 True 0.000000 0.000000
38 True 0.000000 0.000000
39 True 0.000000 0.000000
40 True 0.000000 0.000000
41 True 1.765625 0.326904
42 True 2.902344 0.520020
43 True 5.023438 0.672363
44 True 7.066406 0.794434
45 True 5.703125 0.947754
46 True 8.804688 1.071289
47 True 11.265625 1.173828
48 True 15.304688 1.296875
49 True 16.218750 1.413086
50 True 20.765625 1.499023
51 True 25.500000 1.581055
52 True 28.000000 1.673828
53 True 33.781250 1.760742
54 True 38.500000 1.835938
55 True 40.375000 1.911133
56 True 43.531250 1.989258
57 True 54.968750 2.070312
58 True 55.031250 2.150391
59 True 60.500000 2.222656
60 True 65.062500 2.294922
61 True 67.750000 2.388672
62 True 69.562500 2.455078
63 True 75.125000 2.533203
64 True 78.750000 2.607422
65 True 84.062500 2.699219
66 True 85.375000 2.775391
67 True 90.625000 2.857422
68 True 96.937500 2.931641
69 True 102.562500 3.019531
70 True 101.187500 3.117188
71 True 107.500000 3.218750
72 True 105.187500 3.328125
73 True 109.562500 3.462891
74 True 112.937500 3.609375
75 True 118.437500 3.736328
76 True 109.250000 3.857422
77 True 123.187500 4.031250
78 True 141.875000 4.218750
79 True 133.000000 4.457031
80 True 39.062500 1.160156
What cause the situation and how can I solve it, please!!!
Metadata
Metadata
Assignees
Labels
No labels