Skip to content

Remove redundant get_kbit_device_map()#6158

Open
qgallouedec wants to merge 8 commits into
mainfrom
drop-redundant-kbit-device-map
Open

Remove redundant get_kbit_device_map()#6158
qgallouedec wants to merge 8 commits into
mainfrom
drop-redundant-kbit-device-map

Conversation

@qgallouedec

@qgallouedec qgallouedec commented Jun 24, 2026

Copy link
Copy Markdown
Member

QLoRA scripts set model_init_kwargs["device_map"] = get_kbit_device_map() (which returns {"": local_process_index}) before loading the model. This is redundant: it never changes where the model lands. This PR removes the call everywhere and deletes the helper itself.

Scope

  • examples/scripts/ (24 files) and trl/scripts/ (5 files): drop the device_map line + import.
  • tests/test_grpo_trainer.py: drop the device_map=get_kbit_device_map() kwarg + import.
  • trl/trainer/utils.py: delete get_kbit_device_map.
  • trl/__init__.py, trl/trainer/__init__.py: remove the public export (so from trl import get_kbit_device_map no longer resolves — minor breaking change for a thin internal helper).

Background

A bitsandbytes-quantized model is placed at load time and can't be moved afterward, so QLoRA needed an explicit device_map. Born in #373 (2023) as inline device_map = {"": 0} on a quantized reward model, generalized to {"": local_process_index} for DDP in #725, then extracted into get_kbit_device_map() in #1176. A real 2023-era placement problem — which the trainer path now handles on its own (see below).

Why it's redundant

Two cases, exhaustively:

  • Distributed (MULTI_GPU / DEEPSPEED) — the trainer overrides model_init_kwargs["device_map"] = None before from_pretrained, so the value from get_kbit_device_map() is discarded regardless.
  • Single process — recent transformers auto-places quantized weights on the current CUDA device, which is exactly the {"": 0} the helper would have set.

How it was demonstrated

Setup: 8×H100, Qwen/Qwen2.5-0.5B, --load_in_4bit, bitsandbytes 0.49.2. Every check was run on both transformers==4.56.2 (the minimum TRL supports) and 5.13 (current), with identical outcomes.

  1. Single-GPU load, no device_map : from_pretrained(quantization_config=BitsAndBytesConfig(load_in_4bit=True)) with no device_map: parameters land on cuda:0 (not CPU) and a forward pass succeeds. Rules out the "without device_map it falls back to CPU and breaks" assumption.

  2. DDP placement, no device_map: 2-process run where the trainer forces device_map=None: each rank still loads onto its own local GPU (rank 0 → cuda:0, rank 1 → cuda:1) and training steps run. Per-rank placement does not depend on get_kbit_device_map().

  3. Real script, with vs without the line: trl/scripts/sft.py on 2 GPUs (accelerate launch, --load_in_4bit --use_peft), baseline vs. a copy with the line removed. It runs in both cases and the losses are identical step-for-step:

    step with line without line
    1 2.107 2.107
    2 2.756 2.756
    3 1.775 1.775
    final train_loss 2.213 2.213

Note

Most edited scripts are trainer-loaded or pass the dict through model_init_kwargs (the path covered by checks 2–3). A few load directly via from_pretrained (e.g. sft_video_llm.py, the GRPO VLM test); these are covered by check 1 for single-GPU and are reasoned-safe under DDP (the per-rank device is set before the model loads).


Note

Medium Risk
Minor public API removal and reliance on transformers/trainer device placement for quantized loads; behavior was validated in the PR but QLoRA + multi-GPU remains sensitive.

Overview
Removes the get_kbit_device_map() helper and stops setting device_map from QLoRA example/CLI scripts when quantization_config is present. quantization_config is still passed unchanged; placement is left to transformers on single process and to trainers that force device_map=None under MULTI_GPU / DEEPSPEED.

from trl import get_kbit_device_map is no longer exported (minor breaking change for external callers).

DistillationTrainer and GOLDTrainer now explicitly set device_map=None when loading student/teacher from a path in distributed runs, matching the pattern already used in core trainers like DPO / GRPO.

Reviewed by Cursor Bugbot for commit 627c405. Bugbot is set up for automated code reviews on this repo. Configure here.

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default effort and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 53275a2. Configure here.

Comment thread examples/scripts/distillation.py
@bot-ci-comment

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant