Skip to content
9 changes: 1 addition & 8 deletions examples/scripts/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,7 @@ def main(script_args, training_args, model_args):
from datasets import load_dataset
from transformers import GenerationConfig

from trl import (
LogCompletionsCallback,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl import LogCompletionsCallback, get_peft_config, get_quantization_config
from trl.experimental.distillation import DistillationTrainer

################
Expand All @@ -99,7 +94,6 @@ def main(script_args, training_args, model_args):
attn_implementation=model_args.attn_implementation,
dtype=model_args.dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
Comment thread
cursor[bot] marked this conversation as resolved.
quantization_config=quantization_config,
)
training_args.model_init_kwargs = model_kwargs
Expand All @@ -109,7 +103,6 @@ def main(script_args, training_args, model_args):
attn_implementation=model_args.attn_implementation,
dtype=model_args.dtype,
use_cache=True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
if training_args.teacher_model_init_kwargs is not None:
Expand Down
2 changes: 0 additions & 2 deletions examples/scripts/dpo_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
ModelConfig,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
Expand All @@ -89,7 +88,6 @@
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

model = AutoModelForImageTextToText.from_pretrained(
Expand Down
3 changes: 0 additions & 3 deletions examples/scripts/gkd.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
ModelConfig,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
Expand All @@ -80,7 +79,6 @@
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

training_args.model_init_kwargs = model_kwargs
Expand All @@ -93,7 +91,6 @@
)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

training_args.teacher_model_init_kwargs = teacher_model_kwargs
Expand Down
3 changes: 0 additions & 3 deletions examples/scripts/gold.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
ModelConfig,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
Expand All @@ -84,7 +83,6 @@
attn_implementation=model_args.attn_implementation,
dtype=model_args.dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
training_args.model_init_kwargs = model_kwargs
Expand All @@ -96,7 +94,6 @@
attn_implementation=model_args.attn_implementation,
dtype=model_args.dtype,
use_cache=True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
if training_args.teacher_model_init_kwargs is not None:
Expand Down
2 changes: 0 additions & 2 deletions examples/scripts/grpo_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@
ModelConfig,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
Expand All @@ -93,7 +92,6 @@
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
training_args.model_init_kwargs["quantization_config"] = quantization_config

################
Expand Down
2 changes: 0 additions & 2 deletions examples/scripts/gspo.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
ModelConfig,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
Expand All @@ -82,7 +81,6 @@
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
training_args.model_init_kwargs["quantization_config"] = quantization_config

################
Expand Down
2 changes: 0 additions & 2 deletions examples/scripts/gspo_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
ModelConfig,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
Expand All @@ -82,7 +81,6 @@
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
training_args.model_init_kwargs["quantization_config"] = quantization_config

################
Expand Down
2 changes: 0 additions & 2 deletions examples/scripts/mpo_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
ModelConfig,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
Expand All @@ -73,7 +72,6 @@
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

model = AutoModelForImageTextToText.from_pretrained(
Expand Down
10 changes: 1 addition & 9 deletions examples/scripts/nash_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,7 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig

from trl import (
LogCompletionsCallback,
ModelConfig,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_quantization_config,
)
from trl import LogCompletionsCallback, ModelConfig, ScriptArguments, TrlParser, get_quantization_config
from trl.experimental.nash_md import NashMDConfig, NashMDTrainer


Expand All @@ -84,7 +77,6 @@
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
Expand Down
2 changes: 0 additions & 2 deletions examples/scripts/online_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
ModelConfig,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
Expand All @@ -79,7 +78,6 @@
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
Expand Down
2 changes: 0 additions & 2 deletions examples/scripts/online_dpo_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@
ModelConfig,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
Expand All @@ -113,7 +112,6 @@
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

# Load the VLM model using correct architecture (from GRPO pattern)
Expand Down
3 changes: 1 addition & 2 deletions examples/scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
HfArgumentParser,
)

from trl import ModelConfig, ScriptArguments, get_kbit_device_map, get_peft_config, get_quantization_config
from trl import ModelConfig, ScriptArguments, get_peft_config, get_quantization_config
from trl.experimental.ppo import PPOConfig, PPOTrainer


Expand Down Expand Up @@ -83,7 +83,6 @@
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, padding_side="left")
Expand Down
3 changes: 1 addition & 2 deletions examples/scripts/ppo/ppo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
HfArgumentParser,
)

from trl import ModelConfig, ScriptArguments, get_kbit_device_map, get_peft_config, get_quantization_config
from trl import ModelConfig, ScriptArguments, get_peft_config, get_quantization_config
from trl.experimental.ppo import PPOConfig, PPOTrainer


Expand Down Expand Up @@ -90,7 +90,6 @@
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, padding_side="left")
Expand Down
9 changes: 1 addition & 8 deletions examples/scripts/prm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,7 @@
from datasets import load_dataset
from transformers import AutoModelForTokenClassification, AutoTokenizer, HfArgumentParser

from trl import (
ModelConfig,
ScriptArguments,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl import ModelConfig, ScriptArguments, get_peft_config, get_quantization_config
from trl.experimental.prm import PRMConfig, PRMTrainer


Expand All @@ -80,7 +74,6 @@
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, use_fast=True)
Expand Down
11 changes: 1 addition & 10 deletions examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,7 @@
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, HfArgumentParser

from trl import (
ModelConfig,
RewardConfig,
RewardTrainer,
ScriptArguments,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl import ModelConfig, RewardConfig, RewardTrainer, ScriptArguments, get_peft_config, get_quantization_config


logger = logging.get_logger(__name__)
Expand All @@ -85,7 +77,6 @@
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

model = AutoModelForSequenceClassification.from_pretrained(
Expand Down
2 changes: 0 additions & 2 deletions examples/scripts/rloo_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@
RLOOTrainer,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
Expand All @@ -93,7 +92,6 @@
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
training_args.model_init_kwargs["quantization_config"] = quantization_config

################
Expand Down
2 changes: 0 additions & 2 deletions examples/scripts/sdft.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@
ModelConfig,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
Expand Down Expand Up @@ -331,7 +330,6 @@ def _run_tooluse_eval(
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

training_args.model_init_kwargs = model_kwargs
Expand Down
2 changes: 0 additions & 2 deletions examples/scripts/sdpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
ModelConfig,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
Expand Down Expand Up @@ -298,7 +297,6 @@ def _run_accuracy_eval(
)
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
training_args.model_init_kwargs["device_map"] = get_kbit_device_map()
training_args.model_init_kwargs["quantization_config"] = quantization_config

if script_args.dataset_path is not None:
Expand Down
3 changes: 1 addition & 2 deletions examples/scripts/sft_video_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
from qwen_vl_utils import process_vision_info
from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig, Qwen2VLProcessor

from trl import ModelConfig, ScriptArguments, SFTConfig, SFTTrainer, TrlParser, get_kbit_device_map
from trl import ModelConfig, ScriptArguments, SFTConfig, SFTTrainer, TrlParser


def download_video(url: str, cache_dir: str) -> str:
Expand Down Expand Up @@ -195,7 +195,6 @@ class CustomScriptArguments(ScriptArguments):
model_kwargs = dict(
revision=model_args.model_revision,
dtype=dtype,
device_map=get_kbit_device_map(),
quantization_config=bnb_config,
)

Expand Down
2 changes: 0 additions & 2 deletions examples/scripts/sft_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
SFTConfig,
SFTTrainer,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
Expand All @@ -86,7 +85,6 @@
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

model = AutoModelForImageTextToText.from_pretrained(model_args.model_name_or_path, **model_kwargs)
Expand Down
2 changes: 0 additions & 2 deletions examples/scripts/sft_vlm_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
SFTConfig,
SFTTrainer,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
Expand Down Expand Up @@ -152,7 +151,6 @@ def main():
quantization_config = get_quantization_config(model_args)
if quantization_config is not None:
# Passing None would not be treated the same as omitting the argument, so we include it only when valid.
model_kwargs["device_map"] = get_kbit_device_map()
model_kwargs["quantization_config"] = quantization_config

model = AutoModelForImageTextToText.from_pretrained(model_args.model_name_or_path, **model_kwargs)
Expand Down
Loading
Loading