cciwon-code-review-cli 2.0.1 → 2.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/code-review.js +1 -1
- package/lib/chat-mode.js +7 -2
- package/package.json +1 -1
- package/unsloth_compiled_cache/.locks/.lock.AqlmLoraLinear_peft_forward.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.AwqLoraLinear_peft_forward.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.BatchNorm1d.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.BatchNorm2d.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.BatchNorm3d.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.Conv1d.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.Conv2d.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.Conv3d.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.ConvTranspose1d.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.ConvTranspose2d.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.ConvTranspose3d.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.GPTQLoraLinear_peft_forward.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.GroupNorm.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.LayerNorm.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.Linear4bit_peft_forward.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.Linear8bitLt_peft_forward.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.Linear_peft_forward.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.LoraParallelLinear_peft_forward.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.RMSNorm.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothBCOTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothCPOTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothDPOTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothGKDTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothGRPOTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothKTOTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothNashMDTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothORPOTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothOnlineDPOTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothPPOTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothPRMTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothRLOOTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothRewardTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothSFTTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.UnslothXPOTrainer.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.unsloth_compiled_module_qwen3_moe.py +0 -0
- package/unsloth_compiled_cache/.locks/.lock.unsloth_compiled_module_siglip.py +0 -0
- package/unsloth_compiled_cache/AqlmLoraLinear_peft_forward.py +88 -0
- package/unsloth_compiled_cache/AwqLoraLinear_peft_forward.py +87 -0
- package/unsloth_compiled_cache/BatchNorm1d.py +117 -0
- package/unsloth_compiled_cache/BatchNorm2d.py +117 -0
- package/unsloth_compiled_cache/BatchNorm3d.py +117 -0
- package/unsloth_compiled_cache/Conv1d.py +70 -0
- package/unsloth_compiled_cache/Conv2d.py +70 -0
- package/unsloth_compiled_cache/Conv3d.py +70 -0
- package/unsloth_compiled_cache/ConvTranspose1d.py +97 -0
- package/unsloth_compiled_cache/ConvTranspose2d.py +106 -0
- package/unsloth_compiled_cache/ConvTranspose3d.py +98 -0
- package/unsloth_compiled_cache/GPTQLoraLinear_peft_forward.py +95 -0
- package/unsloth_compiled_cache/GroupNorm.py +70 -0
- package/unsloth_compiled_cache/LayerNorm.py +72 -0
- package/unsloth_compiled_cache/Linear4bit_peft_forward.py +115 -0
- package/unsloth_compiled_cache/Linear8bitLt_peft_forward.py +113 -0
- package/unsloth_compiled_cache/Linear_peft_forward.py +104 -0
- package/unsloth_compiled_cache/LoraParallelLinear_peft_forward.py +91 -0
- package/unsloth_compiled_cache/RMSNorm.py +73 -0
- package/unsloth_compiled_cache/UnslothBCOTrainer.py +2026 -0
- package/unsloth_compiled_cache/UnslothCPOTrainer.py +1806 -0
- package/unsloth_compiled_cache/UnslothDPOTrainer.py +2750 -0
- package/unsloth_compiled_cache/UnslothGKDTrainer.py +1157 -0
- package/unsloth_compiled_cache/UnslothGRPOTrainer.py +3607 -0
- package/unsloth_compiled_cache/UnslothKTOTrainer.py +2220 -0
- package/unsloth_compiled_cache/UnslothNashMDTrainer.py +1210 -0
- package/unsloth_compiled_cache/UnslothORPOTrainer.py +1730 -0
- package/unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +2313 -0
- package/unsloth_compiled_cache/UnslothPPOTrainer.py +1504 -0
- package/unsloth_compiled_cache/UnslothPRMTrainer.py +979 -0
- package/unsloth_compiled_cache/UnslothRLOOTrainer.py +2674 -0
- package/unsloth_compiled_cache/UnslothRewardTrainer.py +1197 -0
- package/unsloth_compiled_cache/UnslothSFTTrainer.py +1416 -0
- package/unsloth_compiled_cache/UnslothXPOTrainer.py +1255 -0
- package/unsloth_compiled_cache/__pycache__/AqlmLoraLinear_peft_forward.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/AwqLoraLinear_peft_forward.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/BatchNorm1d.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/BatchNorm2d.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/BatchNorm3d.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/Conv1d.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/Conv2d.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/Conv3d.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/ConvTranspose1d.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/ConvTranspose2d.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/ConvTranspose3d.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/GPTQLoraLinear_peft_forward.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/GroupNorm.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/LayerNorm.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/Linear4bit_peft_forward.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/Linear8bitLt_peft_forward.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/Linear_peft_forward.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/LoraParallelLinear_peft_forward.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/RMSNorm.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/unsloth_compiled_module_qwen3_moe.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/__pycache__/unsloth_compiled_module_siglip.cpython-312.pyc +0 -0
- package/unsloth_compiled_cache/unsloth_compiled_module_qwen3_moe.py +726 -0
- package/unsloth_compiled_cache/unsloth_compiled_module_siglip.py +534 -0
|
@@ -0,0 +1,534 @@
|
|
|
1
|
+
"""
|
|
2
|
+
2025.12.6
|
|
3
|
+
2025.12.7
|
|
4
|
+
4.57.1
|
|
5
|
+
0.24.0
|
|
6
|
+
__UNSLOTH_VERSIONING__
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
# Unsloth auto generated code
|
|
10
|
+
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
|
|
11
|
+
#
|
|
12
|
+
# This program is free software: you can redistribute it and/or modify
|
|
13
|
+
# it under the terms of the GNU Lesser General Public License as published by
|
|
14
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
15
|
+
# (at your option) any later version.
|
|
16
|
+
#
|
|
17
|
+
# This program is distributed in the hope that it will be useful,
|
|
18
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
19
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
20
|
+
# GNU General Public License for more details.
|
|
21
|
+
#
|
|
22
|
+
# You should have received a copy of the GNU Lesser General Public License
|
|
23
|
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
import os
|
|
27
|
+
import torch
|
|
28
|
+
import importlib.util
|
|
29
|
+
import math
|
|
30
|
+
if importlib.util.find_spec("unsloth_studio") is None:
|
|
31
|
+
UNSLOTH_STUDIO_ENABLED = False
|
|
32
|
+
else:
|
|
33
|
+
UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
|
|
34
|
+
pass
|
|
35
|
+
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
|
36
|
+
import math
|
|
37
|
+
|
|
38
|
+
UNSLOTH_ENABLE_LOGGING = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1"
|
|
39
|
+
UNSLOTH_ENABLE_CCE = os.environ.get("UNSLOTH_ENABLE_CCE", "1") == "1"
|
|
40
|
+
UNSLOTH_COMPILE_DISABLE = os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") in ("1", "partial",)
|
|
41
|
+
|
|
42
|
+
import logging
|
|
43
|
+
logger_compiler = logging.getLogger(__name__)
|
|
44
|
+
if UNSLOTH_ENABLE_LOGGING:
|
|
45
|
+
logger_compiler.setLevel(logging.DEBUG)
|
|
46
|
+
|
|
47
|
+
global INFERENCE_RUNS
|
|
48
|
+
INFERENCE_RUNS = 0
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
import torch._dynamo.eval_frame as torch_dynamo_eval_frame
|
|
52
|
+
torch_dynamo_eval_frame._stance.stance
|
|
53
|
+
torch_compiler_set_stance = torch.compiler.set_stance
|
|
54
|
+
except:
|
|
55
|
+
torch_dynamo_eval_frame = None
|
|
56
|
+
torch_compiler_set_stance = None
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
from unsloth_zoo import DEVICE_TYPE_TORCH, DEVICE_COUNT
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
from unsloth_zoo.loss_utils import (
|
|
63
|
+
fused_linear_cross_entropy,
|
|
64
|
+
unsloth_fused_ce_loss,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
if UNSLOTH_STUDIO_ENABLED:
|
|
68
|
+
from unsloth_zoo.loss_utils import fast_linear_cross_entropy
|
|
69
|
+
|
|
70
|
+
scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
|
71
|
+
@torch.compiler.disable(recursive = False)
|
|
72
|
+
def disable_compile_scaled_dot_product_attention(*args, **kwargs):
|
|
73
|
+
return scaled_dot_product_attention(*args, **kwargs)
|
|
74
|
+
pass
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
from transformers.modeling_flash_attention_utils import is_flash_attn_available
|
|
78
|
+
|
|
79
|
+
if is_flash_attn_available():
|
|
80
|
+
try:
|
|
81
|
+
from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask
|
|
82
|
+
except:
|
|
83
|
+
flash_attn_supports_top_left_mask = None
|
|
84
|
+
try:
|
|
85
|
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
|
86
|
+
except:
|
|
87
|
+
_flash_attention_forward = None
|
|
88
|
+
try:
|
|
89
|
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
|
90
|
+
except:
|
|
91
|
+
FlashAttentionKwargs = None
|
|
92
|
+
try:
|
|
93
|
+
from transformers.modeling_flash_attention_utils import flash_attn_varlen_func
|
|
94
|
+
except:
|
|
95
|
+
flash_attn_varlen_func = None
|
|
96
|
+
else:
|
|
97
|
+
flash_attn_supports_top_left_mask = None
|
|
98
|
+
_flash_attention_forward = None
|
|
99
|
+
FlashAttentionKwargs = None
|
|
100
|
+
flash_attn_varlen_func = None
|
|
101
|
+
pass
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False, 'debug': False, 'dce': True, 'memory_planning': True, 'coordinate_descent_tuning': False, 'trace.graph_diagram': False, 'compile_threads': 32, 'group_fusion': True, 'disable_progress': True, 'verbose_progress': False, 'triton.multi_kernel': 0, 'triton.use_block_ptr': False, 'triton.enable_persistent_tma_matmul': True, 'triton.autotune_at_compile_time': False, 'triton.cooperative_reductions': False, 'cuda.compile_opt_level': '-O2', 'cuda.enable_cuda_lto': True, 'combo_kernels': False, 'benchmark_combo_kernel': True, 'combo_kernel_foreach_dynamic_shapes': True}
|
|
105
|
+
|
|
106
|
+
from torch.nn import CrossEntropyLoss
|
|
107
|
+
|
|
108
|
+
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
109
|
+
def normal_cross_entropy_loss(self, hidden_states, labels):
|
|
110
|
+
logits = self.lm_head(hidden_states)
|
|
111
|
+
logits = logits.float()
|
|
112
|
+
# Shift so that tokens < n predict n
|
|
113
|
+
shift_logits = logits[..., :-1, :].contiguous()
|
|
114
|
+
shift_labels = labels[..., 1:].contiguous()
|
|
115
|
+
# Flatten the tokens
|
|
116
|
+
loss_fct = CrossEntropyLoss()
|
|
117
|
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
|
118
|
+
shift_labels = shift_labels.view(-1)
|
|
119
|
+
# Enable model parallelism
|
|
120
|
+
shift_labels = shift_labels.to(shift_logits.device)
|
|
121
|
+
loss = loss_fct(shift_logits, shift_labels)
|
|
122
|
+
return loss, logits
|
|
123
|
+
pass
|
|
124
|
+
|
|
125
|
+
# We need an empty logits flag to warn people logits will not be returned anymore unless asked ie
|
|
126
|
+
# os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
|
127
|
+
LOGITS_ERROR_STRING = \
|
|
128
|
+
"Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "\
|
|
129
|
+
'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n'\
|
|
130
|
+
"```\nimport os\n"\
|
|
131
|
+
"os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"\
|
|
132
|
+
"trainer.train()\n```\n"\
|
|
133
|
+
"No need to restart your console - just add `os.environ['UNSLOTH_RETURN_LOGITS'] = '1'` before trainer.train() and re-run the cell!"
|
|
134
|
+
|
|
135
|
+
def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING)
|
|
136
|
+
def return_none(*args, **kwargs): return None
|
|
137
|
+
class EmptyLogits:
|
|
138
|
+
def __init__(self): return
|
|
139
|
+
def raise_getattr_error(self, attr): return return_none if attr == "to" else raise_logits_error
|
|
140
|
+
__getitem__ = raise_logits_error
|
|
141
|
+
__getattr__ = raise_getattr_error
|
|
142
|
+
def __repr__(self): return LOGITS_ERROR_STRING
|
|
143
|
+
def __str__ (self): return LOGITS_ERROR_STRING
|
|
144
|
+
pass
|
|
145
|
+
EMPTY_LOGITS = EmptyLogits()
|
|
146
|
+
functions = dir(torch.Tensor)
|
|
147
|
+
for j, function in enumerate(functions):
|
|
148
|
+
if function.startswith("__") and function.endswith("__"):
|
|
149
|
+
exec(f"def raise_{j}(*args, **kwargs): print('{function}')", globals(), locals())
|
|
150
|
+
try: exec(f"EMPTY_LOGITS.{function} = raise_{j}", globals(), locals())
|
|
151
|
+
except: continue
|
|
152
|
+
pass
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def mask_attention_mask_out(labels = None, attention_mask = None):
|
|
156
|
+
if labels is not None and attention_mask is not None:
|
|
157
|
+
attention_mask = attention_mask.to(device = labels.device)
|
|
158
|
+
labels[attention_mask == 0] = -100
|
|
159
|
+
return labels
|
|
160
|
+
pass
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
from torch import Tensor
|
|
164
|
+
import torch
|
|
165
|
+
import torch.nn as nn
|
|
166
|
+
from torch.nn import functional as F
|
|
167
|
+
from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
|
|
168
|
+
from transformers.models.siglip.modeling_siglip import (math, warnings, Callable, Optional, np, torch, nn, _calculate_fan_in_and_fan_out, ACT2FN, ALL_ATTENTION_FUNCTIONS, torch_int, SiglipTextConfig, SiglipVisionConfig)
|
|
169
|
+
|
|
170
|
+
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
171
|
+
def _trunc_normal_(tensor, mean, std, a, b):
|
|
172
|
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
|
173
|
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
|
174
|
+
def norm_cdf(x):
|
|
175
|
+
# Computes standard normal cumulative distribution function
|
|
176
|
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
|
177
|
+
|
|
178
|
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
|
179
|
+
warnings.warn(
|
|
180
|
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
|
181
|
+
"The distribution of values may be incorrect.",
|
|
182
|
+
stacklevel=2,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# Values are generated by using a truncated uniform distribution and
|
|
186
|
+
# then using the inverse CDF for the normal distribution.
|
|
187
|
+
# Get upper and lower cdf values
|
|
188
|
+
l = norm_cdf((a - mean) / std)
|
|
189
|
+
u = norm_cdf((b - mean) / std)
|
|
190
|
+
|
|
191
|
+
# Uniformly fill tensor with values from [l, u], then translate to
|
|
192
|
+
# [2l-1, 2u-1].
|
|
193
|
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
|
194
|
+
|
|
195
|
+
# Use inverse cdf transform for normal distribution to get truncated
|
|
196
|
+
# standard normal
|
|
197
|
+
tensor.erfinv_()
|
|
198
|
+
|
|
199
|
+
# Transform to proper mean, std
|
|
200
|
+
tensor.mul_(std * math.sqrt(2.0))
|
|
201
|
+
tensor.add_(mean)
|
|
202
|
+
|
|
203
|
+
# Clamp to ensure it's in the proper range
|
|
204
|
+
tensor.clamp_(min=a, max=b)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
208
|
+
def trunc_normal_tf_(
|
|
209
|
+
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
|
|
210
|
+
) -> torch.Tensor:
|
|
211
|
+
"""Fills the input Tensor with values drawn from a truncated
|
|
212
|
+
normal distribution. The values are effectively drawn from the
|
|
213
|
+
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
|
|
214
|
+
with values outside :math:`[a, b]` redrawn until they are within
|
|
215
|
+
the bounds. The method used for generating the random values works
|
|
216
|
+
best when :math:`a \\leq \text{mean} \\leq b`.
|
|
217
|
+
|
|
218
|
+
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
|
|
219
|
+
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
|
|
220
|
+
and the result is subsequently scaled and shifted by the mean and std args.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
tensor: an n-dimensional `torch.Tensor`
|
|
224
|
+
mean: the mean of the normal distribution
|
|
225
|
+
std: the standard deviation of the normal distribution
|
|
226
|
+
a: the minimum cutoff value
|
|
227
|
+
b: the maximum cutoff value
|
|
228
|
+
"""
|
|
229
|
+
with torch.no_grad():
|
|
230
|
+
_trunc_normal_(tensor, 0, 1.0, a, b)
|
|
231
|
+
tensor.mul_(std).add_(mean)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
235
|
+
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
|
|
236
|
+
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
|
237
|
+
if mode == "fan_in":
|
|
238
|
+
denom = fan_in
|
|
239
|
+
elif mode == "fan_out":
|
|
240
|
+
denom = fan_out
|
|
241
|
+
elif mode == "fan_avg":
|
|
242
|
+
denom = (fan_in + fan_out) / 2
|
|
243
|
+
|
|
244
|
+
variance = scale / denom
|
|
245
|
+
|
|
246
|
+
if distribution == "truncated_normal":
|
|
247
|
+
# constant is stddev of standard normal truncated to (-2, 2)
|
|
248
|
+
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
|
|
249
|
+
elif distribution == "normal":
|
|
250
|
+
with torch.no_grad():
|
|
251
|
+
tensor.normal_(std=math.sqrt(variance))
|
|
252
|
+
elif distribution == "uniform":
|
|
253
|
+
bound = math.sqrt(3 * variance)
|
|
254
|
+
with torch.no_grad():
|
|
255
|
+
tensor.uniform_(-bound, bound)
|
|
256
|
+
else:
|
|
257
|
+
raise ValueError(f"invalid distribution {distribution}")
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
261
|
+
def lecun_normal_(tensor):
|
|
262
|
+
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
266
|
+
def default_flax_embed_init(tensor):
|
|
267
|
+
variance_scaling_(tensor, mode="fan_in", distribution="normal")
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
|
|
271
|
+
def SiglipVisionEmbeddings_forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
|
|
272
|
+
_, _, height, width = pixel_values.shape
|
|
273
|
+
target_dtype = self.patch_embedding.weight.dtype
|
|
274
|
+
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
|
275
|
+
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
|
276
|
+
|
|
277
|
+
if interpolate_pos_encoding:
|
|
278
|
+
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
|
279
|
+
else:
|
|
280
|
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
|
281
|
+
return embeddings
|
|
282
|
+
|
|
283
|
+
class SiglipVisionEmbeddings(nn.Module):
|
|
284
|
+
def __init__(self, config: SiglipVisionConfig):
|
|
285
|
+
super().__init__()
|
|
286
|
+
self.config = config
|
|
287
|
+
self.embed_dim = config.hidden_size
|
|
288
|
+
self.image_size = config.image_size
|
|
289
|
+
self.patch_size = config.patch_size
|
|
290
|
+
|
|
291
|
+
self.patch_embedding = nn.Conv2d(
|
|
292
|
+
in_channels=config.num_channels,
|
|
293
|
+
out_channels=self.embed_dim,
|
|
294
|
+
kernel_size=self.patch_size,
|
|
295
|
+
stride=self.patch_size,
|
|
296
|
+
padding="valid",
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
|
300
|
+
self.num_positions = self.num_patches
|
|
301
|
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
|
302
|
+
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
|
303
|
+
|
|
304
|
+
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
|
305
|
+
"""
|
|
306
|
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
|
307
|
+
images. This method is also adapted to support torch.jit tracing and no class embeddings.
|
|
308
|
+
|
|
309
|
+
Adapted from:
|
|
310
|
+
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
|
311
|
+
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
|
312
|
+
"""
|
|
313
|
+
|
|
314
|
+
num_patches = embeddings.shape[1]
|
|
315
|
+
num_positions = self.position_embedding.weight.shape[0]
|
|
316
|
+
|
|
317
|
+
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
|
318
|
+
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
|
319
|
+
return self.position_embedding(self.position_ids)
|
|
320
|
+
|
|
321
|
+
patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
|
|
322
|
+
|
|
323
|
+
dim = embeddings.shape[-1]
|
|
324
|
+
|
|
325
|
+
new_height = height // self.patch_size
|
|
326
|
+
new_width = width // self.patch_size
|
|
327
|
+
|
|
328
|
+
sqrt_num_positions = torch_int(num_positions**0.5)
|
|
329
|
+
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
|
330
|
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
|
331
|
+
|
|
332
|
+
patch_pos_embed = nn.functional.interpolate(
|
|
333
|
+
patch_pos_embed,
|
|
334
|
+
size=(new_height, new_width),
|
|
335
|
+
mode="bicubic",
|
|
336
|
+
align_corners=False,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
|
340
|
+
return patch_pos_embed
|
|
341
|
+
|
|
342
|
+
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
|
|
343
|
+
return SiglipVisionEmbeddings_forward(self, pixel_values, interpolate_pos_encoding)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
347
|
+
def SiglipTextEmbeddings_forward(
|
|
348
|
+
self,
|
|
349
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
350
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
351
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
352
|
+
) -> torch.Tensor:
|
|
353
|
+
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
|
354
|
+
max_position_embedding = self.position_embedding.weight.shape[0]
|
|
355
|
+
|
|
356
|
+
if seq_length > max_position_embedding:
|
|
357
|
+
raise ValueError(
|
|
358
|
+
f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
|
|
359
|
+
f"{seq_length} and max_position_embeddings: {max_position_embedding}"
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
if position_ids is None:
|
|
363
|
+
position_ids = self.position_ids[:, :seq_length]
|
|
364
|
+
|
|
365
|
+
if inputs_embeds is None:
|
|
366
|
+
inputs_embeds = self.token_embedding(input_ids)
|
|
367
|
+
|
|
368
|
+
position_embeddings = self.position_embedding(position_ids)
|
|
369
|
+
embeddings = inputs_embeds + position_embeddings
|
|
370
|
+
|
|
371
|
+
return embeddings
|
|
372
|
+
|
|
373
|
+
class SiglipTextEmbeddings(nn.Module):
|
|
374
|
+
def __init__(self, config: SiglipTextConfig):
|
|
375
|
+
super().__init__()
|
|
376
|
+
embed_dim = config.hidden_size
|
|
377
|
+
|
|
378
|
+
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
|
379
|
+
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
|
|
380
|
+
|
|
381
|
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
|
382
|
+
self.register_buffer(
|
|
383
|
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
def forward(
|
|
387
|
+
self,
|
|
388
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
389
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
390
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
391
|
+
) -> torch.Tensor:
|
|
392
|
+
return SiglipTextEmbeddings_forward(self, input_ids, position_ids, inputs_embeds)
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
|
396
|
+
def eager_attention_forward(
|
|
397
|
+
module: nn.Module,
|
|
398
|
+
query: torch.Tensor,
|
|
399
|
+
key: torch.Tensor,
|
|
400
|
+
value: torch.Tensor,
|
|
401
|
+
attention_mask: Optional[torch.Tensor],
|
|
402
|
+
scaling: float,
|
|
403
|
+
dropout: float = 0.0,
|
|
404
|
+
**kwargs,
|
|
405
|
+
):
|
|
406
|
+
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
|
407
|
+
if attention_mask is not None:
|
|
408
|
+
attn_weights = attn_weights + attention_mask
|
|
409
|
+
|
|
410
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype = torch.float32).to(attn_weights.dtype).to(query.dtype)
|
|
411
|
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
|
412
|
+
|
|
413
|
+
attn_output = torch.matmul(attn_weights, value)
|
|
414
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
415
|
+
|
|
416
|
+
return attn_output, attn_weights
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
@torch.compiler.disable(recursive = False)
|
|
420
|
+
def SiglipAttention_forward(
|
|
421
|
+
self,
|
|
422
|
+
hidden_states: torch.Tensor,
|
|
423
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
424
|
+
**kwargs,
|
|
425
|
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
426
|
+
"""Input shape: Batch x Time x Channel"""
|
|
427
|
+
|
|
428
|
+
batch_size, seq_length, embed_dim = hidden_states.shape
|
|
429
|
+
|
|
430
|
+
queries = self.q_proj(hidden_states)
|
|
431
|
+
keys = self.k_proj(hidden_states)
|
|
432
|
+
values = self.v_proj(hidden_states)
|
|
433
|
+
|
|
434
|
+
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
|
435
|
+
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
|
436
|
+
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
|
437
|
+
|
|
438
|
+
attention_interface: Callable = eager_attention_forward
|
|
439
|
+
if self.config._attn_implementation != "eager":
|
|
440
|
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
441
|
+
|
|
442
|
+
attn_output, attn_weights = attention_interface(
|
|
443
|
+
self,
|
|
444
|
+
queries,
|
|
445
|
+
keys,
|
|
446
|
+
values,
|
|
447
|
+
attention_mask,
|
|
448
|
+
is_causal=self.is_causal,
|
|
449
|
+
scaling=self.scale,
|
|
450
|
+
dropout=0.0 if not self.training else self.dropout,
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
|
|
454
|
+
attn_output = self.out_proj(attn_output)
|
|
455
|
+
|
|
456
|
+
return attn_output, attn_weights
|
|
457
|
+
|
|
458
|
+
class SiglipAttention(nn.Module):
|
|
459
|
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
460
|
+
|
|
461
|
+
def __init__(self, config):
|
|
462
|
+
super().__init__()
|
|
463
|
+
self.config = config
|
|
464
|
+
self.embed_dim = config.hidden_size
|
|
465
|
+
self.num_heads = config.num_attention_heads
|
|
466
|
+
self.head_dim = self.embed_dim // self.num_heads
|
|
467
|
+
if self.head_dim * self.num_heads != self.embed_dim:
|
|
468
|
+
raise ValueError(
|
|
469
|
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
|
470
|
+
f" {self.num_heads})."
|
|
471
|
+
)
|
|
472
|
+
self.scale = self.head_dim**-0.5
|
|
473
|
+
self.dropout = config.attention_dropout
|
|
474
|
+
self.is_causal = False
|
|
475
|
+
|
|
476
|
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
477
|
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
478
|
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
479
|
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
480
|
+
|
|
481
|
+
def forward(
|
|
482
|
+
self,
|
|
483
|
+
hidden_states: torch.Tensor,
|
|
484
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
485
|
+
**kwargs,
|
|
486
|
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
487
|
+
return SiglipAttention_forward(self, hidden_states, attention_mask, **kwargs)
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
|
|
491
|
+
def SiglipMLP_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
492
|
+
hidden_states = self.fc1(hidden_states)
|
|
493
|
+
hidden_states = self.activation_fn(hidden_states)
|
|
494
|
+
hidden_states = self.fc2(hidden_states)
|
|
495
|
+
return hidden_states
|
|
496
|
+
|
|
497
|
+
class SiglipMLP(nn.Module):
|
|
498
|
+
def __init__(self, config):
|
|
499
|
+
super().__init__()
|
|
500
|
+
self.config = config
|
|
501
|
+
self.activation_fn = ACT2FN[config.hidden_act]
|
|
502
|
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
503
|
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
504
|
+
|
|
505
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
506
|
+
return SiglipMLP_forward(self, hidden_states)
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
|
|
510
|
+
def SiglipMultiheadAttentionPoolingHead_forward(self, hidden_state):
|
|
511
|
+
batch_size = hidden_state.shape[0]
|
|
512
|
+
probe = self.probe.repeat(batch_size, 1, 1)
|
|
513
|
+
|
|
514
|
+
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
|
|
515
|
+
|
|
516
|
+
residual = hidden_state
|
|
517
|
+
hidden_state = self.layernorm(hidden_state)
|
|
518
|
+
hidden_state = residual + self.mlp(hidden_state)
|
|
519
|
+
|
|
520
|
+
return hidden_state[:, 0]
|
|
521
|
+
|
|
522
|
+
class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
|
523
|
+
"""Multihead Attention Pooling."""
|
|
524
|
+
|
|
525
|
+
def __init__(self, config: SiglipVisionConfig):
|
|
526
|
+
super().__init__()
|
|
527
|
+
|
|
528
|
+
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
|
529
|
+
self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
|
|
530
|
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
531
|
+
self.mlp = SiglipMLP(config)
|
|
532
|
+
|
|
533
|
+
def forward(self, hidden_state):
|
|
534
|
+
return SiglipMultiheadAttentionPoolingHead_forward(self, hidden_state)
|