cciwon-code-review-cli 2.0.1 → 2.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/code-review.js +1 -1
- package/lib/chat-mode.js +7 -2
- package/package.json +1 -1
- package/unsloth_compiled_cache/.locks/.lock.AqlmLoraLinear_peft_forward.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.AwqLoraLinear_peft_forward.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.BatchNorm1d.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.BatchNorm2d.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.BatchNorm3d.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.Conv1d.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.Conv2d.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.Conv3d.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.ConvTranspose1d.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.ConvTranspose2d.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.ConvTranspose3d.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.GPTQLoraLinear_peft_forward.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.GroupNorm.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.LayerNorm.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.Linear4bit_peft_forward.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.Linear8bitLt_peft_forward.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.Linear_peft_forward.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.LoraParallelLinear_peft_forward.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.RMSNorm.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothBCOTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothCPOTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothDPOTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothGKDTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothGRPOTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothKTOTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothNashMDTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothORPOTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothOnlineDPOTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothPPOTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothPRMTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothRLOOTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothRewardTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothSFTTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothXPOTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.unsloth_compiled_module_qwen3_moe.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.unsloth_compiled_module_siglip.py +0 -0
- package/unsloth_compiled_cache/AqlmLoraLinear_peft_forward.py +88 -0
- package/unsloth_compiled_cache/AwqLoraLinear_peft_forward.py +87 -0
- package/unsloth_compiled_cache/BatchNorm1d.py +117 -0
- package/unsloth_compiled_cache/BatchNorm2d.py +117 -0
- package/unsloth_compiled_cache/BatchNorm3d.py +117 -0
- package/unsloth_compiled_cache/Conv1d.py +70 -0
- package/unsloth_compiled_cache/Conv2d.py +70 -0
- package/unsloth_compiled_cache/Conv3d.py +70 -0
- package/unsloth_compiled_cache/ConvTranspose1d.py +97 -0
- package/unsloth_compiled_cache/ConvTranspose2d.py +106 -0
- package/unsloth_compiled_cache/ConvTranspose3d.py +98 -0
- package/unsloth_compiled_cache/GPTQLoraLinear_peft_forward.py +95 -0
- package/unsloth_compiled_cache/GroupNorm.py +70 -0
- package/unsloth_compiled_cache/LayerNorm.py +72 -0
- package/unsloth_compiled_cache/Linear4bit_peft_forward.py +115 -0
- package/unsloth_compiled_cache/Linear8bitLt_peft_forward.py +113 -0
- package/unsloth_compiled_cache/Linear_peft_forward.py +104 -0
- package/unsloth_compiled_cache/LoraParallelLinear_peft_forward.py +91 -0
- package/unsloth_compiled_cache/RMSNorm.py +73 -0
- package/unsloth_compiled_cache/UnslothBCOTrainer.py +2026 -0
- package/unsloth_compiled_cache/UnslothCPOTrainer.py +1806 -0
- package/unsloth_compiled_cache/UnslothDPOTrainer.py +2750 -0
- package/unsloth_compiled_cache/UnslothGKDTrainer.py +1157 -0
- package/unsloth_compiled_cache/UnslothGRPOTrainer.py +3607 -0
- package/unsloth_compiled_cache/UnslothKTOTrainer.py +2220 -0
- package/unsloth_compiled_cache/UnslothNashMDTrainer.py +1210 -0
- package/unsloth_compiled_cache/UnslothORPOTrainer.py +1730 -0
- package/unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +2313 -0
- package/unsloth_compiled_cache/UnslothPPOTrainer.py +1504 -0
- package/unsloth_compiled_cache/UnslothPRMTrainer.py +979 -0
- package/unsloth_compiled_cache/UnslothRLOOTrainer.py +2674 -0
- package/unsloth_compiled_cache/UnslothRewardTrainer.py +1197 -0
- package/unsloth_compiled_cache/UnslothSFTTrainer.py +1416 -0
- package/unsloth_compiled_cache/UnslothXPOTrainer.py +1255 -0
- package/unsloth_compiled_cache/__pycache__/AqlmLoraLinear_peft_forward.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/AwqLoraLinear_peft_forward.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/BatchNorm1d.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/BatchNorm2d.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/BatchNorm3d.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/Conv1d.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/Conv2d.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/Conv3d.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/ConvTranspose1d.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/ConvTranspose2d.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/ConvTranspose3d.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/GPTQLoraLinear_peft_forward.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/GroupNorm.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/LayerNorm.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/Linear4bit_peft_forward.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/Linear8bitLt_peft_forward.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/Linear_peft_forward.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/LoraParallelLinear_peft_forward.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/RMSNorm.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/unsloth_compiled_module_qwen3_moe.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/unsloth_compiled_module_siglip.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/unsloth_compiled_module_qwen3_moe.py +726 -0
- package/unsloth_compiled_cache/unsloth_compiled_module_siglip.py +534 -0
|
@@ -0,0 +1,726 @@
|
|
|
1
|
+
"""
|
|
2
|
+
2025.12.6
|
|
3
|
+
2025.12.7
|
|
4
|
+
4.57.1
|
|
5
|
+
0.24.0
|
|
6
|
+
__UNSLOTH_VERSIONING__
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
# Unsloth auto generated code
|
|
10
|
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
|
11
|
+
#
|
|
12
|
+
# This program is free software: you can redistribute it and/or modify
|
|
13
|
+
# it under the terms of the GNU Lesser General Public License as published by
|
|
14
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
15
|
+
# (at your option) any later version.
|
|
16
|
+
#
|
|
17
|
+
# This program is distributed in the hope that it will be useful,
|
|
18
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
19
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
20
|
+
# GNU General Public License for more details.
|
|
21
|
+
#
|
|
22
|
+
# You should have received a copy of the GNU Lesser General Public License
|
|
23
|
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
import os
|
|
27
|
+
import torch
|
|
28
|
+
import importlib.util
|
|
29
|
+
import math
|
|
30
|
+
if importlib.util.find_spec("unsloth_studio") is None:
|
|
31
|
+
UNSLOTH_STUDIO_ENABLED = False
|
|
32
|
+
else:
|
|
33
|
+
UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
|
|
34
|
+
pass
|
|
35
|
+
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
|
36
|
+
import math
|
|
37
|
+
|
|
38
|
+
UNSLOTH_ENABLE_LOGGING = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1"
|
|
39
|
+
UNSLOTH_ENABLE_CCE = os.environ.get("UNSLOTH_ENABLE_CCE", "1") == "1"
|
|
40
|
+
UNSLOTH_COMPILE_DISABLE = os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") in ("1", "partial",)
|
|
41
|
+
|
|
42
|
+
import logging
|
|
43
|
+
logger_compiler = logging.getLogger(__name__)
|
|
44
|
+
if UNSLOTH_ENABLE_LOGGING:
|
|
45
|
+
logger_compiler.setLevel(logging.DEBUG)
|
|
46
|
+
|
|
47
|
+
global INFERENCE_RUNS
|
|
48
|
+
INFERENCE_RUNS = 0
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
import torch._dynamo.eval_frame as torch_dynamo_eval_frame
|
|
52
|
+
torch_dynamo_eval_frame._stance.stance
|
|
53
|
+
torch_compiler_set_stance = torch.compiler.set_stance
|
|
54
|
+
except:
|
|
55
|
+
torch_dynamo_eval_frame = None
|
|
56
|
+
torch_compiler_set_stance = None
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
from unsloth_zoo import DEVICE_TYPE_TORCH, DEVICE_COUNT
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
from unsloth_zoo.loss_utils import (
|
|
63
|
+
fused_linear_cross_entropy,
|
|
64
|
+
unsloth_fused_ce_loss,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
if UNSLOTH_STUDIO_ENABLED:
|
|
68
|
+
from unsloth_zoo.loss_utils import fast_linear_cross_entropy
|
|
69
|
+
|
|
70
|
+
scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
|
71
|
+
@torch.compiler.disable(recursive = False)
|
|
72
|
+
def disable_compile_scaled_dot_product_attention(*args, **kwargs):
|
|
73
|
+
return scaled_dot_product_attention(*args, **kwargs)
|
|
74
|
+
pass
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
from transformers.modeling_flash_attention_utils import is_flash_attn_available
|
|
78
|
+
|
|
79
|
+
if is_flash_attn_available():
|
|
80
|
+
try:
|
|
81
|
+
from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask
|
|
82
|
+
except:
|
|
83
|
+
flash_attn_supports_top_left_mask = None
|
|
84
|
+
try:
|
|
85
|
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
|
86
|
+
except:
|
|
87
|
+
_flash_attention_forward = None
|
|
88
|
+
try:
|
|
89
|
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
|
90
|
+
except:
|
|
91
|
+
FlashAttentionKwargs = None
|
|
92
|
+
try:
|
|
93
|
+
from transformers.modeling_flash_attention_utils import flash_attn_varlen_func
|
|
94
|
+
except:
|
|
95
|
+
flash_attn_varlen_func = None
|
|
96
|
+
else:
|
|
97
|
+
flash_attn_supports_top_left_mask = None
|
|
98
|
+
_flash_attention_forward = None
|
|
99
|
+
FlashAttentionKwargs = None
|
|
100
|
+
flash_attn_varlen_func = None
|
|
101
|
+
pass
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False, 'debug': False, 'dce': True, 'memory_planning': True, 'coordinate_descent_tuning': False, 'trace.graph_diagram': False, 'compile_threads': 32, 'group_fusion': True, 'disable_progress': True, 'verbose_progress': False, 'triton.multi_kernel': 0, 'triton.use_block_ptr': False, 'triton.enable_persistent_tma_matmul': True, 'triton.autotune_at_compile_time': False, 'triton.cooperative_reductions': False, 'cuda.compile_opt_level': '-O2', 'cuda.enable_cuda_lto': True, 'combo_kernels': False, 'benchmark_combo_kernel': True, 'combo_kernel_foreach_dynamic_shapes': True}
|
|
105
|
+
|
|
106
|
+
from torch.nn import CrossEntropyLoss
|
|
107
|
+
|
|
108
|
+
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
109
|
+
def normal_cross_entropy_loss(self, hidden_states, labels):
|
|
110
|
+
logits = self.lm_head(hidden_states)
|
|
111
|
+
logits = logits.float()
|
|
112
|
+
# Shift so that tokens < n predict n
|
|
113
|
+
shift_logits = logits[..., :-1, :].contiguous()
|
|
114
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
115
|
+
# Flatten the tokens
|
|
116
|
+
loss_fct = CrossEntropyLoss()
|
|
117
|
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
118
|
+
shift_labels = shift_labels.view(-1)
|
|
119
|
+
# Enable model parallelism
|
|
120
|
+
shift_labels = shift_labels.to(shift_logits.device)
|
|
121
|
+
loss = loss_fct(shift_logits, shift_labels)
|
|
122
|
+
return loss, logits
|
|
123
|
+
pass
|
|
124
|
+
|
|
125
|
+
# We need an empty logits flag to warn people logits will not be returned anymore unless asked ie
|
|
126
|
+
# os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
|
127
|
+
LOGITS_ERROR_STRING = \
|
|
128
|
+
"Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "\
|
|
129
|
+
'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n'\
|
|
130
|
+
"```\nimport os\n"\
|
|
131
|
+
"os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"\
|
|
132
|
+
"trainer.train()\n```\n"\
|
|
133
|
+
"No need to restart your console - just add `os.environ['UNSLOTH_RETURN_LOGITS'] = '1'` before trainer.train() and re-run the cell!"
|
|
134
|
+
|
|
135
|
+
def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING)
|
|
136
|
+
def return_none(*args, **kwargs): return None
|
|
137
|
+
class EmptyLogits:
|
|
138
|
+
def __init__(self): return
|
|
139
|
+
def raise_getattr_error(self, attr): return return_none if attr == "to" else raise_logits_error
|
|
140
|
+
__getitem__ = raise_logits_error
|
|
141
|
+
__getattr__ = raise_getattr_error
|
|
142
|
+
def __repr__(self): return LOGITS_ERROR_STRING
|
|
143
|
+
def __str__ (self): return LOGITS_ERROR_STRING
|
|
144
|
+
pass
|
|
145
|
+
EMPTY_LOGITS = EmptyLogits()
|
|
146
|
+
functions = dir(torch.Tensor)
|
|
147
|
+
for j, function in enumerate(functions):
|
|
148
|
+
if function.startswith("__") and function.endswith("__"):
|
|
149
|
+
exec(f"def raise_{j}(*args, **kwargs): print('{function}')", globals(), locals())
|
|
150
|
+
try: exec(f"EMPTY_LOGITS.{function} = raise_{j}", globals(), locals())
|
|
151
|
+
except: continue
|
|
152
|
+
pass
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def mask_attention_mask_out(labels = None, attention_mask = None):
|
|
156
|
+
if labels is not None and attention_mask is not None:
|
|
157
|
+
attention_mask = attention_mask.to(device = labels.device)
|
|
158
|
+
labels[attention_mask == 0] = -100
|
|
159
|
+
return labels
|
|
160
|
+
pass
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
from torch import Tensor
|
|
164
|
+
import torch
|
|
165
|
+
import torch.nn as nn
|
|
166
|
+
from torch.nn import functional as F
|
|
167
|
+
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
|
168
|
+
from transformers.models.qwen3_moe.modeling_qwen3_moe import (__name__, F, Callable, Optional, Union, torch, nn, ACT2FN, Cache, GenerationMixin, use_kernel_forward_from_hub, FlashAttentionKwargs, MoeCausalLMOutputWithPast, MoeModelOutputWithPast, ROPE_INIT_FUNCTIONS, dynamic_rope_update, ALL_ATTENTION_FUNCTIONS, PreTrainedModel, Unpack, TransformersKwargs, can_return_tuple, deprecate_kwarg, Qwen3MoeConfig, Qwen3MoePreTrainedModel, Qwen3MoeModel, Qwen3MoeForCausalLM)
|
|
169
|
+
|
|
170
|
+
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
171
|
+
def rotate_half(x):
|
|
172
|
+
"""Rotates half the hidden dims of the input."""
|
|
173
|
+
x1 = x[..., : x.shape[-1] // 2]
|
|
174
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
|
175
|
+
return torch.cat((-x2, x1), dim=-1)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
179
|
+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
|
180
|
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
q (`torch.Tensor`): The query tensor.
|
|
184
|
+
k (`torch.Tensor`): The key tensor.
|
|
185
|
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
|
186
|
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
|
187
|
+
position_ids (`torch.Tensor`, *optional*):
|
|
188
|
+
Deprecated and unused.
|
|
189
|
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
|
190
|
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
|
191
|
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
|
192
|
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
|
193
|
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
|
194
|
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
|
195
|
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
|
196
|
+
Returns:
|
|
197
|
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
|
198
|
+
"""
|
|
199
|
+
cos = cos.unsqueeze(unsqueeze_dim)
|
|
200
|
+
sin = sin.unsqueeze(unsqueeze_dim)
|
|
201
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
202
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
203
|
+
return q_embed, k_embed
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
207
|
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
208
|
+
"""
|
|
209
|
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
210
|
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
211
|
+
"""
|
|
212
|
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
213
|
+
if n_rep == 1:
|
|
214
|
+
return hidden_states
|
|
215
|
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
|
216
|
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
220
|
+
def eager_attention_forward(
|
|
221
|
+
module: nn.Module,
|
|
222
|
+
query: torch.Tensor,
|
|
223
|
+
key: torch.Tensor,
|
|
224
|
+
value: torch.Tensor,
|
|
225
|
+
attention_mask: Optional[torch.Tensor],
|
|
226
|
+
scaling: float,
|
|
227
|
+
dropout: float = 0.0,
|
|
228
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
229
|
+
):
|
|
230
|
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
|
231
|
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
|
232
|
+
|
|
233
|
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
|
234
|
+
if attention_mask is not None:
|
|
235
|
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
|
236
|
+
attn_weights = attn_weights + causal_mask
|
|
237
|
+
|
|
238
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype = torch.float32).to(attn_weights.dtype).to(query.dtype)
|
|
239
|
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
240
|
+
attn_output = torch.matmul(attn_weights, value_states)
|
|
241
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
242
|
+
|
|
243
|
+
return attn_output, attn_weights
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
@torch.compiler.disable(recursive = False)
|
|
247
|
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
|
248
|
+
def Qwen3MoeAttention_forward(
|
|
249
|
+
self,
|
|
250
|
+
hidden_states: torch.Tensor,
|
|
251
|
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
|
252
|
+
attention_mask: Optional[torch.Tensor],
|
|
253
|
+
past_key_values: Optional[Cache] = None,
|
|
254
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
255
|
+
**kwargs: Unpack[FlashAttentionKwargs],
|
|
256
|
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
257
|
+
input_shape = hidden_states.shape[:-1]
|
|
258
|
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
|
259
|
+
|
|
260
|
+
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
|
261
|
+
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
|
262
|
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
263
|
+
|
|
264
|
+
cos, sin = position_embeddings
|
|
265
|
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
266
|
+
|
|
267
|
+
if past_key_values is not None:
|
|
268
|
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
269
|
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
270
|
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
271
|
+
|
|
272
|
+
attention_interface: Callable = eager_attention_forward
|
|
273
|
+
if self.config._attn_implementation != "eager":
|
|
274
|
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
275
|
+
|
|
276
|
+
attn_output, attn_weights = attention_interface(
|
|
277
|
+
self,
|
|
278
|
+
query_states,
|
|
279
|
+
key_states,
|
|
280
|
+
value_states,
|
|
281
|
+
attention_mask,
|
|
282
|
+
dropout=0.0 if not self.training else self.attention_dropout,
|
|
283
|
+
scaling=self.scaling,
|
|
284
|
+
sliding_window=self.sliding_window, # diff with Llama
|
|
285
|
+
**kwargs,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
|
289
|
+
attn_output = self.o_proj(attn_output)
|
|
290
|
+
return attn_output, attn_weights
|
|
291
|
+
|
|
292
|
+
class Qwen3MoeAttention(nn.Module):
|
|
293
|
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
294
|
+
|
|
295
|
+
def __init__(self, config: Qwen3MoeConfig, layer_idx: int):
|
|
296
|
+
super().__init__()
|
|
297
|
+
self.config = config
|
|
298
|
+
self.layer_idx = layer_idx
|
|
299
|
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
|
300
|
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
|
301
|
+
self.scaling = self.head_dim**-0.5
|
|
302
|
+
self.attention_dropout = config.attention_dropout
|
|
303
|
+
self.is_causal = True
|
|
304
|
+
|
|
305
|
+
self.q_proj = nn.Linear(
|
|
306
|
+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
|
307
|
+
)
|
|
308
|
+
self.k_proj = nn.Linear(
|
|
309
|
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
|
310
|
+
)
|
|
311
|
+
self.v_proj = nn.Linear(
|
|
312
|
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
|
313
|
+
)
|
|
314
|
+
self.o_proj = nn.Linear(
|
|
315
|
+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
|
316
|
+
)
|
|
317
|
+
self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
|
|
318
|
+
self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
|
|
319
|
+
self.sliding_window = getattr(config, "sliding_window", None)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def forward(
|
|
323
|
+
self,
|
|
324
|
+
hidden_states: torch.Tensor,
|
|
325
|
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
|
326
|
+
attention_mask: Optional[torch.Tensor],
|
|
327
|
+
past_key_values: Optional[Cache] = None,
|
|
328
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
329
|
+
**kwargs: Unpack[FlashAttentionKwargs],
|
|
330
|
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
331
|
+
return Qwen3MoeAttention_forward(self, hidden_states, position_embeddings, attention_mask, past_key_values, cache_position, **kwargs)
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
335
|
+
def Qwen3MoeRMSNorm_forward(self, hidden_states):
|
|
336
|
+
input_dtype = hidden_states.dtype
|
|
337
|
+
hidden_states = hidden_states.to(torch.float32)
|
|
338
|
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
339
|
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
340
|
+
return self.weight * hidden_states.to(input_dtype)
|
|
341
|
+
|
|
342
|
+
@use_kernel_forward_from_hub("RMSNorm")
|
|
343
|
+
class Qwen3MoeRMSNorm(nn.Module):
|
|
344
|
+
def __init__(self, hidden_size, eps=1e-6):
|
|
345
|
+
"""
|
|
346
|
+
Qwen3MoeRMSNorm is equivalent to T5LayerNorm
|
|
347
|
+
"""
|
|
348
|
+
super().__init__()
|
|
349
|
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
350
|
+
self.variance_epsilon = eps
|
|
351
|
+
|
|
352
|
+
def forward(self, hidden_states):
|
|
353
|
+
return Qwen3MoeRMSNorm_forward(self, hidden_states)
|
|
354
|
+
|
|
355
|
+
def extra_repr(self):
|
|
356
|
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
|
|
360
|
+
def Qwen3MoeMLP_forward(self, x):
|
|
361
|
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
|
362
|
+
return down_proj
|
|
363
|
+
|
|
364
|
+
class Qwen3MoeMLP(nn.Module):
|
|
365
|
+
def __init__(self, config, intermediate_size=None):
|
|
366
|
+
super().__init__()
|
|
367
|
+
self.config = config
|
|
368
|
+
self.hidden_size = config.hidden_size
|
|
369
|
+
self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
|
|
370
|
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
371
|
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
372
|
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
373
|
+
self.act_fn = ACT2FN[config.hidden_act]
|
|
374
|
+
|
|
375
|
+
def forward(self, x):
|
|
376
|
+
return Qwen3MoeMLP_forward(self, x)
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
380
|
+
@torch.no_grad()
|
|
381
|
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
|
382
|
+
def Qwen3MoeRotaryEmbedding_forward(self, x, position_ids):
|
|
383
|
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
|
384
|
+
position_ids_expanded = position_ids[:, None, :].float()
|
|
385
|
+
|
|
386
|
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
|
387
|
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
|
388
|
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
389
|
+
emb = torch.cat((freqs, freqs), dim=-1)
|
|
390
|
+
cos = emb.cos() * self.attention_scaling
|
|
391
|
+
sin = emb.sin() * self.attention_scaling
|
|
392
|
+
|
|
393
|
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
394
|
+
|
|
395
|
+
class Qwen3MoeRotaryEmbedding(nn.Module):
|
|
396
|
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
|
397
|
+
|
|
398
|
+
def __init__(self, config: Qwen3MoeConfig, device=None):
|
|
399
|
+
super().__init__()
|
|
400
|
+
# BC: "rope_type" was originally "type"
|
|
401
|
+
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
|
402
|
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
|
403
|
+
else:
|
|
404
|
+
self.rope_type = "default"
|
|
405
|
+
self.max_seq_len_cached = config.max_position_embeddings
|
|
406
|
+
self.original_max_seq_len = config.max_position_embeddings
|
|
407
|
+
|
|
408
|
+
self.config = config
|
|
409
|
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
|
410
|
+
|
|
411
|
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
|
412
|
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
413
|
+
self.original_inv_freq = self.inv_freq
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def forward(self, x, position_ids):
|
|
417
|
+
return Qwen3MoeRotaryEmbedding_forward(self, x, position_ids)
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
421
|
+
def load_balancing_loss_func(
|
|
422
|
+
gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
|
|
423
|
+
num_experts: Optional[int] = None,
|
|
424
|
+
top_k=2,
|
|
425
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
426
|
+
) -> Union[torch.Tensor, int]:
|
|
427
|
+
r"""
|
|
428
|
+
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
|
|
429
|
+
|
|
430
|
+
See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss
|
|
431
|
+
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
|
|
432
|
+
experts is too unbalanced.
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
gate_logits:
|
|
436
|
+
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
|
|
437
|
+
shape [batch_size X sequence_length, num_experts].
|
|
438
|
+
num_experts:
|
|
439
|
+
Number of experts
|
|
440
|
+
top_k:
|
|
441
|
+
The number of experts to route per-token, can be also interpreted as the `top-k` routing
|
|
442
|
+
parameter.
|
|
443
|
+
attention_mask (`torch.Tensor`, *optional*):
|
|
444
|
+
The attention_mask used in forward function
|
|
445
|
+
shape [batch_size X sequence_length] if not None.
|
|
446
|
+
|
|
447
|
+
Returns:
|
|
448
|
+
The auxiliary loss.
|
|
449
|
+
"""
|
|
450
|
+
if gate_logits is None or not isinstance(gate_logits, tuple):
|
|
451
|
+
return 0
|
|
452
|
+
|
|
453
|
+
if isinstance(gate_logits, tuple):
|
|
454
|
+
compute_device = gate_logits[0].device
|
|
455
|
+
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
|
|
456
|
+
|
|
457
|
+
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1, dtype = torch.float32).to(concatenated_gate_logits.dtype)
|
|
458
|
+
|
|
459
|
+
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
|
|
460
|
+
|
|
461
|
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
|
|
462
|
+
|
|
463
|
+
if attention_mask is None:
|
|
464
|
+
# Compute the percentage of tokens routed to each experts
|
|
465
|
+
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
|
466
|
+
|
|
467
|
+
# Compute the average probability of routing to these experts
|
|
468
|
+
router_prob_per_expert = torch.mean(routing_weights, dim=0)
|
|
469
|
+
else:
|
|
470
|
+
batch_size, sequence_length = attention_mask.shape
|
|
471
|
+
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
|
|
472
|
+
|
|
473
|
+
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
|
|
474
|
+
expert_attention_mask = (
|
|
475
|
+
attention_mask[None, :, :, None, None]
|
|
476
|
+
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
|
|
477
|
+
.reshape(-1, top_k, num_experts)
|
|
478
|
+
.to(compute_device)
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
# Compute the percentage of tokens routed to each experts
|
|
482
|
+
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
|
|
483
|
+
expert_attention_mask, dim=0
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
|
|
487
|
+
router_per_expert_attention_mask = (
|
|
488
|
+
attention_mask[None, :, :, None]
|
|
489
|
+
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
|
|
490
|
+
.reshape(-1, num_experts)
|
|
491
|
+
.to(compute_device)
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
# Compute the average probability of routing to these experts
|
|
495
|
+
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
|
|
496
|
+
router_per_expert_attention_mask, dim=0
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
|
|
500
|
+
return overall_loss * num_experts
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
@torch.compiler.disable(recursive = False)
|
|
504
|
+
@can_return_tuple
|
|
505
|
+
def Qwen3MoeForCausalLM_forward(
|
|
506
|
+
self,
|
|
507
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
508
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
509
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
510
|
+
past_key_values: Optional[Cache] = None,
|
|
511
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
512
|
+
labels: Optional[torch.LongTensor] = None,
|
|
513
|
+
use_cache: Optional[bool] = None,
|
|
514
|
+
output_router_logits: Optional[bool] = None,
|
|
515
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
516
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
517
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
518
|
+
) -> MoeCausalLMOutputWithPast:
|
|
519
|
+
r"""
|
|
520
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
521
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
522
|
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
523
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
524
|
+
|
|
525
|
+
Example:
|
|
526
|
+
|
|
527
|
+
```python
|
|
528
|
+
>>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM
|
|
529
|
+
|
|
530
|
+
>>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
|
531
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
|
532
|
+
|
|
533
|
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
534
|
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
535
|
+
|
|
536
|
+
>>> # Generate
|
|
537
|
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
538
|
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
539
|
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
540
|
+
```"""
|
|
541
|
+
|
|
542
|
+
output_router_logits = (
|
|
543
|
+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
547
|
+
outputs: MoeModelOutputWithPast = self.model(
|
|
548
|
+
input_ids=input_ids,
|
|
549
|
+
attention_mask=attention_mask,
|
|
550
|
+
position_ids=position_ids,
|
|
551
|
+
past_key_values=past_key_values,
|
|
552
|
+
inputs_embeds=inputs_embeds,
|
|
553
|
+
use_cache=use_cache,
|
|
554
|
+
output_router_logits=output_router_logits,
|
|
555
|
+
cache_position=cache_position,
|
|
556
|
+
**kwargs,
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
hidden_states = outputs.last_hidden_state
|
|
560
|
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
561
|
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
|
562
|
+
logits = self.lm_head(hidden_states[:, slice_indices, :]) if os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '1' else EMPTY_LOGITS
|
|
563
|
+
loss = None
|
|
564
|
+
NOT_RETURN_LOGITS = os.environ.get('UNSLOTH_RETURN_LOGITS', '0') == '0'
|
|
565
|
+
RETURN_HIDDEN_STATES = os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1"
|
|
566
|
+
|
|
567
|
+
n_items = None
|
|
568
|
+
if (kwargs) != () and type(kwargs) is dict:
|
|
569
|
+
n_items = (kwargs).get("num_items_in_batch", None) or (kwargs).get("n_items", None)
|
|
570
|
+
if n_items is None:
|
|
571
|
+
all_locals = locals()
|
|
572
|
+
if 'loss_kwargs' in all_locals:
|
|
573
|
+
__kwargs = all_locals['loss_kwargs']
|
|
574
|
+
if type(__kwargs) is dict:
|
|
575
|
+
n_items = __kwargs.get("num_items_in_batch", None)
|
|
576
|
+
if n_items is None: n_items = __kwargs.get("n_items", None)
|
|
577
|
+
if n_items is None and 'kwargs' in all_locals:
|
|
578
|
+
__kwargs = all_locals['kwargs']
|
|
579
|
+
if type(__kwargs) is dict:
|
|
580
|
+
n_items = __kwargs.get("num_items_in_batch", None)
|
|
581
|
+
if n_items is None: n_items = __kwargs.get("n_items", None)
|
|
582
|
+
if n_items is None:
|
|
583
|
+
all_locals = all_locals.values()
|
|
584
|
+
for __kwargs in all_locals:
|
|
585
|
+
if type(__kwargs) is dict:
|
|
586
|
+
n_items = __kwargs.get("num_items_in_batch", None)
|
|
587
|
+
if n_items is None: n_items = __kwargs.get("n_items", None)
|
|
588
|
+
break
|
|
589
|
+
pass
|
|
590
|
+
|
|
591
|
+
requires_grad_ = self.lm_head.weight.requires_grad
|
|
592
|
+
requires_grad_ = requires_grad_ or self.lm_head.weight.dtype == torch.float32
|
|
593
|
+
|
|
594
|
+
if RETURN_HIDDEN_STATES:
|
|
595
|
+
logits = hidden_states[:, slice_indices, :]
|
|
596
|
+
elif labels is None:
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
# Set compiler stance to fail on recompiles for inference
|
|
600
|
+
global INFERENCE_RUNS
|
|
601
|
+
if torch_dynamo_eval_frame is not None:
|
|
602
|
+
old_stance = torch_dynamo_eval_frame._stance.stance
|
|
603
|
+
else:
|
|
604
|
+
old_stance = None
|
|
605
|
+
if old_stance is not None and INFERENCE_RUNS == 1:
|
|
606
|
+
# Skip guards and return to eager -> we still need guards!
|
|
607
|
+
torch_compiler_set_stance(stance = "eager_on_recompile", skip_guard_eval_unsafe = False)
|
|
608
|
+
if UNSLOTH_ENABLE_LOGGING:
|
|
609
|
+
logger_compiler.info(
|
|
610
|
+
f"Unsloth: Removing compiler guards after 1 inference run. "\
|
|
611
|
+
f"DYNAMO_STANCE.stance = {torch_dynamo_eval_frame._stance.stance} "\
|
|
612
|
+
f"DYNAMO_STANCE.skip_guard_eval_unsafe = {torch_dynamo_eval_frame._stance.skip_guard_eval_unsafe}"
|
|
613
|
+
)
|
|
614
|
+
elif old_stance == "eager_on_recompile":
|
|
615
|
+
pass
|
|
616
|
+
elif old_stance == "default" and INFERENCE_RUNS > 1:
|
|
617
|
+
# Reset compiler stance
|
|
618
|
+
torch_compiler_set_stance(stance = "default", skip_guard_eval_unsafe = False)
|
|
619
|
+
if UNSLOTH_ENABLE_LOGGING:
|
|
620
|
+
logger_compiler.info(
|
|
621
|
+
f"Unsloth: Reseting guards. "\
|
|
622
|
+
f"DYNAMO_STANCE.stance = {torch_dynamo_eval_frame._stance.stance} "\
|
|
623
|
+
f"DYNAMO_STANCE.skip_guard_eval_unsafe = {torch_dynamo_eval_frame._stance.skip_guard_eval_unsafe}"
|
|
624
|
+
)
|
|
625
|
+
INFERENCE_RUNS = 0
|
|
626
|
+
INFERENCE_RUNS += 1
|
|
627
|
+
|
|
628
|
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
629
|
+
elif (() == () and () == ()) and (UNSLOTH_ENABLE_CCE) and NOT_RETURN_LOGITS and self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None and not requires_grad_:
|
|
630
|
+
loss = fused_linear_cross_entropy(
|
|
631
|
+
hidden_states = hidden_states[:, slice_indices, :],
|
|
632
|
+
lm_weight = self.lm_head.weight,
|
|
633
|
+
labels = labels.to(self.lm_head.weight.device),
|
|
634
|
+
num_items_in_batch = n_items,
|
|
635
|
+
logit_softcapping = None if () == () else (),
|
|
636
|
+
)
|
|
637
|
+
elif self.loss_function.__name__.endswith("ForCausalLMLoss") and labels is not None:
|
|
638
|
+
lm_head_weight = self.lm_head.weight
|
|
639
|
+
lm_head_bias = getattr(self.lm_head, "bias", None)
|
|
640
|
+
|
|
641
|
+
# ========= NEW fused =========
|
|
642
|
+
_hidden_states = hidden_states[:, slice_indices, :]
|
|
643
|
+
torch._dynamo.mark_dynamic(_hidden_states, 1)
|
|
644
|
+
torch._dynamo.mark_dynamic(labels, 1)
|
|
645
|
+
loss = unsloth_fused_ce_loss(
|
|
646
|
+
trainer = None,
|
|
647
|
+
hidden_states = _hidden_states,
|
|
648
|
+
lm_head_weight = lm_head_weight,
|
|
649
|
+
lm_head_bias = lm_head_bias,
|
|
650
|
+
labels = labels,
|
|
651
|
+
mask = None,
|
|
652
|
+
n_items = n_items,
|
|
653
|
+
scaling = getattr(self, "accelerator_scaler", None),
|
|
654
|
+
target_gb = None,
|
|
655
|
+
torch_compile = not UNSLOTH_COMPILE_DISABLE,
|
|
656
|
+
logit_scale_multiply = () if () != () else 0,
|
|
657
|
+
logit_scale_divide = () if () != () else 0,
|
|
658
|
+
logit_softcapping = () if () != () else 0,
|
|
659
|
+
)
|
|
660
|
+
else:
|
|
661
|
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
662
|
+
if () != ():
|
|
663
|
+
logits = logits * ()
|
|
664
|
+
if () != ():
|
|
665
|
+
logits = logits / ()
|
|
666
|
+
if () not in (None, (),):
|
|
667
|
+
logits = logits / ()
|
|
668
|
+
logits = torch.tanh(logits)
|
|
669
|
+
logits = logits * ()
|
|
670
|
+
loss = self.loss_function(logits, labels.to(self.lm_head.weight.device), vocab_size=self.vocab_size, **kwargs)
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
aux_loss = None
|
|
674
|
+
if output_router_logits:
|
|
675
|
+
aux_loss = load_balancing_loss_func(
|
|
676
|
+
outputs.router_logits,
|
|
677
|
+
self.num_experts,
|
|
678
|
+
self.num_experts_per_tok,
|
|
679
|
+
attention_mask,
|
|
680
|
+
)
|
|
681
|
+
if labels is not None:
|
|
682
|
+
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
|
|
683
|
+
|
|
684
|
+
return MoeCausalLMOutputWithPast(
|
|
685
|
+
loss=loss,
|
|
686
|
+
aux_loss=aux_loss,
|
|
687
|
+
logits=logits,
|
|
688
|
+
past_key_values=outputs.past_key_values,
|
|
689
|
+
hidden_states=outputs.hidden_states,
|
|
690
|
+
attentions=outputs.attentions,
|
|
691
|
+
router_logits=outputs.router_logits,
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
class Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin):
|
|
695
|
+
_tied_weights_keys = ["lm_head.weight"]
|
|
696
|
+
_tp_plan = {"lm_head": "colwise_rep"}
|
|
697
|
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
|
698
|
+
|
|
699
|
+
def __init__(self, config):
|
|
700
|
+
super().__init__(config)
|
|
701
|
+
self.model = Qwen3MoeModel(config)
|
|
702
|
+
self.vocab_size = config.vocab_size
|
|
703
|
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
704
|
+
self.router_aux_loss_coef = config.router_aux_loss_coef
|
|
705
|
+
self.num_experts = config.num_experts
|
|
706
|
+
self.num_experts_per_tok = config.num_experts_per_tok
|
|
707
|
+
|
|
708
|
+
# Initialize weights and apply final processing
|
|
709
|
+
self.post_init()
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
def forward(
|
|
713
|
+
self,
|
|
714
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
715
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
716
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
717
|
+
past_key_values: Optional[Cache] = None,
|
|
718
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
719
|
+
labels: Optional[torch.LongTensor] = None,
|
|
720
|
+
use_cache: Optional[bool] = None,
|
|
721
|
+
output_router_logits: Optional[bool] = None,
|
|
722
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
723
|
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
724
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
725
|
+
) -> MoeCausalLMOutputWithPast:
|
|
726
|
+
return Qwen3MoeForCausalLM_forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_router_logits, cache_position, logits_to_keep, **kwargs)
|