-
-
Notifications
You must be signed in to change notification settings - Fork 4k
Closed
Labels
Description
- Did you update?
pip install --upgrade unsloth unsloth_zoo
Yes ColaborKaggleor local / cloud
local- Number GPUs used, use
nvidia-smi
1 - Which notebook? Please link!
https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_8B_FP8_GRPO.ipynb#scrollTo=iK3fwlTDM4zg - Which Unsloth version, TRL version, transformers version, PyTorch version?
unsloth==2025.11.3
trl==0.23.0
transformers==4.57.1
torch==2.9.0 - Which trainer?
SFTTrainer,GRPOTraineretc
SFTTrainer
from unsloth import FastLanguageModel, FastModel
import torch
max_seq_length = 1024*34
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Qwen3-30B-A3B-Instruct-2507-FP8",
max_seq_length = max_seq_length, # Choose any for long context!
load_in_4bit = False, # 4 bit quantization to reduce memory
load_in_fp8 = True,
) # able to get this line working by applying PR #3649, thank you bhuvanprakash and danielhanchen
model = FastLanguageModel.get_peft_model(
model,
r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
lora_alpha = 64,
lora_dropout = 0, # Supports any, but = 0 is optimized
bias = "none", # Supports any, but = "none" is optimized
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
random_state = 3407,
use_rslora = False, # We support rank stabilized LoRA
loftq_config = None, # And LoftQ
)
train_dataset: datasets.arrow_dataset.Dataset = ...
val_dataset: datasets.arrow_dataset.Dataset = ...
from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = train_dataset,
eval_dataset = val_dataset, # Can set up evaluation!
args = SFTConfig(
dataset_text_field = "text",
per_device_train_batch_size = 1,
per_device_eval_batch_size = 1,
gradient_accumulation_steps = 16, # Use GA to mimic batch size!
warmup_steps = 40,
# num_train_epochs = 1, # Set this for 1 full training run.
max_steps = 1400,
#max_length = max_seq_length,
learning_rate = 2e-5, # Reduce to 2e-5 for long training runs
fp16 = not torch.cuda.is_bf16_supported(),
bf16 = torch.cuda.is_bf16_supported(),
logging_steps = 25,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "/path/to/checkpoint",
save_strategy = "steps",
eval_strategy = "steps",
eval_steps=25,
save_steps=25,
do_eval=True
),
)
from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
trainer,
instruction_part = "<|im_start|>user\n",
response_part = "<|im_start|>assistant\n",
)
trainer_stats = trainer.train() # <----- bug occurs on this lineError:
The model is already on multiple devices. Skipping the move to device specified in `args`.
==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = 1
\\ /| Num examples = 7,221 | Num Epochs = 4 | Total steps = 1,400
O^O/ \_/ \ Batch size per device = 1 | Gradient accumulation steps = 16
\ / Data Parallel GPUs = 1 | Total batch size (1 x 16 x 1) = 16
"-____-" Trainable parameters = 26,738,688 of 30,560,686,080 (0.09% trained)
Unsloth: Will smartly offload gradients to save VRAM!
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[13], line 1
----> 1 trainer_stats = trainer.train()
File ~/notebooks/unsloth_compiled_cache/UnslothSFTTrainer.py:54, in prepare_for_training_mode.<locals>.wrapper(self, *args, **kwargs)
52 if hasattr(self, 'model') and hasattr(self.model, "for_training"):
53 self.model.for_training()
---> 54 output = f(self, *args, **kwargs)
55 # Return inference mode
56 if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/transformers/trainer.py:2325, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
2323 hf_hub_utils.enable_progress_bars()
2324 else:
-> 2325 return inner_training_loop(
2326 args=args,
2327 resume_from_checkpoint=resume_from_checkpoint,
2328 trial=trial,
2329 ignore_keys_for_eval=ignore_keys_for_eval,
2330 )
File <string>:328, in _fast_inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
File ~/notebooks/unsloth_compiled_cache/UnslothSFTTrainer.py:1092, in _UnslothSFTTrainer.training_step(self, *args, **kwargs)
1090 def training_step(self, *args, **kwargs):
1091 with self.maybe_activation_offload_context:
-> 1092 return super().training_step(*args, **kwargs)
File <string>:40, in _unsloth_training_step(self, model, inputs, num_items_in_batch)
File ~/notebooks/unsloth_compiled_cache/UnslothSFTTrainer.py:1081, in _UnslothSFTTrainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
1078 def compute_loss(
1079 self, model, inputs, return_outputs = False, num_items_in_batch = None
1080 ):
-> 1081 outputs = super().compute_loss(
1082 model,
1083 inputs,
1084 return_outputs = return_outputs,
1085 num_items_in_batch = num_items_in_batch,
1086 )
1087 return outputs
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/unsloth/models/_utils.py:1625, in _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs)
1618 name = inner_model.__class__.__name__
1620 logger.warning_once(
1621 f"Unsloth: Not an error, but {name} does not accept `num_items_in_batch`.\n"
1622 "Using gradient accumulation will be very slightly less accurate.\n"
1623 "Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient"
1624 )
-> 1625 outputs = self._old_compute_loss(model, inputs, *args, **kwargs)
1626 return outputs
File <string>:36, in compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/accelerate/utils/operations.py:819, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
818 def forward(*args, **kwargs):
--> 819 return model_forward(*args, **kwargs)
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/accelerate/utils/operations.py:807, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
806 def __call__(self, *args, **kwargs):
--> 807 return convert_to_fp32(self.model_forward(*args, **kwargs))
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/amp/autocast_mode.py:44, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
41 @functools.wraps(func)
42 def decorate_autocast(*args, **kwargs):
43 with autocast_instance:
---> 44 return func(*args, **kwargs)
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/peft/peft_model.py:1923, in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
1921 with self._enable_peft_forward_hooks(**kwargs):
1922 kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1923 return self.base_model(
1924 input_ids=input_ids,
1925 attention_mask=attention_mask,
1926 inputs_embeds=inputs_embeds,
1927 labels=labels,
1928 output_attentions=output_attentions,
1929 output_hidden_states=output_hidden_states,
1930 return_dict=return_dict,
1931 **kwargs,
1932 )
1934 batch_size = _get_batch_size(input_ids, inputs_embeds)
1935 if attention_mask is not None:
1936 # concat prompt attention mask
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/peft/tuners/tuners_utils.py:308, in BaseTuner.forward(self, *args, **kwargs)
307 def forward(self, *args: Any, **kwargs: Any):
--> 308 return self.model.forward(*args, **kwargs)
File ~/notebooks/unsloth_compiled_cache/unsloth_compiled_module_qwen3_moe.py:782, in Qwen3MoeForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_router_logits, cache_position, logits_to_keep, **kwargs)
768 def forward(
769 self,
770 input_ids: Optional[torch.LongTensor] = None,
(...) 780 **kwargs: Unpack[TransformersKwargs],
781 ) -> MoeCausalLMOutputWithPast:
--> 782 return Qwen3MoeForCausalLM_forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_router_logits, cache_position, logits_to_keep, **kwargs)
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/_dynamo/external_utils.py:198, in get_nonrecursive_disable_wrapper.<locals>.nonrecursive_disable_wrapper(*args, **kwargs)
196 @functools.wraps(fn)
197 def nonrecursive_disable_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
--> 198 return fn(*args, **kwargs)
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/transformers/utils/generic.py:918, in can_return_tuple.<locals>.wrapper(self, *args, **kwargs)
916 if return_dict_passed is not None:
917 return_dict = return_dict_passed
--> 918 output = func(self, *args, **kwargs)
919 if not return_dict and not isinstance(output, tuple):
920 output = output.to_tuple()
File ~/notebooks/unsloth_compiled_cache/unsloth_compiled_module_qwen3_moe.py:603, in Qwen3MoeForCausalLM_forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_router_logits, cache_position, logits_to_keep, **kwargs)
598 output_router_logits = (
599 output_router_logits if output_router_logits is not None else self.config.output_router_logits
600 )
602 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 603 outputs: MoeModelOutputWithPast = self.model(
604 input_ids=input_ids,
605 attention_mask=attention_mask,
606 position_ids=position_ids,
607 past_key_values=past_key_values,
608 inputs_embeds=inputs_embeds,
609 use_cache=use_cache,
610 output_router_logits=output_router_logits,
611 cache_position=cache_position,
612 **kwargs,
613 )
615 hidden_states = outputs.last_hidden_state
616 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/transformers/utils/generic.py:1064, in check_model_inputs.<locals>.wrapper(self, *args, **kwargs)
1061 monkey_patched_layers.append((module, original_forward))
1063 try:
-> 1064 outputs = func(self, *args, **kwargs)
1065 except TypeError as original_exception:
1066 # If we get a TypeError, it's possible that the model is not receiving the recordable kwargs correctly.
1067 # Get a TypeError even after removing the recordable kwargs -> re-raise the original exception
1068 # Otherwise -> we're probably missing `**kwargs` in the decorated function
1069 kwargs_without_recordable = {k: v for k, v in kwargs.items() if k not in recordable_keys}
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/transformers/models/qwen3_moe/modeling_qwen3_moe.py:487, in Qwen3MoeModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, cache_position, **kwargs)
484 position_embeddings = self.rotary_emb(hidden_states, position_ids)
486 for decoder_layer in self.layers[: self.config.num_hidden_layers]:
--> 487 hidden_states = decoder_layer(
488 hidden_states,
489 position_embeddings=position_embeddings,
490 attention_mask=causal_mask,
491 position_ids=position_ids,
492 past_key_values=past_key_values,
493 use_cache=use_cache,
494 cache_position=cache_position,
495 **kwargs,
496 )
498 hidden_states = self.norm(hidden_states)
500 return MoeModelOutputWithPast( # only diff with Mistral is the output type, we need MoE
501 last_hidden_state=hidden_states,
502 past_key_values=past_key_values,
503 )
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/transformers/modeling_layers.py:93, in GradientCheckpointingLayer.__call__(self, *args, **kwargs)
90 message = message.rstrip(",") + "."
91 logger.warning_once(message)
---> 93 return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
94 return super().__call__(*args, **kwargs)
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/_compile.py:53, in _disable_dynamo.<locals>.inner(*args, **kwargs)
50 disable_fn = torch._dynamo.disable(fn, recursive, wrapping=False)
51 fn.__dynamo_disable = disable_fn # type: ignore[attr-defined]
---> 53 return disable_fn(*args, **kwargs)
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:929, in DisableContext.__call__.<locals>._fn(*args, **kwargs)
927 _maybe_set_eval_frame(_callback_from_stance(self.callback))
928 try:
--> 929 return fn(*args, **kwargs)
930 finally:
931 set_eval_frame(None)
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/utils/checkpoint.py:488, in checkpoint(function, use_reentrant, context_fn, determinism_check, debug, *args, **kwargs)
483 if context_fn is not noop_context_fn or debug is not False:
484 raise ValueError(
485 "Passing `context_fn` or `debug` is only supported when "
486 "use_reentrant=False."
487 )
--> 488 return CheckpointFunction.apply(function, preserve, *args)
489 else:
490 gen = _checkpoint_without_reentrant_generator(
491 function, preserve, context_fn, determinism_check, debug, *args, **kwargs
492 )
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/autograd/function.py:576, in Function.apply(cls, *args, **kwargs)
573 if not torch._C._are_functorch_transforms_active():
574 # See NOTE: [functorch vjp and autograd interaction]
575 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 576 return super().apply(*args, **kwargs) # type: ignore[misc]
578 if not is_setup_ctx_defined:
579 raise RuntimeError(
580 "In order to use an autograd.Function with functorch transforms "
581 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
582 "staticmethod. For more details, please see "
583 "https://pytorch.org/docs/main/notes/extending.func.html"
584 )
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/unsloth_zoo/gradient_checkpointing.py:484, in UnslothCheckpointFunction.forward(ctx, run_function, preserve_rng_state, *args)
481 if ctx._requires_gradient: ctx.save_for_backward(*tensor_inputs)
483 with torch.no_grad():
--> 484 outputs = run_function(*args)
486 if use_gpu_buffer: MAIN_STREAM.wait_stream(EXTRA_STREAM)
487 return outputs
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/transformers/utils/deprecation.py:172, in deprecate_kwarg.<locals>.wrapper.<locals>.wrapped_func(*args, **kwargs)
168 elif minimum_action in (Action.NOTIFY, Action.NOTIFY_ALWAYS) and not is_torchdynamo_compiling():
169 # DeprecationWarning is ignored by default, so we use FutureWarning instead
170 warnings.warn(message, FutureWarning, stacklevel=2)
--> 172 return func(*args, **kwargs)
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/transformers/models/qwen3_moe/modeling_qwen3_moe.py:345, in Qwen3MoeDecoderLayer.forward(self, hidden_states, position_embeddings, attention_mask, position_ids, past_key_values, cache_position, **kwargs)
342 hidden_states = self.input_layernorm(hidden_states)
344 # Self Attention
--> 345 hidden_states, _ = self.self_attn(
346 hidden_states=hidden_states,
347 position_embeddings=position_embeddings,
348 attention_mask=attention_mask,
349 position_ids=position_ids,
350 past_key_values=past_key_values,
351 cache_position=cache_position,
352 **kwargs,
353 )
354 hidden_states = residual + hidden_states
356 # Fully Connected
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File ~/notebooks/unsloth_compiled_cache/unsloth_compiled_module_qwen3_moe.py:329, in Qwen3MoeAttention.forward(self, hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs)
320 def forward(
321 self,
322 hidden_states: torch.Tensor,
(...) 327 **kwargs: Unpack[FlashAttentionKwargs],
328 ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
--> 329 return Qwen3MoeAttention_forward(self, hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs)
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/_dynamo/external_utils.py:198, in get_nonrecursive_disable_wrapper.<locals>.nonrecursive_disable_wrapper(*args, **kwargs)
196 @functools.wraps(fn)
197 def nonrecursive_disable_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
--> 198 return fn(*args, **kwargs)
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/transformers/utils/deprecation.py:172, in deprecate_kwarg.<locals>.wrapper.<locals>.wrapped_func(*args, **kwargs)
168 elif minimum_action in (Action.NOTIFY, Action.NOTIFY_ALWAYS) and not is_torchdynamo_compiling():
169 # DeprecationWarning is ignored by default, so we use FutureWarning instead
170 warnings.warn(message, FutureWarning, stacklevel=2)
--> 172 return func(*args, **kwargs)
File ~/notebooks/unsloth_compiled_cache/unsloth_compiled_module_qwen3_moe.py:258, in Qwen3MoeAttention_forward(self, hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs)
255 input_shape = hidden_states.shape[:-1]
256 hidden_shape = (*input_shape, -1, self.head_dim)
--> 258 query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
259 key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
260 value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File ~/notebooks/unsloth_compiled_cache/Linear_peft_forward.py:74, in unsloth_forward(self, x, *args, **kwargs)
72 result = self.base_layer(x, *args, **kwargs)
73 else:
---> 74 result = self.base_layer(x, *args, **kwargs)
75 torch_result_dtype = result.dtype
77 lora_A_keys = self.lora_A.keys()
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> 1773 return self._call_impl(*args, **kwargs)
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1784 return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/unsloth/kernels/fp8.py:549, in module_forward_patch.<locals>.patched_forward(self, X)
548 def patched_forward(self, X):
--> 549 return forward_function(X, self.weight, getattr(self, scale_attr))
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:736, in _TorchDynamoContext.__call__.<locals>.compile_wrapper(*args, **kwargs)
733 _maybe_set_eval_frame(_callback_from_stance(callback))
735 try:
--> 736 return fn(*args, **kwargs)
737 except Unsupported as e:
738 if config.verbose:
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/unsloth/kernels/fp8.py:391, in fp8_torch_block_quant_forward(X, weight, weight_scale)
389 @torch_compile
390 def fp8_torch_block_quant_forward(X, weight, weight_scale):
--> 391 return FP8BlockQuantLinear.apply(X, weight, weight_scale)
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/torch/autograd/function.py:576, in Function.apply(cls, *args, **kwargs)
573 if not torch._C._are_functorch_transforms_active():
574 # See NOTE: [functorch vjp and autograd interaction]
575 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 576 return super().apply(*args, **kwargs) # type: ignore[misc]
578 if not is_setup_ctx_defined:
579 raise RuntimeError(
580 "In order to use an autograd.Function with functorch transforms "
581 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
582 "staticmethod. For more details, please see "
583 "https://pytorch.org/docs/main/notes/extending.func.html"
584 )
File ~/venv312_unsloth_3c658a6/lib/python3.12/site-packages/unsloth/kernels/fp8.py:350, in FP8BlockQuantLinear.forward(ctx, X, weight, weight_scale)
346 p, q = weight_scale.shape
347 block_size = getattr(weight, "block_size", None) or getattr(
348 weight_scale, "block_size", None
349 )
--> 350 assert block_size is not None, "block_size is not set"
351 if triton.cdiv(m, block_size[0]) != p or triton.cdiv(n, block_size[1]) != q:
352 if (
353 triton.cdiv(m, block_size[0]) == q
354 and triton.cdiv(n, block_size[1]) == p
355 ):
356 # weights are transposed during backward pass for training :)
357 # We transpose weight scale to counter that. Note that transposing weight would cause issues with matmul with input X
AssertionError: block_size is not set
🦥 You can also ask via our Reddit page: https://www.reddit.com/r/unsloth/