llama-stack 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/cli/stack/_list_deps.py +11 -7
- llama_stack/cli/stack/run.py +3 -25
- llama_stack/core/access_control/datatypes.py +78 -0
- llama_stack/core/configure.py +2 -2
- llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
- llama_stack/core/connectors/connectors.py +162 -0
- llama_stack/core/conversations/conversations.py +61 -58
- llama_stack/core/datatypes.py +54 -8
- llama_stack/core/library_client.py +60 -13
- llama_stack/core/prompts/prompts.py +43 -42
- llama_stack/core/routers/datasets.py +20 -17
- llama_stack/core/routers/eval_scoring.py +143 -53
- llama_stack/core/routers/inference.py +20 -9
- llama_stack/core/routers/safety.py +30 -42
- llama_stack/core/routers/vector_io.py +15 -7
- llama_stack/core/routing_tables/models.py +42 -3
- llama_stack/core/routing_tables/scoring_functions.py +19 -19
- llama_stack/core/routing_tables/shields.py +20 -17
- llama_stack/core/routing_tables/vector_stores.py +8 -5
- llama_stack/core/server/auth.py +192 -17
- llama_stack/core/server/fastapi_router_registry.py +40 -5
- llama_stack/core/server/server.py +24 -5
- llama_stack/core/stack.py +54 -10
- llama_stack/core/storage/datatypes.py +9 -0
- llama_stack/core/store/registry.py +1 -1
- llama_stack/core/utils/exec.py +2 -2
- llama_stack/core/utils/type_inspection.py +16 -2
- llama_stack/distributions/dell/config.yaml +4 -1
- llama_stack/distributions/dell/run-with-safety.yaml +4 -1
- llama_stack/distributions/nvidia/config.yaml +4 -1
- llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
- llama_stack/distributions/oci/config.yaml +4 -1
- llama_stack/distributions/open-benchmark/config.yaml +9 -1
- llama_stack/distributions/postgres-demo/config.yaml +1 -1
- llama_stack/distributions/starter/build.yaml +62 -0
- llama_stack/distributions/starter/config.yaml +22 -3
- llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/starter/starter.py +13 -1
- llama_stack/distributions/starter-gpu/build.yaml +62 -0
- llama_stack/distributions/starter-gpu/config.yaml +22 -3
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/template.py +10 -2
- llama_stack/distributions/watsonx/config.yaml +4 -1
- llama_stack/log.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +53 -51
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
- llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
- llama_stack/providers/inline/batches/reference/batches.py +2 -1
- llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
- llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
- llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
- llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
- llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
- llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
- llama_stack/providers/registry/agents.py +1 -0
- llama_stack/providers/registry/inference.py +1 -9
- llama_stack/providers/registry/vector_io.py +136 -16
- llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
- llama_stack/providers/remote/files/s3/config.py +5 -3
- llama_stack/providers/remote/files/s3/files.py +2 -2
- llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
- llama_stack/providers/remote/inference/openai/openai.py +2 -0
- llama_stack/providers/remote/inference/together/together.py +4 -0
- llama_stack/providers/remote/inference/vertexai/config.py +3 -3
- llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
- llama_stack/providers/remote/inference/vllm/config.py +37 -18
- llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
- llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
- llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
- llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
- llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
- llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
- llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
- llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
- llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
- llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
- llama_stack/providers/remote/vector_io/oci/config.py +41 -0
- llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
- llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
- llama_stack/providers/utils/bedrock/client.py +3 -3
- llama_stack/providers/utils/bedrock/config.py +7 -7
- llama_stack/providers/utils/inference/__init__.py +0 -25
- llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
- llama_stack/providers/utils/inference/http_client.py +239 -0
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
- llama_stack/providers/utils/inference/model_registry.py +148 -2
- llama_stack/providers/utils/inference/openai_compat.py +1 -158
- llama_stack/providers/utils/inference/openai_mixin.py +42 -2
- llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
- llama_stack/providers/utils/memory/vector_store.py +46 -19
- llama_stack/providers/utils/responses/responses_store.py +7 -7
- llama_stack/providers/utils/safety.py +114 -0
- llama_stack/providers/utils/tools/mcp.py +44 -3
- llama_stack/testing/api_recorder.py +9 -3
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/RECORD +115 -148
- llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
- llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
- llama_stack/models/llama/hadamard_utils.py +0 -88
- llama_stack/models/llama/llama3/args.py +0 -74
- llama_stack/models/llama/llama3/dog.jpg +0 -0
- llama_stack/models/llama/llama3/generation.py +0 -378
- llama_stack/models/llama/llama3/model.py +0 -304
- llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
- llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
- llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
- llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
- llama_stack/models/llama/llama3/pasta.jpeg +0 -0
- llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama3/quantization/loader.py +0 -316
- llama_stack/models/llama/llama3_1/__init__.py +0 -12
- llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
- llama_stack/models/llama/llama3_1/prompts.py +0 -258
- llama_stack/models/llama/llama3_2/__init__.py +0 -5
- llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
- llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
- llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
- llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
- llama_stack/models/llama/llama3_3/__init__.py +0 -5
- llama_stack/models/llama/llama3_3/prompts.py +0 -259
- llama_stack/models/llama/llama4/args.py +0 -107
- llama_stack/models/llama/llama4/ffn.py +0 -58
- llama_stack/models/llama/llama4/moe.py +0 -214
- llama_stack/models/llama/llama4/preprocess.py +0 -435
- llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama4/quantization/loader.py +0 -226
- llama_stack/models/llama/llama4/vision/__init__.py +0 -5
- llama_stack/models/llama/llama4/vision/embedding.py +0 -210
- llama_stack/models/llama/llama4/vision/encoder.py +0 -412
- llama_stack/models/llama/quantize_impls.py +0 -316
- llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
- llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
- llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
- llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
- llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
- llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,226 +0,0 @@
|
|
|
1
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
-
# the root directory of this source tree.
|
|
6
|
-
|
|
7
|
-
import os
|
|
8
|
-
from collections.abc import Callable
|
|
9
|
-
|
|
10
|
-
import torch
|
|
11
|
-
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
|
12
|
-
from torch import Tensor, nn
|
|
13
|
-
from torch.nn import functional as F
|
|
14
|
-
|
|
15
|
-
from llama_stack.log import get_logger
|
|
16
|
-
|
|
17
|
-
from ...datatypes import QuantizationMode
|
|
18
|
-
from ..model import Transformer, TransformerBlock
|
|
19
|
-
from ..moe import MoE
|
|
20
|
-
|
|
21
|
-
log = get_logger(name=__name__, category="models::llama")
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def swiglu_wrapper_no_reduce(
|
|
25
|
-
self,
|
|
26
|
-
x: Tensor,
|
|
27
|
-
):
|
|
28
|
-
from ...quantize_impls import ffn_swiglu
|
|
29
|
-
|
|
30
|
-
return ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def experts_batched_swiglu_wrapper(
|
|
34
|
-
self,
|
|
35
|
-
x: Tensor, # (e, g, D)
|
|
36
|
-
w1: Tensor, # (e, D, F)
|
|
37
|
-
w3: Tensor, # (e, D, F)
|
|
38
|
-
w2: Tensor, # (e, F, D)
|
|
39
|
-
) -> torch.Tensor:
|
|
40
|
-
from ...quantize_impls import bmm_nt
|
|
41
|
-
|
|
42
|
-
middle_out_egF = F.silu(bmm_nt(x, w1)) * bmm_nt(x, w3) # noqa: N806
|
|
43
|
-
return bmm_nt(middle_out_egF, w2)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
def convert_to_quantized_model(
|
|
47
|
-
model: Transformer,
|
|
48
|
-
checkpoint_dir: str,
|
|
49
|
-
quantization_mode: str | None = None,
|
|
50
|
-
fp8_activation_scale_ub: float | None = 1200.0,
|
|
51
|
-
use_rich_progress: bool = True,
|
|
52
|
-
) -> Transformer:
|
|
53
|
-
from ...quantize_impls import (
|
|
54
|
-
Fp8ScaledWeights,
|
|
55
|
-
Int4ScaledWeights,
|
|
56
|
-
load_fp8,
|
|
57
|
-
load_int4,
|
|
58
|
-
quantize_fp8,
|
|
59
|
-
quantize_int4,
|
|
60
|
-
)
|
|
61
|
-
|
|
62
|
-
rank = get_model_parallel_rank()
|
|
63
|
-
|
|
64
|
-
def should_quantize_block(block: nn.Module) -> bool:
|
|
65
|
-
if not isinstance(block, TransformerBlock):
|
|
66
|
-
return False
|
|
67
|
-
|
|
68
|
-
is_moe = isinstance(block.feed_forward, MoE)
|
|
69
|
-
if quantization_mode == QuantizationMode.fp8_mixed:
|
|
70
|
-
# skip quantization on first and last layers
|
|
71
|
-
return is_moe and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
|
|
72
|
-
|
|
73
|
-
return is_moe
|
|
74
|
-
|
|
75
|
-
use_rich_progress = use_rich_progress and rank == 0
|
|
76
|
-
progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model, should_quantize_block)
|
|
77
|
-
if quantization_mode == QuantizationMode.int4_mixed:
|
|
78
|
-
int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt")
|
|
79
|
-
if os.path.isfile(int4_scales_path):
|
|
80
|
-
log_status(f"Rank {rank}: Loading int4 scales")
|
|
81
|
-
int4_scales = torch.load(int4_scales_path, weights_only=True)
|
|
82
|
-
|
|
83
|
-
def apply_quantization(key, weight):
|
|
84
|
-
scale = int4_scales[key]
|
|
85
|
-
return load_int4(
|
|
86
|
-
weight,
|
|
87
|
-
scale,
|
|
88
|
-
output_device=torch.device("cuda"),
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
else:
|
|
92
|
-
log_status(f"Rank {rank}: Quantizing int4 weights from bf16")
|
|
93
|
-
|
|
94
|
-
def apply_quantization(_, weight):
|
|
95
|
-
return quantize_int4(weight, output_device=torch.device("cuda"))
|
|
96
|
-
|
|
97
|
-
else:
|
|
98
|
-
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
|
|
99
|
-
if os.path.isfile(fp8_scales_path):
|
|
100
|
-
log_status(f"Rank {rank}: Loading fp8 scales")
|
|
101
|
-
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
|
102
|
-
|
|
103
|
-
def apply_quantization(key, weight):
|
|
104
|
-
scale = fp8_scales[key]
|
|
105
|
-
return load_fp8(
|
|
106
|
-
weight,
|
|
107
|
-
scale,
|
|
108
|
-
fp8_activation_scale_ub,
|
|
109
|
-
output_device=torch.device("cuda"),
|
|
110
|
-
)
|
|
111
|
-
|
|
112
|
-
else:
|
|
113
|
-
log_status(f"Rank {rank}: Quantizing fp8 weights from bf16")
|
|
114
|
-
|
|
115
|
-
def apply_quantization(_, weight):
|
|
116
|
-
return quantize_fp8(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
|
|
117
|
-
|
|
118
|
-
processed_blocks = 0
|
|
119
|
-
try:
|
|
120
|
-
if use_rich_progress:
|
|
121
|
-
progress.start()
|
|
122
|
-
|
|
123
|
-
for _, block in model.named_modules():
|
|
124
|
-
if not should_quantize_block(block):
|
|
125
|
-
continue
|
|
126
|
-
|
|
127
|
-
update_status(f"Rank {rank} - Layer {block.layer_id}")
|
|
128
|
-
|
|
129
|
-
# Quantize only routed experts, not shared
|
|
130
|
-
prefix = f"layers.{block.layer_id}.feed_forward"
|
|
131
|
-
moe = block.feed_forward
|
|
132
|
-
moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)
|
|
133
|
-
|
|
134
|
-
for key in ("w1", "w3", "w2"):
|
|
135
|
-
param = getattr(moe.experts, key)
|
|
136
|
-
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
|
|
137
|
-
setattr(
|
|
138
|
-
moe.experts,
|
|
139
|
-
key,
|
|
140
|
-
apply_quantization(
|
|
141
|
-
f"{prefix}.experts.{key}",
|
|
142
|
-
param.transpose(1, 2).contiguous(),
|
|
143
|
-
),
|
|
144
|
-
)
|
|
145
|
-
|
|
146
|
-
if quantization_mode == QuantizationMode.int4_mixed:
|
|
147
|
-
# Quantize shared experts
|
|
148
|
-
moe.shared_expert.forward = swiglu_wrapper_no_reduce.__get__(moe.shared_expert)
|
|
149
|
-
for key in ("w1", "w3", "w2"):
|
|
150
|
-
param = getattr(moe.shared_expert, key)
|
|
151
|
-
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE shared expert {key}")
|
|
152
|
-
param.weight = apply_quantization(f"{prefix}.shared_expert.{key}", param.weight)
|
|
153
|
-
|
|
154
|
-
processed_blocks += 1
|
|
155
|
-
update_status(message=None, completed=processed_blocks)
|
|
156
|
-
|
|
157
|
-
update_status(f"Rank {rank} - Moving parameters to CUDA")
|
|
158
|
-
|
|
159
|
-
param_count = 0
|
|
160
|
-
for _, parameter in model.named_parameters():
|
|
161
|
-
if not isinstance(parameter, Fp8ScaledWeights) and not isinstance(parameter, Int4ScaledWeights):
|
|
162
|
-
parameter.data = parameter.to(device="cuda")
|
|
163
|
-
param_count += 1
|
|
164
|
-
|
|
165
|
-
update_status(f"Rank {rank} - Completed - moved {param_count} parameters to CUDA")
|
|
166
|
-
finally:
|
|
167
|
-
if use_rich_progress:
|
|
168
|
-
progress.stop()
|
|
169
|
-
|
|
170
|
-
return model
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
# fp8/int4 loading can be very slow so we add progress bars to make life slightly better
|
|
174
|
-
def logging_callbacks(
|
|
175
|
-
use_rich_progress: bool,
|
|
176
|
-
rank: int,
|
|
177
|
-
model: Transformer,
|
|
178
|
-
should_quantize_block: Callable[[nn.Module], bool],
|
|
179
|
-
):
|
|
180
|
-
console = None
|
|
181
|
-
if use_rich_progress:
|
|
182
|
-
from rich.console import Console
|
|
183
|
-
|
|
184
|
-
console = Console(highlight=False)
|
|
185
|
-
|
|
186
|
-
def log_status(message: str) -> None:
|
|
187
|
-
if use_rich_progress:
|
|
188
|
-
console.print(message)
|
|
189
|
-
elif rank == 0: # Only log from rank 0 for non-rich logging
|
|
190
|
-
log.info(message)
|
|
191
|
-
|
|
192
|
-
total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block))
|
|
193
|
-
progress = None
|
|
194
|
-
if use_rich_progress:
|
|
195
|
-
from rich.progress import (
|
|
196
|
-
BarColumn,
|
|
197
|
-
Progress,
|
|
198
|
-
SpinnerColumn,
|
|
199
|
-
TextColumn,
|
|
200
|
-
TimeElapsedColumn,
|
|
201
|
-
TimeRemainingColumn,
|
|
202
|
-
)
|
|
203
|
-
|
|
204
|
-
progress = Progress(
|
|
205
|
-
SpinnerColumn(),
|
|
206
|
-
BarColumn(complete_style="green", finished_style="bright_green"),
|
|
207
|
-
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
|
208
|
-
TimeElapsedColumn(),
|
|
209
|
-
TextColumn("ETA:"),
|
|
210
|
-
TimeRemainingColumn(),
|
|
211
|
-
TextColumn("[bold]{task.fields[status]}"),
|
|
212
|
-
console=console,
|
|
213
|
-
expand=True,
|
|
214
|
-
)
|
|
215
|
-
task_id = progress.add_task("[blue]Converting layers...", total=total_blocks, status="Starting")
|
|
216
|
-
|
|
217
|
-
def update_status(message: str | None, completed: int | None = None) -> None:
|
|
218
|
-
if use_rich_progress:
|
|
219
|
-
if message is not None:
|
|
220
|
-
progress.update(task_id, status=message)
|
|
221
|
-
if completed is not None:
|
|
222
|
-
progress.update(task_id, completed=completed)
|
|
223
|
-
elif rank == 0 and completed and completed % 10 == 0:
|
|
224
|
-
log.info(f"Rank {rank}: {completed}/{total_blocks} blocks completed")
|
|
225
|
-
|
|
226
|
-
return progress, log_status, update_status
|
|
@@ -1,210 +0,0 @@
|
|
|
1
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
-
# the root directory of this source tree.
|
|
6
|
-
|
|
7
|
-
import math
|
|
8
|
-
from collections.abc import Callable
|
|
9
|
-
from typing import Any
|
|
10
|
-
|
|
11
|
-
import torch
|
|
12
|
-
import torch.nn as nn
|
|
13
|
-
import torch.nn.functional as F
|
|
14
|
-
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
|
15
|
-
|
|
16
|
-
from ..args import VisionArgs
|
|
17
|
-
from .encoder import VisionEncoder
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class PixelShuffle(nn.Module):
|
|
21
|
-
def __init__(self, ps_ratio):
|
|
22
|
-
super().__init__()
|
|
23
|
-
self.ps_ratio = ps_ratio
|
|
24
|
-
|
|
25
|
-
def forward(self, x):
|
|
26
|
-
# x: [B, N, C], N = number of patches
|
|
27
|
-
assert self.ps_ratio is not None, "ps_ratio is required for pixel shuffle"
|
|
28
|
-
assert x.dim() == 3, "pixel shuffle requires encoded patches [B, N, C]"
|
|
29
|
-
hh = ww = int(math.sqrt(x.shape[1]))
|
|
30
|
-
x = x.reshape(x.shape[0], hh, ww, -1)
|
|
31
|
-
x = pixel_shuffle_op(x, ps_ratio=self.ps_ratio)
|
|
32
|
-
pixel_shuffle_patches = x.reshape(x.shape[0], -1, x.shape[-1])
|
|
33
|
-
return pixel_shuffle_patches
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def pixel_shuffle_op(input_x, ps_ratio):
|
|
37
|
-
n, w, h, c = input_x.size()
|
|
38
|
-
input_x = input_x.view(n, w, int(h * ps_ratio), int(c / ps_ratio))
|
|
39
|
-
input_x = input_x.permute(0, 2, 1, 3).contiguous()
|
|
40
|
-
input_x = input_x.view(
|
|
41
|
-
n,
|
|
42
|
-
int(h * ps_ratio),
|
|
43
|
-
int(w * ps_ratio),
|
|
44
|
-
int(c / (ps_ratio * ps_ratio)),
|
|
45
|
-
)
|
|
46
|
-
input_x = input_x.permute(0, 2, 1, 3).contiguous()
|
|
47
|
-
return input_x
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
class SimpleMLP(torch.nn.Module):
|
|
51
|
-
def __init__(
|
|
52
|
-
self,
|
|
53
|
-
dim: int,
|
|
54
|
-
hidden_dim: int,
|
|
55
|
-
bias: bool = True,
|
|
56
|
-
dropout: float = 0.0,
|
|
57
|
-
act_layer: Callable = nn.GELU,
|
|
58
|
-
):
|
|
59
|
-
super().__init__()
|
|
60
|
-
# layers
|
|
61
|
-
self.c_fc = ColumnParallelLinear(
|
|
62
|
-
dim,
|
|
63
|
-
hidden_dim,
|
|
64
|
-
bias=bias,
|
|
65
|
-
gather_output=False,
|
|
66
|
-
)
|
|
67
|
-
self.c_proj = RowParallelLinear(
|
|
68
|
-
hidden_dim,
|
|
69
|
-
hidden_dim,
|
|
70
|
-
bias=bias,
|
|
71
|
-
input_is_parallel=True,
|
|
72
|
-
)
|
|
73
|
-
self.non_linearity = act_layer()
|
|
74
|
-
self.dropout = dropout
|
|
75
|
-
|
|
76
|
-
def forward(self, x):
|
|
77
|
-
hidden = self.c_fc(x)
|
|
78
|
-
hidden = self.non_linearity(hidden)
|
|
79
|
-
hidden = F.dropout(hidden, p=self.dropout, training=self.training)
|
|
80
|
-
return self.non_linearity(self.c_proj(hidden))
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
class PixelShuffleMLP(torch.nn.Module):
|
|
84
|
-
def __init__(
|
|
85
|
-
self,
|
|
86
|
-
ps_ratio: float,
|
|
87
|
-
input_dim: int,
|
|
88
|
-
output_dim: int = 4096,
|
|
89
|
-
add_fc: bool = False,
|
|
90
|
-
):
|
|
91
|
-
super().__init__()
|
|
92
|
-
self.pixel_shuffle = PixelShuffle(ps_ratio)
|
|
93
|
-
self.mlp = SimpleMLP(
|
|
94
|
-
int(input_dim // (ps_ratio**2)),
|
|
95
|
-
output_dim,
|
|
96
|
-
bias=False,
|
|
97
|
-
dropout=0.0,
|
|
98
|
-
act_layer=nn.GELU,
|
|
99
|
-
)
|
|
100
|
-
self.fc = nn.Identity()
|
|
101
|
-
if add_fc:
|
|
102
|
-
self.fc = ColumnParallelLinear(
|
|
103
|
-
output_dim,
|
|
104
|
-
output_dim,
|
|
105
|
-
bias=False,
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor:
|
|
109
|
-
encoded_patches = self.pixel_shuffle(encoded_patches)
|
|
110
|
-
return self.fc(self.mlp(encoded_patches))
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
class VisionEmbeddings(torch.nn.Module):
|
|
114
|
-
def __init__(self, args: VisionArgs):
|
|
115
|
-
super().__init__()
|
|
116
|
-
self.args = args
|
|
117
|
-
|
|
118
|
-
image_size = args.image_size
|
|
119
|
-
patch_size = args.patch_size
|
|
120
|
-
self.vision_encoder = VisionEncoder(
|
|
121
|
-
image_size=(image_size.height, image_size.width),
|
|
122
|
-
patch_size=(patch_size.height, patch_size.width),
|
|
123
|
-
dim=args.dim,
|
|
124
|
-
layers=args.n_layers,
|
|
125
|
-
heads=args.n_heads,
|
|
126
|
-
mlp_ratio=args.mlp_ratio,
|
|
127
|
-
)
|
|
128
|
-
self.vision_encoder = self.vision_encoder.to(torch.bfloat16)
|
|
129
|
-
self.vision_adapter = PixelShuffleMLP(
|
|
130
|
-
ps_ratio=args.pixel_shuffle_ratio,
|
|
131
|
-
input_dim=args.dim,
|
|
132
|
-
output_dim=args.output_dim,
|
|
133
|
-
)
|
|
134
|
-
|
|
135
|
-
self.output_dim = args.output_dim
|
|
136
|
-
self._register_load_state_dict_pre_hook(self.load_hook)
|
|
137
|
-
|
|
138
|
-
def load_hook(
|
|
139
|
-
self,
|
|
140
|
-
state_dict: dict[str, Any],
|
|
141
|
-
prefix: str,
|
|
142
|
-
local_metadata: dict[str, Any],
|
|
143
|
-
strict: bool = True,
|
|
144
|
-
missing_keys: list[str] = None,
|
|
145
|
-
unexpected_keys: list[str] = None,
|
|
146
|
-
error_msgs: list[str] = None,
|
|
147
|
-
return_state_dict: bool = False,
|
|
148
|
-
) -> None:
|
|
149
|
-
original_sd = self.state_dict()
|
|
150
|
-
for k in state_dict:
|
|
151
|
-
if k.startswith(prefix) and len(state_dict[k].shape) == 1 and state_dict[k].shape[0] == 0:
|
|
152
|
-
state_dict[k] = state_dict[k].reshape(original_sd[k[len(prefix) :]].shape)
|
|
153
|
-
|
|
154
|
-
def _get_empty_sequence(self, h):
|
|
155
|
-
return torch.zeros(
|
|
156
|
-
h.shape[0],
|
|
157
|
-
h.shape[1],
|
|
158
|
-
self.output_dim,
|
|
159
|
-
device=h.device,
|
|
160
|
-
dtype=h.dtype,
|
|
161
|
-
)
|
|
162
|
-
|
|
163
|
-
# x_images is batched; each batch sample contains a list of images. so this is List[List[torch.Tensor]]
|
|
164
|
-
# each image is a tensor of shape [num_tiles, C, H, W]
|
|
165
|
-
def forward(
|
|
166
|
-
self,
|
|
167
|
-
image_batch: list[list[torch.Tensor]],
|
|
168
|
-
image_mask: torch.Tensor,
|
|
169
|
-
h_ref: torch.Tensor,
|
|
170
|
-
) -> torch.Tensor:
|
|
171
|
-
images_flattened = [image for sample in image_batch for image in sample]
|
|
172
|
-
images_flattened = torch.vstack(images_flattened).unsqueeze(1).to(h_ref.dtype).to(h_ref.device)
|
|
173
|
-
embedding = self.vision_encoder(images_flattened)
|
|
174
|
-
projected_embedding = self.vision_adapter(embedding)
|
|
175
|
-
|
|
176
|
-
h_image = self._get_empty_sequence(h_ref)
|
|
177
|
-
return scatter_embeddings(image_batch, image_mask, h_image, projected_embedding)
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
def scatter_embeddings(image_batch, image_mask, h_image, encoded_patches_proj):
|
|
181
|
-
# If dynamic transform is used and the batch contains 2 images (where image_1 has 2 chunks and image_2 has 3 chunks),
|
|
182
|
-
# `num_images_per_sequence` now records the number of chunks per image as `[2, 3]`.
|
|
183
|
-
# `encoded_patches_proj.split` will then split the image chunks into 2 groups: `[image_1_chunks, image_2_chunks]`.
|
|
184
|
-
num_images_per_sequence = [sum(image.size(0) for image in sample_images) for sample_images in image_batch]
|
|
185
|
-
|
|
186
|
-
assert not torch.isnan(encoded_patches_proj).any()
|
|
187
|
-
assert sum(num_images_per_sequence) == encoded_patches_proj.size(0), (
|
|
188
|
-
f"{sum(num_images_per_sequence)=} != {encoded_patches_proj.shape=}"
|
|
189
|
-
)
|
|
190
|
-
|
|
191
|
-
encoded_patches_list = encoded_patches_proj.split(num_images_per_sequence, dim=0)
|
|
192
|
-
for index in range(h_image.size(0)):
|
|
193
|
-
encoded_patches_per_sample = encoded_patches_list[index]
|
|
194
|
-
sample_image_mask = image_mask[index]
|
|
195
|
-
|
|
196
|
-
if encoded_patches_per_sample.numel() == 0:
|
|
197
|
-
continue
|
|
198
|
-
encoded_patches_per_sample = encoded_patches_per_sample.contiguous().view(
|
|
199
|
-
-1, encoded_patches_per_sample.size(-1)
|
|
200
|
-
)
|
|
201
|
-
|
|
202
|
-
n_tokens_to_fill = sample_image_mask.sum()
|
|
203
|
-
assert n_tokens_to_fill <= encoded_patches_per_sample.size(0)
|
|
204
|
-
|
|
205
|
-
h_image[index].masked_scatter_(
|
|
206
|
-
sample_image_mask.expand(-1, h_image.size(-1)),
|
|
207
|
-
encoded_patches_per_sample[:n_tokens_to_fill],
|
|
208
|
-
)
|
|
209
|
-
|
|
210
|
-
return h_image
|