cciwon-code-review-cli 2.0.2 → 2.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. package/lib/chat-mode.js +7 -2
  2. package/package.json +1 -1
  3. package/unsloth_compiled_cache/.locks/.lock.AqlmLoraLinear_peft_forward.py +0 -0
  4. package/unsloth_compiled_cache/.locks/.lock.AwqLoraLinear_peft_forward.py +0 -0
  5. package/unsloth_compiled_cache/.locks/.lock.BatchNorm1d.py +0 -0
  6. package/unsloth_compiled_cache/.locks/.lock.BatchNorm2d.py +0 -0
  7. package/unsloth_compiled_cache/.locks/.lock.BatchNorm3d.py +0 -0
  8. package/unsloth_compiled_cache/.locks/.lock.Conv1d.py +0 -0
  9. package/unsloth_compiled_cache/.locks/.lock.Conv2d.py +0 -0
  10. package/unsloth_compiled_cache/.locks/.lock.Conv3d.py +0 -0
  11. package/unsloth_compiled_cache/.locks/.lock.ConvTranspose1d.py +0 -0
  12. package/unsloth_compiled_cache/.locks/.lock.ConvTranspose2d.py +0 -0
  13. package/unsloth_compiled_cache/.locks/.lock.ConvTranspose3d.py +0 -0
  14. package/unsloth_compiled_cache/.locks/.lock.GPTQLoraLinear_peft_forward.py +0 -0
  15. package/unsloth_compiled_cache/.locks/.lock.GroupNorm.py +0 -0
  16. package/unsloth_compiled_cache/.locks/.lock.LayerNorm.py +0 -0
  17. package/unsloth_compiled_cache/.locks/.lock.Linear4bit_peft_forward.py +0 -0
  18. package/unsloth_compiled_cache/.locks/.lock.Linear8bitLt_peft_forward.py +0 -0
  19. package/unsloth_compiled_cache/.locks/.lock.Linear_peft_forward.py +0 -0
  20. package/unsloth_compiled_cache/.locks/.lock.LoraParallelLinear_peft_forward.py +0 -0
  21. package/unsloth_compiled_cache/.locks/.lock.RMSNorm.py +0 -0
  22. package/unsloth_compiled_cache/.locks/.lock.UnslothBCOTrainer.py +0 -0
  23. package/unsloth_compiled_cache/.locks/.lock.UnslothCPOTrainer.py +0 -0
  24. package/unsloth_compiled_cache/.locks/.lock.UnslothDPOTrainer.py +0 -0
  25. package/unsloth_compiled_cache/.locks/.lock.UnslothGKDTrainer.py +0 -0
  26. package/unsloth_compiled_cache/.locks/.lock.UnslothGRPOTrainer.py +0 -0
  27. package/unsloth_compiled_cache/.locks/.lock.UnslothKTOTrainer.py +0 -0
  28. package/unsloth_compiled_cache/.locks/.lock.UnslothNashMDTrainer.py +0 -0
  29. package/unsloth_compiled_cache/.locks/.lock.UnslothORPOTrainer.py +0 -0
  30. package/unsloth_compiled_cache/.locks/.lock.UnslothOnlineDPOTrainer.py +0 -0
  31. package/unsloth_compiled_cache/.locks/.lock.UnslothPPOTrainer.py +0 -0
  32. package/unsloth_compiled_cache/.locks/.lock.UnslothPRMTrainer.py +0 -0
  33. package/unsloth_compiled_cache/.locks/.lock.UnslothRLOOTrainer.py +0 -0
  34. package/unsloth_compiled_cache/.locks/.lock.UnslothRewardTrainer.py +0 -0
  35. package/unsloth_compiled_cache/.locks/.lock.UnslothSFTTrainer.py +0 -0
  36. package/unsloth_compiled_cache/.locks/.lock.UnslothXPOTrainer.py +0 -0
  37. package/unsloth_compiled_cache/.locks/.lock.unsloth_compiled_module_qwen3_moe.py +0 -0
  38. package/unsloth_compiled_cache/.locks/.lock.unsloth_compiled_module_siglip.py +0 -0
  39. package/unsloth_compiled_cache/AqlmLoraLinear_peft_forward.py +88 -0
  40. package/unsloth_compiled_cache/AwqLoraLinear_peft_forward.py +87 -0
  41. package/unsloth_compiled_cache/BatchNorm1d.py +117 -0
  42. package/unsloth_compiled_cache/BatchNorm2d.py +117 -0
  43. package/unsloth_compiled_cache/BatchNorm3d.py +117 -0
  44. package/unsloth_compiled_cache/Conv1d.py +70 -0
  45. package/unsloth_compiled_cache/Conv2d.py +70 -0
  46. package/unsloth_compiled_cache/Conv3d.py +70 -0
  47. package/unsloth_compiled_cache/ConvTranspose1d.py +97 -0
  48. package/unsloth_compiled_cache/ConvTranspose2d.py +106 -0
  49. package/unsloth_compiled_cache/ConvTranspose3d.py +98 -0
  50. package/unsloth_compiled_cache/GPTQLoraLinear_peft_forward.py +95 -0
  51. package/unsloth_compiled_cache/GroupNorm.py +70 -0
  52. package/unsloth_compiled_cache/LayerNorm.py +72 -0
  53. package/unsloth_compiled_cache/Linear4bit_peft_forward.py +115 -0
  54. package/unsloth_compiled_cache/Linear8bitLt_peft_forward.py +113 -0
  55. package/unsloth_compiled_cache/Linear_peft_forward.py +104 -0
  56. package/unsloth_compiled_cache/LoraParallelLinear_peft_forward.py +91 -0
  57. package/unsloth_compiled_cache/RMSNorm.py +73 -0
  58. package/unsloth_compiled_cache/UnslothBCOTrainer.py +2026 -0
  59. package/unsloth_compiled_cache/UnslothCPOTrainer.py +1806 -0
  60. package/unsloth_compiled_cache/UnslothDPOTrainer.py +2750 -0
  61. package/unsloth_compiled_cache/UnslothGKDTrainer.py +1157 -0
  62. package/unsloth_compiled_cache/UnslothGRPOTrainer.py +3607 -0
  63. package/unsloth_compiled_cache/UnslothKTOTrainer.py +2220 -0
  64. package/unsloth_compiled_cache/UnslothNashMDTrainer.py +1210 -0
  65. package/unsloth_compiled_cache/UnslothORPOTrainer.py +1730 -0
  66. package/unsloth_compiled_cache/UnslothOnlineDPOTrainer.py +2313 -0
  67. package/unsloth_compiled_cache/UnslothPPOTrainer.py +1504 -0
  68. package/unsloth_compiled_cache/UnslothPRMTrainer.py +979 -0
  69. package/unsloth_compiled_cache/UnslothRLOOTrainer.py +2674 -0
  70. package/unsloth_compiled_cache/UnslothRewardTrainer.py +1197 -0
  71. package/unsloth_compiled_cache/UnslothSFTTrainer.py +1416 -0
  72. package/unsloth_compiled_cache/UnslothXPOTrainer.py +1255 -0
  73. package/unsloth_compiled_cache/__pycache__/AqlmLoraLinear_peft_forward.cpython-312.pyc +0 -0
  74. package/unsloth_compiled_cache/__pycache__/AwqLoraLinear_peft_forward.cpython-312.pyc +0 -0
  75. package/unsloth_compiled_cache/__pycache__/BatchNorm1d.cpython-312.pyc +0 -0
  76. package/unsloth_compiled_cache/__pycache__/BatchNorm2d.cpython-312.pyc +0 -0
  77. package/unsloth_compiled_cache/__pycache__/BatchNorm3d.cpython-312.pyc +0 -0
  78. package/unsloth_compiled_cache/__pycache__/Conv1d.cpython-312.pyc +0 -0
  79. package/unsloth_compiled_cache/__pycache__/Conv2d.cpython-312.pyc +0 -0
  80. package/unsloth_compiled_cache/__pycache__/Conv3d.cpython-312.pyc +0 -0
  81. package/unsloth_compiled_cache/__pycache__/ConvTranspose1d.cpython-312.pyc +0 -0
  82. package/unsloth_compiled_cache/__pycache__/ConvTranspose2d.cpython-312.pyc +0 -0
  83. package/unsloth_compiled_cache/__pycache__/ConvTranspose3d.cpython-312.pyc +0 -0
  84. package/unsloth_compiled_cache/__pycache__/GPTQLoraLinear_peft_forward.cpython-312.pyc +0 -0
  85. package/unsloth_compiled_cache/__pycache__/GroupNorm.cpython-312.pyc +0 -0
  86. package/unsloth_compiled_cache/__pycache__/LayerNorm.cpython-312.pyc +0 -0
  87. package/unsloth_compiled_cache/__pycache__/Linear4bit_peft_forward.cpython-312.pyc +0 -0
  88. package/unsloth_compiled_cache/__pycache__/Linear8bitLt_peft_forward.cpython-312.pyc +0 -0
  89. package/unsloth_compiled_cache/__pycache__/Linear_peft_forward.cpython-312.pyc +0 -0
  90. package/unsloth_compiled_cache/__pycache__/LoraParallelLinear_peft_forward.cpython-312.pyc +0 -0
  91. package/unsloth_compiled_cache/__pycache__/RMSNorm.cpython-312.pyc +0 -0
  92. package/unsloth_compiled_cache/__pycache__/UnslothBCOTrainer.cpython-312.pyc +0 -0
  93. package/unsloth_compiled_cache/__pycache__/UnslothCPOTrainer.cpython-312.pyc +0 -0
  94. package/unsloth_compiled_cache/__pycache__/UnslothDPOTrainer.cpython-312.pyc +0 -0
  95. package/unsloth_compiled_cache/__pycache__/UnslothGKDTrainer.cpython-312.pyc +0 -0
  96. package/unsloth_compiled_cache/__pycache__/UnslothGRPOTrainer.cpython-312.pyc +0 -0
  97. package/unsloth_compiled_cache/__pycache__/UnslothKTOTrainer.cpython-312.pyc +0 -0
  98. package/unsloth_compiled_cache/__pycache__/UnslothNashMDTrainer.cpython-312.pyc +0 -0
  99. package/unsloth_compiled_cache/__pycache__/UnslothORPOTrainer.cpython-312.pyc +0 -0
  100. package/unsloth_compiled_cache/__pycache__/UnslothOnlineDPOTrainer.cpython-312.pyc +0 -0
  101. package/unsloth_compiled_cache/__pycache__/UnslothPPOTrainer.cpython-312.pyc +0 -0
  102. package/unsloth_compiled_cache/__pycache__/UnslothPRMTrainer.cpython-312.pyc +0 -0
  103. package/unsloth_compiled_cache/__pycache__/UnslothRLOOTrainer.cpython-312.pyc +0 -0
  104. package/unsloth_compiled_cache/__pycache__/UnslothRewardTrainer.cpython-312.pyc +0 -0
  105. package/unsloth_compiled_cache/__pycache__/UnslothSFTTrainer.cpython-312.pyc +0 -0
  106. package/unsloth_compiled_cache/__pycache__/UnslothXPOTrainer.cpython-312.pyc +0 -0
  107. package/unsloth_compiled_cache/__pycache__/unsloth_compiled_module_qwen3_moe.cpython-312.pyc +0 -0
  108. package/unsloth_compiled_cache/__pycache__/unsloth_compiled_module_siglip.cpython-312.pyc +0 -0
  109. package/unsloth_compiled_cache/unsloth_compiled_module_qwen3_moe.py +726 -0
  110. package/unsloth_compiled_cache/unsloth_compiled_module_siglip.py +534 -0
@@ -0,0 +1,534 @@
1
+ """
2
+ 2025.12.6
3
+ 2025.12.7
4
+ 4.57.1
5
+ 0.24.0
6
+ __UNSLOTH_VERSIONING__
7
+ """
8
+
9
+ # Unsloth auto generated code
10
+ # Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
11
+ #
12
+ # This program is free software: you can redistribute it and/or modify
13
+ # it under the terms of the GNU Lesser General Public License as published by
14
+ # the Free Software Foundation, either version 3 of the License, or
15
+ # (at your option) any later version.
16
+ #
17
+ # This program is distributed in the hope that it will be useful,
18
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
19
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20
+ # GNU General Public License for more details.
21
+ #
22
+ # You should have received a copy of the GNU Lesser General Public License
23
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
24
+
25
+
26
+ import os
27
+ import torch
28
+ import importlib.util
29
+ import math
30
+ if importlib.util.find_spec("unsloth_studio") is None:
31
+ UNSLOTH_STUDIO_ENABLED = False
32
+ else:
33
+ UNSLOTH_STUDIO_ENABLED = os.environ.get("UNSLOTH_STUDIO_DISABLED", "0") == "0"
34
+ pass
35
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
36
+ import math
37
+
38
+ UNSLOTH_ENABLE_LOGGING = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1"
39
+ UNSLOTH_ENABLE_CCE = os.environ.get("UNSLOTH_ENABLE_CCE", "1") == "1"
40
+ UNSLOTH_COMPILE_DISABLE = os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") in ("1", "partial",)
41
+
42
+ import logging
43
+ logger_compiler = logging.getLogger(__name__)
44
+ if UNSLOTH_ENABLE_LOGGING:
45
+ logger_compiler.setLevel(logging.DEBUG)
46
+
47
+ global INFERENCE_RUNS
48
+ INFERENCE_RUNS = 0
49
+
50
+ try:
51
+ import torch._dynamo.eval_frame as torch_dynamo_eval_frame
52
+ torch_dynamo_eval_frame._stance.stance
53
+ torch_compiler_set_stance = torch.compiler.set_stance
54
+ except:
55
+ torch_dynamo_eval_frame = None
56
+ torch_compiler_set_stance = None
57
+ pass
58
+
59
+ from unsloth_zoo import DEVICE_TYPE_TORCH, DEVICE_COUNT
60
+
61
+
62
+ from unsloth_zoo.loss_utils import (
63
+ fused_linear_cross_entropy,
64
+ unsloth_fused_ce_loss,
65
+ )
66
+
67
+ if UNSLOTH_STUDIO_ENABLED:
68
+ from unsloth_zoo.loss_utils import fast_linear_cross_entropy
69
+
70
+ scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
71
+ @torch.compiler.disable(recursive = False)
72
+ def disable_compile_scaled_dot_product_attention(*args, **kwargs):
73
+ return scaled_dot_product_attention(*args, **kwargs)
74
+ pass
75
+
76
+
77
+ from transformers.modeling_flash_attention_utils import is_flash_attn_available
78
+
79
+ if is_flash_attn_available():
80
+ try:
81
+ from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask
82
+ except:
83
+ flash_attn_supports_top_left_mask = None
84
+ try:
85
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
86
+ except:
87
+ _flash_attention_forward = None
88
+ try:
89
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
90
+ except:
91
+ FlashAttentionKwargs = None
92
+ try:
93
+ from transformers.modeling_flash_attention_utils import flash_attn_varlen_func
94
+ except:
95
+ flash_attn_varlen_func = None
96
+ else:
97
+ flash_attn_supports_top_left_mask = None
98
+ _flash_attention_forward = None
99
+ FlashAttentionKwargs = None
100
+ flash_attn_varlen_func = None
101
+ pass
102
+
103
+
104
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False, 'debug': False, 'dce': True, 'memory_planning': True, 'coordinate_descent_tuning': False, 'trace.graph_diagram': False, 'compile_threads': 32, 'group_fusion': True, 'disable_progress': True, 'verbose_progress': False, 'triton.multi_kernel': 0, 'triton.use_block_ptr': False, 'triton.enable_persistent_tma_matmul': True, 'triton.autotune_at_compile_time': False, 'triton.cooperative_reductions': False, 'cuda.compile_opt_level': '-O2', 'cuda.enable_cuda_lto': True, 'combo_kernels': False, 'benchmark_combo_kernel': True, 'combo_kernel_foreach_dynamic_shapes': True}
105
+
106
+ from torch.nn import CrossEntropyLoss
107
+
108
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
109
+ def normal_cross_entropy_loss(self, hidden_states, labels):
110
+ logits = self.lm_head(hidden_states)
111
+ logits = logits.float()
112
+ # Shift so that tokens < n predict n
113
+ shift_logits = logits[..., :-1, :].contiguous()
114
+ shift_labels = labels[..., 1:].contiguous()
115
+ # Flatten the tokens
116
+ loss_fct = CrossEntropyLoss()
117
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
118
+ shift_labels = shift_labels.view(-1)
119
+ # Enable model parallelism
120
+ shift_labels = shift_labels.to(shift_logits.device)
121
+ loss = loss_fct(shift_logits, shift_labels)
122
+ return loss, logits
123
+ pass
124
+
125
+ # We need an empty logits flag to warn people logits will not be returned anymore unless asked ie
126
+ # os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
127
+ LOGITS_ERROR_STRING = \
128
+ "Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "\
129
+ 'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n'\
130
+ "```\nimport os\n"\
131
+ "os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"\
132
+ "trainer.train()\n```\n"\
133
+ "No need to restart your console - just add `os.environ['UNSLOTH_RETURN_LOGITS'] = '1'` before trainer.train() and re-run the cell!"
134
+
135
+ def raise_logits_error(*args, **kwargs): raise NotImplementedError(LOGITS_ERROR_STRING)
136
+ def return_none(*args, **kwargs): return None
137
+ class EmptyLogits:
138
+ def __init__(self): return
139
+ def raise_getattr_error(self, attr): return return_none if attr == "to" else raise_logits_error
140
+ __getitem__ = raise_logits_error
141
+ __getattr__ = raise_getattr_error
142
+ def __repr__(self): return LOGITS_ERROR_STRING
143
+ def __str__ (self): return LOGITS_ERROR_STRING
144
+ pass
145
+ EMPTY_LOGITS = EmptyLogits()
146
+ functions = dir(torch.Tensor)
147
+ for j, function in enumerate(functions):
148
+ if function.startswith("__") and function.endswith("__"):
149
+ exec(f"def raise_{j}(*args, **kwargs): print('{function}')", globals(), locals())
150
+ try: exec(f"EMPTY_LOGITS.{function} = raise_{j}", globals(), locals())
151
+ except: continue
152
+ pass
153
+
154
+
155
+ def mask_attention_mask_out(labels = None, attention_mask = None):
156
+ if labels is not None and attention_mask is not None:
157
+ attention_mask = attention_mask.to(device = labels.device)
158
+ labels[attention_mask == 0] = -100
159
+ return labels
160
+ pass
161
+
162
+
163
+ from torch import Tensor
164
+ import torch
165
+ import torch.nn as nn
166
+ from torch.nn import functional as F
167
+ from typing import Any, List, Optional, Tuple, Union, Dict, Set, Callable
168
+ from transformers.models.siglip.modeling_siglip import (math, warnings, Callable, Optional, np, torch, nn, _calculate_fan_in_and_fan_out, ACT2FN, ALL_ATTENTION_FUNCTIONS, torch_int, SiglipTextConfig, SiglipVisionConfig)
169
+
170
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
171
+ def _trunc_normal_(tensor, mean, std, a, b):
172
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
173
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
174
+ def norm_cdf(x):
175
+ # Computes standard normal cumulative distribution function
176
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
177
+
178
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
179
+ warnings.warn(
180
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
181
+ "The distribution of values may be incorrect.",
182
+ stacklevel=2,
183
+ )
184
+
185
+ # Values are generated by using a truncated uniform distribution and
186
+ # then using the inverse CDF for the normal distribution.
187
+ # Get upper and lower cdf values
188
+ l = norm_cdf((a - mean) / std)
189
+ u = norm_cdf((b - mean) / std)
190
+
191
+ # Uniformly fill tensor with values from [l, u], then translate to
192
+ # [2l-1, 2u-1].
193
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
194
+
195
+ # Use inverse cdf transform for normal distribution to get truncated
196
+ # standard normal
197
+ tensor.erfinv_()
198
+
199
+ # Transform to proper mean, std
200
+ tensor.mul_(std * math.sqrt(2.0))
201
+ tensor.add_(mean)
202
+
203
+ # Clamp to ensure it's in the proper range
204
+ tensor.clamp_(min=a, max=b)
205
+
206
+
207
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
208
+ def trunc_normal_tf_(
209
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
210
+ ) -> torch.Tensor:
211
+ """Fills the input Tensor with values drawn from a truncated
212
+ normal distribution. The values are effectively drawn from the
213
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
214
+ with values outside :math:`[a, b]` redrawn until they are within
215
+ the bounds. The method used for generating the random values works
216
+ best when :math:`a \\leq \text{mean} \\leq b`.
217
+
218
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
219
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
220
+ and the result is subsequently scaled and shifted by the mean and std args.
221
+
222
+ Args:
223
+ tensor: an n-dimensional `torch.Tensor`
224
+ mean: the mean of the normal distribution
225
+ std: the standard deviation of the normal distribution
226
+ a: the minimum cutoff value
227
+ b: the maximum cutoff value
228
+ """
229
+ with torch.no_grad():
230
+ _trunc_normal_(tensor, 0, 1.0, a, b)
231
+ tensor.mul_(std).add_(mean)
232
+
233
+
234
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
235
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
236
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
237
+ if mode == "fan_in":
238
+ denom = fan_in
239
+ elif mode == "fan_out":
240
+ denom = fan_out
241
+ elif mode == "fan_avg":
242
+ denom = (fan_in + fan_out) / 2
243
+
244
+ variance = scale / denom
245
+
246
+ if distribution == "truncated_normal":
247
+ # constant is stddev of standard normal truncated to (-2, 2)
248
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
249
+ elif distribution == "normal":
250
+ with torch.no_grad():
251
+ tensor.normal_(std=math.sqrt(variance))
252
+ elif distribution == "uniform":
253
+ bound = math.sqrt(3 * variance)
254
+ with torch.no_grad():
255
+ tensor.uniform_(-bound, bound)
256
+ else:
257
+ raise ValueError(f"invalid distribution {distribution}")
258
+
259
+
260
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
261
+ def lecun_normal_(tensor):
262
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
263
+
264
+
265
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
266
+ def default_flax_embed_init(tensor):
267
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
268
+
269
+
270
+ @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
271
+ def SiglipVisionEmbeddings_forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
272
+ _, _, height, width = pixel_values.shape
273
+ target_dtype = self.patch_embedding.weight.dtype
274
+ patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
275
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
276
+
277
+ if interpolate_pos_encoding:
278
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
279
+ else:
280
+ embeddings = embeddings + self.position_embedding(self.position_ids)
281
+ return embeddings
282
+
283
+ class SiglipVisionEmbeddings(nn.Module):
284
+ def __init__(self, config: SiglipVisionConfig):
285
+ super().__init__()
286
+ self.config = config
287
+ self.embed_dim = config.hidden_size
288
+ self.image_size = config.image_size
289
+ self.patch_size = config.patch_size
290
+
291
+ self.patch_embedding = nn.Conv2d(
292
+ in_channels=config.num_channels,
293
+ out_channels=self.embed_dim,
294
+ kernel_size=self.patch_size,
295
+ stride=self.patch_size,
296
+ padding="valid",
297
+ )
298
+
299
+ self.num_patches = (self.image_size // self.patch_size) ** 2
300
+ self.num_positions = self.num_patches
301
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
302
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
303
+
304
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
305
+ """
306
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
307
+ images. This method is also adapted to support torch.jit tracing and no class embeddings.
308
+
309
+ Adapted from:
310
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
311
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
312
+ """
313
+
314
+ num_patches = embeddings.shape[1]
315
+ num_positions = self.position_embedding.weight.shape[0]
316
+
317
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
318
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
319
+ return self.position_embedding(self.position_ids)
320
+
321
+ patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
322
+
323
+ dim = embeddings.shape[-1]
324
+
325
+ new_height = height // self.patch_size
326
+ new_width = width // self.patch_size
327
+
328
+ sqrt_num_positions = torch_int(num_positions**0.5)
329
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
330
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
331
+
332
+ patch_pos_embed = nn.functional.interpolate(
333
+ patch_pos_embed,
334
+ size=(new_height, new_width),
335
+ mode="bicubic",
336
+ align_corners=False,
337
+ )
338
+
339
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
340
+ return patch_pos_embed
341
+
342
+ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
343
+ return SiglipVisionEmbeddings_forward(self, pixel_values, interpolate_pos_encoding)
344
+
345
+
346
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
347
+ def SiglipTextEmbeddings_forward(
348
+ self,
349
+ input_ids: Optional[torch.LongTensor] = None,
350
+ position_ids: Optional[torch.LongTensor] = None,
351
+ inputs_embeds: Optional[torch.FloatTensor] = None,
352
+ ) -> torch.Tensor:
353
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
354
+ max_position_embedding = self.position_embedding.weight.shape[0]
355
+
356
+ if seq_length > max_position_embedding:
357
+ raise ValueError(
358
+ f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
359
+ f"{seq_length} and max_position_embeddings: {max_position_embedding}"
360
+ )
361
+
362
+ if position_ids is None:
363
+ position_ids = self.position_ids[:, :seq_length]
364
+
365
+ if inputs_embeds is None:
366
+ inputs_embeds = self.token_embedding(input_ids)
367
+
368
+ position_embeddings = self.position_embedding(position_ids)
369
+ embeddings = inputs_embeds + position_embeddings
370
+
371
+ return embeddings
372
+
373
+ class SiglipTextEmbeddings(nn.Module):
374
+ def __init__(self, config: SiglipTextConfig):
375
+ super().__init__()
376
+ embed_dim = config.hidden_size
377
+
378
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
379
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
380
+
381
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
382
+ self.register_buffer(
383
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
384
+ )
385
+
386
+ def forward(
387
+ self,
388
+ input_ids: Optional[torch.LongTensor] = None,
389
+ position_ids: Optional[torch.LongTensor] = None,
390
+ inputs_embeds: Optional[torch.FloatTensor] = None,
391
+ ) -> torch.Tensor:
392
+ return SiglipTextEmbeddings_forward(self, input_ids, position_ids, inputs_embeds)
393
+
394
+
395
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
396
+ def eager_attention_forward(
397
+ module: nn.Module,
398
+ query: torch.Tensor,
399
+ key: torch.Tensor,
400
+ value: torch.Tensor,
401
+ attention_mask: Optional[torch.Tensor],
402
+ scaling: float,
403
+ dropout: float = 0.0,
404
+ **kwargs,
405
+ ):
406
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
407
+ if attention_mask is not None:
408
+ attn_weights = attn_weights + attention_mask
409
+
410
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype = torch.float32).to(attn_weights.dtype).to(query.dtype)
411
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
412
+
413
+ attn_output = torch.matmul(attn_weights, value)
414
+ attn_output = attn_output.transpose(1, 2).contiguous()
415
+
416
+ return attn_output, attn_weights
417
+
418
+
419
+ @torch.compiler.disable(recursive = False)
420
+ def SiglipAttention_forward(
421
+ self,
422
+ hidden_states: torch.Tensor,
423
+ attention_mask: Optional[torch.Tensor] = None,
424
+ **kwargs,
425
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
426
+ """Input shape: Batch x Time x Channel"""
427
+
428
+ batch_size, seq_length, embed_dim = hidden_states.shape
429
+
430
+ queries = self.q_proj(hidden_states)
431
+ keys = self.k_proj(hidden_states)
432
+ values = self.v_proj(hidden_states)
433
+
434
+ queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
435
+ keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
436
+ values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
437
+
438
+ attention_interface: Callable = eager_attention_forward
439
+ if self.config._attn_implementation != "eager":
440
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
441
+
442
+ attn_output, attn_weights = attention_interface(
443
+ self,
444
+ queries,
445
+ keys,
446
+ values,
447
+ attention_mask,
448
+ is_causal=self.is_causal,
449
+ scaling=self.scale,
450
+ dropout=0.0 if not self.training else self.dropout,
451
+ )
452
+
453
+ attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
454
+ attn_output = self.out_proj(attn_output)
455
+
456
+ return attn_output, attn_weights
457
+
458
+ class SiglipAttention(nn.Module):
459
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
460
+
461
+ def __init__(self, config):
462
+ super().__init__()
463
+ self.config = config
464
+ self.embed_dim = config.hidden_size
465
+ self.num_heads = config.num_attention_heads
466
+ self.head_dim = self.embed_dim // self.num_heads
467
+ if self.head_dim * self.num_heads != self.embed_dim:
468
+ raise ValueError(
469
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
470
+ f" {self.num_heads})."
471
+ )
472
+ self.scale = self.head_dim**-0.5
473
+ self.dropout = config.attention_dropout
474
+ self.is_causal = False
475
+
476
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
477
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
478
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
479
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
480
+
481
+ def forward(
482
+ self,
483
+ hidden_states: torch.Tensor,
484
+ attention_mask: Optional[torch.Tensor] = None,
485
+ **kwargs,
486
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
487
+ return SiglipAttention_forward(self, hidden_states, attention_mask, **kwargs)
488
+
489
+
490
+ @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
491
+ def SiglipMLP_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
492
+ hidden_states = self.fc1(hidden_states)
493
+ hidden_states = self.activation_fn(hidden_states)
494
+ hidden_states = self.fc2(hidden_states)
495
+ return hidden_states
496
+
497
+ class SiglipMLP(nn.Module):
498
+ def __init__(self, config):
499
+ super().__init__()
500
+ self.config = config
501
+ self.activation_fn = ACT2FN[config.hidden_act]
502
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
503
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
504
+
505
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
506
+ return SiglipMLP_forward(self, hidden_states)
507
+
508
+
509
+ @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
510
+ def SiglipMultiheadAttentionPoolingHead_forward(self, hidden_state):
511
+ batch_size = hidden_state.shape[0]
512
+ probe = self.probe.repeat(batch_size, 1, 1)
513
+
514
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
515
+
516
+ residual = hidden_state
517
+ hidden_state = self.layernorm(hidden_state)
518
+ hidden_state = residual + self.mlp(hidden_state)
519
+
520
+ return hidden_state[:, 0]
521
+
522
+ class SiglipMultiheadAttentionPoolingHead(nn.Module):
523
+ """Multihead Attention Pooling."""
524
+
525
+ def __init__(self, config: SiglipVisionConfig):
526
+ super().__init__()
527
+
528
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
529
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
530
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
531
+ self.mlp = SiglipMLP(config)
532
+
533
+ def forward(self, hidden_state):
534
+ return SiglipMultiheadAttentionPoolingHead_forward(self, hidden_state)