Skip to content

[Bug] AssertionError: block_size is not set #3656

@nole69

Description

@nole69
  1. Did you update? pip install --upgrade unsloth unsloth_zoo
    Yes
  2. Colab or Kaggle or local / cloud
    local
  3. Number GPUs used, use nvidia-smi
    1
  4. Which notebook? Please link!
    https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_8B_FP8_GRPO.ipynb#scrollTo=iK3fwlTDM4zg
  5. Which Unsloth version, TRL version, transformers version, PyTorch version?
    unsloth==2025.11.3
    trl==0.23.0
    transformers==4.57.1
    torch==2.9.0
  6. Which trainer? SFTTrainer, GRPOTrainer etc
    SFTTrainer
from unsloth import FastLanguageModel, FastModel
import torch

max_seq_length = 1024*34

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-30B-A3B-Instruct-2507-FP8",
    max_seq_length = max_seq_length, # Choose any for long context!
    load_in_4bit = False,  # 4 bit quantization to reduce memory
    load_in_fp8 = True,
)  # able to get this line working by applying PR #3649, thank you bhuvanprakash and danielhanchen

model = FastLanguageModel.get_peft_model(
    model,
    r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_alpha = 64,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

train_dataset: datasets.arrow_dataset.Dataset = ...
val_dataset: datasets.arrow_dataset.Dataset = ...

from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_dataset,
    eval_dataset = val_dataset, # Can set up evaluation!
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 1,
        per_device_eval_batch_size = 1,
        gradient_accumulation_steps = 16, # Use GA to mimic batch size!
        warmup_steps = 40,
        # num_train_epochs = 1, # Set this for 1 full training run.
        max_steps = 1400,
        #max_length = max_seq_length,
        learning_rate = 2e-5, # Reduce to 2e-5 for long training runs
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 25,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "/path/to/checkpoint",
        save_strategy = "steps",
        eval_strategy = "steps",
        eval_steps=25,
        save_steps=25,
        do_eval=True
    ),
)

from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    instruction_part = "<|im_start|>user\n",
    response_part = "<|im_start|>assistant\n",
)

trainer_stats = trainer.train()  # <----- bug occurs on this line

Error:

The model is already on multiple devices. Skipping the move to device specified in `args`.
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 7,221 | Num Epochs = 4 | Total steps = 1,400
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 16
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 16 x 1) = 16
 "-____-"     Trainable parameters = 26,738,688 of 30,560,686,080 (0.09% trained)
Unsloth: Will smartly offload gradients to save VRAM!
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[13], line 1
----> 1 trainer_stats = trainer.train()

File ~/notebooks/unsloth_compiled_cache/UnslothSFTTrainer.py:54, in prepare_for_training_mode.<locals>.wrapper(self, *args, **kwargs)
     52 if hasattr(self, 'model') and hasattr(self.model, "for_training"):
     53     self.model.for_training()
---> 54 output = f(self, *args, **kwargs)
     55 # Return inference mode
     56 if hasattr(self, 'model') and hasattr(self.model, "for_inference"):

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/transformers/trainer.py:2325, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2323         hf_hub_utils.enable_progress_bars()
   2324 else:
-> 2325     return inner_training_loop(
   2326         args=args,
   2327         resume_from_checkpoint=resume_from_checkpoint,
   2328         trial=trial,
   2329         ignore_keys_for_eval=ignore_keys_for_eval,
   2330     )

File <string>:328, in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)

File ~/notebooks/unsloth_compiled_cache/UnslothSFTTrainer.py:1092, in _UnslothSFTTrainer.training_step(self, *args, **kwargs)
   1090 def training_step(self, *args, **kwargs):
   1091     with self.maybe_activation_offload_context:
-> 1092         return super().training_step(*args, **kwargs)

File <string>:40, in _unsloth_training_step(self, model, inputs, num_items_in_batch)

File ~/notebooks/unsloth_compiled_cache/UnslothSFTTrainer.py:1081, in _UnslothSFTTrainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
   1078 def compute_loss(
   1079     self, model, inputs, return_outputs = False, num_items_in_batch = None
   1080 ):
-> 1081     outputs = super().compute_loss(
   1082         model,
   1083         inputs,
   1084         return_outputs = return_outputs,
   1085         num_items_in_batch = num_items_in_batch,
   1086     )
   1087     return outputs

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/unsloth/models/_utils.py:1625, in _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs)
   1618     name = inner_model.__class__.__name__
   1620     logger.warning_once(
   1621         f"Unsloth: Not an error, but {name} does not accept `num_items_in_batch`.\n"
   1622         "Using gradient accumulation will be very slightly less accurate.\n"
   1623         "Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient"
   1624     )
-> 1625 outputs = self._old_compute_loss(model, inputs, *args, **kwargs)
   1626 return outputs

File <string>:36, in compute_loss(self, model, inputs, return_outputs, num_items_in_batch)

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
   1771     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1772 else:
-> 1773     return self._call_impl(*args, **kwargs)

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
   1779 # If we don't have any hooks, we want to skip the rest of the logic in
   1780 # this function, and just call forward.
   1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1782         or _global_backward_pre_hooks or _global_backward_hooks
   1783         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784     return forward_call(*args, **kwargs)
   1786 result = None
   1787 called_always_called_hooks = set()

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/accelerate/utils/operations.py:819, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
    818 def forward(*args, **kwargs):
--> 819     return model_forward(*args, **kwargs)

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/accelerate/utils/operations.py:807, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
    806 def __call__(self, *args, **kwargs):
--> 807     return convert_to_fp32(self.model_forward(*args, **kwargs))

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/amp/autocast_mode.py:44, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     41 @functools.wraps(func)
     42 def decorate_autocast(*args, **kwargs):
     43     with autocast_instance:
---> 44         return func(*args, **kwargs)

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/peft/peft_model.py:1923, in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
   1921     with self._enable_peft_forward_hooks(**kwargs):
   1922         kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1923         return self.base_model(
   1924             input_ids=input_ids,
   1925             attention_mask=attention_mask,
   1926             inputs_embeds=inputs_embeds,
   1927             labels=labels,
   1928             output_attentions=output_attentions,
   1929             output_hidden_states=output_hidden_states,
   1930             return_dict=return_dict,
   1931             **kwargs,
   1932         )
   1934 batch_size = _get_batch_size(input_ids, inputs_embeds)
   1935 if attention_mask is not None:
   1936     # concat prompt attention mask

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
   1771     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1772 else:
-> 1773     return self._call_impl(*args, **kwargs)

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
   1779 # If we don't have any hooks, we want to skip the rest of the logic in
   1780 # this function, and just call forward.
   1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1782         or _global_backward_pre_hooks or _global_backward_hooks
   1783         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784     return forward_call(*args, **kwargs)
   1786 result = None
   1787 called_always_called_hooks = set()

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/peft/tuners/tuners_utils.py:308, in BaseTuner.forward(self, *args, **kwargs)
    307 def forward(self, *args: Any, **kwargs: Any):
--> 308     return self.model.forward(*args, **kwargs)

File ~/notebooks/unsloth_compiled_cache/unsloth_compiled_module_qwen3_moe.py:782, in Qwen3MoeForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_router_logits, cache_position, logits_to_keep, **kwargs)
    768 def forward(
    769     self,
    770     input_ids: Optional[torch.LongTensor] = None,
   (...)    780     **kwargs: Unpack[TransformersKwargs],
    781 ) -> MoeCausalLMOutputWithPast:
--> 782     return Qwen3MoeForCausalLM_forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_router_logits, cache_position, logits_to_keep, **kwargs)

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/_dynamo/external_utils.py:198, in get_nonrecursive_disable_wrapper.<locals>.nonrecursive_disable_wrapper(*args, **kwargs)
    196 @functools.wraps(fn)
    197 def nonrecursive_disable_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
--> 198     return fn(*args, **kwargs)

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/transformers/utils/generic.py:918, in can_return_tuple.<locals>.wrapper(self, *args, **kwargs)
    916 if return_dict_passed is not None:
    917     return_dict = return_dict_passed
--> 918 output = func(self, *args, **kwargs)
    919 if not return_dict and not isinstance(output, tuple):
    920     output = output.to_tuple()

File ~/notebooks/unsloth_compiled_cache/unsloth_compiled_module_qwen3_moe.py:603, in Qwen3MoeForCausalLM_forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_router_logits, cache_position, logits_to_keep, **kwargs)
    598 output_router_logits = (
    599     output_router_logits if output_router_logits is not None else self.config.output_router_logits
    600 )
    602 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 603 outputs: MoeModelOutputWithPast = self.model(
    604     input_ids=input_ids,
    605     attention_mask=attention_mask,
    606     position_ids=position_ids,
    607     past_key_values=past_key_values,
    608     inputs_embeds=inputs_embeds,
    609     use_cache=use_cache,
    610     output_router_logits=output_router_logits,
    611     cache_position=cache_position,
    612     **kwargs,
    613 )
    615 hidden_states = outputs.last_hidden_state
    616 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
   1771     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1772 else:
-> 1773     return self._call_impl(*args, **kwargs)

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
   1779 # If we don't have any hooks, we want to skip the rest of the logic in
   1780 # this function, and just call forward.
   1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1782         or _global_backward_pre_hooks or _global_backward_hooks
   1783         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784     return forward_call(*args, **kwargs)
   1786 result = None
   1787 called_always_called_hooks = set()

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/transformers/utils/generic.py:1064, in check_model_inputs.<locals>.wrapper(self, *args, **kwargs)
   1061                 monkey_patched_layers.append((module, original_forward))
   1063 try:
-> 1064     outputs = func(self, *args, **kwargs)
   1065 except TypeError as original_exception:
   1066     # If we get a TypeError, it's possible that the model is not receiving the recordable kwargs correctly.
   1067     # Get a TypeError even after removing the recordable kwargs -> re-raise the original exception
   1068     # Otherwise -> we're probably missing `**kwargs` in the decorated function
   1069     kwargs_without_recordable = {k: v for k, v in kwargs.items() if k not in recordable_keys}

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/transformers/models/qwen3_moe/modeling_qwen3_moe.py:487, in Qwen3MoeModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, cache_position, **kwargs)
    484 position_embeddings = self.rotary_emb(hidden_states, position_ids)
    486 for decoder_layer in self.layers[: self.config.num_hidden_layers]:
--> 487     hidden_states = decoder_layer(
    488         hidden_states,
    489         position_embeddings=position_embeddings,
    490         attention_mask=causal_mask,
    491         position_ids=position_ids,
    492         past_key_values=past_key_values,
    493         use_cache=use_cache,
    494         cache_position=cache_position,
    495         **kwargs,
    496     )
    498 hidden_states = self.norm(hidden_states)
    500 return MoeModelOutputWithPast(  # only diff with Mistral is the output type, we need MoE
    501     last_hidden_state=hidden_states,
    502     past_key_values=past_key_values,
    503 )

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/transformers/modeling_layers.py:93, in GradientCheckpointingLayer.__call__(self, *args, **kwargs)
     90         message = message.rstrip(",") + "."
     91         logger.warning_once(message)
---> 93     return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
     94 return super().__call__(*args, **kwargs)

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/_compile.py:53, in _disable_dynamo.<locals>.inner(*args, **kwargs)
     50     disable_fn = torch._dynamo.disable(fn, recursive, wrapping=False)
     51     fn.__dynamo_disable = disable_fn  # type: ignore[attr-defined]
---> 53 return disable_fn(*args, **kwargs)

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:929, in DisableContext.__call__.<locals>._fn(*args, **kwargs)
    927 _maybe_set_eval_frame(_callback_from_stance(self.callback))
    928 try:
--> 929     return fn(*args, **kwargs)
    930 finally:
    931     set_eval_frame(None)

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/utils/checkpoint.py:488, in checkpoint(function, use_reentrant, context_fn, determinism_check, debug, *args, **kwargs)
    483     if context_fn is not noop_context_fn or debug is not False:
    484         raise ValueError(
    485             "Passing `context_fn` or `debug` is only supported when "
    486             "use_reentrant=False."
    487         )
--> 488     return CheckpointFunction.apply(function, preserve, *args)
    489 else:
    490     gen = _checkpoint_without_reentrant_generator(
    491         function, preserve, context_fn, determinism_check, debug, *args, **kwargs
    492     )

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/autograd/function.py:576, in Function.apply(cls, *args, **kwargs)
    573 if not torch._C._are_functorch_transforms_active():
    574     # See NOTE: [functorch vjp and autograd interaction]
    575     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 576     return super().apply(*args, **kwargs)  # type: ignore[misc]
    578 if not is_setup_ctx_defined:
    579     raise RuntimeError(
    580         "In order to use an autograd.Function with functorch transforms "
    581         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    582         "staticmethod. For more details, please see "
    583         "https://pytorch.org/docs/main/notes/extending.func.html"
    584     )

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/unsloth_zoo/gradient_checkpointing.py:484, in UnslothCheckpointFunction.forward(ctx, run_function, preserve_rng_state, *args)
    481 if ctx._requires_gradient: ctx.save_for_backward(*tensor_inputs)
    483 with torch.no_grad():
--> 484     outputs = run_function(*args)
    486 if use_gpu_buffer: MAIN_STREAM.wait_stream(EXTRA_STREAM)
    487 return outputs

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
   1771     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1772 else:
-> 1773     return self._call_impl(*args, **kwargs)

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
   1779 # If we don't have any hooks, we want to skip the rest of the logic in
   1780 # this function, and just call forward.
   1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1782         or _global_backward_pre_hooks or _global_backward_hooks
   1783         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784     return forward_call(*args, **kwargs)
   1786 result = None
   1787 called_always_called_hooks = set()

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/transformers/utils/deprecation.py:172, in deprecate_kwarg.<locals>.wrapper.<locals>.wrapped_func(*args, **kwargs)
    168 elif minimum_action in (Action.NOTIFY, Action.NOTIFY_ALWAYS) and not is_torchdynamo_compiling():
    169     # DeprecationWarning is ignored by default, so we use FutureWarning instead
    170     warnings.warn(message, FutureWarning, stacklevel=2)
--> 172 return func(*args, **kwargs)

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/transformers/models/qwen3_moe/modeling_qwen3_moe.py:345, in Qwen3MoeDecoderLayer.forward(self, hidden_states, position_embeddings, attention_mask, position_ids, past_key_values, cache_position, **kwargs)
    342 hidden_states = self.input_layernorm(hidden_states)
    344 # Self Attention
--> 345 hidden_states, _ = self.self_attn(
    346     hidden_states=hidden_states,
    347     position_embeddings=position_embeddings,
    348     attention_mask=attention_mask,
    349     position_ids=position_ids,
    350     past_key_values=past_key_values,
    351     cache_position=cache_position,
    352     **kwargs,
    353 )
    354 hidden_states = residual + hidden_states
    356 # Fully Connected

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
   1771     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1772 else:
-> 1773     return self._call_impl(*args, **kwargs)

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
   1779 # If we don't have any hooks, we want to skip the rest of the logic in
   1780 # this function, and just call forward.
   1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1782         or _global_backward_pre_hooks or _global_backward_hooks
   1783         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784     return forward_call(*args, **kwargs)
   1786 result = None
   1787 called_always_called_hooks = set()

File ~/notebooks/unsloth_compiled_cache/unsloth_compiled_module_qwen3_moe.py:329, in Qwen3MoeAttention.forward(self, hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs)
    320 def forward(
    321     self,
    322     hidden_states: torch.Tensor,
   (...)    327     **kwargs: Unpack[FlashAttentionKwargs],
    328 ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
--> 329     return Qwen3MoeAttention_forward(self, hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs)

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/_dynamo/external_utils.py:198, in get_nonrecursive_disable_wrapper.<locals>.nonrecursive_disable_wrapper(*args, **kwargs)
    196 @functools.wraps(fn)
    197 def nonrecursive_disable_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
--> 198     return fn(*args, **kwargs)

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/transformers/utils/deprecation.py:172, in deprecate_kwarg.<locals>.wrapper.<locals>.wrapped_func(*args, **kwargs)
    168 elif minimum_action in (Action.NOTIFY, Action.NOTIFY_ALWAYS) and not is_torchdynamo_compiling():
    169     # DeprecationWarning is ignored by default, so we use FutureWarning instead
    170     warnings.warn(message, FutureWarning, stacklevel=2)
--> 172 return func(*args, **kwargs)

File ~/notebooks/unsloth_compiled_cache/unsloth_compiled_module_qwen3_moe.py:258, in Qwen3MoeAttention_forward(self, hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs)
    255 input_shape = hidden_states.shape[:-1]
    256 hidden_shape = (*input_shape, -1, self.head_dim)
--> 258 query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
    259 key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
    260 value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
   1771     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1772 else:
-> 1773     return self._call_impl(*args, **kwargs)

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
   1779 # If we don't have any hooks, we want to skip the rest of the logic in
   1780 # this function, and just call forward.
   1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1782         or _global_backward_pre_hooks or _global_backward_hooks
   1783         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784     return forward_call(*args, **kwargs)
   1786 result = None
   1787 called_always_called_hooks = set()

File ~/notebooks/unsloth_compiled_cache/Linear_peft_forward.py:74, in unsloth_forward(self, x, *args, **kwargs)
     72     result = self.base_layer(x, *args, **kwargs)
     73 else:
---> 74     result = self.base_layer(x, *args, **kwargs)
     75     torch_result_dtype = result.dtype
     77     lora_A_keys = self.lora_A.keys()

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
   1771     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1772 else:
-> 1773     return self._call_impl(*args, **kwargs)

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
   1779 # If we don't have any hooks, we want to skip the rest of the logic in
   1780 # this function, and just call forward.
   1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1782         or _global_backward_pre_hooks or _global_backward_hooks
   1783         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784     return forward_call(*args, **kwargs)
   1786 result = None
   1787 called_always_called_hooks = set()

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/unsloth/kernels/fp8.py:549, in module_forward_patch.<locals>.patched_forward(self, X)
    548 def patched_forward(self, X):
--> 549     return forward_function(X, self.weight, getattr(self, scale_attr))

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:736, in _TorchDynamoContext.__call__.<locals>.compile_wrapper(*args, **kwargs)
    733 _maybe_set_eval_frame(_callback_from_stance(callback))
    735 try:
--> 736     return fn(*args, **kwargs)
    737 except Unsupported as e:
    738     if config.verbose:

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/unsloth/kernels/fp8.py:391, in fp8_torch_block_quant_forward(X, weight, weight_scale)
    389 @torch_compile
    390 def fp8_torch_block_quant_forward(X, weight, weight_scale):
--> 391     return FP8BlockQuantLinear.apply(X, weight, weight_scale)

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/autograd/function.py:576, in Function.apply(cls, *args, **kwargs)
    573 if not torch._C._are_functorch_transforms_active():
    574     # See NOTE: [functorch vjp and autograd interaction]
    575     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 576     return super().apply(*args, **kwargs)  # type: ignore[misc]
    578 if not is_setup_ctx_defined:
    579     raise RuntimeError(
    580         "In order to use an autograd.Function with functorch transforms "
    581         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    582         "staticmethod. For more details, please see "
    583         "https://pytorch.org/docs/main/notes/extending.func.html"
    584     )

File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/unsloth/kernels/fp8.py:350, in FP8BlockQuantLinear.forward(ctx, X, weight, weight_scale)
    346 p, q = weight_scale.shape
    347 block_size = getattr(weight, "block_size", None) or getattr(
    348     weight_scale, "block_size", None
    349 )
--> 350 assert block_size is not None, "block_size is not set"
    351 if triton.cdiv(m, block_size[0]) != p or triton.cdiv(n, block_size[1]) != q:
    352     if (
    353         triton.cdiv(m, block_size[0]) == q
    354         and triton.cdiv(n, block_size[1]) == p
    355     ):
    356         # weights are transposed during backward pass for training :)
    357         # We transpose weight scale to counter that. Note that transposing weight would cause issues with matmul with input X

AssertionError: block_size is not set

🦥 You can also ask via our Reddit page: https://www.reddit.com/r/unsloth/

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions