Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion toolbench/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,18 @@ def train():
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
bnb_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
device_map=device_map
device_map=device_map,
use_flash_attention_2=True,
quantization_config=bnb_config,
)
model.config.use_cache = False
trainer = Trainer(
Expand Down
19 changes: 14 additions & 5 deletions toolbench/train/train_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from dataclasses import dataclass, field
import logging
import pathlib
Expand All @@ -33,11 +34,11 @@
make_supervised_data_module,
)

from toolbench.train.llama_flash_attn_monkey_patch import (
replace_llama_attn_with_flash_attn,
)
# from toolbench.train.llama_flash_attn_monkey_patch import (
# replace_llama_attn_with_flash_attn,
# )
from toolbench.train.llama_condense_monkey_patch import replace_llama_with_condense
replace_llama_attn_with_flash_attn()
# replace_llama_attn_with_flash_attn()


@dataclass
Expand Down Expand Up @@ -107,10 +108,18 @@ def train():
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
bnb_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
device_map=device_map
device_map=device_map,
use_flash_attention_2=True,
quantization_config=bnb_config,
)
lora_config = LoraConfig(
r=lora_args.lora_r,
Expand Down
10 changes: 5 additions & 5 deletions toolbench/train/train_mem.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.

# Need to call this before importing transformers.
from toolbench.train.llama_flash_attn_monkey_patch import (
replace_llama_attn_with_flash_attn,
)
# # Need to call this before importing transformers.
# from toolbench.train.llama_flash_attn_monkey_patch import (
# replace_llama_attn_with_flash_attn,
# )

replace_llama_attn_with_flash_attn()
# replace_llama_attn_with_flash_attn()

from toolbench.train.train import train

Expand Down