cciwon-code-review-cli 2.0.1 → 2.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. package/bin/code-review.js +1 -1
  2. package/lib/chat-mode.js +7 -2
  3. package/package.json +1 -1
  4. package/unsloth_compiled_cache/.locks/.lock.AqlmLoraLinear_peft_forward.py +0 -0
  5. package/unsloth_compiled_cache/.locks/.lock.AwqLoraLinear_peft_forward.py +0 -0
  6. package/unsloth_compiled_cache/.locks/.lock.BatchNorm1d.py +0 -0
  7. package/unsloth_compiled_cache/.locks/.lock.BatchNorm2d.py +0 -0
  8. package/unsloth_compiled_cache/.locks/.lock.BatchNorm3d.py +0 -0
  9. package/unsloth_compiled_cache/.locks/.lock.Conv1d.py +0 -0
  10. package/unsloth_compiled_cache/.locks/.lock.Conv2d.py +0 -0
  11. package/unsloth_compiled_cache/.locks/.lock.Conv3d.py +0 -0
  12. package/unsloth_compiled_cache/.locks/.lock.ConvTranspose1d.py +0 -0
  13. package/unsloth_compiled_cache/.locks/.lock.ConvTranspose2d.py +0 -0
  14. package/unsloth_compiled_cache/.locks/.lock.ConvTranspose3d.py +0 -0
  15. package/unsloth_compiled_cache/.locks/.lock.GPTQLoraLinear_peft_forward.py +0 -0
  16. package/unsloth_compiled_cache/.locks/.lock.GroupNorm.py +0 -0
  17. package/unsloth_compiled_cache/.locks/.lock.LayerNorm.py +0 -0
  18. package/unsloth_compiled_cache/.locks/.lock.Linear4bit_peft_forward.py +0 -0
  19. package/unsloth_compiled_cache/.locks/.lock.Linear8bitLt_peft_forward.py +0 -0
  20. package/unsloth_compiled_cache/.locks/.lock.Linear_peft_forward.py +0 -0
  21. package/unsloth_compiled_cache/.locks/.lock.LoraParallelLinear_peft_forward.py +0 -0
  22. package/unsloth_compiled_cache/.locks/.lock.RMSNorm.py +0 -0
  23. package/unsloth_compiled_cache/.locks/.lock.UnslothBCOTrainer.py +0 -0
  24. package/unsloth_compiled_cache/.locks/.lock.UnslothCPOTrainer.py +0 -0
  25. package/unsloth_compiled_cache/.locks/.lock.UnslothDPOTrainer.py +0 -0
  26. package/unsloth_compiled_cache/.locks/.lock.UnslothGKDTrainer.py +0 -0
  27. package/unsloth_compiled_cache/.locks/.lock.UnslothGRPOTrainer.py +0 -0
  28. package/unsloth_compiled_cache/.locks/.lock.UnslothKTOTrainer.py +0 -0
  29. package/unsloth_compiled_cache/.locks/.lock.UnslothNashMDTrainer.py +0 -0
  30. package/unsloth_compiled_cache/.locks/.lock.UnslothORPOTrainer.py +0 -0
  31. package/unsloth_compiled_cache/.locks/.lock.UnslothOnlineDPOTrainer.py +0 -0
  32. package/unsloth_compiled_cache/.locks/.lock.UnslothPPOTrainer.py +0 -0
  33. package/unsloth_compiled_cache/.locks/.lock.UnslothPRMTrainer.py +0 -0
  34. package/unsloth_compiled_cache/.locks/.lock.UnslothRLOOTrainer.py +0 -0
  35. package/unsloth_compiled_cache/.locks/.lock.UnslothRewardTrainer.py +0 -0
  36. package/unsloth_compiled_cache/.locks/.lock.UnslothSFTTrainer.py +0 -0
  37. package/unsloth_compiled_cache/.locks/.lock.UnslothXPOTrainer.py +0 -0
  38. package/unsloth_compiled_cache/.locks/.lock.unsloth_compiled_module_qwen3_moe.py +0 -0
  39. package/unsloth_compiled_cache/.locks/.lock.unsloth_compiled_module_siglip.py +0 -0
  40. package/unsloth_compiled_cache/AqlmLoraLinear_peft_forward.py +88 -0
  41. package/unsloth_compiled_cache/AwqLoraLinear_peft_forward.py +87 -0
  42. package/unsloth_compiled_cache/BatchNorm1d.py +117 -0
  43. package/unsloth_compiled_cache/BatchNorm2d.py +117 -0
  44. package/unsloth_compiled_cache/BatchNorm3d.py +117 -0
  45. package/unsloth_compiled_cache/Conv1d.py +70 -0
  46. package/unsloth_compiled_cache/Conv2d.py +70 -0
  47. package/unsloth_compiled_cache/Conv3d.py +70 -0
  48. package/unsloth_compiled_cache/ConvTranspose1d.py +97 -0
  49. package/unsloth_compiled_cache/ConvTranspose2d.py +106 -0
  50. package/unsloth_compiled_cache/ConvTranspose3d.py +98 -0
  51. package/unsloth_compiled_cache/GPTQLoraLinear_peft_forward.py +95 -0
  52. package/unsloth_compiled_cache/GroupNorm.py +70 -0
  53. package/unsloth_compiled_cache/LayerNorm.py +72 -0
  54. package/unsloth_compiled_cache/Linear4bit_peft_forward.py +115 -0
  55. package/unsloth_compiled_cache/Linear8bitLt_peft_forward.py +113 -0
  56. package/unsloth_compiled_cache/Linear_peft_forward.py +104 -0
  57. package/unsloth_compiled_cache/LoraParallelLinear_peft_forward.py +91 -0
  58. package/unsloth_compiled_cache/RMSNorm.py +73 -0
  59. package/unsloth_compiled_cache/UnslothBCOTrainer.py +2026 -0
  60. package/unsloth_compiled_cache/UnslothCPOTrainer.py +1806 -0
  61. package/unsloth_compiled_cache/UnslothDPOTrainer.py +2750 -0
  62. package/unsloth_compiled_cache/UnslothGKDTrainer.py +1157 -0
  63. package/unsloth_compiled_cache/UnslothGRPOTrainer.py +3607 -0
  64. package/unsloth_compiled_cache/UnslothKTOTrainer.py +2220 -0
  65. package/unsloth_compiled_cache/UnslothNashMDTrainer.py +1210 -0
  66. package/unsloth_compiled_cache/UnslothORPOTrainer.py +1730 -0
  67. package/unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +2313 -0
  68. package/unsloth_compiled_cache/UnslothPPOTrainer.py +1504 -0
  69. package/unsloth_compiled_cache/UnslothPRMTrainer.py +979 -0
  70. package/unsloth_compiled_cache/UnslothRLOOTrainer.py +2674 -0
  71. package/unsloth_compiled_cache/UnslothRewardTrainer.py +1197 -0
  72. package/unsloth_compiled_cache/UnslothSFTTrainer.py +1416 -0
  73. package/unsloth_compiled_cache/UnslothXPOTrainer.py +1255 -0
  74. package/unsloth_compiled_cache/__pycache__/AqlmLoraLinear_peft_forward.cpython-312.pyc +0 -0
  75. package/unsloth_compiled_cache/__pycache__/AwqLoraLinear_peft_forward.cpython-312.pyc +0 -0
  76. package/unsloth_compiled_cache/__pycache__/BatchNorm1d.cpython-312.pyc +0 -0
  77. package/unsloth_compiled_cache/__pycache__/BatchNorm2d.cpython-312.pyc +0 -0
  78. package/unsloth_compiled_cache/__pycache__/BatchNorm3d.cpython-312.pyc +0 -0
  79. package/unsloth_compiled_cache/__pycache__/Conv1d.cpython-312.pyc +0 -0
  80. package/unsloth_compiled_cache/__pycache__/Conv2d.cpython-312.pyc +0 -0
  81. package/unsloth_compiled_cache/__pycache__/Conv3d.cpython-312.pyc +0 -0
  82. package/unsloth_compiled_cache/__pycache__/ConvTranspose1d.cpython-312.pyc +0 -0
  83. package/unsloth_compiled_cache/__pycache__/ConvTranspose2d.cpython-312.pyc +0 -0
  84. package/unsloth_compiled_cache/__pycache__/ConvTranspose3d.cpython-312.pyc +0 -0
  85. package/unsloth_compiled_cache/__pycache__/GPTQLoraLinear_peft_forward.cpython-312.pyc +0 -0
  86. package/unsloth_compiled_cache/__pycache__/GroupNorm.cpython-312.pyc +0 -0
  87. package/unsloth_compiled_cache/__pycache__/LayerNorm.cpython-312.pyc +0 -0
  88. package/unsloth_compiled_cache/__pycache__/Linear4bit_peft_forward.cpython-312.pyc +0 -0
  89. package/unsloth_compiled_cache/__pycache__/Linear8bitLt_peft_forward.cpython-312.pyc +0 -0
  90. package/unsloth_compiled_cache/__pycache__/Linear_peft_forward.cpython-312.pyc +0 -0
  91. package/unsloth_compiled_cache/__pycache__/LoraParallelLinear_peft_forward.cpython-312.pyc +0 -0
  92. package/unsloth_compiled_cache/__pycache__/RMSNorm.cpython-312.pyc +0 -0
  93. package/unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-312.pyc +0 -0
  94. package/unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-312.pyc +0 -0
  95. package/unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-312.pyc +0 -0
  96. package/unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-312.pyc +0 -0
  97. package/unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-312.pyc +0 -0
  98. package/unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-312.pyc +0 -0
  99. package/unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-312.pyc +0 -0
  100. package/unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-312.pyc +0 -0
  101. package/unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-312.pyc +0 -0
  102. package/unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-312.pyc +0 -0
  103. package/unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-312.pyc +0 -0
  104. package/unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-312.pyc +0 -0
  105. package/unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-312.pyc +0 -0
  106. package/unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-312.pyc +0 -0
  107. package/unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-312.pyc +0 -0
  108. package/unsloth_compiled_cache/__pycache__/unsloth_compiled_module_qwen3_moe.cpython-312.pyc +0 -0
  109. package/unsloth_compiled_cache/__pycache__/unsloth_compiled_module_siglip.cpython-312.pyc +0 -0
  110. package/unsloth_compiled_cache/unsloth_compiled_module_qwen3_moe.py +726 -0
  111. package/unsloth_compiled_cache/unsloth_compiled_module_siglip.py +534 -0
@@ -0,0 +1,726 @@
1
+ """
2
+ 2025.12.6
3
+ 2025.12.7
4
+ 4.57.1
5
+ 0.24.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth auto generated code
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+
26
+ import os
27
+ import torch
28
+ import importlib.util
29
+ import math
30
+ if importlib.util.find_spec("unsloth_studio") is None:
31
+ UNSLOTH_STUDIO_ENABLED = False
32
+ else:
33
+ UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
34
+ pass
35
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
36
+ import math
37
+
38
+ UNSLOTH_ENABLE_LOGGING = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1"
39
+ UNSLOTH_ENABLE_CCE = os.environ.get("UNSLOTH_ENABLE_CCE", "1") == "1"
40
+ UNSLOTH_COMPILE_DISABLE = os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") in ("1", "partial",)
41
+
42
+ import logging
43
+ logger_compiler = logging.getLogger(__name__)
44
+ if UNSLOTH_ENABLE_LOGGING:
45
+ logger_compiler.setLevel(logging.DEBUG)
46
+
47
+ global INFERENCE_RUNS
48
+ INFERENCE_RUNS = 0
49
+
50
+ try:
51
+ import torch._dynamo.eval_frame as torch_dynamo_eval_frame
52
+ torch_dynamo_eval_frame._stance.stance
53
+ torch_compiler_set_stance = torch.compiler.set_stance
54
+ except:
55
+ torch_dynamo_eval_frame = None
56
+ torch_compiler_set_stance = None
57
+ pass
58
+
59
+ from unsloth_zoo import DEVICE_TYPE_TORCH, DEVICE_COUNT
60
+
61
+
62
+ from unsloth_zoo.loss_utils import (
63
+ fused_linear_cross_entropy,
64
+ unsloth_fused_ce_loss,
65
+ )
66
+
67
+ if UNSLOTH_STUDIO_ENABLED:
68
+ from unsloth_zoo.loss_utils import fast_linear_cross_entropy
69
+
70
+ scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
71
+ @torch.compiler.disable(recursive = False)
72
+ def disable_compile_scaled_dot_product_attention(*args, **kwargs):
73
+ return scaled_dot_product_attention(*args, **kwargs)
74
+ pass
75
+
76
+
77
+ from transformers.modeling_flash_attention_utils import is_flash_attn_available
78
+
79
+ if is_flash_attn_available():
80
+ try:
81
+ from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask
82
+ except:
83
+ flash_attn_supports_top_left_mask = None
84
+ try:
85
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
86
+ except:
87
+ _flash_attention_forward = None
88
+ try:
89
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
90
+ except:
91
+ FlashAttentionKwargs = None
92
+ try:
93
+ from transformers.modeling_flash_attention_utils import flash_attn_varlen_func
94
+ except:
95
+ flash_attn_varlen_func = None
96
+ else:
97
+ flash_attn_supports_top_left_mask = None
98
+ _flash_attention_forward = None
99
+ FlashAttentionKwargs = None
100
+ flash_attn_varlen_func = None
101
+ pass
102
+
103
+
104
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False, 'debug': False, 'dce': True, 'memory_planning': True, 'coordinate_descent_tuning': False, 'trace.graph_diagram': False, 'compile_threads': 32, 'group_fusion': True, 'disable_progress': True, 'verbose_progress': False, 'triton.multi_kernel': 0, 'triton.use_block_ptr': False, 'triton.enable_persistent_tma_matmul': True, 'triton.autotune_at_compile_time': False, 'triton.cooperative_reductions': False, 'cuda.compile_opt_level': '-O2', 'cuda.enable_cuda_lto': True, 'combo_kernels': False, 'benchmark_combo_kernel': True, 'combo_kernel_foreach_dynamic_shapes': True}
105
+
106
+ from torch.nn import CrossEntropyLoss
107
+
108
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
109
+ def normal_cross_entropy_loss(self, hidden_states, labels):
110
+ logits = self.lm_head(hidden_states)
111
+ logits = logits.float()
112
+ # Shift so that tokens < n predict n
113
+ shift_logits = logits[..., :-1, :].contiguous()
114
+ shift_labels = labels[..., 1:].contiguous()
115
+ # Flatten the tokens
116
+ loss_fct = CrossEntropyLoss()
117
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
118
+ shift_labels = shift_labels.view(-1)
119
+ # Enable model parallelism
120
+ shift_labels = shift_labels.to(shift_logits.device)
121
+ loss = loss_fct(shift_logits, shift_labels)
122
+ return loss, logits
123
+ pass
124
+
125
+ # We need an empty logits flag to warn people logits will not be returned anymore unless asked ie
126
+ # os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
127
+ LOGITS_ERROR_STRING = \
128
+ "Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "\
129
+ 'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n'\
130
+ "```\nimport os\n"\
131
+ "os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"\
132
+ "trainer.train()\n```\n"\
133
+ "No need to restart your console - just add `os.environ['UNSLOTH_RETURN_LOGITS'] = '1'` before trainer.train() and re-run the cell!"
134
+
135
+ def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING)
136
+ def return_none(*args, **kwargs): return None
137
+ class EmptyLogits:
138
+ def __init__(self): return
139
+ def raise_getattr_error(self, attr): return return_none if attr == "to" else raise_logits_error
140
+ __getitem__ = raise_logits_error
141
+ __getattr__ = raise_getattr_error
142
+ def __repr__(self): return LOGITS_ERROR_STRING
143
+ def __str__ (self): return LOGITS_ERROR_STRING
144
+ pass
145
+ EMPTY_LOGITS = EmptyLogits()
146
+ functions = dir(torch.Tensor)
147
+ for j, function in enumerate(functions):
148
+ if function.startswith("__") and function.endswith("__"):
149
+ exec(f"def raise_{j}(*args, **kwargs): print('{function}')", globals(), locals())
150
+ try: exec(f"EMPTY_LOGITS.{function} = raise_{j}", globals(), locals())
151
+ except: continue
152
+ pass
153
+
154
+
155
+ def mask_attention_mask_out(labels = None, attention_mask = None):
156
+ if labels is not None and attention_mask is not None:
157
+ attention_mask = attention_mask.to(device = labels.device)
158
+ labels[attention_mask == 0] = -100
159
+ return labels
160
+ pass
161
+
162
+
163
+ from torch import Tensor
164
+ import torch
165
+ import torch.nn as nn
166
+ from torch.nn import functional as F
167
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
168
+ from transformers.models.qwen3_moe.modeling_qwen3_moe import (__name__, F, Callable, Optional, Union, torch, nn, ACT2FN, Cache, GenerationMixin, use_kernel_forward_from_hub, FlashAttentionKwargs, MoeCausalLMOutputWithPast, MoeModelOutputWithPast, ROPE_INIT_FUNCTIONS, dynamic_rope_update, ALL_ATTENTION_FUNCTIONS, PreTrainedModel, Unpack, TransformersKwargs, can_return_tuple, deprecate_kwarg, Qwen3MoeConfig, Qwen3MoePreTrainedModel, Qwen3MoeModel, Qwen3MoeForCausalLM)
169
+
170
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
171
+ def rotate_half(x):
172
+ """Rotates half the hidden dims of the input."""
173
+ x1 = x[..., : x.shape[-1] // 2]
174
+ x2 = x[..., x.shape[-1] // 2 :]
175
+ return torch.cat((-x2, x1), dim=-1)
176
+
177
+
178
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
179
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
180
+ """Applies Rotary Position Embedding to the query and key tensors.
181
+
182
+ Args:
183
+ q (`torch.Tensor`): The query tensor.
184
+ k (`torch.Tensor`): The key tensor.
185
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
186
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
187
+ position_ids (`torch.Tensor`, *optional*):
188
+ Deprecated and unused.
189
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
190
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
191
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
192
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
193
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
194
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
195
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
196
+ Returns:
197
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
198
+ """
199
+ cos = cos.unsqueeze(unsqueeze_dim)
200
+ sin = sin.unsqueeze(unsqueeze_dim)
201
+ q_embed = (q * cos) + (rotate_half(q) * sin)
202
+ k_embed = (k * cos) + (rotate_half(k) * sin)
203
+ return q_embed, k_embed
204
+
205
+
206
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
207
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
208
+ """
209
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
210
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
211
+ """
212
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
213
+ if n_rep == 1:
214
+ return hidden_states
215
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
216
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
217
+
218
+
219
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
220
+ def eager_attention_forward(
221
+ module: nn.Module,
222
+ query: torch.Tensor,
223
+ key: torch.Tensor,
224
+ value: torch.Tensor,
225
+ attention_mask: Optional[torch.Tensor],
226
+ scaling: float,
227
+ dropout: float = 0.0,
228
+ **kwargs: Unpack[TransformersKwargs],
229
+ ):
230
+ key_states = repeat_kv(key, module.num_key_value_groups)
231
+ value_states = repeat_kv(value, module.num_key_value_groups)
232
+
233
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
234
+ if attention_mask is not None:
235
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
236
+ attn_weights = attn_weights + causal_mask
237
+
238
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype = torch.float32).to(attn_weights.dtype).to(query.dtype)
239
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
240
+ attn_output = torch.matmul(attn_weights, value_states)
241
+ attn_output = attn_output.transpose(1, 2).contiguous()
242
+
243
+ return attn_output, attn_weights
244
+
245
+
246
+ @torch.compiler.disable(recursive = False)
247
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
248
+ def Qwen3MoeAttention_forward(
249
+ self,
250
+ hidden_states: torch.Tensor,
251
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
252
+ attention_mask: Optional[torch.Tensor],
253
+ past_key_values: Optional[Cache] = None,
254
+ cache_position: Optional[torch.LongTensor] = None,
255
+ **kwargs: Unpack[FlashAttentionKwargs],
256
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
257
+ input_shape = hidden_states.shape[:-1]
258
+ hidden_shape = (*input_shape, -1, self.head_dim)
259
+
260
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
261
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
262
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
263
+
264
+ cos, sin = position_embeddings
265
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
266
+
267
+ if past_key_values is not None:
268
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
269
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
270
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
271
+
272
+ attention_interface: Callable = eager_attention_forward
273
+ if self.config._attn_implementation != "eager":
274
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
275
+
276
+ attn_output, attn_weights = attention_interface(
277
+ self,
278
+ query_states,
279
+ key_states,
280
+ value_states,
281
+ attention_mask,
282
+ dropout=0.0 if not self.training else self.attention_dropout,
283
+ scaling=self.scaling,
284
+ sliding_window=self.sliding_window, # diff with Llama
285
+ **kwargs,
286
+ )
287
+
288
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
289
+ attn_output = self.o_proj(attn_output)
290
+ return attn_output, attn_weights
291
+
292
+ class Qwen3MoeAttention(nn.Module):
293
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
294
+
295
+ def __init__(self, config: Qwen3MoeConfig, layer_idx: int):
296
+ super().__init__()
297
+ self.config = config
298
+ self.layer_idx = layer_idx
299
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
300
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
301
+ self.scaling = self.head_dim**-0.5
302
+ self.attention_dropout = config.attention_dropout
303
+ self.is_causal = True
304
+
305
+ self.q_proj = nn.Linear(
306
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
307
+ )
308
+ self.k_proj = nn.Linear(
309
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
310
+ )
311
+ self.v_proj = nn.Linear(
312
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
313
+ )
314
+ self.o_proj = nn.Linear(
315
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
316
+ )
317
+ self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
318
+ self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
319
+ self.sliding_window = getattr(config, "sliding_window", None)
320
+
321
+
322
+ def forward(
323
+ self,
324
+ hidden_states: torch.Tensor,
325
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
326
+ attention_mask: Optional[torch.Tensor],
327
+ past_key_values: Optional[Cache] = None,
328
+ cache_position: Optional[torch.LongTensor] = None,
329
+ **kwargs: Unpack[FlashAttentionKwargs],
330
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
331
+ return Qwen3MoeAttention_forward(self, hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs)
332
+
333
+
334
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
335
+ def Qwen3MoeRMSNorm_forward(self, hidden_states):
336
+ input_dtype = hidden_states.dtype
337
+ hidden_states = hidden_states.to(torch.float32)
338
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
339
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
340
+ return self.weight * hidden_states.to(input_dtype)
341
+
342
+ @use_kernel_forward_from_hub("RMSNorm")
343
+ class Qwen3MoeRMSNorm(nn.Module):
344
+ def __init__(self, hidden_size, eps=1e-6):
345
+ """
346
+ Qwen3MoeRMSNorm is equivalent to T5LayerNorm
347
+ """
348
+ super().__init__()
349
+ self.weight = nn.Parameter(torch.ones(hidden_size))
350
+ self.variance_epsilon = eps
351
+
352
+ def forward(self, hidden_states):
353
+ return Qwen3MoeRMSNorm_forward(self, hidden_states)
354
+
355
+ def extra_repr(self):
356
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
357
+
358
+
359
+ @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
360
+ def Qwen3MoeMLP_forward(self, x):
361
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
362
+ return down_proj
363
+
364
+ class Qwen3MoeMLP(nn.Module):
365
+ def __init__(self, config, intermediate_size=None):
366
+ super().__init__()
367
+ self.config = config
368
+ self.hidden_size = config.hidden_size
369
+ self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
370
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
371
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
372
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
373
+ self.act_fn = ACT2FN[config.hidden_act]
374
+
375
+ def forward(self, x):
376
+ return Qwen3MoeMLP_forward(self, x)
377
+
378
+
379
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
380
+ @torch.no_grad()
381
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
382
+ def Qwen3MoeRotaryEmbedding_forward(self, x, position_ids):
383
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
384
+ position_ids_expanded = position_ids[:, None, :].float()
385
+
386
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
387
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
388
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
389
+ emb = torch.cat((freqs, freqs), dim=-1)
390
+ cos = emb.cos() * self.attention_scaling
391
+ sin = emb.sin() * self.attention_scaling
392
+
393
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
394
+
395
+ class Qwen3MoeRotaryEmbedding(nn.Module):
396
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
397
+
398
+ def __init__(self, config: Qwen3MoeConfig, device=None):
399
+ super().__init__()
400
+ # BC: "rope_type" was originally "type"
401
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
402
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
403
+ else:
404
+ self.rope_type = "default"
405
+ self.max_seq_len_cached = config.max_position_embeddings
406
+ self.original_max_seq_len = config.max_position_embeddings
407
+
408
+ self.config = config
409
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
410
+
411
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
412
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
413
+ self.original_inv_freq = self.inv_freq
414
+
415
+
416
+ def forward(self, x, position_ids):
417
+ return Qwen3MoeRotaryEmbedding_forward(self, x, position_ids)
418
+
419
+
420
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
421
+ def load_balancing_loss_func(
422
+ gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
423
+ num_experts: Optional[int] = None,
424
+ top_k=2,
425
+ attention_mask: Optional[torch.Tensor] = None,
426
+ ) -> Union[torch.Tensor, int]:
427
+ r"""
428
+ Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
429
+
430
+ See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
431
+ function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
432
+ experts is too unbalanced.
433
+
434
+ Args:
435
+ gate_logits:
436
+ Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
437
+ shape [batch_size X sequence_length, num_experts].
438
+ num_experts:
439
+ Number of experts
440
+ top_k:
441
+ The number of experts to route per-token, can be also interpreted as the `top-k` routing
442
+ parameter.
443
+ attention_mask (`torch.Tensor`, *optional*):
444
+ The attention_mask used in forward function
445
+ shape [batch_size X sequence_length] if not None.
446
+
447
+ Returns:
448
+ The auxiliary loss.
449
+ """
450
+ if gate_logits is None or not isinstance(gate_logits, tuple):
451
+ return 0
452
+
453
+ if isinstance(gate_logits, tuple):
454
+ compute_device = gate_logits[0].device
455
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
456
+
457
+ routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1, dtype = torch.float32).to(concatenated_gate_logits.dtype)
458
+
459
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
460
+
461
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
462
+
463
+ if attention_mask is None:
464
+ # Compute the percentage of tokens routed to each experts
465
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
466
+
467
+ # Compute the average probability of routing to these experts
468
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
469
+ else:
470
+ batch_size, sequence_length = attention_mask.shape
471
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
472
+
473
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
474
+ expert_attention_mask = (
475
+ attention_mask[None, :, :, None, None]
476
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
477
+ .reshape(-1, top_k, num_experts)
478
+ .to(compute_device)
479
+ )
480
+
481
+ # Compute the percentage of tokens routed to each experts
482
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
483
+ expert_attention_mask, dim=0
484
+ )
485
+
486
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
487
+ router_per_expert_attention_mask = (
488
+ attention_mask[None, :, :, None]
489
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
490
+ .reshape(-1, num_experts)
491
+ .to(compute_device)
492
+ )
493
+
494
+ # Compute the average probability of routing to these experts
495
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
496
+ router_per_expert_attention_mask, dim=0
497
+ )
498
+
499
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
500
+ return overall_loss * num_experts
501
+
502
+
503
+ @torch.compiler.disable(recursive = False)
504
+ @can_return_tuple
505
+ def Qwen3MoeForCausalLM_forward(
506
+ self,
507
+ input_ids: Optional[torch.LongTensor] = None,
508
+ attention_mask: Optional[torch.Tensor] = None,
509
+ position_ids: Optional[torch.LongTensor] = None,
510
+ past_key_values: Optional[Cache] = None,
511
+ inputs_embeds: Optional[torch.FloatTensor] = None,
512
+ labels: Optional[torch.LongTensor] = None,
513
+ use_cache: Optional[bool] = None,
514
+ output_router_logits: Optional[bool] = None,
515
+ cache_position: Optional[torch.LongTensor] = None,
516
+ logits_to_keep: Union[int, torch.Tensor] = 0,
517
+ **kwargs: Unpack[TransformersKwargs],
518
+ ) -> MoeCausalLMOutputWithPast:
519
+ r"""
520
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
521
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
522
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
523
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
524
+
525
+ Example:
526
+
527
+ ```python
528
+ >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM
529
+
530
+ >>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
531
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
532
+
533
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
534
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
535
+
536
+ >>> # Generate
537
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
538
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
539
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
540
+ ```"""
541
+
542
+ output_router_logits = (
543
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
544
+ )
545
+
546
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
547
+ outputs: MoeModelOutputWithPast = self.model(
548
+ input_ids=input_ids,
549
+ attention_mask=attention_mask,
550
+ position_ids=position_ids,
551
+ past_key_values=past_key_values,
552
+ inputs_embeds=inputs_embeds,
553
+ use_cache=use_cache,
554
+ output_router_logits=output_router_logits,
555
+ cache_position=cache_position,
556
+ **kwargs,
557
+ )
558
+
559
+ hidden_states = outputs.last_hidden_state
560
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
561
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
562
+ logits = self.lm_head(hidden_states[:, slice_indices, :]) if os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '1' else EMPTY_LOGITS
563
+ loss = None
564
+ NOT_RETURN_LOGITS = os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0'
565
+ RETURN_HIDDEN_STATES = os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1"
566
+
567
+ n_items = None
568
+ if (kwargs) != () and type(kwargs) is dict:
569
+ n_items = (kwargs).get("num_items_in_batch", None) or (kwargs).get("n_items", None)
570
+ if n_items is None:
571
+ all_locals = locals()
572
+ if 'loss_kwargs' in all_locals:
573
+ __kwargs = all_locals['loss_kwargs']
574
+ if type(__kwargs) is dict:
575
+ n_items = __kwargs.get("num_items_in_batch", None)
576
+ if n_items is None: n_items = __kwargs.get("n_items", None)
577
+ if n_items is None and 'kwargs' in all_locals:
578
+ __kwargs = all_locals['kwargs']
579
+ if type(__kwargs) is dict:
580
+ n_items = __kwargs.get("num_items_in_batch", None)
581
+ if n_items is None: n_items = __kwargs.get("n_items", None)
582
+ if n_items is None:
583
+ all_locals = all_locals.values()
584
+ for __kwargs in all_locals:
585
+ if type(__kwargs) is dict:
586
+ n_items = __kwargs.get("num_items_in_batch", None)
587
+ if n_items is None: n_items = __kwargs.get("n_items", None)
588
+ break
589
+ pass
590
+
591
+ requires_grad_ = self.lm_head.weight.requires_grad
592
+ requires_grad_ = requires_grad_ or self.lm_head.weight.dtype == torch.float32
593
+
594
+ if RETURN_HIDDEN_STATES:
595
+ logits = hidden_states[:, slice_indices, :]
596
+ elif labels is None:
597
+
598
+
599
+ # Set compiler stance to fail on recompiles for inference
600
+ global INFERENCE_RUNS
601
+ if torch_dynamo_eval_frame is not None:
602
+ old_stance = torch_dynamo_eval_frame._stance.stance
603
+ else:
604
+ old_stance = None
605
+ if old_stance is not None and INFERENCE_RUNS == 1:
606
+ # Skip guards and return to eager -> we still need guards!
607
+ torch_compiler_set_stance(stance = "eager_on_recompile", skip_guard_eval_unsafe = False)
608
+ if UNSLOTH_ENABLE_LOGGING:
609
+ logger_compiler.info(
610
+ f"Unsloth: Removing compiler guards after 1 inference run. "\
611
+ f"DYNAMO_STANCE.stance = {torch_dynamo_eval_frame._stance.stance} "\
612
+ f"DYNAMO_STANCE.skip_guard_eval_unsafe = {torch_dynamo_eval_frame._stance.skip_guard_eval_unsafe}"
613
+ )
614
+ elif old_stance == "eager_on_recompile":
615
+ pass
616
+ elif old_stance == "default" and INFERENCE_RUNS > 1:
617
+ # Reset compiler stance
618
+ torch_compiler_set_stance(stance = "default", skip_guard_eval_unsafe = False)
619
+ if UNSLOTH_ENABLE_LOGGING:
620
+ logger_compiler.info(
621
+ f"Unsloth: Reseting guards. "\
622
+ f"DYNAMO_STANCE.stance = {torch_dynamo_eval_frame._stance.stance} "\
623
+ f"DYNAMO_STANCE.skip_guard_eval_unsafe = {torch_dynamo_eval_frame._stance.skip_guard_eval_unsafe}"
624
+ )
625
+ INFERENCE_RUNS = 0
626
+ INFERENCE_RUNS += 1
627
+
628
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
629
+ elif (() == () and () == ()) and (UNSLOTH_ENABLE_CCE) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and not requires_grad_:
630
+ loss = fused_linear_cross_entropy(
631
+ hidden_states = hidden_states[:, slice_indices, :],
632
+ lm_weight = self.lm_head.weight,
633
+ labels = labels.to(self.lm_head.weight.device),
634
+ num_items_in_batch = n_items,
635
+ logit_softcapping = None if () == () else (),
636
+ )
637
+ elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None:
638
+ lm_head_weight = self.lm_head.weight
639
+ lm_head_bias = getattr(self.lm_head, "bias", None)
640
+
641
+ # ========= NEW fused =========
642
+ _hidden_states = hidden_states[:, slice_indices, :]
643
+ torch._dynamo.mark_dynamic(_hidden_states, 1)
644
+ torch._dynamo.mark_dynamic(labels, 1)
645
+ loss = unsloth_fused_ce_loss(
646
+ trainer = None,
647
+ hidden_states = _hidden_states,
648
+ lm_head_weight = lm_head_weight,
649
+ lm_head_bias = lm_head_bias,
650
+ labels = labels,
651
+ mask = None,
652
+ n_items = n_items,
653
+ scaling = getattr(self, "accelerator_scaler", None),
654
+ target_gb = None,
655
+ torch_compile = not UNSLOTH_COMPILE_DISABLE,
656
+ logit_scale_multiply = () if () != () else 0,
657
+ logit_scale_divide = () if () != () else 0,
658
+ logit_softcapping = () if () != () else 0,
659
+ )
660
+ else:
661
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
662
+ if () != ():
663
+ logits = logits * ()
664
+ if () != ():
665
+ logits = logits / ()
666
+ if () not in (None, (),):
667
+ logits = logits / ()
668
+ logits = torch.tanh(logits)
669
+ logits = logits * ()
670
+ loss = self.loss_function(logits, labels.to(self.lm_head.weight.device), vocab_size=self.vocab_size, **kwargs)
671
+
672
+
673
+ aux_loss = None
674
+ if output_router_logits:
675
+ aux_loss = load_balancing_loss_func(
676
+ outputs.router_logits,
677
+ self.num_experts,
678
+ self.num_experts_per_tok,
679
+ attention_mask,
680
+ )
681
+ if labels is not None:
682
+ loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
683
+
684
+ return MoeCausalLMOutputWithPast(
685
+ loss=loss,
686
+ aux_loss=aux_loss,
687
+ logits=logits,
688
+ past_key_values=outputs.past_key_values,
689
+ hidden_states=outputs.hidden_states,
690
+ attentions=outputs.attentions,
691
+ router_logits=outputs.router_logits,
692
+ )
693
+
694
+ class Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin):
695
+ _tied_weights_keys = ["lm_head.weight"]
696
+ _tp_plan = {"lm_head": "colwise_rep"}
697
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
698
+
699
+ def __init__(self, config):
700
+ super().__init__(config)
701
+ self.model = Qwen3MoeModel(config)
702
+ self.vocab_size = config.vocab_size
703
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
704
+ self.router_aux_loss_coef = config.router_aux_loss_coef
705
+ self.num_experts = config.num_experts
706
+ self.num_experts_per_tok = config.num_experts_per_tok
707
+
708
+ # Initialize weights and apply final processing
709
+ self.post_init()
710
+
711
+
712
+ def forward(
713
+ self,
714
+ input_ids: Optional[torch.LongTensor] = None,
715
+ attention_mask: Optional[torch.Tensor] = None,
716
+ position_ids: Optional[torch.LongTensor] = None,
717
+ past_key_values: Optional[Cache] = None,
718
+ inputs_embeds: Optional[torch.FloatTensor] = None,
719
+ labels: Optional[torch.LongTensor] = None,
720
+ use_cache: Optional[bool] = None,
721
+ output_router_logits: Optional[bool] = None,
722
+ cache_position: Optional[torch.LongTensor] = None,
723
+ logits_to_keep: Union[int, torch.Tensor] = 0,
724
+ **kwargs: Unpack[TransformersKwargs],
725
+ ) -> MoeCausalLMOutputWithPast:
726
+ return Qwen3MoeForCausalLM_forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_router_logits, cache_position, logits_to_keep, **kwargs)