sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +1 -1
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/function_call_parser.py +33 -2
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +1 -3
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/dp_attention.py +30 -2
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/logits_processor.py +1 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +74 -8
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +32 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +213 -118
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +176 -683
- sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
- sglang/srt/managers/tokenizer_manager.py +6 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +71 -34
- sglang/srt/mem_cache/memory_pool.py +81 -17
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/model_executor/cuda_graph_runner.py +68 -20
- sglang/srt/model_executor/forward_batch_info.py +23 -10
- sglang/srt/model_executor/model_runner.py +63 -63
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +200 -191
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +59 -35
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +24 -16
- sglang/srt/speculative/eagle_worker.py +75 -39
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2127 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
# Copied and Adapted from:
|
16
|
+
# https://github.com/deepseek-ai/Janus
|
17
|
+
|
18
|
+
|
19
|
+
import collections
|
20
|
+
import math
|
21
|
+
import os
|
22
|
+
from dataclasses import field
|
23
|
+
from enum import Enum
|
24
|
+
from functools import partial
|
25
|
+
from itertools import repeat
|
26
|
+
from typing import (
|
27
|
+
Callable,
|
28
|
+
Final,
|
29
|
+
Iterable,
|
30
|
+
Literal,
|
31
|
+
Optional,
|
32
|
+
Sequence,
|
33
|
+
Set,
|
34
|
+
Tuple,
|
35
|
+
Type,
|
36
|
+
Union,
|
37
|
+
)
|
38
|
+
|
39
|
+
import torch
|
40
|
+
import torch.nn.functional as F
|
41
|
+
from einops import rearrange
|
42
|
+
from torch import Tensor, _assert, nn
|
43
|
+
from torch.nn.init import trunc_normal_
|
44
|
+
from transformers import AutoModel, PreTrainedModel
|
45
|
+
|
46
|
+
from sglang.srt.configs.janus_pro import *
|
47
|
+
from sglang.srt.layers.attention.vision import VisionAttention
|
48
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
49
|
+
from sglang.srt.layers.quantization import QuantizationConfig
|
50
|
+
from sglang.srt.managers.multi_modality_padding import (
|
51
|
+
MultiModalityDataPaddingPatternTokenPairs,
|
52
|
+
)
|
53
|
+
from sglang.srt.managers.schedule_batch import ImageInputs
|
54
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
55
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
56
|
+
from sglang.srt.models.llama import LlamaForCausalLM
|
57
|
+
from sglang.utils import logger
|
58
|
+
|
59
|
+
#################################################################################
|
60
|
+
# VQ Model Configs #
|
61
|
+
#################################################################################
|
62
|
+
|
63
|
+
|
64
|
+
# Copied from:
|
65
|
+
# https://github.com/deepseek-ai/Janus/tree/main/janus/models/vq_model.py
|
66
|
+
@dataclass
|
67
|
+
class ModelArgs:
|
68
|
+
codebook_size: int = 16384
|
69
|
+
codebook_embed_dim: int = 8
|
70
|
+
codebook_l2_norm: bool = True
|
71
|
+
codebook_show_usage: bool = True
|
72
|
+
commit_loss_beta: float = 0.25
|
73
|
+
entropy_loss_ratio: float = 0.0
|
74
|
+
|
75
|
+
encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
|
76
|
+
decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
|
77
|
+
z_channels: int = 256
|
78
|
+
dropout_p: float = 0.0
|
79
|
+
|
80
|
+
|
81
|
+
def named_apply(
|
82
|
+
fn: Callable,
|
83
|
+
module: nn.Module,
|
84
|
+
name="",
|
85
|
+
depth_first: bool = True,
|
86
|
+
include_root: bool = False,
|
87
|
+
) -> nn.Module:
|
88
|
+
if not depth_first and include_root:
|
89
|
+
fn(module=module, name=name)
|
90
|
+
for child_name, child_module in module.named_children():
|
91
|
+
child_name = ".".join((name, child_name)) if name else child_name
|
92
|
+
named_apply(
|
93
|
+
fn=fn,
|
94
|
+
module=child_module,
|
95
|
+
name=child_name,
|
96
|
+
depth_first=depth_first,
|
97
|
+
include_root=True,
|
98
|
+
)
|
99
|
+
if depth_first and include_root:
|
100
|
+
fn(module=module, name=name)
|
101
|
+
return module
|
102
|
+
|
103
|
+
|
104
|
+
def VQ_16(**kwargs):
|
105
|
+
return VQModel(
|
106
|
+
ModelArgs(
|
107
|
+
encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs
|
108
|
+
)
|
109
|
+
)
|
110
|
+
|
111
|
+
|
112
|
+
VQ_models = {"VQ-16": VQ_16}
|
113
|
+
|
114
|
+
import collections.abc
|
115
|
+
|
116
|
+
|
117
|
+
# From PyTorch internals
|
118
|
+
def _ntuple(n):
|
119
|
+
def parse(x):
|
120
|
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
121
|
+
return tuple(x)
|
122
|
+
return tuple(repeat(x, n))
|
123
|
+
|
124
|
+
return parse
|
125
|
+
|
126
|
+
|
127
|
+
def _trunc_normal_(tensor, mean, std, a, b):
|
128
|
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
129
|
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
130
|
+
def norm_cdf(x):
|
131
|
+
# Computes standard normal cumulative distribution function
|
132
|
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
133
|
+
|
134
|
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
135
|
+
logger.warn(
|
136
|
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
137
|
+
"The distribution of values may be incorrect.",
|
138
|
+
stacklevel=2,
|
139
|
+
)
|
140
|
+
|
141
|
+
# Values are generated by using a truncated uniform distribution and
|
142
|
+
# then using the inverse CDF for the normal distribution.
|
143
|
+
# Get upper and lower cdf values
|
144
|
+
l = norm_cdf((a - mean) / std)
|
145
|
+
u = norm_cdf((b - mean) / std)
|
146
|
+
|
147
|
+
# Uniformly fill tensor with values from [l, u], then translate to
|
148
|
+
# [2l-1, 2u-1].
|
149
|
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
150
|
+
|
151
|
+
# Use inverse cdf transform for normal distribution to get truncated
|
152
|
+
# standard normal
|
153
|
+
if tensor.dtype in [torch.float16, torch.bfloat16]:
|
154
|
+
# The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
|
155
|
+
og_dtype = tensor.dtype
|
156
|
+
tensor = tensor.to(torch.float32)
|
157
|
+
tensor.erfinv_()
|
158
|
+
tensor = tensor.to(og_dtype)
|
159
|
+
else:
|
160
|
+
tensor.erfinv_()
|
161
|
+
|
162
|
+
# Transform to proper mean, std
|
163
|
+
tensor.mul_(std * math.sqrt(2.0))
|
164
|
+
tensor.add_(mean)
|
165
|
+
|
166
|
+
# Clamp to ensure it's in the proper range
|
167
|
+
if tensor.dtype == torch.float16:
|
168
|
+
# The `clamp_` op is not (yet?) defined in float16+cpu
|
169
|
+
tensor = tensor.to(torch.float32)
|
170
|
+
tensor.clamp_(min=a, max=b)
|
171
|
+
else:
|
172
|
+
tensor.clamp_(min=a, max=b)
|
173
|
+
|
174
|
+
|
175
|
+
def trunc_normal_tf_(
|
176
|
+
tensor: torch.Tensor,
|
177
|
+
mean: float = 0.0,
|
178
|
+
std: float = 1.0,
|
179
|
+
a: float = -2.0,
|
180
|
+
b: float = 2.0,
|
181
|
+
):
|
182
|
+
"""Fills the input Tensor with values drawn from a truncated
|
183
|
+
normal distribution. The values are effectively drawn from the
|
184
|
+
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
|
185
|
+
with values outside :math:`[a, b]` redrawn until they are within
|
186
|
+
the bounds. The method used for generating the random values works
|
187
|
+
best when :math:`a \\leq \text{mean} \\leq b`.
|
188
|
+
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
|
189
|
+
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
|
190
|
+
and the result is subsquently scaled and shifted by the mean and std args.
|
191
|
+
Args:
|
192
|
+
tensor: an n-dimensional `torch.Tensor`
|
193
|
+
mean: the mean of the normal distribution
|
194
|
+
std: the standard deviation of the normal distribution
|
195
|
+
a: the minimum cutoff value
|
196
|
+
b: the maximum cutoff value
|
197
|
+
"""
|
198
|
+
with torch.no_grad():
|
199
|
+
_trunc_normal_(tensor, 0, 1.0, a, b)
|
200
|
+
tensor.mul_(std).add_(mean)
|
201
|
+
|
202
|
+
|
203
|
+
to_2tuple = _ntuple(2)
|
204
|
+
|
205
|
+
|
206
|
+
class Format(str, Enum):
|
207
|
+
NCHW = "NCHW"
|
208
|
+
NHWC = "NHWC"
|
209
|
+
NCL = "NCL"
|
210
|
+
NLC = "NLC"
|
211
|
+
|
212
|
+
|
213
|
+
def nchw_to(x: torch.Tensor, fmt: Format):
|
214
|
+
if fmt == Format.NHWC:
|
215
|
+
x = x.permute(0, 2, 3, 1)
|
216
|
+
elif fmt == Format.NLC:
|
217
|
+
x = x.flatten(2).transpose(1, 2)
|
218
|
+
elif fmt == Format.NCL:
|
219
|
+
x = x.flatten(2)
|
220
|
+
return x
|
221
|
+
|
222
|
+
|
223
|
+
def resample_patch_embed(
|
224
|
+
patch_embed,
|
225
|
+
new_size: List[int],
|
226
|
+
interpolation: str = "bicubic",
|
227
|
+
antialias: bool = True,
|
228
|
+
verbose: bool = False,
|
229
|
+
):
|
230
|
+
"""Resample the weights of the patch embedding kernel to target resolution.
|
231
|
+
We resample the patch embedding kernel by approximately inverting the effect
|
232
|
+
of patch resizing.
|
233
|
+
|
234
|
+
Code based on:
|
235
|
+
https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py
|
236
|
+
|
237
|
+
With this resizing, we can for example load a B/8 filter into a B/16 model
|
238
|
+
and, on 2x larger input image, the result will match.
|
239
|
+
|
240
|
+
Args:
|
241
|
+
patch_embed: original parameter to be resized.
|
242
|
+
new_size (tuple(int, int): target shape (height, width)-only.
|
243
|
+
interpolation (str): interpolation for resize
|
244
|
+
antialias (bool): use anti-aliasing filter in resize
|
245
|
+
verbose (bool): log operation
|
246
|
+
Returns:
|
247
|
+
Resized patch embedding kernel.
|
248
|
+
"""
|
249
|
+
import numpy as np
|
250
|
+
|
251
|
+
try:
|
252
|
+
from torch import vmap
|
253
|
+
except ImportError:
|
254
|
+
from functorch import vmap
|
255
|
+
|
256
|
+
assert len(patch_embed.shape) == 4, "Four dimensions expected"
|
257
|
+
assert len(new_size) == 2, "New shape should only be hw"
|
258
|
+
old_size = patch_embed.shape[-2:]
|
259
|
+
if tuple(old_size) == tuple(new_size):
|
260
|
+
return patch_embed
|
261
|
+
|
262
|
+
if verbose:
|
263
|
+
logger.info(
|
264
|
+
f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation."
|
265
|
+
)
|
266
|
+
|
267
|
+
def resize(x_np, _new_size):
|
268
|
+
x_tf = torch.Tensor(x_np)[None, None, ...]
|
269
|
+
x_upsampled = F.interpolate(
|
270
|
+
x_tf, size=_new_size, mode=interpolation, antialias=antialias
|
271
|
+
)[0, 0, ...].numpy()
|
272
|
+
return x_upsampled
|
273
|
+
|
274
|
+
def get_resize_mat(_old_size, _new_size):
|
275
|
+
mat = []
|
276
|
+
for i in range(np.prod(_old_size)):
|
277
|
+
basis_vec = np.zeros(_old_size)
|
278
|
+
basis_vec[np.unravel_index(i, _old_size)] = 1.0
|
279
|
+
mat.append(resize(basis_vec, _new_size).reshape(-1))
|
280
|
+
return np.stack(mat).T
|
281
|
+
|
282
|
+
resize_mat = get_resize_mat(old_size, new_size)
|
283
|
+
resize_mat_pinv = torch.tensor(
|
284
|
+
np.linalg.pinv(resize_mat.T), device=patch_embed.device
|
285
|
+
)
|
286
|
+
|
287
|
+
def resample_kernel(kernel):
|
288
|
+
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
|
289
|
+
return resampled_kernel.reshape(new_size)
|
290
|
+
|
291
|
+
v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1)
|
292
|
+
orig_dtype = patch_embed.dtype
|
293
|
+
patch_embed = patch_embed.float()
|
294
|
+
patch_embed = v_resample_kernel(patch_embed)
|
295
|
+
patch_embed = patch_embed.to(orig_dtype)
|
296
|
+
return patch_embed
|
297
|
+
|
298
|
+
|
299
|
+
# Copied from:
|
300
|
+
# https://github.com/deepseek-ai/Janus/tree/main/janus/models/siglip_vit.py
|
301
|
+
class PatchEmbed(nn.Module):
|
302
|
+
"""2D Image to Patch Embedding"""
|
303
|
+
|
304
|
+
output_fmt: Format
|
305
|
+
dynamic_img_pad: torch.jit.Final[bool]
|
306
|
+
|
307
|
+
def __init__(
|
308
|
+
self,
|
309
|
+
img_size: Optional[int] = 224,
|
310
|
+
patch_size: int = 16,
|
311
|
+
in_chans: int = 3,
|
312
|
+
embed_dim: int = 768,
|
313
|
+
norm_layer: Optional[Callable] = None,
|
314
|
+
flatten: bool = True,
|
315
|
+
output_fmt: Optional[str] = None,
|
316
|
+
bias: bool = True,
|
317
|
+
strict_img_size: bool = True,
|
318
|
+
dynamic_img_pad: bool = False,
|
319
|
+
):
|
320
|
+
super().__init__()
|
321
|
+
self.patch_size = tuple(to_2tuple(patch_size))
|
322
|
+
self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size)
|
323
|
+
|
324
|
+
if output_fmt is not None:
|
325
|
+
self.flatten = False
|
326
|
+
self.output_fmt = Format(output_fmt)
|
327
|
+
else:
|
328
|
+
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
329
|
+
self.flatten = flatten
|
330
|
+
self.output_fmt = Format.NCHW
|
331
|
+
self.strict_img_size = strict_img_size
|
332
|
+
self.dynamic_img_pad = dynamic_img_pad
|
333
|
+
|
334
|
+
self.proj = nn.Conv2d(
|
335
|
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
|
336
|
+
)
|
337
|
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
338
|
+
|
339
|
+
def _init_img_size(self, img_size: Union[int, Tuple[int, int]]):
|
340
|
+
assert self.patch_size
|
341
|
+
if img_size is None:
|
342
|
+
return None, None, None
|
343
|
+
img_size = to_2tuple(img_size)
|
344
|
+
grid_size = tuple([s // p for s, p in zip(img_size, self.patch_size)])
|
345
|
+
num_patches = grid_size[0] * grid_size[1]
|
346
|
+
return img_size, grid_size, num_patches
|
347
|
+
|
348
|
+
def set_input_size(
|
349
|
+
self,
|
350
|
+
img_size: Optional[Union[int, Tuple[int, int]]] = None,
|
351
|
+
patch_size: Optional[Union[int, Tuple[int, int]]] = None,
|
352
|
+
):
|
353
|
+
new_patch_size = None
|
354
|
+
if patch_size is not None:
|
355
|
+
new_patch_size = to_2tuple(patch_size)
|
356
|
+
if new_patch_size is not None and new_patch_size != self.patch_size:
|
357
|
+
with torch.no_grad():
|
358
|
+
new_proj = nn.Conv2d(
|
359
|
+
self.proj.in_channels,
|
360
|
+
self.proj.out_channels,
|
361
|
+
kernel_size=new_patch_size,
|
362
|
+
stride=new_patch_size,
|
363
|
+
bias=self.proj.bias is not None,
|
364
|
+
)
|
365
|
+
new_proj.weight.copy_(
|
366
|
+
resample_patch_embed(self.proj.weight, new_patch_size, verbose=True)
|
367
|
+
)
|
368
|
+
if self.proj.bias is not None:
|
369
|
+
new_proj.bias.copy_(self.proj.bias)
|
370
|
+
self.proj = new_proj
|
371
|
+
self.patch_size = new_patch_size
|
372
|
+
img_size = img_size or self.img_size
|
373
|
+
if img_size != self.img_size or new_patch_size is not None:
|
374
|
+
self.img_size, self.grid_size, self.num_patches = self._init_img_size(
|
375
|
+
img_size
|
376
|
+
)
|
377
|
+
|
378
|
+
def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
|
379
|
+
if as_scalar:
|
380
|
+
return max(self.patch_size)
|
381
|
+
else:
|
382
|
+
return self.patch_size
|
383
|
+
|
384
|
+
def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
|
385
|
+
"""Get grid (feature) size for given image size taking account of dynamic padding.
|
386
|
+
NOTE: must be torchscript compatible so using fixed tuple indexing
|
387
|
+
"""
|
388
|
+
if self.dynamic_img_pad:
|
389
|
+
return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(
|
390
|
+
img_size[1] / self.patch_size[1]
|
391
|
+
)
|
392
|
+
else:
|
393
|
+
return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
|
394
|
+
|
395
|
+
def forward(self, x):
|
396
|
+
B, C, H, W = x.shape
|
397
|
+
if self.img_size is not None:
|
398
|
+
if self.strict_img_size:
|
399
|
+
_assert(
|
400
|
+
H == self.img_size[0],
|
401
|
+
f"Input height ({H}) doesn't match model ({self.img_size[0]}).",
|
402
|
+
)
|
403
|
+
_assert(
|
404
|
+
W == self.img_size[1],
|
405
|
+
f"Input width ({W}) doesn't match model ({self.img_size[1]}).",
|
406
|
+
)
|
407
|
+
elif not self.dynamic_img_pad:
|
408
|
+
_assert(
|
409
|
+
H % self.patch_size[0] == 0,
|
410
|
+
f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]}).",
|
411
|
+
)
|
412
|
+
_assert(
|
413
|
+
W % self.patch_size[1] == 0,
|
414
|
+
f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]}).",
|
415
|
+
)
|
416
|
+
if self.dynamic_img_pad:
|
417
|
+
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
418
|
+
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
419
|
+
x = F.pad(x, (0, pad_w, 0, pad_h))
|
420
|
+
x = self.proj(x)
|
421
|
+
if self.flatten:
|
422
|
+
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
423
|
+
elif self.output_fmt != Format.NCHW:
|
424
|
+
x = nchw_to(x, self.output_fmt)
|
425
|
+
x = self.norm(x)
|
426
|
+
return x
|
427
|
+
|
428
|
+
|
429
|
+
class Mlp(nn.Module):
|
430
|
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks
|
431
|
+
|
432
|
+
NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected.
|
433
|
+
"""
|
434
|
+
|
435
|
+
def __init__(
|
436
|
+
self,
|
437
|
+
in_features,
|
438
|
+
hidden_features=None,
|
439
|
+
out_features=None,
|
440
|
+
act_layer=nn.GELU,
|
441
|
+
norm_layer=None,
|
442
|
+
bias=True,
|
443
|
+
drop=0.0,
|
444
|
+
use_conv=False,
|
445
|
+
):
|
446
|
+
super().__init__()
|
447
|
+
out_features = out_features or in_features
|
448
|
+
hidden_features = hidden_features or in_features
|
449
|
+
bias = to_2tuple(bias)
|
450
|
+
drop_probs = to_2tuple(drop)
|
451
|
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
452
|
+
|
453
|
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
454
|
+
self.act = act_layer()
|
455
|
+
self.drop1 = nn.Dropout(drop_probs[0])
|
456
|
+
self.norm = (
|
457
|
+
norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
458
|
+
)
|
459
|
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
460
|
+
self.drop2 = nn.Dropout(drop_probs[1])
|
461
|
+
|
462
|
+
def forward(self, x):
|
463
|
+
x = self.fc1(x)
|
464
|
+
x = self.act(x)
|
465
|
+
x = self.drop1(x)
|
466
|
+
x = self.norm(x)
|
467
|
+
x = self.fc2(x)
|
468
|
+
x = self.drop2(x)
|
469
|
+
return x
|
470
|
+
|
471
|
+
|
472
|
+
def drop_path(
|
473
|
+
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
474
|
+
):
|
475
|
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
476
|
+
|
477
|
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
478
|
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
479
|
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
480
|
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
481
|
+
'survival rate' as the argument.
|
482
|
+
|
483
|
+
"""
|
484
|
+
if drop_prob == 0.0 or not training:
|
485
|
+
return x
|
486
|
+
keep_prob = 1 - drop_prob
|
487
|
+
shape = (x.shape[0],) + (1,) * (
|
488
|
+
x.ndim - 1
|
489
|
+
) # work with diff dim tensors, not just 2D ConvNets
|
490
|
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
491
|
+
if keep_prob > 0.0 and scale_by_keep:
|
492
|
+
random_tensor.div_(keep_prob)
|
493
|
+
return x * random_tensor
|
494
|
+
|
495
|
+
|
496
|
+
class DropPath(nn.Module):
|
497
|
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
498
|
+
|
499
|
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
500
|
+
super(DropPath, self).__init__()
|
501
|
+
self.drop_prob = drop_prob
|
502
|
+
self.scale_by_keep = scale_by_keep
|
503
|
+
|
504
|
+
def forward(self, x):
|
505
|
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
506
|
+
|
507
|
+
def extra_repr(self):
|
508
|
+
return f"drop_prob={round(self.drop_prob, 3):0.3f}"
|
509
|
+
|
510
|
+
|
511
|
+
class VisionTransformerBlock(nn.Module):
|
512
|
+
def __init__(
|
513
|
+
self,
|
514
|
+
dim: int,
|
515
|
+
num_heads: int,
|
516
|
+
mlp_ratio: float = 4.0,
|
517
|
+
qkv_bias: bool = False,
|
518
|
+
qk_norm: bool = False,
|
519
|
+
proj_drop: float = 0.0,
|
520
|
+
attn_drop: float = 0.0,
|
521
|
+
init_values: Optional[float] = None,
|
522
|
+
drop_path: float = 0.0,
|
523
|
+
act_layer: nn.Module = nn.GELU,
|
524
|
+
norm_layer: nn.Module = nn.LayerNorm,
|
525
|
+
mlp_layer: nn.Module = Mlp,
|
526
|
+
) -> None:
|
527
|
+
super().__init__()
|
528
|
+
self.norm1 = norm_layer(dim)
|
529
|
+
self.attn = VisionAttention(
|
530
|
+
embed_dim=dim,
|
531
|
+
num_heads=num_heads,
|
532
|
+
projection_size=dim,
|
533
|
+
use_qkv_parallel=True,
|
534
|
+
use_context_forward=False,
|
535
|
+
softmax_in_single_precision=False,
|
536
|
+
dropout=attn_drop,
|
537
|
+
)
|
538
|
+
|
539
|
+
self.ls1 = (
|
540
|
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
541
|
+
)
|
542
|
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
543
|
+
|
544
|
+
self.norm2 = norm_layer(dim)
|
545
|
+
self.mlp = mlp_layer(
|
546
|
+
in_features=dim,
|
547
|
+
hidden_features=int(dim * mlp_ratio),
|
548
|
+
act_layer=act_layer,
|
549
|
+
drop=proj_drop,
|
550
|
+
)
|
551
|
+
self.ls2 = (
|
552
|
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
553
|
+
)
|
554
|
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
555
|
+
|
556
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
557
|
+
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
|
558
|
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
559
|
+
return x
|
560
|
+
|
561
|
+
|
562
|
+
LayerType = Union[str, Callable, Type[torch.nn.Module]]
|
563
|
+
|
564
|
+
|
565
|
+
class PatchDropout(nn.Module):
|
566
|
+
"""
|
567
|
+
https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220
|
568
|
+
"""
|
569
|
+
|
570
|
+
return_indices: torch.jit.Final[bool]
|
571
|
+
|
572
|
+
def __init__(
|
573
|
+
self,
|
574
|
+
prob: float = 0.5,
|
575
|
+
num_prefix_tokens: int = 1,
|
576
|
+
ordered: bool = False,
|
577
|
+
return_indices: bool = False,
|
578
|
+
):
|
579
|
+
super().__init__()
|
580
|
+
assert 0 <= prob < 1.0
|
581
|
+
self.prob = prob
|
582
|
+
self.num_prefix_tokens = (
|
583
|
+
num_prefix_tokens # exclude CLS token (or other prefix tokens)
|
584
|
+
)
|
585
|
+
self.ordered = ordered
|
586
|
+
self.return_indices = return_indices
|
587
|
+
|
588
|
+
def forward(
|
589
|
+
self, x
|
590
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
591
|
+
if not self.training or self.prob == 0.0:
|
592
|
+
if self.return_indices:
|
593
|
+
return x, None
|
594
|
+
return x
|
595
|
+
|
596
|
+
if self.num_prefix_tokens:
|
597
|
+
prefix_tokens, x = (
|
598
|
+
x[:, : self.num_prefix_tokens],
|
599
|
+
x[:, self.num_prefix_tokens :],
|
600
|
+
)
|
601
|
+
else:
|
602
|
+
prefix_tokens = None
|
603
|
+
|
604
|
+
B = x.shape[0]
|
605
|
+
L = x.shape[1]
|
606
|
+
num_keep = max(1, int(L * (1.0 - self.prob)))
|
607
|
+
keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[
|
608
|
+
:, :num_keep
|
609
|
+
]
|
610
|
+
if self.ordered:
|
611
|
+
# NOTE does not need to maintain patch order in typical transformer use,
|
612
|
+
# but possibly useful for debug / visualization
|
613
|
+
keep_indices = keep_indices.sort(dim=-1)[0]
|
614
|
+
x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:]))
|
615
|
+
|
616
|
+
if prefix_tokens is not None:
|
617
|
+
x = torch.cat((prefix_tokens, x), dim=1)
|
618
|
+
|
619
|
+
if self.return_indices:
|
620
|
+
return x, keep_indices
|
621
|
+
return x
|
622
|
+
|
623
|
+
|
624
|
+
def resample_abs_pos_embed(
|
625
|
+
posemb: torch.Tensor,
|
626
|
+
new_size: List[int],
|
627
|
+
old_size: Optional[List[int]] = None,
|
628
|
+
num_prefix_tokens: int = 1,
|
629
|
+
interpolation: str = "bicubic",
|
630
|
+
antialias: bool = True,
|
631
|
+
verbose: bool = False,
|
632
|
+
):
|
633
|
+
# sort out sizes, assume square if old size not provided
|
634
|
+
num_pos_tokens = posemb.shape[1]
|
635
|
+
num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens
|
636
|
+
if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
|
637
|
+
return posemb
|
638
|
+
|
639
|
+
if old_size is None:
|
640
|
+
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
|
641
|
+
old_size = hw, hw
|
642
|
+
|
643
|
+
if num_prefix_tokens:
|
644
|
+
posemb_prefix, posemb = (
|
645
|
+
posemb[:, :num_prefix_tokens],
|
646
|
+
posemb[:, num_prefix_tokens:],
|
647
|
+
)
|
648
|
+
else:
|
649
|
+
posemb_prefix, posemb = None, posemb
|
650
|
+
|
651
|
+
# do the interpolation
|
652
|
+
embed_dim = posemb.shape[-1]
|
653
|
+
orig_dtype = posemb.dtype
|
654
|
+
posemb = posemb.float() # interpolate needs float32
|
655
|
+
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
|
656
|
+
posemb = F.interpolate(
|
657
|
+
posemb, size=new_size, mode=interpolation, antialias=antialias
|
658
|
+
)
|
659
|
+
posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
|
660
|
+
posemb = posemb.to(orig_dtype)
|
661
|
+
|
662
|
+
# add back extra (class, etc) prefix tokens
|
663
|
+
if posemb_prefix is not None:
|
664
|
+
posemb = torch.cat([posemb_prefix, posemb], dim=1)
|
665
|
+
|
666
|
+
if not torch.jit.is_scripting() and verbose:
|
667
|
+
logger.info(f"Resized position embedding: {old_size} to {new_size}.")
|
668
|
+
|
669
|
+
return posemb
|
670
|
+
|
671
|
+
|
672
|
+
def init_weights(self):
|
673
|
+
if self.pos_embed is not None:
|
674
|
+
trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
675
|
+
trunc_normal_(self.latent, std=self.latent_dim**-0.5)
|
676
|
+
|
677
|
+
|
678
|
+
def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
|
679
|
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
680
|
+
if isinstance(module, nn.Linear):
|
681
|
+
trunc_normal_(module.weight, std=0.02)
|
682
|
+
if module.bias is not None:
|
683
|
+
nn.init.zeros_(module.bias)
|
684
|
+
elif hasattr(module, "init_weights"):
|
685
|
+
module.init_weights()
|
686
|
+
|
687
|
+
|
688
|
+
class VisionTransformer(nn.Module):
|
689
|
+
"""Vision Transformer
|
690
|
+
|
691
|
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
692
|
+
- https://arxiv.org/abs/2010.11929
|
693
|
+
"""
|
694
|
+
|
695
|
+
dynamic_img_size: Final[bool]
|
696
|
+
|
697
|
+
def __init__(
|
698
|
+
self,
|
699
|
+
img_size: Union[int, Tuple[int, int]] = 224,
|
700
|
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
701
|
+
in_chans: int = 3,
|
702
|
+
num_classes: int = 1000,
|
703
|
+
global_pool: Literal["", "avg", "token", "map"] = "token",
|
704
|
+
embed_dim: int = 768,
|
705
|
+
depth: int = 12,
|
706
|
+
num_heads: int = 12,
|
707
|
+
mlp_ratio: float = 4.0,
|
708
|
+
qkv_bias: bool = True,
|
709
|
+
qk_norm: bool = False,
|
710
|
+
init_values: Optional[float] = None,
|
711
|
+
class_token: bool = True,
|
712
|
+
no_embed_class: bool = False,
|
713
|
+
reg_tokens: int = 0,
|
714
|
+
pre_norm: bool = False,
|
715
|
+
fc_norm: Optional[bool] = None,
|
716
|
+
dynamic_img_size: bool = False,
|
717
|
+
dynamic_img_pad: bool = False,
|
718
|
+
drop_rate: float = 0.0,
|
719
|
+
pos_drop_rate: float = 0.0,
|
720
|
+
patch_drop_rate: float = 0.0,
|
721
|
+
proj_drop_rate: float = 0.0,
|
722
|
+
attn_drop_rate: float = 0.0,
|
723
|
+
drop_path_rate: float = 0.0,
|
724
|
+
weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
|
725
|
+
embed_layer: Callable = PatchEmbed,
|
726
|
+
_norm_layer: Optional[LayerType] = None,
|
727
|
+
_act_layer: Optional[LayerType] = None,
|
728
|
+
block_fn: Type[nn.Module] = VisionTransformerBlock,
|
729
|
+
mlp_layer: Type[nn.Module] = Mlp,
|
730
|
+
ignore_head: bool = False,
|
731
|
+
) -> None:
|
732
|
+
"""
|
733
|
+
Args:
|
734
|
+
img_size: Input image size.
|
735
|
+
patch_size: Patch size.
|
736
|
+
in_chans: Number of image input channels.
|
737
|
+
num_classes: Mumber of classes for classification head.
|
738
|
+
global_pool: Type of global pooling for final sequence (default: 'token').
|
739
|
+
embed_dim: Transformer embedding dimension.
|
740
|
+
depth: Depth of transformer.
|
741
|
+
num_heads: Number of attention heads.
|
742
|
+
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
743
|
+
qkv_bias: Enable bias for qkv projections if True.
|
744
|
+
init_values: Layer-scale init values (layer-scale enabled if not None).
|
745
|
+
class_token: Use class token.
|
746
|
+
no_embed_class: Don't include position embeddings for class (or reg) tokens.
|
747
|
+
reg_tokens: Number of register tokens.
|
748
|
+
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
|
749
|
+
drop_rate: Head dropout rate.
|
750
|
+
pos_drop_rate: Position embedding dropout rate.
|
751
|
+
attn_drop_rate: Attention dropout rate.
|
752
|
+
drop_path_rate: Stochastic depth rate.
|
753
|
+
weight_init: Weight initialization scheme.
|
754
|
+
embed_layer: Patch embedding layer.
|
755
|
+
_norm_layer: Normalization layer.
|
756
|
+
_act_layer: MLP activation layer.
|
757
|
+
block_fn: Transformer block layer.
|
758
|
+
"""
|
759
|
+
super().__init__()
|
760
|
+
assert global_pool in ("", "avg", "token", "map")
|
761
|
+
assert class_token or global_pool != "token"
|
762
|
+
use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
|
763
|
+
# norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
764
|
+
# act_layer = get_act_layer(act_layer) or nn.GELU
|
765
|
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
766
|
+
act_layer = nn.GELU
|
767
|
+
|
768
|
+
self.num_classes = num_classes
|
769
|
+
self.global_pool = global_pool
|
770
|
+
self.num_features = self.embed_dim = (
|
771
|
+
embed_dim # num_features for consistency with other models
|
772
|
+
)
|
773
|
+
self.num_prefix_tokens = 1 if class_token else 0
|
774
|
+
self.num_prefix_tokens += reg_tokens
|
775
|
+
self.num_reg_tokens = reg_tokens
|
776
|
+
self.has_class_token = class_token
|
777
|
+
self.no_embed_class = (
|
778
|
+
no_embed_class # don't embed prefix positions (includes reg)
|
779
|
+
)
|
780
|
+
self.dynamic_img_size = dynamic_img_size
|
781
|
+
self.grad_checkpointing = False
|
782
|
+
self.ignore_head = ignore_head
|
783
|
+
|
784
|
+
embed_args = {}
|
785
|
+
if dynamic_img_size:
|
786
|
+
# flatten deferred until after pos embed
|
787
|
+
embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
|
788
|
+
self.patch_embed = embed_layer(
|
789
|
+
img_size=img_size,
|
790
|
+
patch_size=patch_size,
|
791
|
+
in_chans=in_chans,
|
792
|
+
embed_dim=embed_dim,
|
793
|
+
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
794
|
+
dynamic_img_pad=dynamic_img_pad,
|
795
|
+
**embed_args,
|
796
|
+
)
|
797
|
+
num_patches = self.patch_embed.num_patches
|
798
|
+
|
799
|
+
self.cls_token = (
|
800
|
+
nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
801
|
+
)
|
802
|
+
self.reg_token = (
|
803
|
+
nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
|
804
|
+
)
|
805
|
+
embed_len = (
|
806
|
+
num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
807
|
+
)
|
808
|
+
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
|
809
|
+
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
810
|
+
if patch_drop_rate > 0:
|
811
|
+
self.patch_drop = PatchDropout(
|
812
|
+
patch_drop_rate,
|
813
|
+
num_prefix_tokens=self.num_prefix_tokens,
|
814
|
+
)
|
815
|
+
else:
|
816
|
+
self.patch_drop = nn.Identity()
|
817
|
+
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
|
818
|
+
|
819
|
+
dpr = [
|
820
|
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
821
|
+
] # stochastic depth decay rule
|
822
|
+
self.blocks = nn.Sequential(
|
823
|
+
*[
|
824
|
+
block_fn(
|
825
|
+
dim=embed_dim,
|
826
|
+
num_heads=num_heads,
|
827
|
+
mlp_ratio=mlp_ratio,
|
828
|
+
qkv_bias=qkv_bias,
|
829
|
+
qk_norm=qk_norm,
|
830
|
+
init_values=init_values,
|
831
|
+
proj_drop=proj_drop_rate,
|
832
|
+
attn_drop=attn_drop_rate,
|
833
|
+
drop_path=dpr[i],
|
834
|
+
norm_layer=norm_layer,
|
835
|
+
act_layer=act_layer,
|
836
|
+
mlp_layer=mlp_layer,
|
837
|
+
)
|
838
|
+
for i in range(depth)
|
839
|
+
]
|
840
|
+
)
|
841
|
+
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
842
|
+
|
843
|
+
# Classifier Head
|
844
|
+
if global_pool == "map":
|
845
|
+
AttentionPoolLatent.init_weights = init_weights
|
846
|
+
self.attn_pool = AttentionPoolLatent(
|
847
|
+
self.embed_dim,
|
848
|
+
num_heads=num_heads,
|
849
|
+
mlp_ratio=mlp_ratio,
|
850
|
+
norm_layer=norm_layer,
|
851
|
+
)
|
852
|
+
else:
|
853
|
+
self.attn_pool = None
|
854
|
+
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
855
|
+
self.head_drop = nn.Dropout(drop_rate)
|
856
|
+
self.head = (
|
857
|
+
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
858
|
+
)
|
859
|
+
|
860
|
+
if weight_init != "skip":
|
861
|
+
self.init_weights(weight_init)
|
862
|
+
|
863
|
+
def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
|
864
|
+
assert mode in ("jax", "jax_nlhb", "moco", "")
|
865
|
+
# head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
|
866
|
+
trunc_normal_(self.pos_embed, std=0.02)
|
867
|
+
if self.cls_token is not None:
|
868
|
+
nn.init.normal_(self.cls_token, std=1e-6)
|
869
|
+
named_apply(init_weights_vit_timm, self)
|
870
|
+
|
871
|
+
@torch.jit.ignore
|
872
|
+
def no_weight_decay(self) -> Set:
|
873
|
+
return {"pos_embed", "cls_token", "dist_token"}
|
874
|
+
|
875
|
+
@torch.jit.ignore
|
876
|
+
def group_matcher(self, coarse: bool = False) -> Dict:
|
877
|
+
return dict(
|
878
|
+
stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
|
879
|
+
blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
|
880
|
+
)
|
881
|
+
|
882
|
+
@torch.jit.ignore
|
883
|
+
def get_classifier(self) -> nn.Module:
|
884
|
+
return self.head
|
885
|
+
|
886
|
+
def reset_classifier(self, num_classes: int, global_pool=None) -> None:
|
887
|
+
self.num_classes = num_classes
|
888
|
+
if global_pool is not None:
|
889
|
+
assert global_pool in ("", "avg", "token", "map")
|
890
|
+
if global_pool == "map" and self.attn_pool is None:
|
891
|
+
assert (
|
892
|
+
False
|
893
|
+
), "Cannot currently add attention pooling in reset_classifier()."
|
894
|
+
elif global_pool != "map " and self.attn_pool is not None:
|
895
|
+
self.attn_pool = None # remove attention pooling
|
896
|
+
self.global_pool = global_pool
|
897
|
+
self.head = (
|
898
|
+
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
899
|
+
)
|
900
|
+
|
901
|
+
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
902
|
+
if self.dynamic_img_size:
|
903
|
+
B, H, W, C = x.shape
|
904
|
+
pos_embed = resample_abs_pos_embed(
|
905
|
+
self.pos_embed,
|
906
|
+
[H, W],
|
907
|
+
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
908
|
+
)
|
909
|
+
x = x.view(B, -1, C)
|
910
|
+
else:
|
911
|
+
pos_embed = self.pos_embed
|
912
|
+
|
913
|
+
to_cat = []
|
914
|
+
if self.cls_token is not None:
|
915
|
+
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
|
916
|
+
if self.reg_token is not None:
|
917
|
+
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
|
918
|
+
|
919
|
+
if self.no_embed_class:
|
920
|
+
# deit-3, updated JAX (big vision)
|
921
|
+
# position embedding does not overlap with class token, add then concat
|
922
|
+
x = x + pos_embed
|
923
|
+
if to_cat:
|
924
|
+
x = torch.cat(to_cat + [x], dim=1)
|
925
|
+
else:
|
926
|
+
# original timm, JAX, and deit vit impl
|
927
|
+
# pos_embed has entry for class token, concat then add
|
928
|
+
if to_cat:
|
929
|
+
x = torch.cat(to_cat + [x], dim=1)
|
930
|
+
x = x + pos_embed
|
931
|
+
|
932
|
+
return self.pos_drop(x)
|
933
|
+
|
934
|
+
def _intermediate_layers(
|
935
|
+
self,
|
936
|
+
x: torch.Tensor,
|
937
|
+
n: Union[int, Sequence] = 1,
|
938
|
+
) -> List[torch.Tensor]:
|
939
|
+
outputs, num_blocks = [], len(self.blocks)
|
940
|
+
take_indices = set(
|
941
|
+
range(num_blocks - n, num_blocks) if isinstance(n, int) else n
|
942
|
+
)
|
943
|
+
|
944
|
+
# forward pass
|
945
|
+
x = self.patch_embed(x)
|
946
|
+
x = self._pos_embed(x)
|
947
|
+
x = self.patch_drop(x)
|
948
|
+
x = self.norm_pre(x)
|
949
|
+
for i, blk in enumerate(self.blocks):
|
950
|
+
x = blk(x)
|
951
|
+
if i in take_indices:
|
952
|
+
outputs.append(x)
|
953
|
+
|
954
|
+
return outputs
|
955
|
+
|
956
|
+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
957
|
+
x = self.patch_embed(x)
|
958
|
+
x = self._pos_embed(x)
|
959
|
+
x = self.patch_drop(x)
|
960
|
+
x = self.norm_pre(x)
|
961
|
+
x = self.blocks(x)
|
962
|
+
x = self.norm(x)
|
963
|
+
return x
|
964
|
+
|
965
|
+
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
966
|
+
if self.attn_pool is not None:
|
967
|
+
x = self.attn_pool(x)
|
968
|
+
elif self.global_pool == "avg":
|
969
|
+
x = x[:, self.num_prefix_tokens :].mean(dim=1)
|
970
|
+
elif self.global_pool:
|
971
|
+
x = x[:, 0] # class token
|
972
|
+
x = self.fc_norm(x)
|
973
|
+
x = self.head_drop(x)
|
974
|
+
return x if pre_logits else self.head(x)
|
975
|
+
|
976
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
977
|
+
x = self.forward_features(x)
|
978
|
+
if not self.ignore_head:
|
979
|
+
x = self.forward_head(x)
|
980
|
+
return x
|
981
|
+
|
982
|
+
|
983
|
+
def model_name_to_cls(cls_name):
|
984
|
+
if "MlpProjector" in cls_name:
|
985
|
+
cls = MlpProjector
|
986
|
+
|
987
|
+
elif "CLIPVisionTower" in cls_name:
|
988
|
+
cls = CLIPVisionTower
|
989
|
+
|
990
|
+
elif "VQ" in cls_name:
|
991
|
+
|
992
|
+
cls = VQ_models[cls_name]
|
993
|
+
elif "vision_head" in cls_name:
|
994
|
+
cls = vision_head
|
995
|
+
else:
|
996
|
+
raise ValueError(f"class_name {cls_name} is invalid.")
|
997
|
+
|
998
|
+
return cls
|
999
|
+
|
1000
|
+
|
1001
|
+
class vision_head(torch.nn.Module):
|
1002
|
+
def __init__(self, params):
|
1003
|
+
super().__init__()
|
1004
|
+
self.output_mlp_projector = torch.nn.Linear(
|
1005
|
+
params["n_embed"], params["image_token_embed"]
|
1006
|
+
)
|
1007
|
+
self.vision_activation = torch.nn.GELU()
|
1008
|
+
self.vision_head = torch.nn.Linear(
|
1009
|
+
params["image_token_embed"], params["image_token_size"]
|
1010
|
+
)
|
1011
|
+
|
1012
|
+
def forward(self, x):
|
1013
|
+
x = self.output_mlp_projector(x)
|
1014
|
+
x = self.vision_activation(x)
|
1015
|
+
x = self.vision_head(x)
|
1016
|
+
return x
|
1017
|
+
|
1018
|
+
|
1019
|
+
SigLIP_MODEL_CONFIG = {
|
1020
|
+
"siglip_so400m_patch14_384": {
|
1021
|
+
"image_size": 336,
|
1022
|
+
"patch_size": 14,
|
1023
|
+
"width": 1152,
|
1024
|
+
"layers": 27,
|
1025
|
+
"heads": 16,
|
1026
|
+
"mlp_ratio": 3.7362,
|
1027
|
+
"global_pool": "map",
|
1028
|
+
"use_checkpoint": False,
|
1029
|
+
},
|
1030
|
+
"siglip_so400m_patch14_224": {
|
1031
|
+
"image_size": 224,
|
1032
|
+
"patch_size": 14,
|
1033
|
+
"width": 1152,
|
1034
|
+
"layers": 27,
|
1035
|
+
"heads": 16,
|
1036
|
+
"mlp_ratio": 3.7362,
|
1037
|
+
"global_pool": "map",
|
1038
|
+
"use_checkpoint": False,
|
1039
|
+
},
|
1040
|
+
"siglip_large_patch16_384": {
|
1041
|
+
"image_size": 384,
|
1042
|
+
"patch_size": 16,
|
1043
|
+
"width": 1024,
|
1044
|
+
"layers": 24,
|
1045
|
+
"heads": 16,
|
1046
|
+
"mlp_ratio": 4,
|
1047
|
+
"global_pool": "map",
|
1048
|
+
"use_checkpoint": False,
|
1049
|
+
},
|
1050
|
+
}
|
1051
|
+
|
1052
|
+
|
1053
|
+
def create_siglip_vit(
|
1054
|
+
model_name: str = "siglip_so400m_patch14_384",
|
1055
|
+
image_size: int = 384,
|
1056
|
+
select_layer: int = -1,
|
1057
|
+
ckpt_path: str = "",
|
1058
|
+
**kwargs,
|
1059
|
+
):
|
1060
|
+
assert (
|
1061
|
+
model_name in SigLIP_MODEL_CONFIG.keys()
|
1062
|
+
), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
|
1063
|
+
|
1064
|
+
vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
|
1065
|
+
|
1066
|
+
if select_layer <= 0:
|
1067
|
+
layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
|
1068
|
+
else:
|
1069
|
+
layers = min(vision_cfg.layers, select_layer)
|
1070
|
+
|
1071
|
+
model = VisionTransformer(
|
1072
|
+
img_size=image_size,
|
1073
|
+
patch_size=vision_cfg.patch_size,
|
1074
|
+
embed_dim=vision_cfg.width,
|
1075
|
+
depth=layers,
|
1076
|
+
num_heads=vision_cfg.heads,
|
1077
|
+
mlp_ratio=vision_cfg.mlp_ratio,
|
1078
|
+
class_token=vision_cfg.class_token,
|
1079
|
+
global_pool=vision_cfg.global_pool,
|
1080
|
+
ignore_head=kwargs.get("ignore_head", True),
|
1081
|
+
weight_init=kwargs.get("weight_init", "skip"),
|
1082
|
+
num_classes=0,
|
1083
|
+
)
|
1084
|
+
|
1085
|
+
if ckpt_path:
|
1086
|
+
state_dict = torch.load(ckpt_path, map_location="cpu")
|
1087
|
+
|
1088
|
+
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
1089
|
+
print(
|
1090
|
+
f"SigLIP-ViT restores from {ckpt_path},\n"
|
1091
|
+
f"\tincompatible_keys:', {incompatible_keys}."
|
1092
|
+
)
|
1093
|
+
|
1094
|
+
return model
|
1095
|
+
|
1096
|
+
|
1097
|
+
class Normalize(torch.nn.Module):
|
1098
|
+
"""Normalize a tensor image with mean and standard deviation.
|
1099
|
+
This transform does not support PIL Image.
|
1100
|
+
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
|
1101
|
+
channels, this transform will normalize each channel of the input
|
1102
|
+
``torch.*Tensor`` i.e.,
|
1103
|
+
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
|
1104
|
+
|
1105
|
+
.. note::
|
1106
|
+
This transform acts out of place, i.e., it does not mutate the input tensor.
|
1107
|
+
|
1108
|
+
Args:
|
1109
|
+
mean (sequence): Sequence of means for each channel.
|
1110
|
+
std (sequence): Sequence of standard deviations for each channel.
|
1111
|
+
inplace(bool,optional): Bool to make this operation in-place.
|
1112
|
+
|
1113
|
+
"""
|
1114
|
+
|
1115
|
+
def __init__(self, mean, std, inplace=False):
|
1116
|
+
super().__init__()
|
1117
|
+
# _log_api_usage_once(self)
|
1118
|
+
self.mean = mean
|
1119
|
+
self.std = std
|
1120
|
+
self.inplace = inplace
|
1121
|
+
|
1122
|
+
def forward(self, tensor: Tensor) -> Tensor:
|
1123
|
+
"""
|
1124
|
+
Args:
|
1125
|
+
tensor (Tensor): Tensor image to be normalized.
|
1126
|
+
|
1127
|
+
Returns:
|
1128
|
+
Tensor: Normalized Tensor image.
|
1129
|
+
"""
|
1130
|
+
return F.normalize(tensor, self.mean, self.std, self.inplace)
|
1131
|
+
|
1132
|
+
def __repr__(self) -> str:
|
1133
|
+
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
|
1134
|
+
|
1135
|
+
|
1136
|
+
class CLIPVisionTower(nn.Module):
|
1137
|
+
def __init__(
|
1138
|
+
self,
|
1139
|
+
model_name: str = "siglip_large_patch16_384",
|
1140
|
+
image_size: Union[Tuple[int, int], int] = 336,
|
1141
|
+
select_feature: str = "patch",
|
1142
|
+
select_layer: int = -2,
|
1143
|
+
select_layers: list = None,
|
1144
|
+
ckpt_path: str = "",
|
1145
|
+
pixel_mean: Optional[List[float]] = None,
|
1146
|
+
pixel_std: Optional[List[float]] = None,
|
1147
|
+
**kwargs,
|
1148
|
+
):
|
1149
|
+
super().__init__()
|
1150
|
+
|
1151
|
+
self.model_name = model_name
|
1152
|
+
self.select_feature = select_feature
|
1153
|
+
self.select_layer = select_layer
|
1154
|
+
self.select_layers = select_layers
|
1155
|
+
|
1156
|
+
vision_tower_params = {
|
1157
|
+
"model_name": model_name,
|
1158
|
+
"image_size": image_size,
|
1159
|
+
"ckpt_path": ckpt_path,
|
1160
|
+
"select_layer": select_layer,
|
1161
|
+
}
|
1162
|
+
vision_tower_params.update(kwargs)
|
1163
|
+
self.vision_tower, self.forward_kwargs = self.build_vision_tower(
|
1164
|
+
vision_tower_params
|
1165
|
+
)
|
1166
|
+
|
1167
|
+
if pixel_mean is not None and pixel_std is not None:
|
1168
|
+
image_norm = Normalize(mean=pixel_mean, std=pixel_std)
|
1169
|
+
else:
|
1170
|
+
image_norm = None
|
1171
|
+
|
1172
|
+
self.image_norm = image_norm
|
1173
|
+
|
1174
|
+
@property
|
1175
|
+
def device(self) -> torch.device:
|
1176
|
+
return next(self.vision_tower.parameters()).device
|
1177
|
+
|
1178
|
+
@property
|
1179
|
+
def dtype(self):
|
1180
|
+
return next(self.vision_tower.parameters()).dtype
|
1181
|
+
|
1182
|
+
def build_vision_tower(self, vision_tower_params):
|
1183
|
+
if self.model_name.startswith("siglip"):
|
1184
|
+
self.select_feature = "same"
|
1185
|
+
vision_tower = create_siglip_vit(**vision_tower_params)
|
1186
|
+
forward_kwargs = dict()
|
1187
|
+
|
1188
|
+
elif self.model_name.startswith("sam"):
|
1189
|
+
# vision_tower = create_sam_vit(**vision_tower_params)
|
1190
|
+
forward_kwargs = dict()
|
1191
|
+
|
1192
|
+
else: # huggingface
|
1193
|
+
from transformers import CLIPVisionModel
|
1194
|
+
|
1195
|
+
vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params)
|
1196
|
+
forward_kwargs = dict(output_hidden_states=True)
|
1197
|
+
|
1198
|
+
return vision_tower, forward_kwargs
|
1199
|
+
|
1200
|
+
def feature_select(self, image_forward_outs):
|
1201
|
+
if isinstance(image_forward_outs, torch.Tensor):
|
1202
|
+
# the output has been the self.select_layer"s features
|
1203
|
+
image_features = image_forward_outs
|
1204
|
+
else:
|
1205
|
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
1206
|
+
|
1207
|
+
if self.select_feature == "patch":
|
1208
|
+
# if the output has cls_token
|
1209
|
+
image_features = image_features[:, 1:]
|
1210
|
+
elif self.select_feature == "cls_patch":
|
1211
|
+
image_features = image_features
|
1212
|
+
elif self.select_feature == "same":
|
1213
|
+
image_features = image_features
|
1214
|
+
|
1215
|
+
else:
|
1216
|
+
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
1217
|
+
return image_features
|
1218
|
+
|
1219
|
+
def forward(self, images):
|
1220
|
+
"""
|
1221
|
+
|
1222
|
+
Args:
|
1223
|
+
images (torch.Tensor): [b, 3, H, W]
|
1224
|
+
|
1225
|
+
Returns:
|
1226
|
+
image_features (torch.Tensor): [b, n_patch, d]
|
1227
|
+
"""
|
1228
|
+
|
1229
|
+
if self.image_norm is not None:
|
1230
|
+
images = self.image_norm(images)
|
1231
|
+
|
1232
|
+
image_forward_outs = self.vision_tower(images, **self.forward_kwargs)
|
1233
|
+
image_features = self.feature_select(image_forward_outs)
|
1234
|
+
return image_features
|
1235
|
+
|
1236
|
+
|
1237
|
+
class MlpProjector(nn.Module):
|
1238
|
+
def __init__(self, cfg):
|
1239
|
+
super().__init__()
|
1240
|
+
|
1241
|
+
self.cfg = cfg
|
1242
|
+
|
1243
|
+
if cfg["projector_type"] == "identity":
|
1244
|
+
modules = nn.Identity()
|
1245
|
+
|
1246
|
+
elif cfg["projector_type"] == "linear":
|
1247
|
+
modules = nn.Linear(cfg["input_dim"], cfg["n_embed"])
|
1248
|
+
|
1249
|
+
elif cfg["projector_type"] == "mlp_gelu":
|
1250
|
+
mlp_depth = cfg.get("depth", 1)
|
1251
|
+
modules = [nn.Linear(cfg["input_dim"], cfg["n_embed"])]
|
1252
|
+
for _ in range(1, mlp_depth):
|
1253
|
+
modules.append(nn.GELU())
|
1254
|
+
modules.append(nn.Linear(cfg["n_embed"], cfg["n_embed"]))
|
1255
|
+
modules = nn.Sequential(*modules)
|
1256
|
+
|
1257
|
+
elif cfg["projector_type"] == "low_high_hybrid_split_mlp_gelu":
|
1258
|
+
mlp_depth = cfg.get("depth", 1)
|
1259
|
+
self.high_up_proj = nn.Linear(cfg["input_dim"], cfg["n_embed"] // 2)
|
1260
|
+
self.low_up_proj = nn.Linear(cfg["input_dim"], cfg["n_embed"] // 2)
|
1261
|
+
|
1262
|
+
modules = []
|
1263
|
+
for _ in range(1, mlp_depth):
|
1264
|
+
modules.append(nn.GELU())
|
1265
|
+
modules.append(nn.Linear(cfg["n_embed"], cfg["n_embed"]))
|
1266
|
+
modules = nn.Sequential(*modules)
|
1267
|
+
|
1268
|
+
else:
|
1269
|
+
raise ValueError(f"Unknown projector type: {cfg['projector_type']}")
|
1270
|
+
|
1271
|
+
self.layers = modules
|
1272
|
+
|
1273
|
+
def forward(
|
1274
|
+
self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]
|
1275
|
+
):
|
1276
|
+
"""
|
1277
|
+
|
1278
|
+
Args:
|
1279
|
+
x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor,
|
1280
|
+
then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x);
|
1281
|
+
otherwise it is the feature from the single vision encoder.
|
1282
|
+
|
1283
|
+
Returns:
|
1284
|
+
x (torch.Tensor): [b, s, c]
|
1285
|
+
"""
|
1286
|
+
|
1287
|
+
if isinstance(x_or_tuple, tuple):
|
1288
|
+
# self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
|
1289
|
+
high_x, low_x = x_or_tuple
|
1290
|
+
high_x = self.high_up_proj(high_x)
|
1291
|
+
low_x = self.low_up_proj(low_x)
|
1292
|
+
x = torch.concat([high_x, low_x], dim=-1)
|
1293
|
+
else:
|
1294
|
+
x = x_or_tuple
|
1295
|
+
|
1296
|
+
return self.layers(x)
|
1297
|
+
|
1298
|
+
|
1299
|
+
class LayerScale(nn.Module):
|
1300
|
+
def __init__(
|
1301
|
+
self,
|
1302
|
+
dim: int,
|
1303
|
+
init_values: float = 1e-5,
|
1304
|
+
inplace: bool = False,
|
1305
|
+
) -> None:
|
1306
|
+
super().__init__()
|
1307
|
+
self.inplace = inplace
|
1308
|
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
1309
|
+
|
1310
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
1311
|
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
1312
|
+
|
1313
|
+
|
1314
|
+
# use torch.scaled_dot_product_attention where possible
|
1315
|
+
_HAS_FUSED_ATTN = hasattr(torch.nn.functional, "scaled_dot_product_attention")
|
1316
|
+
if "TIMM_FUSED_ATTN" in os.environ:
|
1317
|
+
_USE_FUSED_ATTN = int(os.environ["TIMM_FUSED_ATTN"])
|
1318
|
+
else:
|
1319
|
+
_USE_FUSED_ATTN = (
|
1320
|
+
1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use)
|
1321
|
+
)
|
1322
|
+
|
1323
|
+
# Set to True if exporting a model with Same padding via ONNX
|
1324
|
+
_EXPORTABLE = False
|
1325
|
+
|
1326
|
+
|
1327
|
+
def use_fused_attn(experimental: bool = False) -> bool:
|
1328
|
+
# NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0
|
1329
|
+
if not _HAS_FUSED_ATTN or _EXPORTABLE:
|
1330
|
+
return False
|
1331
|
+
if experimental:
|
1332
|
+
return _USE_FUSED_ATTN > 1
|
1333
|
+
return _USE_FUSED_ATTN > 0
|
1334
|
+
|
1335
|
+
|
1336
|
+
class AttentionPoolLatent(nn.Module):
|
1337
|
+
"""Attention pooling w/ latent query"""
|
1338
|
+
|
1339
|
+
fused_attn: torch.jit.Final[bool]
|
1340
|
+
|
1341
|
+
def __init__(
|
1342
|
+
self,
|
1343
|
+
in_features: int,
|
1344
|
+
out_features: int = None,
|
1345
|
+
embed_dim: int = None,
|
1346
|
+
num_heads: int = 8,
|
1347
|
+
feat_size: Optional[int] = None,
|
1348
|
+
mlp_ratio: float = 4.0,
|
1349
|
+
qkv_bias: bool = True,
|
1350
|
+
qk_norm: bool = False,
|
1351
|
+
latent_len: int = 1,
|
1352
|
+
latent_dim: int = None,
|
1353
|
+
pos_embed: str = "",
|
1354
|
+
pool_type: str = "token",
|
1355
|
+
norm_layer: Optional[nn.Module] = None,
|
1356
|
+
drop: float = 0.0,
|
1357
|
+
):
|
1358
|
+
super().__init__()
|
1359
|
+
embed_dim = embed_dim or in_features
|
1360
|
+
out_features = out_features or in_features
|
1361
|
+
assert embed_dim % num_heads == 0
|
1362
|
+
self.num_heads = num_heads
|
1363
|
+
self.head_dim = embed_dim // num_heads
|
1364
|
+
self.feat_size = feat_size
|
1365
|
+
self.scale = self.head_dim**-0.5
|
1366
|
+
self.pool = pool_type
|
1367
|
+
self.fused_attn = use_fused_attn()
|
1368
|
+
|
1369
|
+
if pos_embed == "abs":
|
1370
|
+
assert feat_size is not None
|
1371
|
+
self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features))
|
1372
|
+
else:
|
1373
|
+
self.pos_embed = None
|
1374
|
+
|
1375
|
+
self.latent_dim = latent_dim or embed_dim
|
1376
|
+
self.latent_len = latent_len
|
1377
|
+
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim))
|
1378
|
+
|
1379
|
+
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
|
1380
|
+
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
|
1381
|
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
1382
|
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
1383
|
+
self.proj = nn.Linear(embed_dim, embed_dim)
|
1384
|
+
self.proj_drop = nn.Dropout(drop)
|
1385
|
+
|
1386
|
+
self.norm = (
|
1387
|
+
norm_layer(out_features) if norm_layer is not None else nn.Identity()
|
1388
|
+
)
|
1389
|
+
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio))
|
1390
|
+
|
1391
|
+
self.init_weights()
|
1392
|
+
|
1393
|
+
def init_weights(self):
|
1394
|
+
if self.pos_embed is not None:
|
1395
|
+
trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
1396
|
+
trunc_normal_tf_(self.latent, std=self.latent_dim**-0.5)
|
1397
|
+
|
1398
|
+
def forward(self, x):
|
1399
|
+
B, N, C = x.shape
|
1400
|
+
|
1401
|
+
if self.pos_embed is not None:
|
1402
|
+
# FIXME interpolate
|
1403
|
+
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
|
1404
|
+
|
1405
|
+
q_latent = self.latent.expand(B, -1, -1)
|
1406
|
+
q = (
|
1407
|
+
self.q(q_latent)
|
1408
|
+
.reshape(B, self.latent_len, self.num_heads, self.head_dim)
|
1409
|
+
.transpose(1, 2)
|
1410
|
+
)
|
1411
|
+
|
1412
|
+
kv = (
|
1413
|
+
self.kv(x)
|
1414
|
+
.reshape(B, N, 2, self.num_heads, self.head_dim)
|
1415
|
+
.permute(2, 0, 3, 1, 4)
|
1416
|
+
)
|
1417
|
+
k, v = kv.unbind(0)
|
1418
|
+
|
1419
|
+
q, k = self.q_norm(q), self.k_norm(k)
|
1420
|
+
|
1421
|
+
if self.fused_attn:
|
1422
|
+
x = F.scaled_dot_product_attention(q, k, v)
|
1423
|
+
else:
|
1424
|
+
q = q * self.scale
|
1425
|
+
attn = q @ k.transpose(-2, -1)
|
1426
|
+
attn = attn.softmax(dim=-1)
|
1427
|
+
x = attn @ v
|
1428
|
+
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
|
1429
|
+
x = self.proj(x)
|
1430
|
+
x = self.proj_drop(x)
|
1431
|
+
|
1432
|
+
x = x + self.mlp(self.norm(x))
|
1433
|
+
|
1434
|
+
# optional pool if latent seq_len > 1 and pooled output is desired
|
1435
|
+
if self.pool == "token":
|
1436
|
+
x = x[:, 0]
|
1437
|
+
elif self.pool == "avg":
|
1438
|
+
x = x.mean(1)
|
1439
|
+
|
1440
|
+
|
1441
|
+
class Encoder(nn.Module):
|
1442
|
+
def __init__(
|
1443
|
+
self,
|
1444
|
+
in_channels=3,
|
1445
|
+
ch=128,
|
1446
|
+
ch_mult=(1, 1, 2, 2, 4),
|
1447
|
+
num_res_blocks=2,
|
1448
|
+
norm_type="group",
|
1449
|
+
dropout=0.0,
|
1450
|
+
resamp_with_conv=True,
|
1451
|
+
z_channels=256,
|
1452
|
+
):
|
1453
|
+
super().__init__()
|
1454
|
+
self.num_resolutions = len(ch_mult)
|
1455
|
+
self.num_res_blocks = num_res_blocks
|
1456
|
+
self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)
|
1457
|
+
|
1458
|
+
# downsampling
|
1459
|
+
in_ch_mult = (1,) + tuple(ch_mult)
|
1460
|
+
self.conv_blocks = nn.ModuleList()
|
1461
|
+
for i_level in range(self.num_resolutions):
|
1462
|
+
conv_block = nn.Module()
|
1463
|
+
# res & attn
|
1464
|
+
res_block = nn.ModuleList()
|
1465
|
+
attn_block = nn.ModuleList()
|
1466
|
+
block_in = ch * in_ch_mult[i_level]
|
1467
|
+
block_out = ch * ch_mult[i_level]
|
1468
|
+
for _ in range(self.num_res_blocks):
|
1469
|
+
res_block.append(
|
1470
|
+
ResnetBlock(
|
1471
|
+
block_in, block_out, dropout=dropout, norm_type=norm_type
|
1472
|
+
)
|
1473
|
+
)
|
1474
|
+
block_in = block_out
|
1475
|
+
if i_level == self.num_resolutions - 1:
|
1476
|
+
attn_block.append(AttnBlock(block_in, norm_type))
|
1477
|
+
conv_block.res = res_block
|
1478
|
+
conv_block.attn = attn_block
|
1479
|
+
# downsample
|
1480
|
+
if i_level != self.num_resolutions - 1:
|
1481
|
+
conv_block.downsample = Downsample(block_in, resamp_with_conv)
|
1482
|
+
self.conv_blocks.append(conv_block)
|
1483
|
+
|
1484
|
+
# middle
|
1485
|
+
self.mid = nn.ModuleList()
|
1486
|
+
self.mid.append(
|
1487
|
+
ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
|
1488
|
+
)
|
1489
|
+
self.mid.append(AttnBlock(block_in, norm_type=norm_type))
|
1490
|
+
self.mid.append(
|
1491
|
+
ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
|
1492
|
+
)
|
1493
|
+
|
1494
|
+
# end
|
1495
|
+
self.norm_out = Normalize(block_in, norm_type)
|
1496
|
+
self.conv_out = nn.Conv2d(
|
1497
|
+
block_in, z_channels, kernel_size=3, stride=1, padding=1
|
1498
|
+
)
|
1499
|
+
|
1500
|
+
def forward(self, x):
|
1501
|
+
h = self.conv_in(x)
|
1502
|
+
# downsampling
|
1503
|
+
for i_level, block in enumerate(self.conv_blocks):
|
1504
|
+
for i_block in range(self.num_res_blocks):
|
1505
|
+
h = block.res[i_block](h)
|
1506
|
+
if len(block.attn) > 0:
|
1507
|
+
h = block.attn[i_block](h)
|
1508
|
+
if i_level != self.num_resolutions - 1:
|
1509
|
+
h = block.downsample(h)
|
1510
|
+
|
1511
|
+
# middle
|
1512
|
+
for mid_block in self.mid:
|
1513
|
+
h = mid_block(h)
|
1514
|
+
|
1515
|
+
# end
|
1516
|
+
h = self.norm_out(h)
|
1517
|
+
h = nonlinearity(h)
|
1518
|
+
h = self.conv_out(h)
|
1519
|
+
return h
|
1520
|
+
|
1521
|
+
|
1522
|
+
class Decoder(nn.Module):
|
1523
|
+
def __init__(
|
1524
|
+
self,
|
1525
|
+
z_channels=256,
|
1526
|
+
ch=128,
|
1527
|
+
ch_mult=(1, 1, 2, 2, 4),
|
1528
|
+
num_res_blocks=2,
|
1529
|
+
norm_type="group",
|
1530
|
+
dropout=0.0,
|
1531
|
+
resamp_with_conv=True,
|
1532
|
+
out_channels=3,
|
1533
|
+
):
|
1534
|
+
super().__init__()
|
1535
|
+
self.num_resolutions = len(ch_mult)
|
1536
|
+
self.num_res_blocks = num_res_blocks
|
1537
|
+
|
1538
|
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
1539
|
+
# z to block_in
|
1540
|
+
self.conv_in = nn.Conv2d(
|
1541
|
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
1542
|
+
)
|
1543
|
+
|
1544
|
+
# middle
|
1545
|
+
self.mid = nn.ModuleList()
|
1546
|
+
self.mid.append(
|
1547
|
+
ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
|
1548
|
+
)
|
1549
|
+
self.mid.append(AttnBlock(block_in, norm_type=norm_type))
|
1550
|
+
self.mid.append(
|
1551
|
+
ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type)
|
1552
|
+
)
|
1553
|
+
|
1554
|
+
# upsampling
|
1555
|
+
self.conv_blocks = nn.ModuleList()
|
1556
|
+
for i_level in reversed(range(self.num_resolutions)):
|
1557
|
+
conv_block = nn.Module()
|
1558
|
+
# res & attn
|
1559
|
+
res_block = nn.ModuleList()
|
1560
|
+
attn_block = nn.ModuleList()
|
1561
|
+
block_out = ch * ch_mult[i_level]
|
1562
|
+
for _ in range(self.num_res_blocks + 1):
|
1563
|
+
res_block.append(
|
1564
|
+
ResnetBlock(
|
1565
|
+
block_in, block_out, dropout=dropout, norm_type=norm_type
|
1566
|
+
)
|
1567
|
+
)
|
1568
|
+
block_in = block_out
|
1569
|
+
if i_level == self.num_resolutions - 1:
|
1570
|
+
attn_block.append(AttnBlock(block_in, norm_type))
|
1571
|
+
conv_block.res = res_block
|
1572
|
+
conv_block.attn = attn_block
|
1573
|
+
# downsample
|
1574
|
+
if i_level != 0:
|
1575
|
+
conv_block.upsample = Upsample(block_in, resamp_with_conv)
|
1576
|
+
self.conv_blocks.append(conv_block)
|
1577
|
+
|
1578
|
+
# end
|
1579
|
+
self.norm_out = Normalize(block_in, norm_type)
|
1580
|
+
self.conv_out = nn.Conv2d(
|
1581
|
+
block_in, out_channels, kernel_size=3, stride=1, padding=1
|
1582
|
+
)
|
1583
|
+
|
1584
|
+
@property
|
1585
|
+
def last_layer(self):
|
1586
|
+
return self.conv_out.weight
|
1587
|
+
|
1588
|
+
def forward(self, z):
|
1589
|
+
# z to block_in
|
1590
|
+
h = self.conv_in(z)
|
1591
|
+
|
1592
|
+
# middle
|
1593
|
+
for mid_block in self.mid:
|
1594
|
+
h = mid_block(h)
|
1595
|
+
|
1596
|
+
# upsampling
|
1597
|
+
for i_level, block in enumerate(self.conv_blocks):
|
1598
|
+
for i_block in range(self.num_res_blocks + 1):
|
1599
|
+
h = block.res[i_block](h)
|
1600
|
+
if len(block.attn) > 0:
|
1601
|
+
h = block.attn[i_block](h)
|
1602
|
+
if i_level != self.num_resolutions - 1:
|
1603
|
+
h = block.upsample(h)
|
1604
|
+
|
1605
|
+
# end
|
1606
|
+
h = self.norm_out(h)
|
1607
|
+
h = nonlinearity(h)
|
1608
|
+
h = self.conv_out(h)
|
1609
|
+
return h
|
1610
|
+
|
1611
|
+
|
1612
|
+
class VectorQuantizer(nn.Module):
|
1613
|
+
def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):
|
1614
|
+
super().__init__()
|
1615
|
+
self.n_e = n_e
|
1616
|
+
self.e_dim = e_dim
|
1617
|
+
self.beta = beta
|
1618
|
+
self.entropy_loss_ratio = entropy_loss_ratio
|
1619
|
+
self.l2_norm = l2_norm
|
1620
|
+
self.show_usage = show_usage
|
1621
|
+
|
1622
|
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
1623
|
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
1624
|
+
if self.l2_norm:
|
1625
|
+
self.embedding.weight.data = F.normalize(
|
1626
|
+
self.embedding.weight.data, p=2, dim=-1
|
1627
|
+
)
|
1628
|
+
if self.show_usage:
|
1629
|
+
# self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))
|
1630
|
+
self.codebook_used = nn.Parameter(torch.zeros(65536))
|
1631
|
+
|
1632
|
+
def forward(self, z):
|
1633
|
+
# reshape z -> (batch, height, width, channel) and flatten
|
1634
|
+
z = torch.einsum("b c h w -> b h w c", z).contiguous()
|
1635
|
+
z_flattened = z.view(-1, self.e_dim)
|
1636
|
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
1637
|
+
|
1638
|
+
if self.l2_norm:
|
1639
|
+
z = F.normalize(z, p=2, dim=-1)
|
1640
|
+
z_flattened = F.normalize(z_flattened, p=2, dim=-1)
|
1641
|
+
embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
|
1642
|
+
else:
|
1643
|
+
embedding = self.embedding.weight
|
1644
|
+
|
1645
|
+
d = (
|
1646
|
+
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
1647
|
+
+ torch.sum(embedding**2, dim=1)
|
1648
|
+
- 2
|
1649
|
+
* torch.einsum(
|
1650
|
+
"bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding)
|
1651
|
+
)
|
1652
|
+
)
|
1653
|
+
|
1654
|
+
min_encoding_indices = torch.argmin(d, dim=1)
|
1655
|
+
z_q = embedding[min_encoding_indices].view(z.shape)
|
1656
|
+
perplexity = None
|
1657
|
+
min_encodings = None
|
1658
|
+
vq_loss = None
|
1659
|
+
commit_loss = None
|
1660
|
+
entropy_loss = None
|
1661
|
+
|
1662
|
+
# compute loss for embedding
|
1663
|
+
if self.training:
|
1664
|
+
vq_loss = torch.mean((z_q - z.detach()) ** 2)
|
1665
|
+
commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
|
1666
|
+
entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)
|
1667
|
+
|
1668
|
+
# preserve gradients
|
1669
|
+
z_q = z + (z_q - z).detach()
|
1670
|
+
|
1671
|
+
# reshape back to match original input shape
|
1672
|
+
z_q = torch.einsum("b h w c -> b c h w", z_q)
|
1673
|
+
|
1674
|
+
return (
|
1675
|
+
z_q,
|
1676
|
+
(vq_loss, commit_loss, entropy_loss),
|
1677
|
+
(perplexity, min_encodings, min_encoding_indices),
|
1678
|
+
)
|
1679
|
+
|
1680
|
+
def get_codebook_entry(self, indices, shape=None, channel_first=True):
|
1681
|
+
# shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel)
|
1682
|
+
if self.l2_norm:
|
1683
|
+
embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
|
1684
|
+
else:
|
1685
|
+
embedding = self.embedding.weight
|
1686
|
+
z_q = embedding[indices] # (b*h*w, c)
|
1687
|
+
|
1688
|
+
if shape is not None:
|
1689
|
+
if channel_first:
|
1690
|
+
z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1])
|
1691
|
+
# reshape back to match original input shape
|
1692
|
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
1693
|
+
else:
|
1694
|
+
z_q = z_q.view(shape)
|
1695
|
+
return z_q
|
1696
|
+
|
1697
|
+
|
1698
|
+
class ResnetBlock(nn.Module):
|
1699
|
+
def __init__(
|
1700
|
+
self,
|
1701
|
+
in_channels,
|
1702
|
+
out_channels=None,
|
1703
|
+
conv_shortcut=False,
|
1704
|
+
dropout=0.0,
|
1705
|
+
norm_type="group",
|
1706
|
+
):
|
1707
|
+
super().__init__()
|
1708
|
+
self.in_channels = in_channels
|
1709
|
+
out_channels = in_channels if out_channels is None else out_channels
|
1710
|
+
self.out_channels = out_channels
|
1711
|
+
self.use_conv_shortcut = conv_shortcut
|
1712
|
+
|
1713
|
+
self.norm1 = Normalize(in_channels, norm_type)
|
1714
|
+
self.conv1 = nn.Conv2d(
|
1715
|
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
1716
|
+
)
|
1717
|
+
self.norm2 = Normalize(out_channels, norm_type)
|
1718
|
+
self.dropout = nn.Dropout(dropout)
|
1719
|
+
self.conv2 = nn.Conv2d(
|
1720
|
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
1721
|
+
)
|
1722
|
+
|
1723
|
+
if self.in_channels != self.out_channels:
|
1724
|
+
if self.use_conv_shortcut:
|
1725
|
+
self.conv_shortcut = nn.Conv2d(
|
1726
|
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
1727
|
+
)
|
1728
|
+
else:
|
1729
|
+
self.nin_shortcut = nn.Conv2d(
|
1730
|
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
1731
|
+
)
|
1732
|
+
|
1733
|
+
def forward(self, x):
|
1734
|
+
h = x
|
1735
|
+
h = self.norm1(h)
|
1736
|
+
h = nonlinearity(h)
|
1737
|
+
h = self.conv1(h)
|
1738
|
+
h = self.norm2(h)
|
1739
|
+
h = nonlinearity(h)
|
1740
|
+
h = self.dropout(h)
|
1741
|
+
h = self.conv2(h)
|
1742
|
+
|
1743
|
+
if self.in_channels != self.out_channels:
|
1744
|
+
if self.use_conv_shortcut:
|
1745
|
+
x = self.conv_shortcut(x)
|
1746
|
+
else:
|
1747
|
+
x = self.nin_shortcut(x)
|
1748
|
+
return x + h
|
1749
|
+
|
1750
|
+
|
1751
|
+
class AttnBlock(nn.Module):
|
1752
|
+
def __init__(self, in_channels, norm_type="group"):
|
1753
|
+
super().__init__()
|
1754
|
+
self.norm = Normalize(in_channels, norm_type)
|
1755
|
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
1756
|
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
1757
|
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
1758
|
+
self.proj_out = nn.Conv2d(
|
1759
|
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
1760
|
+
)
|
1761
|
+
|
1762
|
+
def forward(self, x):
|
1763
|
+
h_ = x
|
1764
|
+
h_ = self.norm(h_)
|
1765
|
+
q = self.q(h_)
|
1766
|
+
k = self.k(h_)
|
1767
|
+
v = self.v(h_)
|
1768
|
+
|
1769
|
+
# compute attention
|
1770
|
+
b, c, h, w = q.shape
|
1771
|
+
q = q.reshape(b, c, h * w)
|
1772
|
+
q = q.permute(0, 2, 1) # b,hw,c
|
1773
|
+
k = k.reshape(b, c, h * w) # b,c,hw
|
1774
|
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
1775
|
+
w_ = w_ * (int(c) ** (-0.5))
|
1776
|
+
w_ = F.softmax(w_, dim=2)
|
1777
|
+
|
1778
|
+
# attend to values
|
1779
|
+
v = v.reshape(b, c, h * w)
|
1780
|
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
1781
|
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
1782
|
+
h_ = h_.reshape(b, c, h, w)
|
1783
|
+
|
1784
|
+
h_ = self.proj_out(h_)
|
1785
|
+
|
1786
|
+
return x + h_
|
1787
|
+
|
1788
|
+
|
1789
|
+
def nonlinearity(x):
|
1790
|
+
# swish
|
1791
|
+
return x * torch.sigmoid(x)
|
1792
|
+
|
1793
|
+
|
1794
|
+
def Normalize(in_channels, norm_type="group"):
|
1795
|
+
assert norm_type in ["group", "batch"]
|
1796
|
+
if norm_type == "group":
|
1797
|
+
return nn.GroupNorm(
|
1798
|
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
1799
|
+
)
|
1800
|
+
elif norm_type == "batch":
|
1801
|
+
return nn.SyncBatchNorm(in_channels)
|
1802
|
+
|
1803
|
+
|
1804
|
+
class Upsample(nn.Module):
|
1805
|
+
def __init__(self, in_channels, with_conv):
|
1806
|
+
super().__init__()
|
1807
|
+
self.with_conv = with_conv
|
1808
|
+
if self.with_conv:
|
1809
|
+
self.conv = nn.Conv2d(
|
1810
|
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
1811
|
+
)
|
1812
|
+
|
1813
|
+
def forward(self, x):
|
1814
|
+
if x.dtype != torch.float32:
|
1815
|
+
x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to(
|
1816
|
+
torch.bfloat16
|
1817
|
+
)
|
1818
|
+
else:
|
1819
|
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
1820
|
+
|
1821
|
+
if self.with_conv:
|
1822
|
+
x = self.conv(x)
|
1823
|
+
return x
|
1824
|
+
|
1825
|
+
|
1826
|
+
class Downsample(nn.Module):
|
1827
|
+
def __init__(self, in_channels, with_conv):
|
1828
|
+
super().__init__()
|
1829
|
+
self.with_conv = with_conv
|
1830
|
+
if self.with_conv:
|
1831
|
+
# no asymmetric padding in torch conv, must do it ourselves
|
1832
|
+
self.conv = nn.Conv2d(
|
1833
|
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
1834
|
+
)
|
1835
|
+
|
1836
|
+
def forward(self, x):
|
1837
|
+
if self.with_conv:
|
1838
|
+
pad = (0, 1, 0, 1)
|
1839
|
+
x = F.pad(x, pad, mode="constant", value=0)
|
1840
|
+
x = self.conv(x)
|
1841
|
+
else:
|
1842
|
+
x = F.avg_pool2d(x, kernel_size=2, stride=2)
|
1843
|
+
return x
|
1844
|
+
|
1845
|
+
|
1846
|
+
def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
|
1847
|
+
flat_affinity = affinity.reshape(-1, affinity.shape[-1])
|
1848
|
+
flat_affinity /= temperature
|
1849
|
+
probs = F.softmax(flat_affinity, dim=-1)
|
1850
|
+
log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
|
1851
|
+
if loss_type == "softmax":
|
1852
|
+
target_probs = probs
|
1853
|
+
else:
|
1854
|
+
raise ValueError("Entropy loss {} not supported".format(loss_type))
|
1855
|
+
avg_probs = torch.mean(target_probs, dim=0)
|
1856
|
+
avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
|
1857
|
+
sample_entropy = -torch.mean(torch.sum(target_probs * log_probs, dim=-1))
|
1858
|
+
loss = sample_entropy - avg_entropy
|
1859
|
+
return loss
|
1860
|
+
|
1861
|
+
|
1862
|
+
class VQModel(nn.Module):
|
1863
|
+
def __init__(self, config: ModelArgs):
|
1864
|
+
super().__init__()
|
1865
|
+
self.config = config
|
1866
|
+
self.encoder = Encoder(
|
1867
|
+
ch_mult=config.encoder_ch_mult,
|
1868
|
+
z_channels=config.z_channels,
|
1869
|
+
dropout=config.dropout_p,
|
1870
|
+
)
|
1871
|
+
self.decoder = Decoder(
|
1872
|
+
ch_mult=config.decoder_ch_mult,
|
1873
|
+
z_channels=config.z_channels,
|
1874
|
+
dropout=config.dropout_p,
|
1875
|
+
)
|
1876
|
+
|
1877
|
+
self.quantize = VectorQuantizer(
|
1878
|
+
config.codebook_size,
|
1879
|
+
config.codebook_embed_dim,
|
1880
|
+
config.commit_loss_beta,
|
1881
|
+
config.entropy_loss_ratio,
|
1882
|
+
config.codebook_l2_norm,
|
1883
|
+
config.codebook_show_usage,
|
1884
|
+
)
|
1885
|
+
self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
|
1886
|
+
self.post_quant_conv = nn.Conv2d(
|
1887
|
+
config.codebook_embed_dim, config.z_channels, 1
|
1888
|
+
)
|
1889
|
+
|
1890
|
+
def encode(self, x):
|
1891
|
+
h = self.encoder(x)
|
1892
|
+
h = self.quant_conv(h)
|
1893
|
+
quant, emb_loss, info = self.quantize(h)
|
1894
|
+
return quant, emb_loss, info
|
1895
|
+
|
1896
|
+
def decode(self, quant):
|
1897
|
+
quant = self.post_quant_conv(quant)
|
1898
|
+
dec = self.decoder(quant)
|
1899
|
+
return dec
|
1900
|
+
|
1901
|
+
def decode_code(self, code_b, shape=None, channel_first=True):
|
1902
|
+
quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
|
1903
|
+
dec = self.decode(quant_b)
|
1904
|
+
return dec
|
1905
|
+
|
1906
|
+
def forward(self, input):
|
1907
|
+
quant, diff, _ = self.encode(input)
|
1908
|
+
dec = self.decode(quant)
|
1909
|
+
return dec, diff
|
1910
|
+
|
1911
|
+
|
1912
|
+
class MultiModalityPreTrainedModel(PreTrainedModel):
|
1913
|
+
config_class = MultiModalityConfig
|
1914
|
+
base_model_prefix = "multi_modality"
|
1915
|
+
_no_split_modules = []
|
1916
|
+
_skip_keys_device_placement = "past_key_values"
|
1917
|
+
|
1918
|
+
|
1919
|
+
# Copied and adapted from:
|
1920
|
+
# https://github.com/deepseek-ai/Janus/tree/main/janus/models/modeling_vlm.py
|
1921
|
+
class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
1922
|
+
|
1923
|
+
def __init__(
|
1924
|
+
self,
|
1925
|
+
config: MultiModalityConfig,
|
1926
|
+
quant_config: Optional[QuantizationConfig] = None,
|
1927
|
+
):
|
1928
|
+
super().__init__(config)
|
1929
|
+
|
1930
|
+
vision_config = config.vision_config
|
1931
|
+
vision_cls = model_name_to_cls(vision_config.cls)
|
1932
|
+
self.vision_model = vision_cls(**vision_config.params)
|
1933
|
+
|
1934
|
+
aligner_config = config.aligner_config
|
1935
|
+
aligner_cls = model_name_to_cls(aligner_config.cls)
|
1936
|
+
self.aligner = aligner_cls(aligner_config.params)
|
1937
|
+
|
1938
|
+
gen_vision_config = config.gen_vision_config
|
1939
|
+
gen_vision_cls = model_name_to_cls(gen_vision_config.cls)
|
1940
|
+
self.gen_vision_model = gen_vision_cls()
|
1941
|
+
|
1942
|
+
gen_aligner_config = config.gen_aligner_config
|
1943
|
+
gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)
|
1944
|
+
self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)
|
1945
|
+
|
1946
|
+
gen_head_config = config.gen_head_config
|
1947
|
+
gen_head_cls = model_name_to_cls(gen_head_config.cls)
|
1948
|
+
self.gen_head = gen_head_cls(gen_head_config.params)
|
1949
|
+
|
1950
|
+
self.gen_embed = torch.nn.Embedding(
|
1951
|
+
gen_vision_config.params["image_token_size"],
|
1952
|
+
gen_vision_config.params["n_embed"],
|
1953
|
+
)
|
1954
|
+
|
1955
|
+
language_config = config.language_config
|
1956
|
+
self.language_model = LlamaForCausalLM(
|
1957
|
+
language_config, quant_config=quant_config
|
1958
|
+
)
|
1959
|
+
self.logits_processor = LogitsProcessor(config)
|
1960
|
+
|
1961
|
+
def prepare_images_seq_mask(
|
1962
|
+
self, input_ids: torch.Tensor, image_inputs: ImageInputs
|
1963
|
+
) -> Optional[torch.LongTensor]:
|
1964
|
+
images_seq_mask = torch.isin(
|
1965
|
+
input_ids, torch.tensor(image_inputs.pad_values, device=input_ids.device)
|
1966
|
+
)
|
1967
|
+
if images_seq_mask.sum() == 0:
|
1968
|
+
# sometimes image_inputs is not empty, but input_ids contain no image token because of prefix-cache
|
1969
|
+
return None
|
1970
|
+
else:
|
1971
|
+
return images_seq_mask
|
1972
|
+
|
1973
|
+
@torch.no_grad()
|
1974
|
+
def forward(
|
1975
|
+
self,
|
1976
|
+
input_ids: torch.LongTensor,
|
1977
|
+
positions: torch.Tensor,
|
1978
|
+
forward_batch: ForwardBatch,
|
1979
|
+
) -> torch.Tensor:
|
1980
|
+
|
1981
|
+
inputs_embeds = None
|
1982
|
+
if (
|
1983
|
+
forward_batch.image_inputs is not None
|
1984
|
+
and len(forward_batch.image_inputs) != 0
|
1985
|
+
and forward_batch.image_inputs[0] is not None
|
1986
|
+
):
|
1987
|
+
|
1988
|
+
image_inputs = forward_batch.image_inputs[0]
|
1989
|
+
|
1990
|
+
images_seq_mask = self.prepare_images_seq_mask(
|
1991
|
+
input_ids=input_ids, image_inputs=image_inputs
|
1992
|
+
)
|
1993
|
+
|
1994
|
+
if images_seq_mask is not None:
|
1995
|
+
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
1996
|
+
inputs_embeds = self.prepare_inputs_embeds(
|
1997
|
+
input_ids=input_ids,
|
1998
|
+
pixel_values=image_inputs.pixel_values,
|
1999
|
+
images_seq_mask=images_seq_mask,
|
2000
|
+
images_emb_mask=image_inputs.images_emb_mask,
|
2001
|
+
)
|
2002
|
+
input_ids = None
|
2003
|
+
|
2004
|
+
if input_ids is not None:
|
2005
|
+
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
2006
|
+
|
2007
|
+
return self.language_model(
|
2008
|
+
input_ids=input_ids,
|
2009
|
+
positions=positions,
|
2010
|
+
forward_batch=forward_batch,
|
2011
|
+
input_embeds=inputs_embeds,
|
2012
|
+
get_embedding=False,
|
2013
|
+
)
|
2014
|
+
|
2015
|
+
def prepare_inputs_embeds(
|
2016
|
+
self,
|
2017
|
+
input_ids: torch.LongTensor,
|
2018
|
+
pixel_values: torch.FloatTensor,
|
2019
|
+
images_seq_mask: torch.LongTensor,
|
2020
|
+
images_emb_mask: torch.BoolTensor,
|
2021
|
+
**_kwargs,
|
2022
|
+
):
|
2023
|
+
"""
|
2024
|
+
|
2025
|
+
Args:
|
2026
|
+
input_ids (torch.LongTensor): [b, T]
|
2027
|
+
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
|
2028
|
+
images_seq_mask (torch.BoolTensor): [b, T]
|
2029
|
+
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
|
2030
|
+
|
2031
|
+
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
|
2032
|
+
|
2033
|
+
Returns:
|
2034
|
+
input_embeds (torch.Tensor): [b, T, D]
|
2035
|
+
"""
|
2036
|
+
|
2037
|
+
bs, n = pixel_values.shape[0:2]
|
2038
|
+
pixel_values = pixel_values.to(
|
2039
|
+
device=self.vision_model.device, dtype=self.vision_model.dtype
|
2040
|
+
)
|
2041
|
+
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
|
2042
|
+
|
2043
|
+
# [b x n, T2, D]
|
2044
|
+
images_embeds = self.aligner(self.vision_model(images))
|
2045
|
+
|
2046
|
+
# [b x n, T2, D] -> [b, n x T2, D]
|
2047
|
+
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
|
2048
|
+
# [b, n, T2] -> [b, n x T2]
|
2049
|
+
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
|
2050
|
+
|
2051
|
+
# [b, T, D]
|
2052
|
+
# ignore the image embeddings
|
2053
|
+
input_ids[input_ids < 0] = 0
|
2054
|
+
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
2055
|
+
|
2056
|
+
# replace with the image embeddings
|
2057
|
+
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
|
2058
|
+
|
2059
|
+
return inputs_embeds
|
2060
|
+
|
2061
|
+
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
|
2062
|
+
return self.gen_aligner(self.gen_embed(image_ids))
|
2063
|
+
|
2064
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
2065
|
+
im_start_id = image_inputs.im_start_id
|
2066
|
+
im_end_id = image_inputs.im_end_id
|
2067
|
+
media_token_pairs = [(im_start_id, im_end_id)]
|
2068
|
+
|
2069
|
+
helper = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
2070
|
+
|
2071
|
+
return helper.pad_input_tokens(input_ids, image_inputs)
|
2072
|
+
|
2073
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
2074
|
+
stacked_params_mapping = [
|
2075
|
+
# (param_name, shard_name, shard_id)
|
2076
|
+
(".qkv_proj", ".q_proj", "q"),
|
2077
|
+
(".qkv_proj", ".k_proj", "k"),
|
2078
|
+
(".qkv_proj", ".v_proj", "v"),
|
2079
|
+
("gate_up_proj", "gate_proj", 0),
|
2080
|
+
("gate_up_proj", "up_proj", 1),
|
2081
|
+
]
|
2082
|
+
|
2083
|
+
params_dict = dict(self.named_parameters())
|
2084
|
+
for name, loaded_weight in weights:
|
2085
|
+
if "rotary_emb.inv_freq~" in name or "projector" in name:
|
2086
|
+
continue
|
2087
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
2088
|
+
# Models trained using ColossalAI may include these tensors in
|
2089
|
+
# the checkpoint. Skip them.
|
2090
|
+
continue
|
2091
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
2092
|
+
continue
|
2093
|
+
|
2094
|
+
# skip generation sub model
|
2095
|
+
if "gen" in name:
|
2096
|
+
continue
|
2097
|
+
|
2098
|
+
# adapt to VisionAttention
|
2099
|
+
name = name.replace(r"self_attn.out_proj", r"self_attn.proj")
|
2100
|
+
if "vision_model.vision_tower" in name:
|
2101
|
+
name = name.replace("attn.qkv", "attn.qkv_proj")
|
2102
|
+
|
2103
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
2104
|
+
# replace the name and load with customized loader
|
2105
|
+
if weight_name not in name:
|
2106
|
+
continue
|
2107
|
+
name = name.replace(weight_name, param_name)
|
2108
|
+
|
2109
|
+
# # Skip loading extra bias for GPTQ models.
|
2110
|
+
if name.endswith(".bias") and name not in params_dict:
|
2111
|
+
continue
|
2112
|
+
param = params_dict[name]
|
2113
|
+
weight_loader = getattr(param, "weight_loader", None)
|
2114
|
+
weight_loader(param, loaded_weight, shard_id)
|
2115
|
+
break
|
2116
|
+
else:
|
2117
|
+
# Skip loading extra bias for GPTQ models.
|
2118
|
+
if name.endswith(".bias") and name not in params_dict:
|
2119
|
+
continue
|
2120
|
+
|
2121
|
+
param = params_dict[name]
|
2122
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
2123
|
+
weight_loader(param, loaded_weight)
|
2124
|
+
|
2125
|
+
|
2126
|
+
AutoModel.register(config_class=MultiModalityConfig, model_class=MultiModalityCausalLM)
|
2127
|
+
EntryClass = [MultiModalityCausalLM]
|