sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/function_call_parser.py +33 -2
  14. sglang/srt/hf_transformers_utils.py +16 -1
  15. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  17. sglang/srt/layers/attention/triton_backend.py +1 -3
  18. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  21. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  22. sglang/srt/layers/attention/vision.py +43 -62
  23. sglang/srt/layers/dp_attention.py +30 -2
  24. sglang/srt/layers/elementwise.py +411 -0
  25. sglang/srt/layers/linear.py +1 -1
  26. sglang/srt/layers/logits_processor.py +1 -0
  27. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  28. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  37. sglang/srt/layers/moe/router.py +342 -0
  38. sglang/srt/layers/parameter.py +10 -0
  39. sglang/srt/layers/quantization/__init__.py +90 -68
  40. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  41. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/fp8.py +174 -106
  68. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  69. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  70. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  71. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  72. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  73. sglang/srt/layers/rotary_embedding.py +5 -3
  74. sglang/srt/layers/sampler.py +29 -35
  75. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  76. sglang/srt/lora/backend/__init__.py +9 -12
  77. sglang/srt/managers/cache_controller.py +74 -8
  78. sglang/srt/managers/data_parallel_controller.py +1 -1
  79. sglang/srt/managers/image_processor.py +37 -631
  80. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  81. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  82. sglang/srt/managers/image_processors/llava.py +152 -0
  83. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  84. sglang/srt/managers/image_processors/mlama.py +60 -0
  85. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  86. sglang/srt/managers/io_struct.py +32 -15
  87. sglang/srt/managers/multi_modality_padding.py +134 -0
  88. sglang/srt/managers/schedule_batch.py +213 -118
  89. sglang/srt/managers/schedule_policy.py +40 -8
  90. sglang/srt/managers/scheduler.py +176 -683
  91. sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
  92. sglang/srt/managers/tokenizer_manager.py +6 -6
  93. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  94. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  95. sglang/srt/mem_cache/chunk_cache.py +12 -44
  96. sglang/srt/mem_cache/hiradix_cache.py +71 -34
  97. sglang/srt/mem_cache/memory_pool.py +81 -17
  98. sglang/srt/mem_cache/paged_allocator.py +283 -0
  99. sglang/srt/mem_cache/radix_cache.py +117 -36
  100. sglang/srt/model_executor/cuda_graph_runner.py +68 -20
  101. sglang/srt/model_executor/forward_batch_info.py +23 -10
  102. sglang/srt/model_executor/model_runner.py +63 -63
  103. sglang/srt/model_loader/loader.py +2 -1
  104. sglang/srt/model_loader/weight_utils.py +1 -1
  105. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  106. sglang/srt/models/deepseek_nextn.py +23 -3
  107. sglang/srt/models/deepseek_v2.py +200 -191
  108. sglang/srt/models/grok.py +374 -119
  109. sglang/srt/models/minicpmv.py +28 -89
  110. sglang/srt/models/mllama.py +1 -1
  111. sglang/srt/models/qwen2.py +0 -1
  112. sglang/srt/models/qwen2_5_vl.py +25 -50
  113. sglang/srt/models/qwen2_vl.py +33 -49
  114. sglang/srt/openai_api/adapter.py +59 -35
  115. sglang/srt/openai_api/protocol.py +8 -1
  116. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  117. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  118. sglang/srt/server_args.py +24 -16
  119. sglang/srt/speculative/eagle_worker.py +75 -39
  120. sglang/srt/utils.py +104 -9
  121. sglang/test/runners.py +104 -10
  122. sglang/test/test_block_fp8.py +106 -16
  123. sglang/test/test_custom_ops.py +88 -0
  124. sglang/test/test_utils.py +20 -4
  125. sglang/utils.py +0 -4
  126. sglang/version.py +1 -1
  127. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
  128. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
  129. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
  130. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
  131. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2127 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ # Copied and Adapted from:
16
+ # https://github.com/deepseek-ai/Janus
17
+
18
+
19
+ import collections
20
+ import math
21
+ import os
22
+ from dataclasses import field
23
+ from enum import Enum
24
+ from functools import partial
25
+ from itertools import repeat
26
+ from typing import (
27
+ Callable,
28
+ Final,
29
+ Iterable,
30
+ Literal,
31
+ Optional,
32
+ Sequence,
33
+ Set,
34
+ Tuple,
35
+ Type,
36
+ Union,
37
+ )
38
+
39
+ import torch
40
+ import torch.nn.functional as F
41
+ from einops import rearrange
42
+ from torch import Tensor, _assert, nn
43
+ from torch.nn.init import trunc_normal_
44
+ from transformers import AutoModel, PreTrainedModel
45
+
46
+ from sglang.srt.configs.janus_pro import *
47
+ from sglang.srt.layers.attention.vision import VisionAttention
48
+ from sglang.srt.layers.logits_processor import LogitsProcessor
49
+ from sglang.srt.layers.quantization import QuantizationConfig
50
+ from sglang.srt.managers.multi_modality_padding import (
51
+ MultiModalityDataPaddingPatternTokenPairs,
52
+ )
53
+ from sglang.srt.managers.schedule_batch import ImageInputs
54
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
55
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
56
+ from sglang.srt.models.llama import LlamaForCausalLM
57
+ from sglang.utils import logger
58
+
59
+ #################################################################################
60
+ # VQ Model Configs #
61
+ #################################################################################
62
+
63
+
64
+ # Copied from:
65
+ # https://github.com/deepseek-ai/Janus/tree/main/janus/models/vq_model.py
66
+ @dataclass
67
+ class ModelArgs:
68
+ codebook_size: int = 16384
69
+ codebook_embed_dim: int = 8
70
+ codebook_l2_norm: bool = True
71
+ codebook_show_usage: bool = True
72
+ commit_loss_beta: float = 0.25
73
+ entropy_loss_ratio: float = 0.0
74
+
75
+ encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
76
+ decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
77
+ z_channels: int = 256
78
+ dropout_p: float = 0.0
79
+
80
+
81
+ def named_apply(
82
+ fn: Callable,
83
+ module: nn.Module,
84
+ name="",
85
+ depth_first: bool = True,
86
+ include_root: bool = False,
87
+ ) -> nn.Module:
88
+ if not depth_first and include_root:
89
+ fn(module=module, name=name)
90
+ for child_name, child_module in module.named_children():
91
+ child_name = ".".join((name, child_name)) if name else child_name
92
+ named_apply(
93
+ fn=fn,
94
+ module=child_module,
95
+ name=child_name,
96
+ depth_first=depth_first,
97
+ include_root=True,
98
+ )
99
+ if depth_first and include_root:
100
+ fn(module=module, name=name)
101
+ return module
102
+
103
+
104
+ def VQ_16(**kwargs):
105
+ return VQModel(
106
+ ModelArgs(
107
+ encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs
108
+ )
109
+ )
110
+
111
+
112
+ VQ_models = {"VQ-16": VQ_16}
113
+
114
+ import collections.abc
115
+
116
+
117
+ # From PyTorch internals
118
+ def _ntuple(n):
119
+ def parse(x):
120
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
121
+ return tuple(x)
122
+ return tuple(repeat(x, n))
123
+
124
+ return parse
125
+
126
+
127
+ def _trunc_normal_(tensor, mean, std, a, b):
128
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
129
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
130
+ def norm_cdf(x):
131
+ # Computes standard normal cumulative distribution function
132
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
133
+
134
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
135
+ logger.warn(
136
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
137
+ "The distribution of values may be incorrect.",
138
+ stacklevel=2,
139
+ )
140
+
141
+ # Values are generated by using a truncated uniform distribution and
142
+ # then using the inverse CDF for the normal distribution.
143
+ # Get upper and lower cdf values
144
+ l = norm_cdf((a - mean) / std)
145
+ u = norm_cdf((b - mean) / std)
146
+
147
+ # Uniformly fill tensor with values from [l, u], then translate to
148
+ # [2l-1, 2u-1].
149
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
150
+
151
+ # Use inverse cdf transform for normal distribution to get truncated
152
+ # standard normal
153
+ if tensor.dtype in [torch.float16, torch.bfloat16]:
154
+ # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
155
+ og_dtype = tensor.dtype
156
+ tensor = tensor.to(torch.float32)
157
+ tensor.erfinv_()
158
+ tensor = tensor.to(og_dtype)
159
+ else:
160
+ tensor.erfinv_()
161
+
162
+ # Transform to proper mean, std
163
+ tensor.mul_(std * math.sqrt(2.0))
164
+ tensor.add_(mean)
165
+
166
+ # Clamp to ensure it's in the proper range
167
+ if tensor.dtype == torch.float16:
168
+ # The `clamp_` op is not (yet?) defined in float16+cpu
169
+ tensor = tensor.to(torch.float32)
170
+ tensor.clamp_(min=a, max=b)
171
+ else:
172
+ tensor.clamp_(min=a, max=b)
173
+
174
+
175
+ def trunc_normal_tf_(
176
+ tensor: torch.Tensor,
177
+ mean: float = 0.0,
178
+ std: float = 1.0,
179
+ a: float = -2.0,
180
+ b: float = 2.0,
181
+ ):
182
+ """Fills the input Tensor with values drawn from a truncated
183
+ normal distribution. The values are effectively drawn from the
184
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
185
+ with values outside :math:`[a, b]` redrawn until they are within
186
+ the bounds. The method used for generating the random values works
187
+ best when :math:`a \\leq \text{mean} \\leq b`.
188
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
189
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
190
+ and the result is subsquently scaled and shifted by the mean and std args.
191
+ Args:
192
+ tensor: an n-dimensional `torch.Tensor`
193
+ mean: the mean of the normal distribution
194
+ std: the standard deviation of the normal distribution
195
+ a: the minimum cutoff value
196
+ b: the maximum cutoff value
197
+ """
198
+ with torch.no_grad():
199
+ _trunc_normal_(tensor, 0, 1.0, a, b)
200
+ tensor.mul_(std).add_(mean)
201
+
202
+
203
+ to_2tuple = _ntuple(2)
204
+
205
+
206
+ class Format(str, Enum):
207
+ NCHW = "NCHW"
208
+ NHWC = "NHWC"
209
+ NCL = "NCL"
210
+ NLC = "NLC"
211
+
212
+
213
+ def nchw_to(x: torch.Tensor, fmt: Format):
214
+ if fmt == Format.NHWC:
215
+ x = x.permute(0, 2, 3, 1)
216
+ elif fmt == Format.NLC:
217
+ x = x.flatten(2).transpose(1, 2)
218
+ elif fmt == Format.NCL:
219
+ x = x.flatten(2)
220
+ return x
221
+
222
+
223
+ def resample_patch_embed(
224
+ patch_embed,
225
+ new_size: List[int],
226
+ interpolation: str = "bicubic",
227
+ antialias: bool = True,
228
+ verbose: bool = False,
229
+ ):
230
+ """Resample the weights of the patch embedding kernel to target resolution.
231
+ We resample the patch embedding kernel by approximately inverting the effect
232
+ of patch resizing.
233
+
234
+ Code based on:
235
+ https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py
236
+
237
+ With this resizing, we can for example load a B/8 filter into a B/16 model
238
+ and, on 2x larger input image, the result will match.
239
+
240
+ Args:
241
+ patch_embed: original parameter to be resized.
242
+ new_size (tuple(int, int): target shape (height, width)-only.
243
+ interpolation (str): interpolation for resize
244
+ antialias (bool): use anti-aliasing filter in resize
245
+ verbose (bool): log operation
246
+ Returns:
247
+ Resized patch embedding kernel.
248
+ """
249
+ import numpy as np
250
+
251
+ try:
252
+ from torch import vmap
253
+ except ImportError:
254
+ from functorch import vmap
255
+
256
+ assert len(patch_embed.shape) == 4, "Four dimensions expected"
257
+ assert len(new_size) == 2, "New shape should only be hw"
258
+ old_size = patch_embed.shape[-2:]
259
+ if tuple(old_size) == tuple(new_size):
260
+ return patch_embed
261
+
262
+ if verbose:
263
+ logger.info(
264
+ f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation."
265
+ )
266
+
267
+ def resize(x_np, _new_size):
268
+ x_tf = torch.Tensor(x_np)[None, None, ...]
269
+ x_upsampled = F.interpolate(
270
+ x_tf, size=_new_size, mode=interpolation, antialias=antialias
271
+ )[0, 0, ...].numpy()
272
+ return x_upsampled
273
+
274
+ def get_resize_mat(_old_size, _new_size):
275
+ mat = []
276
+ for i in range(np.prod(_old_size)):
277
+ basis_vec = np.zeros(_old_size)
278
+ basis_vec[np.unravel_index(i, _old_size)] = 1.0
279
+ mat.append(resize(basis_vec, _new_size).reshape(-1))
280
+ return np.stack(mat).T
281
+
282
+ resize_mat = get_resize_mat(old_size, new_size)
283
+ resize_mat_pinv = torch.tensor(
284
+ np.linalg.pinv(resize_mat.T), device=patch_embed.device
285
+ )
286
+
287
+ def resample_kernel(kernel):
288
+ resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
289
+ return resampled_kernel.reshape(new_size)
290
+
291
+ v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1)
292
+ orig_dtype = patch_embed.dtype
293
+ patch_embed = patch_embed.float()
294
+ patch_embed = v_resample_kernel(patch_embed)
295
+ patch_embed = patch_embed.to(orig_dtype)
296
+ return patch_embed
297
+
298
+
299
+ # Copied from:
300
+ # https://github.com/deepseek-ai/Janus/tree/main/janus/models/siglip_vit.py
301
+ class PatchEmbed(nn.Module):
302
+ """2D Image to Patch Embedding"""
303
+
304
+ output_fmt: Format
305
+ dynamic_img_pad: torch.jit.Final[bool]
306
+
307
+ def __init__(
308
+ self,
309
+ img_size: Optional[int] = 224,
310
+ patch_size: int = 16,
311
+ in_chans: int = 3,
312
+ embed_dim: int = 768,
313
+ norm_layer: Optional[Callable] = None,
314
+ flatten: bool = True,
315
+ output_fmt: Optional[str] = None,
316
+ bias: bool = True,
317
+ strict_img_size: bool = True,
318
+ dynamic_img_pad: bool = False,
319
+ ):
320
+ super().__init__()
321
+ self.patch_size = tuple(to_2tuple(patch_size))
322
+ self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size)
323
+
324
+ if output_fmt is not None:
325
+ self.flatten = False
326
+ self.output_fmt = Format(output_fmt)
327
+ else:
328
+ # flatten spatial dim and transpose to channels last, kept for bwd compat
329
+ self.flatten = flatten
330
+ self.output_fmt = Format.NCHW
331
+ self.strict_img_size = strict_img_size
332
+ self.dynamic_img_pad = dynamic_img_pad
333
+
334
+ self.proj = nn.Conv2d(
335
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
336
+ )
337
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
338
+
339
+ def _init_img_size(self, img_size: Union[int, Tuple[int, int]]):
340
+ assert self.patch_size
341
+ if img_size is None:
342
+ return None, None, None
343
+ img_size = to_2tuple(img_size)
344
+ grid_size = tuple([s // p for s, p in zip(img_size, self.patch_size)])
345
+ num_patches = grid_size[0] * grid_size[1]
346
+ return img_size, grid_size, num_patches
347
+
348
+ def set_input_size(
349
+ self,
350
+ img_size: Optional[Union[int, Tuple[int, int]]] = None,
351
+ patch_size: Optional[Union[int, Tuple[int, int]]] = None,
352
+ ):
353
+ new_patch_size = None
354
+ if patch_size is not None:
355
+ new_patch_size = to_2tuple(patch_size)
356
+ if new_patch_size is not None and new_patch_size != self.patch_size:
357
+ with torch.no_grad():
358
+ new_proj = nn.Conv2d(
359
+ self.proj.in_channels,
360
+ self.proj.out_channels,
361
+ kernel_size=new_patch_size,
362
+ stride=new_patch_size,
363
+ bias=self.proj.bias is not None,
364
+ )
365
+ new_proj.weight.copy_(
366
+ resample_patch_embed(self.proj.weight, new_patch_size, verbose=True)
367
+ )
368
+ if self.proj.bias is not None:
369
+ new_proj.bias.copy_(self.proj.bias)
370
+ self.proj = new_proj
371
+ self.patch_size = new_patch_size
372
+ img_size = img_size or self.img_size
373
+ if img_size != self.img_size or new_patch_size is not None:
374
+ self.img_size, self.grid_size, self.num_patches = self._init_img_size(
375
+ img_size
376
+ )
377
+
378
+ def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
379
+ if as_scalar:
380
+ return max(self.patch_size)
381
+ else:
382
+ return self.patch_size
383
+
384
+ def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
385
+ """Get grid (feature) size for given image size taking account of dynamic padding.
386
+ NOTE: must be torchscript compatible so using fixed tuple indexing
387
+ """
388
+ if self.dynamic_img_pad:
389
+ return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(
390
+ img_size[1] / self.patch_size[1]
391
+ )
392
+ else:
393
+ return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
394
+
395
+ def forward(self, x):
396
+ B, C, H, W = x.shape
397
+ if self.img_size is not None:
398
+ if self.strict_img_size:
399
+ _assert(
400
+ H == self.img_size[0],
401
+ f"Input height ({H}) doesn't match model ({self.img_size[0]}).",
402
+ )
403
+ _assert(
404
+ W == self.img_size[1],
405
+ f"Input width ({W}) doesn't match model ({self.img_size[1]}).",
406
+ )
407
+ elif not self.dynamic_img_pad:
408
+ _assert(
409
+ H % self.patch_size[0] == 0,
410
+ f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]}).",
411
+ )
412
+ _assert(
413
+ W % self.patch_size[1] == 0,
414
+ f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]}).",
415
+ )
416
+ if self.dynamic_img_pad:
417
+ pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
418
+ pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
419
+ x = F.pad(x, (0, pad_w, 0, pad_h))
420
+ x = self.proj(x)
421
+ if self.flatten:
422
+ x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
423
+ elif self.output_fmt != Format.NCHW:
424
+ x = nchw_to(x, self.output_fmt)
425
+ x = self.norm(x)
426
+ return x
427
+
428
+
429
+ class Mlp(nn.Module):
430
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks
431
+
432
+ NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected.
433
+ """
434
+
435
+ def __init__(
436
+ self,
437
+ in_features,
438
+ hidden_features=None,
439
+ out_features=None,
440
+ act_layer=nn.GELU,
441
+ norm_layer=None,
442
+ bias=True,
443
+ drop=0.0,
444
+ use_conv=False,
445
+ ):
446
+ super().__init__()
447
+ out_features = out_features or in_features
448
+ hidden_features = hidden_features or in_features
449
+ bias = to_2tuple(bias)
450
+ drop_probs = to_2tuple(drop)
451
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
452
+
453
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
454
+ self.act = act_layer()
455
+ self.drop1 = nn.Dropout(drop_probs[0])
456
+ self.norm = (
457
+ norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
458
+ )
459
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
460
+ self.drop2 = nn.Dropout(drop_probs[1])
461
+
462
+ def forward(self, x):
463
+ x = self.fc1(x)
464
+ x = self.act(x)
465
+ x = self.drop1(x)
466
+ x = self.norm(x)
467
+ x = self.fc2(x)
468
+ x = self.drop2(x)
469
+ return x
470
+
471
+
472
+ def drop_path(
473
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
474
+ ):
475
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
476
+
477
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
478
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
479
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
480
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
481
+ 'survival rate' as the argument.
482
+
483
+ """
484
+ if drop_prob == 0.0 or not training:
485
+ return x
486
+ keep_prob = 1 - drop_prob
487
+ shape = (x.shape[0],) + (1,) * (
488
+ x.ndim - 1
489
+ ) # work with diff dim tensors, not just 2D ConvNets
490
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
491
+ if keep_prob > 0.0 and scale_by_keep:
492
+ random_tensor.div_(keep_prob)
493
+ return x * random_tensor
494
+
495
+
496
+ class DropPath(nn.Module):
497
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
498
+
499
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
500
+ super(DropPath, self).__init__()
501
+ self.drop_prob = drop_prob
502
+ self.scale_by_keep = scale_by_keep
503
+
504
+ def forward(self, x):
505
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
506
+
507
+ def extra_repr(self):
508
+ return f"drop_prob={round(self.drop_prob, 3):0.3f}"
509
+
510
+
511
+ class VisionTransformerBlock(nn.Module):
512
+ def __init__(
513
+ self,
514
+ dim: int,
515
+ num_heads: int,
516
+ mlp_ratio: float = 4.0,
517
+ qkv_bias: bool = False,
518
+ qk_norm: bool = False,
519
+ proj_drop: float = 0.0,
520
+ attn_drop: float = 0.0,
521
+ init_values: Optional[float] = None,
522
+ drop_path: float = 0.0,
523
+ act_layer: nn.Module = nn.GELU,
524
+ norm_layer: nn.Module = nn.LayerNorm,
525
+ mlp_layer: nn.Module = Mlp,
526
+ ) -> None:
527
+ super().__init__()
528
+ self.norm1 = norm_layer(dim)
529
+ self.attn = VisionAttention(
530
+ embed_dim=dim,
531
+ num_heads=num_heads,
532
+ projection_size=dim,
533
+ use_qkv_parallel=True,
534
+ use_context_forward=False,
535
+ softmax_in_single_precision=False,
536
+ dropout=attn_drop,
537
+ )
538
+
539
+ self.ls1 = (
540
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
541
+ )
542
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
543
+
544
+ self.norm2 = norm_layer(dim)
545
+ self.mlp = mlp_layer(
546
+ in_features=dim,
547
+ hidden_features=int(dim * mlp_ratio),
548
+ act_layer=act_layer,
549
+ drop=proj_drop,
550
+ )
551
+ self.ls2 = (
552
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
553
+ )
554
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
555
+
556
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
557
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
558
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
559
+ return x
560
+
561
+
562
+ LayerType = Union[str, Callable, Type[torch.nn.Module]]
563
+
564
+
565
+ class PatchDropout(nn.Module):
566
+ """
567
+ https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220
568
+ """
569
+
570
+ return_indices: torch.jit.Final[bool]
571
+
572
+ def __init__(
573
+ self,
574
+ prob: float = 0.5,
575
+ num_prefix_tokens: int = 1,
576
+ ordered: bool = False,
577
+ return_indices: bool = False,
578
+ ):
579
+ super().__init__()
580
+ assert 0 <= prob < 1.0
581
+ self.prob = prob
582
+ self.num_prefix_tokens = (
583
+ num_prefix_tokens # exclude CLS token (or other prefix tokens)
584
+ )
585
+ self.ordered = ordered
586
+ self.return_indices = return_indices
587
+
588
+ def forward(
589
+ self, x
590
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
591
+ if not self.training or self.prob == 0.0:
592
+ if self.return_indices:
593
+ return x, None
594
+ return x
595
+
596
+ if self.num_prefix_tokens:
597
+ prefix_tokens, x = (
598
+ x[:, : self.num_prefix_tokens],
599
+ x[:, self.num_prefix_tokens :],
600
+ )
601
+ else:
602
+ prefix_tokens = None
603
+
604
+ B = x.shape[0]
605
+ L = x.shape[1]
606
+ num_keep = max(1, int(L * (1.0 - self.prob)))
607
+ keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[
608
+ :, :num_keep
609
+ ]
610
+ if self.ordered:
611
+ # NOTE does not need to maintain patch order in typical transformer use,
612
+ # but possibly useful for debug / visualization
613
+ keep_indices = keep_indices.sort(dim=-1)[0]
614
+ x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:]))
615
+
616
+ if prefix_tokens is not None:
617
+ x = torch.cat((prefix_tokens, x), dim=1)
618
+
619
+ if self.return_indices:
620
+ return x, keep_indices
621
+ return x
622
+
623
+
624
+ def resample_abs_pos_embed(
625
+ posemb: torch.Tensor,
626
+ new_size: List[int],
627
+ old_size: Optional[List[int]] = None,
628
+ num_prefix_tokens: int = 1,
629
+ interpolation: str = "bicubic",
630
+ antialias: bool = True,
631
+ verbose: bool = False,
632
+ ):
633
+ # sort out sizes, assume square if old size not provided
634
+ num_pos_tokens = posemb.shape[1]
635
+ num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens
636
+ if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
637
+ return posemb
638
+
639
+ if old_size is None:
640
+ hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
641
+ old_size = hw, hw
642
+
643
+ if num_prefix_tokens:
644
+ posemb_prefix, posemb = (
645
+ posemb[:, :num_prefix_tokens],
646
+ posemb[:, num_prefix_tokens:],
647
+ )
648
+ else:
649
+ posemb_prefix, posemb = None, posemb
650
+
651
+ # do the interpolation
652
+ embed_dim = posemb.shape[-1]
653
+ orig_dtype = posemb.dtype
654
+ posemb = posemb.float() # interpolate needs float32
655
+ posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
656
+ posemb = F.interpolate(
657
+ posemb, size=new_size, mode=interpolation, antialias=antialias
658
+ )
659
+ posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
660
+ posemb = posemb.to(orig_dtype)
661
+
662
+ # add back extra (class, etc) prefix tokens
663
+ if posemb_prefix is not None:
664
+ posemb = torch.cat([posemb_prefix, posemb], dim=1)
665
+
666
+ if not torch.jit.is_scripting() and verbose:
667
+ logger.info(f"Resized position embedding: {old_size} to {new_size}.")
668
+
669
+ return posemb
670
+
671
+
672
+ def init_weights(self):
673
+ if self.pos_embed is not None:
674
+ trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
675
+ trunc_normal_(self.latent, std=self.latent_dim**-0.5)
676
+
677
+
678
+ def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
679
+ """ViT weight initialization, original timm impl (for reproducibility)"""
680
+ if isinstance(module, nn.Linear):
681
+ trunc_normal_(module.weight, std=0.02)
682
+ if module.bias is not None:
683
+ nn.init.zeros_(module.bias)
684
+ elif hasattr(module, "init_weights"):
685
+ module.init_weights()
686
+
687
+
688
+ class VisionTransformer(nn.Module):
689
+ """Vision Transformer
690
+
691
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
692
+ - https://arxiv.org/abs/2010.11929
693
+ """
694
+
695
+ dynamic_img_size: Final[bool]
696
+
697
+ def __init__(
698
+ self,
699
+ img_size: Union[int, Tuple[int, int]] = 224,
700
+ patch_size: Union[int, Tuple[int, int]] = 16,
701
+ in_chans: int = 3,
702
+ num_classes: int = 1000,
703
+ global_pool: Literal["", "avg", "token", "map"] = "token",
704
+ embed_dim: int = 768,
705
+ depth: int = 12,
706
+ num_heads: int = 12,
707
+ mlp_ratio: float = 4.0,
708
+ qkv_bias: bool = True,
709
+ qk_norm: bool = False,
710
+ init_values: Optional[float] = None,
711
+ class_token: bool = True,
712
+ no_embed_class: bool = False,
713
+ reg_tokens: int = 0,
714
+ pre_norm: bool = False,
715
+ fc_norm: Optional[bool] = None,
716
+ dynamic_img_size: bool = False,
717
+ dynamic_img_pad: bool = False,
718
+ drop_rate: float = 0.0,
719
+ pos_drop_rate: float = 0.0,
720
+ patch_drop_rate: float = 0.0,
721
+ proj_drop_rate: float = 0.0,
722
+ attn_drop_rate: float = 0.0,
723
+ drop_path_rate: float = 0.0,
724
+ weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
725
+ embed_layer: Callable = PatchEmbed,
726
+ _norm_layer: Optional[LayerType] = None,
727
+ _act_layer: Optional[LayerType] = None,
728
+ block_fn: Type[nn.Module] = VisionTransformerBlock,
729
+ mlp_layer: Type[nn.Module] = Mlp,
730
+ ignore_head: bool = False,
731
+ ) -> None:
732
+ """
733
+ Args:
734
+ img_size: Input image size.
735
+ patch_size: Patch size.
736
+ in_chans: Number of image input channels.
737
+ num_classes: Mumber of classes for classification head.
738
+ global_pool: Type of global pooling for final sequence (default: 'token').
739
+ embed_dim: Transformer embedding dimension.
740
+ depth: Depth of transformer.
741
+ num_heads: Number of attention heads.
742
+ mlp_ratio: Ratio of mlp hidden dim to embedding dim.
743
+ qkv_bias: Enable bias for qkv projections if True.
744
+ init_values: Layer-scale init values (layer-scale enabled if not None).
745
+ class_token: Use class token.
746
+ no_embed_class: Don't include position embeddings for class (or reg) tokens.
747
+ reg_tokens: Number of register tokens.
748
+ fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
749
+ drop_rate: Head dropout rate.
750
+ pos_drop_rate: Position embedding dropout rate.
751
+ attn_drop_rate: Attention dropout rate.
752
+ drop_path_rate: Stochastic depth rate.
753
+ weight_init: Weight initialization scheme.
754
+ embed_layer: Patch embedding layer.
755
+ _norm_layer: Normalization layer.
756
+ _act_layer: MLP activation layer.
757
+ block_fn: Transformer block layer.
758
+ """
759
+ super().__init__()
760
+ assert global_pool in ("", "avg", "token", "map")
761
+ assert class_token or global_pool != "token"
762
+ use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
763
+ # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
764
+ # act_layer = get_act_layer(act_layer) or nn.GELU
765
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
766
+ act_layer = nn.GELU
767
+
768
+ self.num_classes = num_classes
769
+ self.global_pool = global_pool
770
+ self.num_features = self.embed_dim = (
771
+ embed_dim # num_features for consistency with other models
772
+ )
773
+ self.num_prefix_tokens = 1 if class_token else 0
774
+ self.num_prefix_tokens += reg_tokens
775
+ self.num_reg_tokens = reg_tokens
776
+ self.has_class_token = class_token
777
+ self.no_embed_class = (
778
+ no_embed_class # don't embed prefix positions (includes reg)
779
+ )
780
+ self.dynamic_img_size = dynamic_img_size
781
+ self.grad_checkpointing = False
782
+ self.ignore_head = ignore_head
783
+
784
+ embed_args = {}
785
+ if dynamic_img_size:
786
+ # flatten deferred until after pos embed
787
+ embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
788
+ self.patch_embed = embed_layer(
789
+ img_size=img_size,
790
+ patch_size=patch_size,
791
+ in_chans=in_chans,
792
+ embed_dim=embed_dim,
793
+ bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
794
+ dynamic_img_pad=dynamic_img_pad,
795
+ **embed_args,
796
+ )
797
+ num_patches = self.patch_embed.num_patches
798
+
799
+ self.cls_token = (
800
+ nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
801
+ )
802
+ self.reg_token = (
803
+ nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
804
+ )
805
+ embed_len = (
806
+ num_patches if no_embed_class else num_patches + self.num_prefix_tokens
807
+ )
808
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
809
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
810
+ if patch_drop_rate > 0:
811
+ self.patch_drop = PatchDropout(
812
+ patch_drop_rate,
813
+ num_prefix_tokens=self.num_prefix_tokens,
814
+ )
815
+ else:
816
+ self.patch_drop = nn.Identity()
817
+ self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
818
+
819
+ dpr = [
820
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
821
+ ] # stochastic depth decay rule
822
+ self.blocks = nn.Sequential(
823
+ *[
824
+ block_fn(
825
+ dim=embed_dim,
826
+ num_heads=num_heads,
827
+ mlp_ratio=mlp_ratio,
828
+ qkv_bias=qkv_bias,
829
+ qk_norm=qk_norm,
830
+ init_values=init_values,
831
+ proj_drop=proj_drop_rate,
832
+ attn_drop=attn_drop_rate,
833
+ drop_path=dpr[i],
834
+ norm_layer=norm_layer,
835
+ act_layer=act_layer,
836
+ mlp_layer=mlp_layer,
837
+ )
838
+ for i in range(depth)
839
+ ]
840
+ )
841
+ self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
842
+
843
+ # Classifier Head
844
+ if global_pool == "map":
845
+ AttentionPoolLatent.init_weights = init_weights
846
+ self.attn_pool = AttentionPoolLatent(
847
+ self.embed_dim,
848
+ num_heads=num_heads,
849
+ mlp_ratio=mlp_ratio,
850
+ norm_layer=norm_layer,
851
+ )
852
+ else:
853
+ self.attn_pool = None
854
+ self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
855
+ self.head_drop = nn.Dropout(drop_rate)
856
+ self.head = (
857
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
858
+ )
859
+
860
+ if weight_init != "skip":
861
+ self.init_weights(weight_init)
862
+
863
+ def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
864
+ assert mode in ("jax", "jax_nlhb", "moco", "")
865
+ # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
866
+ trunc_normal_(self.pos_embed, std=0.02)
867
+ if self.cls_token is not None:
868
+ nn.init.normal_(self.cls_token, std=1e-6)
869
+ named_apply(init_weights_vit_timm, self)
870
+
871
+ @torch.jit.ignore
872
+ def no_weight_decay(self) -> Set:
873
+ return {"pos_embed", "cls_token", "dist_token"}
874
+
875
+ @torch.jit.ignore
876
+ def group_matcher(self, coarse: bool = False) -> Dict:
877
+ return dict(
878
+ stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
879
+ blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
880
+ )
881
+
882
+ @torch.jit.ignore
883
+ def get_classifier(self) -> nn.Module:
884
+ return self.head
885
+
886
+ def reset_classifier(self, num_classes: int, global_pool=None) -> None:
887
+ self.num_classes = num_classes
888
+ if global_pool is not None:
889
+ assert global_pool in ("", "avg", "token", "map")
890
+ if global_pool == "map" and self.attn_pool is None:
891
+ assert (
892
+ False
893
+ ), "Cannot currently add attention pooling in reset_classifier()."
894
+ elif global_pool != "map " and self.attn_pool is not None:
895
+ self.attn_pool = None # remove attention pooling
896
+ self.global_pool = global_pool
897
+ self.head = (
898
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
899
+ )
900
+
901
+ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
902
+ if self.dynamic_img_size:
903
+ B, H, W, C = x.shape
904
+ pos_embed = resample_abs_pos_embed(
905
+ self.pos_embed,
906
+ [H, W],
907
+ num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
908
+ )
909
+ x = x.view(B, -1, C)
910
+ else:
911
+ pos_embed = self.pos_embed
912
+
913
+ to_cat = []
914
+ if self.cls_token is not None:
915
+ to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
916
+ if self.reg_token is not None:
917
+ to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
918
+
919
+ if self.no_embed_class:
920
+ # deit-3, updated JAX (big vision)
921
+ # position embedding does not overlap with class token, add then concat
922
+ x = x + pos_embed
923
+ if to_cat:
924
+ x = torch.cat(to_cat + [x], dim=1)
925
+ else:
926
+ # original timm, JAX, and deit vit impl
927
+ # pos_embed has entry for class token, concat then add
928
+ if to_cat:
929
+ x = torch.cat(to_cat + [x], dim=1)
930
+ x = x + pos_embed
931
+
932
+ return self.pos_drop(x)
933
+
934
+ def _intermediate_layers(
935
+ self,
936
+ x: torch.Tensor,
937
+ n: Union[int, Sequence] = 1,
938
+ ) -> List[torch.Tensor]:
939
+ outputs, num_blocks = [], len(self.blocks)
940
+ take_indices = set(
941
+ range(num_blocks - n, num_blocks) if isinstance(n, int) else n
942
+ )
943
+
944
+ # forward pass
945
+ x = self.patch_embed(x)
946
+ x = self._pos_embed(x)
947
+ x = self.patch_drop(x)
948
+ x = self.norm_pre(x)
949
+ for i, blk in enumerate(self.blocks):
950
+ x = blk(x)
951
+ if i in take_indices:
952
+ outputs.append(x)
953
+
954
+ return outputs
955
+
956
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
957
+ x = self.patch_embed(x)
958
+ x = self._pos_embed(x)
959
+ x = self.patch_drop(x)
960
+ x = self.norm_pre(x)
961
+ x = self.blocks(x)
962
+ x = self.norm(x)
963
+ return x
964
+
965
+ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
966
+ if self.attn_pool is not None:
967
+ x = self.attn_pool(x)
968
+ elif self.global_pool == "avg":
969
+ x = x[:, self.num_prefix_tokens :].mean(dim=1)
970
+ elif self.global_pool:
971
+ x = x[:, 0] # class token
972
+ x = self.fc_norm(x)
973
+ x = self.head_drop(x)
974
+ return x if pre_logits else self.head(x)
975
+
976
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
977
+ x = self.forward_features(x)
978
+ if not self.ignore_head:
979
+ x = self.forward_head(x)
980
+ return x
981
+
982
+
983
+ def model_name_to_cls(cls_name):
984
+ if "MlpProjector" in cls_name:
985
+ cls = MlpProjector
986
+
987
+ elif "CLIPVisionTower" in cls_name:
988
+ cls = CLIPVisionTower
989
+
990
+ elif "VQ" in cls_name:
991
+
992
+ cls = VQ_models[cls_name]
993
+ elif "vision_head" in cls_name:
994
+ cls = vision_head
995
+ else:
996
+ raise ValueError(f"class_name {cls_name} is invalid.")
997
+
998
+ return cls
999
+
1000
+
1001
+ class vision_head(torch.nn.Module):
1002
+ def __init__(self, params):
1003
+ super().__init__()
1004
+ self.output_mlp_projector = torch.nn.Linear(
1005
+ params["n_embed"], params["image_token_embed"]
1006
+ )
1007
+ self.vision_activation = torch.nn.GELU()
1008
+ self.vision_head = torch.nn.Linear(
1009
+ params["image_token_embed"], params["image_token_size"]
1010
+ )
1011
+
1012
+ def forward(self, x):
1013
+ x = self.output_mlp_projector(x)
1014
+ x = self.vision_activation(x)
1015
+ x = self.vision_head(x)
1016
+ return x
1017
+
1018
+
1019
+ SigLIP_MODEL_CONFIG = {
1020
+ "siglip_so400m_patch14_384": {
1021
+ "image_size": 336,
1022
+ "patch_size": 14,
1023
+ "width": 1152,
1024
+ "layers": 27,
1025
+ "heads": 16,
1026
+ "mlp_ratio": 3.7362,
1027
+ "global_pool": "map",
1028
+ "use_checkpoint": False,
1029
+ },
1030
+ "siglip_so400m_patch14_224": {
1031
+ "image_size": 224,
1032
+ "patch_size": 14,
1033
+ "width": 1152,
1034
+ "layers": 27,
1035
+ "heads": 16,
1036
+ "mlp_ratio": 3.7362,
1037
+ "global_pool": "map",
1038
+ "use_checkpoint": False,
1039
+ },
1040
+ "siglip_large_patch16_384": {
1041
+ "image_size": 384,
1042
+ "patch_size": 16,
1043
+ "width": 1024,
1044
+ "layers": 24,
1045
+ "heads": 16,
1046
+ "mlp_ratio": 4,
1047
+ "global_pool": "map",
1048
+ "use_checkpoint": False,
1049
+ },
1050
+ }
1051
+
1052
+
1053
+ def create_siglip_vit(
1054
+ model_name: str = "siglip_so400m_patch14_384",
1055
+ image_size: int = 384,
1056
+ select_layer: int = -1,
1057
+ ckpt_path: str = "",
1058
+ **kwargs,
1059
+ ):
1060
+ assert (
1061
+ model_name in SigLIP_MODEL_CONFIG.keys()
1062
+ ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
1063
+
1064
+ vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
1065
+
1066
+ if select_layer <= 0:
1067
+ layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
1068
+ else:
1069
+ layers = min(vision_cfg.layers, select_layer)
1070
+
1071
+ model = VisionTransformer(
1072
+ img_size=image_size,
1073
+ patch_size=vision_cfg.patch_size,
1074
+ embed_dim=vision_cfg.width,
1075
+ depth=layers,
1076
+ num_heads=vision_cfg.heads,
1077
+ mlp_ratio=vision_cfg.mlp_ratio,
1078
+ class_token=vision_cfg.class_token,
1079
+ global_pool=vision_cfg.global_pool,
1080
+ ignore_head=kwargs.get("ignore_head", True),
1081
+ weight_init=kwargs.get("weight_init", "skip"),
1082
+ num_classes=0,
1083
+ )
1084
+
1085
+ if ckpt_path:
1086
+ state_dict = torch.load(ckpt_path, map_location="cpu")
1087
+
1088
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
1089
+ print(
1090
+ f"SigLIP-ViT restores from {ckpt_path},\n"
1091
+ f"\tincompatible_keys:', {incompatible_keys}."
1092
+ )
1093
+
1094
+ return model
1095
+
1096
+
1097
+ class Normalize(torch.nn.Module):
1098
+ """Normalize a tensor image with mean and standard deviation.
1099
+ This transform does not support PIL Image.
1100
+ Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
1101
+ channels, this transform will normalize each channel of the input
1102
+ ``torch.*Tensor`` i.e.,
1103
+ ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
1104
+
1105
+ .. note::
1106
+ This transform acts out of place, i.e., it does not mutate the input tensor.
1107
+
1108
+ Args:
1109
+ mean (sequence): Sequence of means for each channel.
1110
+ std (sequence): Sequence of standard deviations for each channel.
1111
+ inplace(bool,optional): Bool to make this operation in-place.
1112
+
1113
+ """
1114
+
1115
+ def __init__(self, mean, std, inplace=False):
1116
+ super().__init__()
1117
+ # _log_api_usage_once(self)
1118
+ self.mean = mean
1119
+ self.std = std
1120
+ self.inplace = inplace
1121
+
1122
+ def forward(self, tensor: Tensor) -> Tensor:
1123
+ """
1124
+ Args:
1125
+ tensor (Tensor): Tensor image to be normalized.
1126
+
1127
+ Returns:
1128
+ Tensor: Normalized Tensor image.
1129
+ """
1130
+ return F.normalize(tensor, self.mean, self.std, self.inplace)
1131
+
1132
+ def __repr__(self) -> str:
1133
+ return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
1134
+
1135
+
1136
+ class CLIPVisionTower(nn.Module):
1137
+ def __init__(
1138
+ self,
1139
+ model_name: str = "siglip_large_patch16_384",
1140
+ image_size: Union[Tuple[int, int], int] = 336,
1141
+ select_feature: str = "patch",
1142
+ select_layer: int = -2,
1143
+ select_layers: list = None,
1144
+ ckpt_path: str = "",
1145
+ pixel_mean: Optional[List[float]] = None,
1146
+ pixel_std: Optional[List[float]] = None,
1147
+ **kwargs,
1148
+ ):
1149
+ super().__init__()
1150
+
1151
+ self.model_name = model_name
1152
+ self.select_feature = select_feature
1153
+ self.select_layer = select_layer
1154
+ self.select_layers = select_layers
1155
+
1156
+ vision_tower_params = {
1157
+ "model_name": model_name,
1158
+ "image_size": image_size,
1159
+ "ckpt_path": ckpt_path,
1160
+ "select_layer": select_layer,
1161
+ }
1162
+ vision_tower_params.update(kwargs)
1163
+ self.vision_tower, self.forward_kwargs = self.build_vision_tower(
1164
+ vision_tower_params
1165
+ )
1166
+
1167
+ if pixel_mean is not None and pixel_std is not None:
1168
+ image_norm = Normalize(mean=pixel_mean, std=pixel_std)
1169
+ else:
1170
+ image_norm = None
1171
+
1172
+ self.image_norm = image_norm
1173
+
1174
+ @property
1175
+ def device(self) -> torch.device:
1176
+ return next(self.vision_tower.parameters()).device
1177
+
1178
+ @property
1179
+ def dtype(self):
1180
+ return next(self.vision_tower.parameters()).dtype
1181
+
1182
+ def build_vision_tower(self, vision_tower_params):
1183
+ if self.model_name.startswith("siglip"):
1184
+ self.select_feature = "same"
1185
+ vision_tower = create_siglip_vit(**vision_tower_params)
1186
+ forward_kwargs = dict()
1187
+
1188
+ elif self.model_name.startswith("sam"):
1189
+ # vision_tower = create_sam_vit(**vision_tower_params)
1190
+ forward_kwargs = dict()
1191
+
1192
+ else: # huggingface
1193
+ from transformers import CLIPVisionModel
1194
+
1195
+ vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params)
1196
+ forward_kwargs = dict(output_hidden_states=True)
1197
+
1198
+ return vision_tower, forward_kwargs
1199
+
1200
+ def feature_select(self, image_forward_outs):
1201
+ if isinstance(image_forward_outs, torch.Tensor):
1202
+ # the output has been the self.select_layer"s features
1203
+ image_features = image_forward_outs
1204
+ else:
1205
+ image_features = image_forward_outs.hidden_states[self.select_layer]
1206
+
1207
+ if self.select_feature == "patch":
1208
+ # if the output has cls_token
1209
+ image_features = image_features[:, 1:]
1210
+ elif self.select_feature == "cls_patch":
1211
+ image_features = image_features
1212
+ elif self.select_feature == "same":
1213
+ image_features = image_features
1214
+
1215
+ else:
1216
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
1217
+ return image_features
1218
+
1219
+ def forward(self, images):
1220
+ """
1221
+
1222
+ Args:
1223
+ images (torch.Tensor): [b, 3, H, W]
1224
+
1225
+ Returns:
1226
+ image_features (torch.Tensor): [b, n_patch, d]
1227
+ """
1228
+
1229
+ if self.image_norm is not None:
1230
+ images = self.image_norm(images)
1231
+
1232
+ image_forward_outs = self.vision_tower(images, **self.forward_kwargs)
1233
+ image_features = self.feature_select(image_forward_outs)
1234
+ return image_features
1235
+
1236
+
1237
+ class MlpProjector(nn.Module):
1238
+ def __init__(self, cfg):
1239
+ super().__init__()
1240
+
1241
+ self.cfg = cfg
1242
+
1243
+ if cfg["projector_type"] == "identity":
1244
+ modules = nn.Identity()
1245
+
1246
+ elif cfg["projector_type"] == "linear":
1247
+ modules = nn.Linear(cfg["input_dim"], cfg["n_embed"])
1248
+
1249
+ elif cfg["projector_type"] == "mlp_gelu":
1250
+ mlp_depth = cfg.get("depth", 1)
1251
+ modules = [nn.Linear(cfg["input_dim"], cfg["n_embed"])]
1252
+ for _ in range(1, mlp_depth):
1253
+ modules.append(nn.GELU())
1254
+ modules.append(nn.Linear(cfg["n_embed"], cfg["n_embed"]))
1255
+ modules = nn.Sequential(*modules)
1256
+
1257
+ elif cfg["projector_type"] == "low_high_hybrid_split_mlp_gelu":
1258
+ mlp_depth = cfg.get("depth", 1)
1259
+ self.high_up_proj = nn.Linear(cfg["input_dim"], cfg["n_embed"] // 2)
1260
+ self.low_up_proj = nn.Linear(cfg["input_dim"], cfg["n_embed"] // 2)
1261
+
1262
+ modules = []
1263
+ for _ in range(1, mlp_depth):
1264
+ modules.append(nn.GELU())
1265
+ modules.append(nn.Linear(cfg["n_embed"], cfg["n_embed"]))
1266
+ modules = nn.Sequential(*modules)
1267
+
1268
+ else:
1269
+ raise ValueError(f"Unknown projector type: {cfg['projector_type']}")
1270
+
1271
+ self.layers = modules
1272
+
1273
+ def forward(
1274
+ self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]
1275
+ ):
1276
+ """
1277
+
1278
+ Args:
1279
+ x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor,
1280
+ then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x);
1281
+ otherwise it is the feature from the single vision encoder.
1282
+
1283
+ Returns:
1284
+ x (torch.Tensor): [b, s, c]
1285
+ """
1286
+
1287
+ if isinstance(x_or_tuple, tuple):
1288
+ # self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
1289
+ high_x, low_x = x_or_tuple
1290
+ high_x = self.high_up_proj(high_x)
1291
+ low_x = self.low_up_proj(low_x)
1292
+ x = torch.concat([high_x, low_x], dim=-1)
1293
+ else:
1294
+ x = x_or_tuple
1295
+
1296
+ return self.layers(x)
1297
+
1298
+
1299
+ class LayerScale(nn.Module):
1300
+ def __init__(
1301
+ self,
1302
+ dim: int,
1303
+ init_values: float = 1e-5,
1304
+ inplace: bool = False,
1305
+ ) -> None:
1306
+ super().__init__()
1307
+ self.inplace = inplace
1308
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
1309
+
1310
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1311
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
1312
+
1313
+
1314
+ # use torch.scaled_dot_product_attention where possible
1315
+ _HAS_FUSED_ATTN = hasattr(torch.nn.functional, "scaled_dot_product_attention")
1316
+ if "TIMM_FUSED_ATTN" in os.environ:
1317
+ _USE_FUSED_ATTN = int(os.environ["TIMM_FUSED_ATTN"])
1318
+ else:
1319
+ _USE_FUSED_ATTN = (
1320
+ 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use)
1321
+ )
1322
+
1323
+ # Set to True if exporting a model with Same padding via ONNX
1324
+ _EXPORTABLE = False
1325
+
1326
+
1327
+ def use_fused_attn(experimental: bool = False) -> bool:
1328
+ # NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0
1329
+ if not _HAS_FUSED_ATTN or _EXPORTABLE:
1330
+ return False
1331
+ if experimental:
1332
+ return _USE_FUSED_ATTN > 1
1333
+ return _USE_FUSED_ATTN > 0
1334
+
1335
+
1336
+ class AttentionPoolLatent(nn.Module):
1337
+ """Attention pooling w/ latent query"""
1338
+
1339
+ fused_attn: torch.jit.Final[bool]
1340
+
1341
+ def __init__(
1342
+ self,
1343
+ in_features: int,
1344
+ out_features: int = None,
1345
+ embed_dim: int = None,
1346
+ num_heads: int = 8,
1347
+ feat_size: Optional[int] = None,
1348
+ mlp_ratio: float = 4.0,
1349
+ qkv_bias: bool = True,
1350
+ qk_norm: bool = False,
1351
+ latent_len: int = 1,
1352
+ latent_dim: int = None,
1353
+ pos_embed: str = "",
1354
+ pool_type: str = "token",
1355
+ norm_layer: Optional[nn.Module] = None,
1356
+ drop: float = 0.0,
1357
+ ):
1358
+ super().__init__()
1359
+ embed_dim = embed_dim or in_features
1360
+ out_features = out_features or in_features
1361
+ assert embed_dim % num_heads == 0
1362
+ self.num_heads = num_heads
1363
+ self.head_dim = embed_dim // num_heads
1364
+ self.feat_size = feat_size
1365
+ self.scale = self.head_dim**-0.5
1366
+ self.pool = pool_type
1367
+ self.fused_attn = use_fused_attn()
1368
+
1369
+ if pos_embed == "abs":
1370
+ assert feat_size is not None
1371
+ self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features))
1372
+ else:
1373
+ self.pos_embed = None
1374
+
1375
+ self.latent_dim = latent_dim or embed_dim
1376
+ self.latent_len = latent_len
1377
+ self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim))
1378
+
1379
+ self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
1380
+ self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
1381
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
1382
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
1383
+ self.proj = nn.Linear(embed_dim, embed_dim)
1384
+ self.proj_drop = nn.Dropout(drop)
1385
+
1386
+ self.norm = (
1387
+ norm_layer(out_features) if norm_layer is not None else nn.Identity()
1388
+ )
1389
+ self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio))
1390
+
1391
+ self.init_weights()
1392
+
1393
+ def init_weights(self):
1394
+ if self.pos_embed is not None:
1395
+ trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
1396
+ trunc_normal_tf_(self.latent, std=self.latent_dim**-0.5)
1397
+
1398
+ def forward(self, x):
1399
+ B, N, C = x.shape
1400
+
1401
+ if self.pos_embed is not None:
1402
+ # FIXME interpolate
1403
+ x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
1404
+
1405
+ q_latent = self.latent.expand(B, -1, -1)
1406
+ q = (
1407
+ self.q(q_latent)
1408
+ .reshape(B, self.latent_len, self.num_heads, self.head_dim)
1409
+ .transpose(1, 2)
1410
+ )
1411
+
1412
+ kv = (
1413
+ self.kv(x)
1414
+ .reshape(B, N, 2, self.num_heads, self.head_dim)
1415
+ .permute(2, 0, 3, 1, 4)
1416
+ )
1417
+ k, v = kv.unbind(0)
1418
+
1419
+ q, k = self.q_norm(q), self.k_norm(k)
1420
+
1421
+ if self.fused_attn:
1422
+ x = F.scaled_dot_product_attention(q, k, v)
1423
+ else:
1424
+ q = q * self.scale
1425
+ attn = q @ k.transpose(-2, -1)
1426
+ attn = attn.softmax(dim=-1)
1427
+ x = attn @ v
1428
+ x = x.transpose(1, 2).reshape(B, self.latent_len, C)
1429
+ x = self.proj(x)
1430
+ x = self.proj_drop(x)
1431
+
1432
+ x = x + self.mlp(self.norm(x))
1433
+
1434
+ # optional pool if latent seq_len > 1 and pooled output is desired
1435
+ if self.pool == "token":
1436
+ x = x[:, 0]
1437
+ elif self.pool == "avg":
1438
+ x = x.mean(1)
1439
+
1440
+
1441
+ class Encoder(nn.Module):
1442
+ def __init__(
1443
+ self,
1444
+ in_channels=3,
1445
+ ch=128,
1446
+ ch_mult=(1, 1, 2, 2, 4),
1447
+ num_res_blocks=2,
1448
+ norm_type="group",
1449
+ dropout=0.0,
1450
+ resamp_with_conv=True,
1451
+ z_channels=256,
1452
+ ):
1453
+ super().__init__()
1454
+ self.num_resolutions = len(ch_mult)
1455
+ self.num_res_blocks = num_res_blocks
1456
+ self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)
1457
+
1458
+ # downsampling
1459
+ in_ch_mult = (1,) + tuple(ch_mult)
1460
+ self.conv_blocks = nn.ModuleList()
1461
+ for i_level in range(self.num_resolutions):
1462
+ conv_block = nn.Module()
1463
+ # res & attn
1464
+ res_block = nn.ModuleList()
1465
+ attn_block = nn.ModuleList()
1466
+ block_in = ch * in_ch_mult[i_level]
1467
+ block_out = ch * ch_mult[i_level]
1468
+ for _ in range(self.num_res_blocks):
1469
+ res_block.append(
1470
+ ResnetBlock(
1471
+ block_in, block_out, dropout=dropout, norm_type=norm_type
1472
+ )
1473
+ )
1474
+ block_in = block_out
1475
+ if i_level == self.num_resolutions - 1:
1476
+ attn_block.append(AttnBlock(block_in, norm_type))
1477
+ conv_block.res = res_block
1478
+ conv_block.attn = attn_block
1479
+ # downsample
1480
+ if i_level != self.num_resolutions - 1:
1481
+ conv_block.downsample = Downsample(block_in, resamp_with_conv)
1482
+ self.conv_blocks.append(conv_block)
1483
+
1484
+ # middle
1485
+ self.mid = nn.ModuleList()
1486
+ self.mid.append(
1487
+ ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
1488
+ )
1489
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type))
1490
+ self.mid.append(
1491
+ ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
1492
+ )
1493
+
1494
+ # end
1495
+ self.norm_out = Normalize(block_in, norm_type)
1496
+ self.conv_out = nn.Conv2d(
1497
+ block_in, z_channels, kernel_size=3, stride=1, padding=1
1498
+ )
1499
+
1500
+ def forward(self, x):
1501
+ h = self.conv_in(x)
1502
+ # downsampling
1503
+ for i_level, block in enumerate(self.conv_blocks):
1504
+ for i_block in range(self.num_res_blocks):
1505
+ h = block.res[i_block](h)
1506
+ if len(block.attn) > 0:
1507
+ h = block.attn[i_block](h)
1508
+ if i_level != self.num_resolutions - 1:
1509
+ h = block.downsample(h)
1510
+
1511
+ # middle
1512
+ for mid_block in self.mid:
1513
+ h = mid_block(h)
1514
+
1515
+ # end
1516
+ h = self.norm_out(h)
1517
+ h = nonlinearity(h)
1518
+ h = self.conv_out(h)
1519
+ return h
1520
+
1521
+
1522
+ class Decoder(nn.Module):
1523
+ def __init__(
1524
+ self,
1525
+ z_channels=256,
1526
+ ch=128,
1527
+ ch_mult=(1, 1, 2, 2, 4),
1528
+ num_res_blocks=2,
1529
+ norm_type="group",
1530
+ dropout=0.0,
1531
+ resamp_with_conv=True,
1532
+ out_channels=3,
1533
+ ):
1534
+ super().__init__()
1535
+ self.num_resolutions = len(ch_mult)
1536
+ self.num_res_blocks = num_res_blocks
1537
+
1538
+ block_in = ch * ch_mult[self.num_resolutions - 1]
1539
+ # z to block_in
1540
+ self.conv_in = nn.Conv2d(
1541
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
1542
+ )
1543
+
1544
+ # middle
1545
+ self.mid = nn.ModuleList()
1546
+ self.mid.append(
1547
+ ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
1548
+ )
1549
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type))
1550
+ self.mid.append(
1551
+ ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
1552
+ )
1553
+
1554
+ # upsampling
1555
+ self.conv_blocks = nn.ModuleList()
1556
+ for i_level in reversed(range(self.num_resolutions)):
1557
+ conv_block = nn.Module()
1558
+ # res & attn
1559
+ res_block = nn.ModuleList()
1560
+ attn_block = nn.ModuleList()
1561
+ block_out = ch * ch_mult[i_level]
1562
+ for _ in range(self.num_res_blocks + 1):
1563
+ res_block.append(
1564
+ ResnetBlock(
1565
+ block_in, block_out, dropout=dropout, norm_type=norm_type
1566
+ )
1567
+ )
1568
+ block_in = block_out
1569
+ if i_level == self.num_resolutions - 1:
1570
+ attn_block.append(AttnBlock(block_in, norm_type))
1571
+ conv_block.res = res_block
1572
+ conv_block.attn = attn_block
1573
+ # downsample
1574
+ if i_level != 0:
1575
+ conv_block.upsample = Upsample(block_in, resamp_with_conv)
1576
+ self.conv_blocks.append(conv_block)
1577
+
1578
+ # end
1579
+ self.norm_out = Normalize(block_in, norm_type)
1580
+ self.conv_out = nn.Conv2d(
1581
+ block_in, out_channels, kernel_size=3, stride=1, padding=1
1582
+ )
1583
+
1584
+ @property
1585
+ def last_layer(self):
1586
+ return self.conv_out.weight
1587
+
1588
+ def forward(self, z):
1589
+ # z to block_in
1590
+ h = self.conv_in(z)
1591
+
1592
+ # middle
1593
+ for mid_block in self.mid:
1594
+ h = mid_block(h)
1595
+
1596
+ # upsampling
1597
+ for i_level, block in enumerate(self.conv_blocks):
1598
+ for i_block in range(self.num_res_blocks + 1):
1599
+ h = block.res[i_block](h)
1600
+ if len(block.attn) > 0:
1601
+ h = block.attn[i_block](h)
1602
+ if i_level != self.num_resolutions - 1:
1603
+ h = block.upsample(h)
1604
+
1605
+ # end
1606
+ h = self.norm_out(h)
1607
+ h = nonlinearity(h)
1608
+ h = self.conv_out(h)
1609
+ return h
1610
+
1611
+
1612
+ class VectorQuantizer(nn.Module):
1613
+ def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):
1614
+ super().__init__()
1615
+ self.n_e = n_e
1616
+ self.e_dim = e_dim
1617
+ self.beta = beta
1618
+ self.entropy_loss_ratio = entropy_loss_ratio
1619
+ self.l2_norm = l2_norm
1620
+ self.show_usage = show_usage
1621
+
1622
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
1623
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
1624
+ if self.l2_norm:
1625
+ self.embedding.weight.data = F.normalize(
1626
+ self.embedding.weight.data, p=2, dim=-1
1627
+ )
1628
+ if self.show_usage:
1629
+ # self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))
1630
+ self.codebook_used = nn.Parameter(torch.zeros(65536))
1631
+
1632
+ def forward(self, z):
1633
+ # reshape z -> (batch, height, width, channel) and flatten
1634
+ z = torch.einsum("b c h w -> b h w c", z).contiguous()
1635
+ z_flattened = z.view(-1, self.e_dim)
1636
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
1637
+
1638
+ if self.l2_norm:
1639
+ z = F.normalize(z, p=2, dim=-1)
1640
+ z_flattened = F.normalize(z_flattened, p=2, dim=-1)
1641
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
1642
+ else:
1643
+ embedding = self.embedding.weight
1644
+
1645
+ d = (
1646
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
1647
+ + torch.sum(embedding**2, dim=1)
1648
+ - 2
1649
+ * torch.einsum(
1650
+ "bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding)
1651
+ )
1652
+ )
1653
+
1654
+ min_encoding_indices = torch.argmin(d, dim=1)
1655
+ z_q = embedding[min_encoding_indices].view(z.shape)
1656
+ perplexity = None
1657
+ min_encodings = None
1658
+ vq_loss = None
1659
+ commit_loss = None
1660
+ entropy_loss = None
1661
+
1662
+ # compute loss for embedding
1663
+ if self.training:
1664
+ vq_loss = torch.mean((z_q - z.detach()) ** 2)
1665
+ commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
1666
+ entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)
1667
+
1668
+ # preserve gradients
1669
+ z_q = z + (z_q - z).detach()
1670
+
1671
+ # reshape back to match original input shape
1672
+ z_q = torch.einsum("b h w c -> b c h w", z_q)
1673
+
1674
+ return (
1675
+ z_q,
1676
+ (vq_loss, commit_loss, entropy_loss),
1677
+ (perplexity, min_encodings, min_encoding_indices),
1678
+ )
1679
+
1680
+ def get_codebook_entry(self, indices, shape=None, channel_first=True):
1681
+ # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel)
1682
+ if self.l2_norm:
1683
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
1684
+ else:
1685
+ embedding = self.embedding.weight
1686
+ z_q = embedding[indices] # (b*h*w, c)
1687
+
1688
+ if shape is not None:
1689
+ if channel_first:
1690
+ z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1])
1691
+ # reshape back to match original input shape
1692
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
1693
+ else:
1694
+ z_q = z_q.view(shape)
1695
+ return z_q
1696
+
1697
+
1698
+ class ResnetBlock(nn.Module):
1699
+ def __init__(
1700
+ self,
1701
+ in_channels,
1702
+ out_channels=None,
1703
+ conv_shortcut=False,
1704
+ dropout=0.0,
1705
+ norm_type="group",
1706
+ ):
1707
+ super().__init__()
1708
+ self.in_channels = in_channels
1709
+ out_channels = in_channels if out_channels is None else out_channels
1710
+ self.out_channels = out_channels
1711
+ self.use_conv_shortcut = conv_shortcut
1712
+
1713
+ self.norm1 = Normalize(in_channels, norm_type)
1714
+ self.conv1 = nn.Conv2d(
1715
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
1716
+ )
1717
+ self.norm2 = Normalize(out_channels, norm_type)
1718
+ self.dropout = nn.Dropout(dropout)
1719
+ self.conv2 = nn.Conv2d(
1720
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
1721
+ )
1722
+
1723
+ if self.in_channels != self.out_channels:
1724
+ if self.use_conv_shortcut:
1725
+ self.conv_shortcut = nn.Conv2d(
1726
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
1727
+ )
1728
+ else:
1729
+ self.nin_shortcut = nn.Conv2d(
1730
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
1731
+ )
1732
+
1733
+ def forward(self, x):
1734
+ h = x
1735
+ h = self.norm1(h)
1736
+ h = nonlinearity(h)
1737
+ h = self.conv1(h)
1738
+ h = self.norm2(h)
1739
+ h = nonlinearity(h)
1740
+ h = self.dropout(h)
1741
+ h = self.conv2(h)
1742
+
1743
+ if self.in_channels != self.out_channels:
1744
+ if self.use_conv_shortcut:
1745
+ x = self.conv_shortcut(x)
1746
+ else:
1747
+ x = self.nin_shortcut(x)
1748
+ return x + h
1749
+
1750
+
1751
+ class AttnBlock(nn.Module):
1752
+ def __init__(self, in_channels, norm_type="group"):
1753
+ super().__init__()
1754
+ self.norm = Normalize(in_channels, norm_type)
1755
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
1756
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
1757
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
1758
+ self.proj_out = nn.Conv2d(
1759
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
1760
+ )
1761
+
1762
+ def forward(self, x):
1763
+ h_ = x
1764
+ h_ = self.norm(h_)
1765
+ q = self.q(h_)
1766
+ k = self.k(h_)
1767
+ v = self.v(h_)
1768
+
1769
+ # compute attention
1770
+ b, c, h, w = q.shape
1771
+ q = q.reshape(b, c, h * w)
1772
+ q = q.permute(0, 2, 1) # b,hw,c
1773
+ k = k.reshape(b, c, h * w) # b,c,hw
1774
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
1775
+ w_ = w_ * (int(c) ** (-0.5))
1776
+ w_ = F.softmax(w_, dim=2)
1777
+
1778
+ # attend to values
1779
+ v = v.reshape(b, c, h * w)
1780
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
1781
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
1782
+ h_ = h_.reshape(b, c, h, w)
1783
+
1784
+ h_ = self.proj_out(h_)
1785
+
1786
+ return x + h_
1787
+
1788
+
1789
+ def nonlinearity(x):
1790
+ # swish
1791
+ return x * torch.sigmoid(x)
1792
+
1793
+
1794
+ def Normalize(in_channels, norm_type="group"):
1795
+ assert norm_type in ["group", "batch"]
1796
+ if norm_type == "group":
1797
+ return nn.GroupNorm(
1798
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
1799
+ )
1800
+ elif norm_type == "batch":
1801
+ return nn.SyncBatchNorm(in_channels)
1802
+
1803
+
1804
+ class Upsample(nn.Module):
1805
+ def __init__(self, in_channels, with_conv):
1806
+ super().__init__()
1807
+ self.with_conv = with_conv
1808
+ if self.with_conv:
1809
+ self.conv = nn.Conv2d(
1810
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
1811
+ )
1812
+
1813
+ def forward(self, x):
1814
+ if x.dtype != torch.float32:
1815
+ x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to(
1816
+ torch.bfloat16
1817
+ )
1818
+ else:
1819
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
1820
+
1821
+ if self.with_conv:
1822
+ x = self.conv(x)
1823
+ return x
1824
+
1825
+
1826
+ class Downsample(nn.Module):
1827
+ def __init__(self, in_channels, with_conv):
1828
+ super().__init__()
1829
+ self.with_conv = with_conv
1830
+ if self.with_conv:
1831
+ # no asymmetric padding in torch conv, must do it ourselves
1832
+ self.conv = nn.Conv2d(
1833
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
1834
+ )
1835
+
1836
+ def forward(self, x):
1837
+ if self.with_conv:
1838
+ pad = (0, 1, 0, 1)
1839
+ x = F.pad(x, pad, mode="constant", value=0)
1840
+ x = self.conv(x)
1841
+ else:
1842
+ x = F.avg_pool2d(x, kernel_size=2, stride=2)
1843
+ return x
1844
+
1845
+
1846
+ def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
1847
+ flat_affinity = affinity.reshape(-1, affinity.shape[-1])
1848
+ flat_affinity /= temperature
1849
+ probs = F.softmax(flat_affinity, dim=-1)
1850
+ log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
1851
+ if loss_type == "softmax":
1852
+ target_probs = probs
1853
+ else:
1854
+ raise ValueError("Entropy loss {} not supported".format(loss_type))
1855
+ avg_probs = torch.mean(target_probs, dim=0)
1856
+ avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
1857
+ sample_entropy = -torch.mean(torch.sum(target_probs * log_probs, dim=-1))
1858
+ loss = sample_entropy - avg_entropy
1859
+ return loss
1860
+
1861
+
1862
+ class VQModel(nn.Module):
1863
+ def __init__(self, config: ModelArgs):
1864
+ super().__init__()
1865
+ self.config = config
1866
+ self.encoder = Encoder(
1867
+ ch_mult=config.encoder_ch_mult,
1868
+ z_channels=config.z_channels,
1869
+ dropout=config.dropout_p,
1870
+ )
1871
+ self.decoder = Decoder(
1872
+ ch_mult=config.decoder_ch_mult,
1873
+ z_channels=config.z_channels,
1874
+ dropout=config.dropout_p,
1875
+ )
1876
+
1877
+ self.quantize = VectorQuantizer(
1878
+ config.codebook_size,
1879
+ config.codebook_embed_dim,
1880
+ config.commit_loss_beta,
1881
+ config.entropy_loss_ratio,
1882
+ config.codebook_l2_norm,
1883
+ config.codebook_show_usage,
1884
+ )
1885
+ self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
1886
+ self.post_quant_conv = nn.Conv2d(
1887
+ config.codebook_embed_dim, config.z_channels, 1
1888
+ )
1889
+
1890
+ def encode(self, x):
1891
+ h = self.encoder(x)
1892
+ h = self.quant_conv(h)
1893
+ quant, emb_loss, info = self.quantize(h)
1894
+ return quant, emb_loss, info
1895
+
1896
+ def decode(self, quant):
1897
+ quant = self.post_quant_conv(quant)
1898
+ dec = self.decoder(quant)
1899
+ return dec
1900
+
1901
+ def decode_code(self, code_b, shape=None, channel_first=True):
1902
+ quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
1903
+ dec = self.decode(quant_b)
1904
+ return dec
1905
+
1906
+ def forward(self, input):
1907
+ quant, diff, _ = self.encode(input)
1908
+ dec = self.decode(quant)
1909
+ return dec, diff
1910
+
1911
+
1912
+ class MultiModalityPreTrainedModel(PreTrainedModel):
1913
+ config_class = MultiModalityConfig
1914
+ base_model_prefix = "multi_modality"
1915
+ _no_split_modules = []
1916
+ _skip_keys_device_placement = "past_key_values"
1917
+
1918
+
1919
+ # Copied and adapted from:
1920
+ # https://github.com/deepseek-ai/Janus/tree/main/janus/models/modeling_vlm.py
1921
+ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
1922
+
1923
+ def __init__(
1924
+ self,
1925
+ config: MultiModalityConfig,
1926
+ quant_config: Optional[QuantizationConfig] = None,
1927
+ ):
1928
+ super().__init__(config)
1929
+
1930
+ vision_config = config.vision_config
1931
+ vision_cls = model_name_to_cls(vision_config.cls)
1932
+ self.vision_model = vision_cls(**vision_config.params)
1933
+
1934
+ aligner_config = config.aligner_config
1935
+ aligner_cls = model_name_to_cls(aligner_config.cls)
1936
+ self.aligner = aligner_cls(aligner_config.params)
1937
+
1938
+ gen_vision_config = config.gen_vision_config
1939
+ gen_vision_cls = model_name_to_cls(gen_vision_config.cls)
1940
+ self.gen_vision_model = gen_vision_cls()
1941
+
1942
+ gen_aligner_config = config.gen_aligner_config
1943
+ gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)
1944
+ self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)
1945
+
1946
+ gen_head_config = config.gen_head_config
1947
+ gen_head_cls = model_name_to_cls(gen_head_config.cls)
1948
+ self.gen_head = gen_head_cls(gen_head_config.params)
1949
+
1950
+ self.gen_embed = torch.nn.Embedding(
1951
+ gen_vision_config.params["image_token_size"],
1952
+ gen_vision_config.params["n_embed"],
1953
+ )
1954
+
1955
+ language_config = config.language_config
1956
+ self.language_model = LlamaForCausalLM(
1957
+ language_config, quant_config=quant_config
1958
+ )
1959
+ self.logits_processor = LogitsProcessor(config)
1960
+
1961
+ def prepare_images_seq_mask(
1962
+ self, input_ids: torch.Tensor, image_inputs: ImageInputs
1963
+ ) -> Optional[torch.LongTensor]:
1964
+ images_seq_mask = torch.isin(
1965
+ input_ids, torch.tensor(image_inputs.pad_values, device=input_ids.device)
1966
+ )
1967
+ if images_seq_mask.sum() == 0:
1968
+ # sometimes image_inputs is not empty, but input_ids contain no image token because of prefix-cache
1969
+ return None
1970
+ else:
1971
+ return images_seq_mask
1972
+
1973
+ @torch.no_grad()
1974
+ def forward(
1975
+ self,
1976
+ input_ids: torch.LongTensor,
1977
+ positions: torch.Tensor,
1978
+ forward_batch: ForwardBatch,
1979
+ ) -> torch.Tensor:
1980
+
1981
+ inputs_embeds = None
1982
+ if (
1983
+ forward_batch.image_inputs is not None
1984
+ and len(forward_batch.image_inputs) != 0
1985
+ and forward_batch.image_inputs[0] is not None
1986
+ ):
1987
+
1988
+ image_inputs = forward_batch.image_inputs[0]
1989
+
1990
+ images_seq_mask = self.prepare_images_seq_mask(
1991
+ input_ids=input_ids, image_inputs=image_inputs
1992
+ )
1993
+
1994
+ if images_seq_mask is not None:
1995
+ input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
1996
+ inputs_embeds = self.prepare_inputs_embeds(
1997
+ input_ids=input_ids,
1998
+ pixel_values=image_inputs.pixel_values,
1999
+ images_seq_mask=images_seq_mask,
2000
+ images_emb_mask=image_inputs.images_emb_mask,
2001
+ )
2002
+ input_ids = None
2003
+
2004
+ if input_ids is not None:
2005
+ input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
2006
+
2007
+ return self.language_model(
2008
+ input_ids=input_ids,
2009
+ positions=positions,
2010
+ forward_batch=forward_batch,
2011
+ input_embeds=inputs_embeds,
2012
+ get_embedding=False,
2013
+ )
2014
+
2015
+ def prepare_inputs_embeds(
2016
+ self,
2017
+ input_ids: torch.LongTensor,
2018
+ pixel_values: torch.FloatTensor,
2019
+ images_seq_mask: torch.LongTensor,
2020
+ images_emb_mask: torch.BoolTensor,
2021
+ **_kwargs,
2022
+ ):
2023
+ """
2024
+
2025
+ Args:
2026
+ input_ids (torch.LongTensor): [b, T]
2027
+ pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
2028
+ images_seq_mask (torch.BoolTensor): [b, T]
2029
+ images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
2030
+
2031
+ assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
2032
+
2033
+ Returns:
2034
+ input_embeds (torch.Tensor): [b, T, D]
2035
+ """
2036
+
2037
+ bs, n = pixel_values.shape[0:2]
2038
+ pixel_values = pixel_values.to(
2039
+ device=self.vision_model.device, dtype=self.vision_model.dtype
2040
+ )
2041
+ images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
2042
+
2043
+ # [b x n, T2, D]
2044
+ images_embeds = self.aligner(self.vision_model(images))
2045
+
2046
+ # [b x n, T2, D] -> [b, n x T2, D]
2047
+ images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
2048
+ # [b, n, T2] -> [b, n x T2]
2049
+ images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
2050
+
2051
+ # [b, T, D]
2052
+ # ignore the image embeddings
2053
+ input_ids[input_ids < 0] = 0
2054
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
2055
+
2056
+ # replace with the image embeddings
2057
+ inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
2058
+
2059
+ return inputs_embeds
2060
+
2061
+ def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
2062
+ return self.gen_aligner(self.gen_embed(image_ids))
2063
+
2064
+ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
2065
+ im_start_id = image_inputs.im_start_id
2066
+ im_end_id = image_inputs.im_end_id
2067
+ media_token_pairs = [(im_start_id, im_end_id)]
2068
+
2069
+ helper = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
2070
+
2071
+ return helper.pad_input_tokens(input_ids, image_inputs)
2072
+
2073
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
2074
+ stacked_params_mapping = [
2075
+ # (param_name, shard_name, shard_id)
2076
+ (".qkv_proj", ".q_proj", "q"),
2077
+ (".qkv_proj", ".k_proj", "k"),
2078
+ (".qkv_proj", ".v_proj", "v"),
2079
+ ("gate_up_proj", "gate_proj", 0),
2080
+ ("gate_up_proj", "up_proj", 1),
2081
+ ]
2082
+
2083
+ params_dict = dict(self.named_parameters())
2084
+ for name, loaded_weight in weights:
2085
+ if "rotary_emb.inv_freq~" in name or "projector" in name:
2086
+ continue
2087
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
2088
+ # Models trained using ColossalAI may include these tensors in
2089
+ # the checkpoint. Skip them.
2090
+ continue
2091
+ if name.startswith("model.vision_tower") and name not in params_dict:
2092
+ continue
2093
+
2094
+ # skip generation sub model
2095
+ if "gen" in name:
2096
+ continue
2097
+
2098
+ # adapt to VisionAttention
2099
+ name = name.replace(r"self_attn.out_proj", r"self_attn.proj")
2100
+ if "vision_model.vision_tower" in name:
2101
+ name = name.replace("attn.qkv", "attn.qkv_proj")
2102
+
2103
+ for param_name, weight_name, shard_id in stacked_params_mapping:
2104
+ # replace the name and load with customized loader
2105
+ if weight_name not in name:
2106
+ continue
2107
+ name = name.replace(weight_name, param_name)
2108
+
2109
+ # # Skip loading extra bias for GPTQ models.
2110
+ if name.endswith(".bias") and name not in params_dict:
2111
+ continue
2112
+ param = params_dict[name]
2113
+ weight_loader = getattr(param, "weight_loader", None)
2114
+ weight_loader(param, loaded_weight, shard_id)
2115
+ break
2116
+ else:
2117
+ # Skip loading extra bias for GPTQ models.
2118
+ if name.endswith(".bias") and name not in params_dict:
2119
+ continue
2120
+
2121
+ param = params_dict[name]
2122
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
2123
+ weight_loader(param, loaded_weight)
2124
+
2125
+
2126
+ AutoModel.register(config_class=MultiModalityConfig, model_class=MultiModalityCausalLM)
2127
+ EntryClass = [MultiModalityCausalLM]