llama-stack 0.4.4__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  30. llama_stack/distributions/nvidia/config.yaml +4 -1
  31. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  32. llama_stack/distributions/oci/config.yaml +4 -1
  33. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  34. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  35. llama_stack/distributions/starter/build.yaml +62 -0
  36. llama_stack/distributions/starter/config.yaml +22 -3
  37. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  38. llama_stack/distributions/starter/starter.py +13 -1
  39. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  40. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  41. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  42. llama_stack/distributions/template.py +10 -2
  43. llama_stack/distributions/watsonx/config.yaml +4 -1
  44. llama_stack/log.py +1 -0
  45. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  46. llama_stack/providers/inline/agents/meta_reference/agents.py +57 -61
  47. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +49 -51
  48. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +94 -22
  49. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  50. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  51. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  52. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  53. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  54. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  55. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  56. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  57. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  58. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  59. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +15 -18
  60. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  61. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  62. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  63. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  64. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  65. llama_stack/providers/registry/agents.py +1 -0
  66. llama_stack/providers/registry/inference.py +1 -9
  67. llama_stack/providers/registry/vector_io.py +136 -16
  68. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  69. llama_stack/providers/remote/files/s3/config.py +5 -3
  70. llama_stack/providers/remote/files/s3/files.py +2 -2
  71. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  72. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  73. llama_stack/providers/remote/inference/together/together.py +4 -0
  74. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  75. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  76. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  77. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  78. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  79. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  80. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  81. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  82. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  83. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  84. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  85. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  86. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  87. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  88. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  89. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  90. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  91. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  92. llama_stack/providers/utils/bedrock/client.py +3 -3
  93. llama_stack/providers/utils/bedrock/config.py +7 -7
  94. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  95. llama_stack/providers/utils/inference/http_client.py +239 -0
  96. llama_stack/providers/utils/inference/litellm_openai_mixin.py +5 -0
  97. llama_stack/providers/utils/inference/model_registry.py +148 -2
  98. llama_stack/providers/utils/inference/openai_compat.py +2 -1
  99. llama_stack/providers/utils/inference/openai_mixin.py +41 -2
  100. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  101. llama_stack/providers/utils/memory/vector_store.py +46 -19
  102. llama_stack/providers/utils/responses/responses_store.py +7 -7
  103. llama_stack/providers/utils/safety.py +114 -0
  104. llama_stack/providers/utils/tools/mcp.py +44 -3
  105. llama_stack/testing/api_recorder.py +9 -3
  106. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/METADATA +14 -2
  107. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/RECORD +111 -144
  108. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  109. llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
  110. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  111. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  112. llama_stack/models/llama/hadamard_utils.py +0 -88
  113. llama_stack/models/llama/llama3/args.py +0 -74
  114. llama_stack/models/llama/llama3/dog.jpg +0 -0
  115. llama_stack/models/llama/llama3/generation.py +0 -378
  116. llama_stack/models/llama/llama3/model.py +0 -304
  117. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  118. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  119. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  120. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  121. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  122. llama_stack/models/llama/llama3/pasta.jpeg +0 -0
  123. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  124. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  125. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  126. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  127. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  128. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  129. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  130. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  131. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  132. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  133. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  134. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  135. llama_stack/models/llama/llama4/args.py +0 -107
  136. llama_stack/models/llama/llama4/ffn.py +0 -58
  137. llama_stack/models/llama/llama4/moe.py +0 -214
  138. llama_stack/models/llama/llama4/preprocess.py +0 -435
  139. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  140. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  141. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  142. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  143. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  144. llama_stack/models/llama/quantize_impls.py +0 -316
  145. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  146. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  147. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  148. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  149. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  150. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  151. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  152. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/WHEEL +0 -0
  153. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/entry_points.txt +0 -0
  154. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  155. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,1430 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
- import math
7
- from collections.abc import Callable
8
- from functools import partial
9
- from typing import Any
10
-
11
- import fairscale.nn.model_parallel.initialize as fs_init
12
- import torch
13
- import torch.nn.functional as F
14
- from fairscale.nn.model_parallel.layers import (
15
- ColumnParallelLinear,
16
- RowParallelLinear,
17
- VocabParallelEmbedding,
18
- )
19
- from PIL import Image as PIL_Image
20
- from torch import Tensor, nn
21
- from torch.distributed import _functional_collectives as funcol
22
-
23
- from llama_stack.log import get_logger
24
-
25
- from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis
26
- from .encoder_utils import (
27
- build_encoder_attention_mask,
28
- contract_num_tokens_from_mult8,
29
- expand_num_tokens_to_mult8,
30
- initialize_global_position_embedding_from_local,
31
- resize_global_position_embedding,
32
- resize_local_position_embedding,
33
- )
34
- from .image_transform import VariableSizeImageTransform
35
- from .utils import get_negative_inf_value, to_2tuple
36
-
37
- MP_SCALE = 8
38
-
39
- logger = get_logger(name=__name__, category="models::llama")
40
-
41
-
42
- def reduce_from_tensor_model_parallel_region(input_):
43
- """All-reduce the input tensor across model parallel group."""
44
- output = funcol.all_reduce(input_, "sum", group=fs_init.get_model_parallel_group())
45
- output = funcol.wait_tensor(output)
46
- return output
47
-
48
-
49
- def gather_from_tensor_model_parallel_region(input_):
50
- """Gather tensors and concatenate along the last dimension."""
51
-
52
- world_size = fs_init.get_model_parallel_world_size()
53
- # Size and dimension.
54
- last_dim = input_.dim() - 1
55
- rank = fs_init.get_model_parallel_rank()
56
-
57
- tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
58
- tensor_list[rank] = input_
59
- output = funcol.all_gather_tensor(
60
- input_,
61
- gather_dim=last_dim,
62
- group=fs_init.get_model_parallel_group(),
63
- )
64
- output = funcol.wait_tensor(output)
65
- return output
66
-
67
-
68
- def _get_full_row_masked_out_mask(
69
- attn_bias,
70
- negative_inf_value,
71
- ):
72
- """
73
- attn_bias should be a 4D tensor of shape [B, H, S1, S2]
74
- where B is the batch size, H is the number of heads,
75
- and S1/S2 are the sequence lengths. This returns
76
- a 4D tensor of shape [B, H, S1, 1] which stores boolean
77
- values which are 0 if the a full row in the last dimension
78
- contains negative infinity values, otherwise it's 1.
79
- """
80
- return (attn_bias != negative_inf_value).any(dim=-1).type_as(attn_bias)[..., None]
81
-
82
-
83
- # Image encoder for inference
84
- class LayerNorm(nn.LayerNorm):
85
- """Subclass torch's LayerNorm to handle fp16."""
86
-
87
- def forward(self, x: torch.Tensor):
88
- x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
89
- return x
90
-
91
-
92
- class ColumnParallelConv2dPatch(torch.nn.Module):
93
- """Conv2D Patching layer with model parallelism.
94
- Column parallel over unfolded input.
95
- Arguments:
96
- in_channels: Input channels.
97
- out_channels: Output channels.
98
- kernel_size: Size of convolution kernel.
99
- stride (default 1): Stride for convolution.
100
- bias (default False): Use bias in Conv2d.
101
- Input: (bsz, in_channels, width, height)
102
- Output: (bsz, num_tokens, out_channels)
103
- """
104
-
105
- def __init__(
106
- self,
107
- in_channels: int,
108
- out_channels: int,
109
- kernel_size: int | tuple[int, int],
110
- stride: int | tuple[int, int],
111
- bias: bool | None = False,
112
- ) -> None:
113
- super().__init__()
114
- if isinstance(kernel_size, int):
115
- kernel_size = (kernel_size, kernel_size)
116
- self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
117
- self._linear = ColumnParallelLinear(
118
- in_channels * kernel_size[0] * kernel_size[1],
119
- out_channels,
120
- bias=bias,
121
- )
122
-
123
- def forward(self, x: torch.Tensor) -> torch.Tensor:
124
- x = self._unfold(x)
125
- x = x.permute(0, 2, 1)
126
- x = F.linear(x, self._linear.weight)
127
- x = gather_from_tensor_model_parallel_region(x)
128
- return x
129
-
130
-
131
- class ImageFeedForward(torch.nn.Module):
132
- def __init__(
133
- self,
134
- dim: int,
135
- hidden_dim: int,
136
- dropout: float,
137
- act_layer: Callable = nn.GELU,
138
- ):
139
- super().__init__()
140
- # layers
141
- self.c_fc = ColumnParallelLinear(
142
- dim,
143
- hidden_dim,
144
- bias=True,
145
- gather_output=False,
146
- init_method=lambda x: x,
147
- )
148
- self.c_proj = RowParallelLinear(
149
- hidden_dim,
150
- dim,
151
- bias=True,
152
- input_is_parallel=True,
153
- init_method=lambda x: x,
154
- )
155
- self.non_linearity = act_layer()
156
- self.dropout = dropout
157
-
158
- def forward(self, x):
159
- hidden = F.linear(x, self.c_fc.weight, self.c_fc.bias)
160
- hidden = self.non_linearity(hidden)
161
- hidden = F.linear(hidden, self.c_proj.weight)
162
- hidden = reduce_from_tensor_model_parallel_region(hidden)
163
- hidden += self.c_proj.bias
164
- return hidden
165
-
166
-
167
- class ImageAttention(nn.Module):
168
- def __init__(
169
- self,
170
- dim,
171
- head_dim,
172
- n_heads,
173
- ):
174
- super().__init__()
175
- world_size = fs_init.get_model_parallel_world_size()
176
- qkvo_replication = 1
177
- if world_size > 16:
178
- qkvo_replication = world_size // 8
179
-
180
- self.n_kv_heads = n_heads
181
- self.n_local_heads = n_heads * qkvo_replication // world_size
182
- self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // world_size
183
- self.n_rep = self.n_local_heads // self.n_local_kv_heads
184
- self.head_dim = dim // n_heads
185
-
186
- self.wq = ColumnParallelLinear(
187
- dim,
188
- qkvo_replication * n_heads * self.head_dim,
189
- bias=False,
190
- gather_output=False,
191
- init_method=lambda x: x,
192
- )
193
- self.wk = ColumnParallelLinear(
194
- dim,
195
- qkvo_replication * self.n_kv_heads * self.head_dim,
196
- bias=False,
197
- gather_output=False,
198
- init_method=lambda x: x,
199
- )
200
- self.wv = ColumnParallelLinear(
201
- dim,
202
- qkvo_replication * self.n_kv_heads * self.head_dim,
203
- bias=False,
204
- gather_output=False,
205
- init_method=lambda x: x,
206
- )
207
- self.wo = RowParallelLinear(
208
- qkvo_replication * n_heads * self.head_dim,
209
- dim,
210
- bias=False,
211
- input_is_parallel=True,
212
- init_method=lambda x: x,
213
- )
214
- self.qkvo_replication = qkvo_replication
215
-
216
- def forward(
217
- self,
218
- x: torch.Tensor,
219
- mask: torch.Tensor = None,
220
- ):
221
- xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]]
222
-
223
- bs, slen, _ = xq.shape
224
-
225
- xq = xq.view(bs, slen, self.n_local_heads, self.head_dim)
226
- xk = xk.view(bs, xk.shape[1], self.n_local_kv_heads, self.head_dim)
227
- xv = xv.view(bs, xv.shape[1], self.n_local_kv_heads, self.head_dim)
228
-
229
- xq, xk, xv = [tensor.transpose(1, 2) for tensor in (xq, xk, xv)]
230
-
231
- xk = xk.repeat_interleave(self.n_rep, dim=1)
232
- xv = xv.repeat_interleave(self.n_rep, dim=1)
233
-
234
- attn_output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=mask, dropout_p=0.0)
235
-
236
- attn_output = attn_output.transpose(1, 2).contiguous().reshape(bs, slen, -1)
237
-
238
- out = F.linear(attn_output, self.wo.weight)
239
- out = reduce_from_tensor_model_parallel_region(out)
240
- out = out / self.qkvo_replication
241
- return out
242
-
243
-
244
- class ImageTransformerBlock(nn.Module):
245
- def __init__(
246
- self,
247
- d_model: int,
248
- n_head: int,
249
- mlp_ratio: float = 4.0,
250
- act_layer: Callable = nn.GELU,
251
- gated: bool = False,
252
- ):
253
- super().__init__()
254
- assert d_model % n_head == 0
255
- self.n_heads = n_head
256
- self.head_dim = d_model // self.n_heads
257
- self.attn = ImageAttention(
258
- dim=d_model,
259
- head_dim=self.head_dim,
260
- n_heads=self.n_heads,
261
- )
262
- self.ln_1 = LayerNorm(d_model)
263
- self.mlp = ImageFeedForward(
264
- dim=d_model,
265
- hidden_dim=int(mlp_ratio * d_model),
266
- dropout=0.0,
267
- act_layer=act_layer,
268
- )
269
- self.ln_2 = LayerNorm(d_model)
270
- self.gated = gated
271
- if gated:
272
- self.gate_attn = nn.Parameter(torch.zeros(1))
273
- self.gate_ffn = nn.Parameter(torch.zeros(1))
274
-
275
- def forward(
276
- self,
277
- x: torch.Tensor,
278
- mask: torch.Tensor = None,
279
- ):
280
- _gate_attn = 1 if not self.gated else self.gate_attn.tanh()
281
- _gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
282
- x = x + _gate_attn * self.attn(self.ln_1(x), mask=mask)
283
- x = x + _gate_ffn * self.mlp(self.ln_2(x))
284
- return x
285
-
286
-
287
- class ImageTransformer(nn.Module):
288
- def __init__(
289
- self,
290
- width: int,
291
- layers: int,
292
- heads: int,
293
- mlp_ratio: float = 4.0,
294
- act_layer: Callable = nn.GELU,
295
- gated: bool = False,
296
- ):
297
- super().__init__()
298
- self.width = width
299
- self.layers = layers
300
- self.resblocks = nn.ModuleList(
301
- [
302
- ImageTransformerBlock(
303
- d_model=width,
304
- n_head=heads,
305
- mlp_ratio=mlp_ratio,
306
- act_layer=act_layer,
307
- gated=gated,
308
- )
309
- for _ in range(self.layers)
310
- ]
311
- )
312
-
313
- def forward(self, x: torch.Tensor, return_intermediate=None, mask=None):
314
- out = []
315
- for idx, r in enumerate(self.resblocks):
316
- if return_intermediate is not None and idx in return_intermediate:
317
- out.append(x)
318
- x = r(x, mask=mask)
319
- if return_intermediate is not None:
320
- return x, torch.stack(out, dim=-1)
321
- return x
322
-
323
-
324
- class VisionEncoder(nn.Module):
325
- def __init__(
326
- self,
327
- max_num_tiles: int,
328
- ckpt_path: str = None,
329
- image_size: int = 224,
330
- patch_size: int = 14,
331
- width: int = 1280,
332
- layers: int = 32,
333
- heads: int = 16,
334
- mlp_ratio: float = 4.0,
335
- act_layer: Callable = nn.GELU,
336
- in_channels: int = 3,
337
- load_ckpt: bool = False,
338
- n_global_layers: int = 2,
339
- global_model: bool = False,
340
- return_intermediate=None,
341
- ):
342
- super().__init__()
343
- self.global_model = global_model
344
- self.return_intermediate = return_intermediate
345
- self.max_num_tiles = max_num_tiles
346
- self.image_size = to_2tuple(image_size)
347
- self.patch_size = to_2tuple(patch_size)
348
- self.grid_size = (
349
- self.image_size[0] // self.patch_size[0],
350
- self.image_size[1] // self.patch_size[1],
351
- )
352
- self.conv1 = ColumnParallelConv2dPatch(
353
- in_channels=in_channels,
354
- out_channels=width,
355
- kernel_size=patch_size,
356
- stride=patch_size,
357
- bias=False,
358
- )
359
- scale = width**-0.5
360
- self.class_embedding = nn.Parameter(scale * torch.randn(width))
361
- self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
362
- self.ln_post = LayerNorm(width)
363
- self.ln_pre = LayerNorm(width)
364
- self.transformer = ImageTransformer(width, layers, heads, mlp_ratio, act_layer=act_layer)
365
- # pre and post tile position embedding
366
- self.global_transformer = ImageTransformer(
367
- width, n_global_layers, heads, mlp_ratio, act_layer=act_layer, gated=True
368
- )
369
- # pre and post tile position embedding
370
- self.pre_tile_pos_embed = TilePositionEmbedding(
371
- num_tiles=max_num_tiles,
372
- width=width,
373
- gated=True,
374
- )
375
- self.post_tile_pos_embed = TilePositionEmbedding(
376
- num_tiles=max_num_tiles,
377
- width=width,
378
- gated=True,
379
- )
380
- self.gated_positional_embedding = nn.Parameter(
381
- scale
382
- * torch.randn(
383
- max_num_tiles,
384
- max_num_tiles,
385
- self.grid_size[0] * self.grid_size[1] + 1,
386
- width,
387
- )
388
- )
389
- self.gated_positional_embedding_gate = nn.Parameter(torch.zeros(1))
390
-
391
- self._register_load_state_dict_pre_hook(self.load_hook)
392
-
393
- def load_hook(
394
- self,
395
- state_dict: dict[str, Any],
396
- prefix: str,
397
- local_metadata: dict[str, Any],
398
- strict: bool = True,
399
- missing_keys: list[str] = None,
400
- unexpected_keys: list[str] = None,
401
- error_msgs: list[str] = None,
402
- return_state_dict: bool = False,
403
- ) -> None:
404
- orig_pos_embed = state_dict.get(prefix + "positional_embedding")
405
- if orig_pos_embed is not None:
406
- new_pos_embed = resize_local_position_embedding(orig_pos_embed, self.grid_size)
407
- state_dict[prefix + "positional_embedding"] = new_pos_embed
408
- if hasattr(self, "gated_positional_embedding"):
409
- if prefix + "gated_positional_embedding" not in state_dict:
410
- # resize positional_embedding to fit the new grid size
411
- global_pos_embed = initialize_global_position_embedding_from_local(
412
- new_pos_embed,
413
- self.grid_size,
414
- self.max_num_tiles,
415
- self.max_num_tiles,
416
- )
417
- state_dict[prefix + "gated_positional_embedding"] = global_pos_embed
418
- state_dict[prefix + "gated_positional_embedding_gate"] = torch.zeros(1, dtype=global_pos_embed.dtype)
419
- logger.info(f"Initialized global positional embedding with size {global_pos_embed.size()}")
420
- else:
421
- global_pos_embed = resize_global_position_embedding(
422
- state_dict[prefix + "gated_positional_embedding"],
423
- self.grid_size,
424
- self.max_num_tiles,
425
- self.max_num_tiles,
426
- )
427
- logger.info(
428
- f"Resized global positional embedding from {state_dict[prefix + 'gated_positional_embedding'].size()} to {global_pos_embed.size()}"
429
- )
430
- state_dict[prefix + "gated_positional_embedding"] = global_pos_embed
431
- if return_state_dict:
432
- return state_dict
433
-
434
- def apply_positional_embedding(self, x, ar):
435
- # apply regular position embedding
436
- bsz, num_chunks, num_tokens, dim = x.shape
437
- x = x.view(bsz * num_chunks, num_tokens, dim)
438
- x = x + self.positional_embedding * (1 - self.gated_positional_embedding_gate.tanh())
439
- x = x.view(bsz, num_chunks, num_tokens, dim)
440
- for idx, arx in enumerate(ar):
441
- _pos_embed = self.gated_positional_embedding[: arx[0], : arx[1]]
442
- _pos_embed = _pos_embed.reshape(arx[0] * arx[1], *_pos_embed.shape[2:])
443
- x[idx, : arx[0] * arx[1]] += _pos_embed * self.gated_positional_embedding_gate.tanh()
444
- return x
445
-
446
- def apply_class_embedding(self, x):
447
- x = torch.cat(
448
- [
449
- self.class_embedding.to(x.dtype)
450
- + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
451
- x,
452
- ],
453
- dim=1,
454
- ) # shape = [*, grid ** 2 + 1, width]
455
- return x
456
-
457
- def forward(self, images: torch.Tensor, ar: torch.Tensor) -> torch.Tensor:
458
- if images.ndim == 5:
459
- num_concurrent_media = 1
460
- bsz, num_chunks, nch, w, h = images.shape
461
- else:
462
- bsz, num_concurrent_media, num_chunks, nch, w, h = images.shape
463
-
464
- images = images.reshape(bsz * num_concurrent_media * num_chunks, nch, w, h)
465
- ar = ar.reshape(bsz * num_concurrent_media, 2)
466
-
467
- # patch embedding
468
- x = images.reshape(bsz * num_concurrent_media * num_chunks, nch, w, h)
469
- x = self.conv1(x) # shape = [*, width, grid ** 2]
470
- _, ntok, dim = x.shape
471
- x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim)
472
-
473
- # tile embeddings
474
- x = self.pre_tile_pos_embed(x, ar)
475
- x = x.reshape(bsz * num_concurrent_media * num_chunks, ntok, dim)
476
-
477
- # apply cls token
478
- x = self.apply_class_embedding(x)
479
- ntok += 1
480
-
481
- # apply position embeddings
482
- x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim)
483
- x = self.apply_positional_embedding(x, ar)
484
-
485
- x = self.ln_pre(x)
486
- npad, attn_mask = 0, None
487
- x, npad = expand_num_tokens_to_mult8(x)
488
- attn_mask = build_encoder_attention_mask(x, ar, ntok, num_chunks, 1)
489
- x = x.view(bsz * num_concurrent_media, -1, dim)
490
- x, int_x = self.transformer(x, return_intermediate=self.return_intermediate, mask=attn_mask)
491
-
492
- x = self.ln_post(x)
493
- x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, dim)
494
- x = self.post_tile_pos_embed(x, ar)
495
- x = x.reshape(bsz * num_concurrent_media, num_chunks * (ntok + npad), dim)
496
- x = self.global_transformer(x, mask=attn_mask)
497
- x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, dim)
498
- x = contract_num_tokens_from_mult8(x, npad)
499
-
500
- # adding back intermediate layer outputs
501
- x = x.reshape(bsz, num_concurrent_media, num_chunks, ntok, dim)
502
- int_x = int_x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, -1)
503
- int_x = contract_num_tokens_from_mult8(int_x, npad)
504
- int_x = int_x.reshape(bsz, num_concurrent_media, num_chunks, ntok, -1)
505
- x = torch.cat([x, int_x], dim=-1)
506
- return x
507
-
508
-
509
- class Attention(nn.Module):
510
- """Multi-head attention module."""
511
-
512
- def __init__(self, args: ModelArgs):
513
- """
514
- Initialize the Attention module.
515
- Args:
516
- args (ModelArgs): Model configuration parameters.
517
- Attributes:
518
- n_kv_heads (int): Number of key and value heads.
519
- n_local_heads (int): Number of local query heads.
520
- n_local_kv_heads (int): Number of local key and value heads.
521
- n_rep (int): Number of repetitions for local heads.
522
- head_dim (int): Dimension size of each attention head.
523
- wq (ColumnParallelLinear): Linear transformation for queries.
524
- wk (ColumnParallelLinear): Linear transformation for keys.
525
- wv (ColumnParallelLinear): Linear transformation for values.
526
- wo (RowParallelLinear): Linear transformation for output.
527
- cache_k (torch.Tensor): Cached keys for attention.
528
- cache_v (torch.Tensor): Cached values for attention.
529
- """
530
- super().__init__()
531
- world_size = fs_init.get_model_parallel_world_size()
532
- replication_factor = 1
533
- if world_size > 8:
534
- replication_factor = world_size // MP_SCALE
535
-
536
- self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
537
- self.n_kv_heads *= replication_factor
538
-
539
- self.n_local_heads = args.n_heads // world_size
540
- self.n_local_kv_heads = self.n_kv_heads // world_size
541
- self.n_rep = self.n_local_heads // self.n_local_kv_heads
542
- self.head_dim = args.dim // args.n_heads
543
- self.max_seq_len = args.max_seq_len
544
-
545
- self.wq = ColumnParallelLinear(
546
- args.dim,
547
- args.n_heads * self.head_dim,
548
- bias=False,
549
- gather_output=False,
550
- init_method=lambda x: x,
551
- )
552
- self.wk = ColumnParallelLinear(
553
- args.dim,
554
- self.n_kv_heads * self.head_dim,
555
- bias=False,
556
- gather_output=False,
557
- init_method=lambda x: x,
558
- )
559
- self.wv = ColumnParallelLinear(
560
- args.dim,
561
- self.n_kv_heads * self.head_dim,
562
- bias=False,
563
- gather_output=False,
564
- init_method=lambda x: x,
565
- )
566
- self.wo = RowParallelLinear(
567
- args.n_heads * self.head_dim,
568
- args.dim,
569
- bias=False,
570
- input_is_parallel=True,
571
- init_method=lambda x: x,
572
- )
573
- self.n_heads = args.n_heads
574
-
575
- def setup_cache(self, max_batch_size: int, dtype: torch.dtype):
576
- cache_shape = (
577
- max_batch_size,
578
- self.max_seq_len,
579
- self.n_local_kv_heads,
580
- self.head_dim,
581
- )
582
- self.register_buffer(
583
- "key_cache",
584
- torch.zeros(
585
- cache_shape,
586
- dtype=dtype,
587
- ),
588
- persistent=False,
589
- )
590
- self.register_buffer(
591
- "value_cache",
592
- torch.zeros(
593
- cache_shape,
594
- dtype=dtype,
595
- ),
596
- persistent=False,
597
- )
598
-
599
- def forward(
600
- self,
601
- x: torch.Tensor,
602
- mask: torch.Tensor,
603
- freqs_cis: torch.Tensor,
604
- position_ids: torch.LongTensor,
605
- ):
606
- self.key_cache = self.key_cache.to(x.device)
607
- self.value_cache = self.value_cache.to(x.device)
608
-
609
- xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]]
610
-
611
- bs, slen, _ = xq.shape
612
-
613
- xq = xq.view(bs, slen, self.n_local_heads, self.head_dim)
614
- xk = xk.view(bs, xk.shape[1], self.n_local_kv_heads, self.head_dim)
615
- xv = xv.view(bs, xv.shape[1], self.n_local_kv_heads, self.head_dim)
616
-
617
- xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
618
-
619
- self.key_cache[:bs, position_ids, ...] = xk
620
- self.value_cache[:bs, position_ids, ...] = xv
621
-
622
- # TODO: we can avoid slicing on first dimension by always padding to max_batch_size()
623
- xk = self.key_cache[:bs, ...]
624
- xv = self.value_cache[:bs, ...]
625
-
626
- xq, xk, xv = [tensor.transpose(1, 2) for tensor in (xq, xk, xv)]
627
-
628
- xk = xk.repeat_interleave(self.n_rep, dim=1)
629
- xv = xv.repeat_interleave(self.n_rep, dim=1)
630
-
631
- attn_output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=mask, dropout_p=0.0)
632
-
633
- attn_output = attn_output.transpose(1, 2).contiguous().reshape(bs, slen, -1)
634
-
635
- out = F.linear(attn_output, self.wo.weight)
636
- out = reduce_from_tensor_model_parallel_region(out)
637
- return out
638
-
639
-
640
- class FeedForward(nn.Module):
641
- def __init__(
642
- self,
643
- dim: int,
644
- hidden_dim: int,
645
- multiple_of: int,
646
- ffn_dim_multiplier: float | None,
647
- ):
648
- """
649
- Initialize the FeedForward module.
650
- Args:
651
- dim (int): Input dimension.
652
- hidden_dim (int): Hidden dimension of the feedforward layer.
653
- multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
654
- ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
655
- Attributes:
656
- w1 (ColumnParallelLinear): Linear transformation for the first layer.
657
- w2 (RowParallelLinear): Linear transformation for the second layer.
658
- w3 (ColumnParallelLinear): Linear transformation for the third layer.
659
- """
660
- super().__init__()
661
- hidden_dim = int(2 * hidden_dim / 3)
662
- # custom dim factor multiplier
663
- if ffn_dim_multiplier is not None:
664
- hidden_dim = int(ffn_dim_multiplier * hidden_dim)
665
- hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
666
-
667
- self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
668
- self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
669
- self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
670
-
671
- def forward(self, x):
672
- x1, x3 = [F.linear(x, w) for w in [self.w1.weight, self.w3.weight]]
673
- x1 = F.silu(x1)
674
- x_in = x1 * x3
675
- out = F.linear(x_in, self.w2.weight)
676
- out = reduce_from_tensor_model_parallel_region(out)
677
- return out
678
-
679
-
680
- class TransformerBlock(nn.Module):
681
- def __init__(self, layer_id: int, args: ModelArgs):
682
- """
683
- Initialize a TransformerBlock.
684
- Args:
685
- layer_id (int): Identifier for the layer.
686
- args (ModelArgs): Model configuration parameters.
687
- Attributes:
688
- n_heads (int): Number of attention heads.
689
- dim (int): Dimension size of the model.
690
- head_dim (int): Dimension size of each attention head.
691
- attention (Attention): Attention module.
692
- feed_forward (FeedForward): FeedForward module.
693
- layer_id (int): Identifier for the layer.
694
- attention_norm (RMSNorm): Layer normalization for attention output.
695
- ffn_norm (RMSNorm): Layer normalization for feedforward output.
696
- """
697
- super().__init__()
698
- self.n_heads = args.n_heads
699
- self.dim = args.dim
700
- self.head_dim = args.dim // args.n_heads
701
- self.attention = Attention(args)
702
- self.feed_forward = FeedForward(
703
- dim=args.dim,
704
- hidden_dim=4 * args.dim,
705
- multiple_of=args.multiple_of,
706
- ffn_dim_multiplier=args.ffn_dim_multiplier,
707
- )
708
- self.layer_id = layer_id
709
- self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
710
- self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
711
-
712
- def setup_cache(self, max_batch_size: int, dtype: torch.dtype):
713
- self.attention.setup_cache(max_batch_size, dtype)
714
-
715
- def forward(
716
- self,
717
- x: torch.Tensor,
718
- freqs_cis: torch.Tensor,
719
- mask: torch.Tensor,
720
- position_ids: torch.LongTensor,
721
- ) -> torch.Tensor:
722
- """
723
- Perform a forward pass through the TransformerBlock.
724
- Args:
725
- x (torch.Tensor): Input tensor.
726
- start_pos (int): Starting position for attention caching.
727
- freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
728
- mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
729
- Returns:
730
- torch.Tensor: Output tensor after applying attention and feedforward layers.
731
- """
732
- h = self.attention.forward(
733
- x=self.attention_norm(x),
734
- freqs_cis=freqs_cis,
735
- mask=mask,
736
- position_ids=position_ids,
737
- )
738
- h = h + x
739
- out = h + self.feed_forward.forward(self.ffn_norm(h))
740
- return out
741
-
742
-
743
- class TilePositionEmbedding(nn.Module):
744
- def __init__(
745
- self,
746
- num_tiles: int,
747
- width: int,
748
- gated: bool = False,
749
- ):
750
- super().__init__()
751
- self.num_tiles = num_tiles
752
- self.width = width
753
- self.embedding = nn.Parameter(torch.randn(num_tiles, num_tiles, 1, width) / math.sqrt(width))
754
- self.gated = gated
755
- if gated:
756
- self.gate = nn.Parameter(torch.zeros(1))
757
-
758
- self._register_load_state_dict_pre_hook(self.load_hook)
759
-
760
- def load_hook(
761
- self,
762
- state_dict,
763
- prefix,
764
- local_metadata,
765
- strict,
766
- missing_keys,
767
- unexpected_keys,
768
- error_msgs,
769
- ):
770
- # load the weights from the checkpoint
771
- embed = state_dict.get(prefix + "embedding")
772
- if embed is not None:
773
- # reshape the weights to the correct shape
774
- nt_old, nt_old, _, w = embed.shape
775
- logger.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
776
- embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles)
777
- # assign the weights to the module
778
- state_dict[prefix + "embedding"] = embed_new
779
-
780
- @staticmethod
781
- def _dynamic_resize(embed: torch.Tensor, num_tiles: int):
782
- nt_old, nt_old, _, w = embed.shape
783
- embed = embed.permute(2, 3, 0, 1)
784
-
785
- embed_new = F.interpolate(
786
- embed,
787
- size=(num_tiles, num_tiles),
788
- mode="bilinear",
789
- align_corners=True,
790
- )
791
- # reshape the weights to the correct shape
792
- embed_new = embed_new.permute(2, 3, 0, 1)
793
- return embed_new
794
-
795
- def forward(self, x: torch.Tensor, ar: torch.Tensor, num_tiles: int = None):
796
- embed = self.embedding
797
- if num_tiles is None:
798
- num_tiles = self.num_tiles
799
- elif num_tiles > self.num_tiles:
800
- embed = TilePositionEmbedding._dynamic_resize(self.embedding, num_tiles)
801
- out_pos_embed = torch.zeros(x.shape[0], num_tiles, 1, self.width, device=x.device, dtype=x.dtype)
802
- for idx, arx in enumerate(ar):
803
- h, w = arx
804
- out_pos_embed[idx, : w * h] = embed[:h, :w].reshape(w * h, 1, self.width)
805
- if self.gated:
806
- out_pos_embed = out_pos_embed * self.gate.tanh()
807
- x = x + out_pos_embed
808
- return x
809
-
810
-
811
- def _noinit(x):
812
- return x
813
-
814
-
815
- class CrossAttention(torch.nn.Module):
816
- """Cross attention layer with model-parallel attention layers."""
817
-
818
- def __init__(
819
- self,
820
- dim: int,
821
- head_dim: int,
822
- n_heads: int,
823
- n_kv_heads: int,
824
- norm_eps: float,
825
- ):
826
- super().__init__()
827
- self.world_size = fs_init.get_model_parallel_world_size()
828
- replication_factor = 1
829
- if self.world_size > 8:
830
- replication_factor = self.world_size // MP_SCALE
831
- n_kv_heads *= replication_factor
832
-
833
- assert n_heads % n_kv_heads == 0
834
-
835
- self.wq = ColumnParallelLinear(
836
- dim,
837
- n_heads * head_dim,
838
- bias=False,
839
- gather_output=False,
840
- init_method=_noinit,
841
- )
842
-
843
- self.wk = ColumnParallelLinear(
844
- dim,
845
- n_kv_heads * head_dim,
846
- bias=False,
847
- gather_output=False,
848
- init_method=_noinit,
849
- )
850
- self.wv = ColumnParallelLinear(
851
- dim,
852
- n_kv_heads * head_dim,
853
- bias=False,
854
- gather_output=False,
855
- init_method=_noinit,
856
- )
857
- self.wo = RowParallelLinear(
858
- n_heads * head_dim,
859
- dim,
860
- bias=False,
861
- input_is_parallel=True,
862
- init_method=_noinit,
863
- )
864
-
865
- self.n_heads = n_heads
866
- self.head_dim = head_dim
867
- self.n_kv_heads = n_kv_heads
868
-
869
- self.q_norm = RMSNorm(
870
- self.head_dim,
871
- eps=norm_eps,
872
- )
873
- self.k_norm = RMSNorm(
874
- self.head_dim,
875
- eps=norm_eps,
876
- )
877
-
878
- # cross-attention heads are model parallel similar to
879
- # self-attention, and we also use the identical KV head
880
- # combination to ensure parity with the corresponding
881
- # trunk LLM (i.e., group query attention) -- @dubeya
882
- # local heads
883
- assert self.n_heads % self.n_kv_heads == 0
884
- assert self.n_heads % self.world_size == 0
885
- assert self.n_kv_heads % self.world_size == 0
886
- self.n_local_heads = self.n_heads // self.world_size
887
- self.n_local_kv_heads = self.n_kv_heads // self.world_size
888
- self.n_rep = self.n_local_heads // self.n_local_kv_heads
889
-
890
- def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor:
891
- bsz = xattn_tokens.shape[0]
892
- xk = self.wk(xattn_tokens)
893
- xv = self.wv(xattn_tokens)
894
-
895
- _, seqlen_y, _ = xk.shape
896
-
897
- xk = xk.view(bsz, seqlen_y, self.n_local_kv_heads, self.head_dim)
898
- xv = xv.view(bsz, seqlen_y, self.n_local_kv_heads, self.head_dim)
899
-
900
- xk, xv = [tensor.transpose(1, 2) for tensor in (xk, xv)]
901
-
902
- # repeat k/v heads if n_kv_heads < n_heads
903
- xk = xk.repeat_interleave(self.n_rep, dim=1)
904
- xv = xv.repeat_interleave(self.n_rep, dim=1)
905
-
906
- xk = self.k_norm(xk)
907
-
908
- return torch.stack([xk, xv])
909
-
910
- def compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor:
911
- return self._compute_xattn_kv_cache(xattn_tokens)
912
-
913
- def forward(
914
- self,
915
- x: torch.Tensor,
916
- xattn_mask: torch.Tensor,
917
- full_text_row_masked_out_mask: torch.Tensor,
918
- xattn_cache: torch.Tensor,
919
- ) -> torch.Tensor:
920
- xq = F.linear(x, self.wq.weight)
921
- bsz, seqlen, _ = x.shape
922
-
923
- xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
924
- xq = self.q_norm(xq)
925
- xq = xq.transpose(1, 2)
926
-
927
- xk, xv = xattn_cache
928
-
929
- output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=xattn_mask, dropout_p=0.0)
930
- output = output * full_text_row_masked_out_mask
931
- output = output.transpose(1, 2).contiguous().reshape(bsz, seqlen, -1)
932
-
933
- out = F.linear(output, self.wo.weight)
934
- out = reduce_from_tensor_model_parallel_region(out)
935
- return out
936
-
937
-
938
- class CrossAttentionTransformerBlock(torch.nn.Module):
939
- """Cross-attention transformer block with tanh-gated attention and feedforward."""
940
-
941
- def __init__(
942
- self,
943
- args: ModelArgs,
944
- layer_id: int,
945
- no_ffn: bool = False,
946
- ) -> None:
947
- super().__init__()
948
- self.layer_id = layer_id
949
- self.n_heads = args.n_heads
950
- self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
951
- self.dim = args.dim
952
- self.head_dim = args.dim // args.n_heads
953
- self.attention = CrossAttention(
954
- dim=args.dim,
955
- head_dim=self.head_dim,
956
- n_heads=self.n_heads,
957
- n_kv_heads=self.n_kv_heads,
958
- norm_eps=args.norm_eps,
959
- )
960
-
961
- self.attention_norm = RMSNorm(
962
- args.dim,
963
- eps=args.norm_eps,
964
- )
965
- self.gate_attn = torch.nn.Parameter(torch.zeros(1))
966
-
967
- self.feed_forward = FeedForward(
968
- dim=args.dim,
969
- hidden_dim=4 * args.dim,
970
- ffn_dim_multiplier=args.ffn_dim_multiplier,
971
- multiple_of=args.multiple_of,
972
- )
973
- self.ffn_norm = RMSNorm(
974
- args.dim,
975
- eps=args.norm_eps,
976
- )
977
- self.gate_ffwd = torch.nn.Parameter(torch.zeros(1))
978
-
979
- self.no_ffn = no_ffn
980
-
981
- def compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor:
982
- return self.attention.compute_xattn_kv_cache(xattn_tokens)
983
-
984
- def forward(
985
- self,
986
- x: torch.Tensor,
987
- xattn_mask: torch.Tensor,
988
- full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor],
989
- xattn_cache: torch.Tensor,
990
- ) -> torch.Tensor:
991
- _attn_out = self.attention(
992
- x=self.attention_norm(x),
993
- xattn_mask=xattn_mask,
994
- xattn_cache=xattn_cache,
995
- full_text_row_masked_out_mask=full_text_row_masked_out_mask,
996
- )
997
- h = x + self.gate_attn.tanh() * _attn_out
998
- _ffn = self.feed_forward(self.ffn_norm(h))
999
- _ffn = full_text_row_masked_out_mask[:, 0] * _ffn # type: ignore
1000
- h = h + self.gate_ffwd.tanh() * _ffn * float(not self.no_ffn)
1001
- return h
1002
-
1003
-
1004
- class DummyCrossAttentionTransformerBlock:
1005
- """Dummy cross-attention transformer block with tanh-gated attention and feedforward."""
1006
-
1007
- def __call__(
1008
- self,
1009
- x: torch.Tensor,
1010
- *args,
1011
- **kwargs,
1012
- ) -> torch.Tensor:
1013
- return x
1014
-
1015
-
1016
- class DummySelfAttentionTransformerBlock:
1017
- """Dummy self-attention transformer block"""
1018
-
1019
- def __call__(
1020
- self,
1021
- x: torch.Tensor,
1022
- *args,
1023
- **kwargs,
1024
- ) -> torch.Tensor:
1025
- return x
1026
-
1027
-
1028
- class CrossAttentionTransformerVision(torch.nn.Module):
1029
- def __init__(self, args: ModelArgs) -> None:
1030
- super().__init__()
1031
- return_intermediate = "3,7,15,23,30"
1032
- self.vision_input_dim = 1280
1033
- self.image_res = args.vision_chunk_size
1034
- self.max_num_chunks = args.vision_max_num_chunks
1035
- if return_intermediate is not None:
1036
- return_intermediate = [int(layer) for layer in return_intermediate.split(",")]
1037
- self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim
1038
- self.patch_size = 14
1039
- self.vision_encoder = VisionEncoder(
1040
- max_num_tiles=4,
1041
- image_size=args.vision_chunk_size,
1042
- patch_size=self.patch_size,
1043
- n_global_layers=8,
1044
- global_model=True,
1045
- return_intermediate=return_intermediate,
1046
- )
1047
- # vision token projection
1048
- self.vision_projection = ColumnParallelLinear(
1049
- self.vision_input_dim,
1050
- args.dim,
1051
- bias=True,
1052
- init_method=lambda x: x,
1053
- )
1054
-
1055
- def forward(self, images: torch.Tensor, aspect_ratios: torch.Tensor) -> torch.Tensor:
1056
- # vision_tokens: (B, T, D)
1057
- # aspect_ratios: (B, T)
1058
- # h: (B, T, D)
1059
- vision_tokens = self.vision_encoder(images.to(dtype=torch.get_default_dtype()), aspect_ratios)
1060
-
1061
- vision_tokens = F.linear(vision_tokens, self.vision_projection.weight, self.vision_projection.bias)
1062
- vision_tokens = gather_from_tensor_model_parallel_region(vision_tokens)
1063
- return vision_tokens
1064
-
1065
-
1066
- class CrossAttentionTransformerText(torch.nn.Module):
1067
- INFERENCE_IMAGE_TOKEN_ID = 128010
1068
-
1069
- def __init__(self, args: ModelArgs) -> None:
1070
- super().__init__()
1071
- self.world_size = fs_init.get_model_parallel_world_size()
1072
- assert args.vocab_size > 0
1073
- self.vocab_size = args.vocab_size
1074
- self.n_layers = args.n_layers
1075
- self.dim = args.dim
1076
- self.head_dim = args.dim // args.n_heads
1077
- self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
1078
- self.n_local_kv_heads = self.n_kv_heads // self.world_size
1079
- assert self.vocab_size % self.world_size == 0
1080
- self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
1081
- self.pos_embeddings = None
1082
- # final norm layer (not necessary for post-norm)
1083
- self.norm = RMSNorm(args.dim, eps=args.norm_eps)
1084
-
1085
- # output layer
1086
- self.output = ColumnParallelLinear(args.dim, args.vocab_size, bias=False, init_method=lambda x: x)
1087
-
1088
- self.n_llama_layers = args.n_layers
1089
- self.model_dim = args.dim
1090
-
1091
- # BLOCKS
1092
-
1093
- self.fusion_schedule = self._init_fusion_schedule(args.vision_num_cross_attention_layers)
1094
- self.learnable_embedding = VocabParallelEmbedding(
1095
- max(fs_init.get_model_parallel_world_size(), 8),
1096
- args.dim,
1097
- init_method=lambda x: x,
1098
- )
1099
- self.num_frozen_embeddings = self.tok_embeddings.num_embeddings
1100
- self._thresh = self.num_frozen_embeddings - 1
1101
-
1102
- # transformer blocks
1103
- self.layers = torch.nn.ModuleList()
1104
- self.cross_attention_layers = torch.nn.ModuleList()
1105
- for i in range(args.n_layers):
1106
- layer_id = i
1107
- block = TransformerBlock(args=args, layer_id=layer_id)
1108
- self.layers.append(block)
1109
- if layer_id in self.fusion_schedule:
1110
- xa_layer_id = self.fusion_schedule.index(layer_id) + args.n_layers
1111
- block = CrossAttentionTransformerBlock(
1112
- args,
1113
- layer_id=xa_layer_id,
1114
- )
1115
- self.cross_attention_layers.append(block)
1116
-
1117
- # add xattn and dummy layers to avoid conditionals in forward()
1118
- self.text_and_xattn_layers = []
1119
-
1120
- for idx, layer in enumerate(self.layers):
1121
- if idx in self.fusion_schedule:
1122
- xattn_layer_idx = self.fusion_schedule.index(idx)
1123
- xattn_layer = self.cross_attention_layers[xattn_layer_idx]
1124
- else:
1125
- xattn_layer_idx = 0
1126
- xattn_layer = DummyCrossAttentionTransformerBlock()
1127
-
1128
- self.text_and_xattn_layers.append(
1129
- (
1130
- layer,
1131
- xattn_layer,
1132
- xattn_layer_idx,
1133
- )
1134
- )
1135
- self.freqs_cis = precompute_freqs_cis(
1136
- args.dim // args.n_heads,
1137
- args.max_seq_len * 2,
1138
- args.rope_theta,
1139
- args.use_scaled_rope,
1140
- )
1141
-
1142
- self.args = args
1143
- self.cache_is_setup = False
1144
- self.max_seq_len = args.max_seq_len
1145
-
1146
- def _init_fusion_schedule(
1147
- self,
1148
- num_layers: int,
1149
- ) -> list[int]:
1150
- llama_layers = list(range(self.n_llama_layers))
1151
-
1152
- # uniformly spread the layers
1153
- k = math.ceil(len(llama_layers) / num_layers)
1154
- return llama_layers[::-1][::k][:num_layers][::-1]
1155
-
1156
- def get_partially_trainable_embedding(self, x):
1157
- xz = torch.zeros_like(x, device=x.device)
1158
- oz = torch.ones_like(x, device=x.device)
1159
- x_orig = torch.minimum(x, torch.tensor(self._thresh, device=x.device))
1160
- x_new = torch.maximum(x, torch.tensor(self._thresh + 1, device=x.device)) - self.num_frozen_embeddings
1161
-
1162
- mask_orig = torch.where(x >= self.num_frozen_embeddings, xz, oz).unsqueeze(-1)
1163
- mask_new = torch.where(x < self.num_frozen_embeddings, xz, oz).unsqueeze(-1)
1164
-
1165
- x_orig = self.tok_embeddings(x_orig)
1166
- x_new = self.learnable_embedding(x_new).type_as(x_orig)
1167
- return x_orig * mask_orig.type_as(x_orig) + x_new * mask_new.type_as(x_new)
1168
-
1169
- def forward(
1170
- self,
1171
- position_ids: torch.LongTensor,
1172
- h: torch.Tensor,
1173
- xattn_mask: torch.Tensor,
1174
- full_text_row_masked_out_mask: torch.Tensor,
1175
- xattn_caches: torch.Tensor,
1176
- text_only_inference: bool = False,
1177
- ):
1178
- assert self.cache_is_setup, "Please set up cache before calling forward"
1179
- self.mask_cache = self.mask_cache.to(h.device)
1180
- self.freqs_cis = self.freqs_cis.to(h.device)
1181
- mask = self.mask_cache.index_select(2, position_ids)
1182
- freqs_cis = self.freqs_cis.index_select(0, position_ids)
1183
-
1184
- for (
1185
- layer,
1186
- xattn_layer,
1187
- xattn_layer_idx,
1188
- ) in self.text_and_xattn_layers:
1189
- if not text_only_inference:
1190
- h = xattn_layer(
1191
- x=h,
1192
- xattn_mask=xattn_mask,
1193
- xattn_cache=xattn_caches[xattn_layer_idx],
1194
- full_text_row_masked_out_mask=full_text_row_masked_out_mask,
1195
- )
1196
- h = layer(
1197
- x=h,
1198
- mask=mask,
1199
- freqs_cis=freqs_cis,
1200
- position_ids=position_ids,
1201
- )
1202
-
1203
- h = self.norm(h)
1204
-
1205
- output = F.linear(h, self.output.weight)
1206
- output = gather_from_tensor_model_parallel_region(output)
1207
- return output.float()
1208
-
1209
- def setup_cache(self, max_batch_size: int, device: torch.device, dtype=torch.bfloat16):
1210
- # Set up the text kv caches
1211
- ones = torch.ones(
1212
- (self.max_seq_len, self.max_seq_len),
1213
- dtype=torch.bool,
1214
- device=device,
1215
- )
1216
- self.register_buffer(
1217
- "mask_cache",
1218
- torch.tril(
1219
- ones,
1220
- )
1221
- .unsqueeze(0)
1222
- .unsqueeze(0),
1223
- persistent=False,
1224
- )
1225
- for layer in self.layers:
1226
- layer.setup_cache(max_batch_size, dtype=dtype)
1227
- self.cache_is_setup = True
1228
-
1229
- def _get_xattn_mask(
1230
- self,
1231
- num_tokens,
1232
- text_device,
1233
- text_dtype,
1234
- vision_tokens,
1235
- cross_attention_masks,
1236
- ) -> tuple[Tensor, Tensor]:
1237
- assert vision_tokens is not None, "Vision tokens must be provided"
1238
- vision_seqlen = vision_tokens.shape[3]
1239
- assert vision_tokens.shape[1] == cross_attention_masks.shape[2], (
1240
- f"Mismatch in number of images given and number of masks given {vision_tokens.shape} {cross_attention_masks.shape}"
1241
- )
1242
- assert vision_tokens.shape[2] == cross_attention_masks.shape[3], (
1243
- f"Vision tokens shape {vision_tokens.shape} mismatch with xattn shape {cross_attention_masks.shape}"
1244
- )
1245
- assert num_tokens == cross_attention_masks.shape[1], (
1246
- f"Mismatch in text sequence length and cross attention mask sequence length {num_tokens} {cross_attention_masks.shape}"
1247
- )
1248
- _, _, _, num_image_tokens, image_token_dim = tuple(vision_tokens.shape)
1249
- bsz, ntext, nimg, nchunks = cross_attention_masks.shape
1250
- cross_attention_masks = (
1251
- cross_attention_masks.repeat_interleave(vision_seqlen, dim=3).view(bsz, ntext, -1).unsqueeze(1)
1252
- )
1253
- full_text_row_masked_out_mask = _get_full_row_masked_out_mask(
1254
- cross_attention_masks,
1255
- get_negative_inf_value(cross_attention_masks.dtype),
1256
- )
1257
- cross_attention_masks *= full_text_row_masked_out_mask
1258
-
1259
- return (
1260
- cross_attention_masks.to(device=text_device, dtype=text_dtype),
1261
- full_text_row_masked_out_mask.to(device=text_device),
1262
- )
1263
-
1264
-
1265
- class CrossAttentionTransformer(torch.nn.Module):
1266
- def __init__(self, args: ModelArgs) -> None:
1267
- super().__init__()
1268
- self.params = args
1269
-
1270
- self.model_dim = args.dim
1271
- self.vision_model = CrossAttentionTransformerVision(args)
1272
- self.text_model = CrossAttentionTransformerText(args)
1273
- self.image_res = args.vision_chunk_size
1274
- self.max_num_chunks = args.vision_max_num_chunks
1275
- self.image_transform = partial(
1276
- VariableSizeImageTransform(size=args.vision_chunk_size),
1277
- max_num_chunks=args.vision_max_num_chunks,
1278
- )
1279
-
1280
- def setup_cache(self, max_batch_size: int, device: torch.device, dtype: torch.dtype):
1281
- self.text_model.setup_cache(max_batch_size, device, dtype)
1282
-
1283
- def compute_vision_tokens_masks(
1284
- self,
1285
- batch_images: list[list[PIL_Image.Image]],
1286
- batch_masks: list[list[list[int]]],
1287
- total_len: int,
1288
- device: torch.device,
1289
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1290
- skip_vision_encoder = False
1291
-
1292
- assert len(batch_images) == len(batch_masks), "Images and masks must have the same length"
1293
-
1294
- max_num_images = max(len(x) for x in batch_images)
1295
- bsz = len(batch_images)
1296
-
1297
- if max_num_images == 0:
1298
- num_chunks = [[self.max_num_chunks] for _ in batch_images]
1299
- skip_vision_encoder = True
1300
- else:
1301
- images_and_aspect_ratios = [[self.image_transform(im) for im in row] for row in batch_images]
1302
- transformed_images = [[x[0] for x in row] for row in images_and_aspect_ratios]
1303
-
1304
- aspect_ratios = torch.ones(bsz, max_num_images, 2, dtype=torch.int64)
1305
- for i, row in enumerate(images_and_aspect_ratios):
1306
- if len(row) > 0:
1307
- aspect_ratios[i, : len(row)] = torch.stack([torch.tensor(x[1]) for x in row])
1308
-
1309
- stacked_images, num_chunks = _stack_images(
1310
- transformed_images,
1311
- max_num_chunks=self.max_num_chunks,
1312
- image_res=self.params.vision_chunk_size,
1313
- max_num_images=max_num_images,
1314
- )
1315
- stacked_images = stacked_images.to(device=device)
1316
-
1317
- if skip_vision_encoder:
1318
- vision_tokens = torch.zeros(
1319
- (
1320
- bsz,
1321
- max_num_images,
1322
- self.max_num_chunks,
1323
- int((self.vision_model.image_res / self.vision_model.patch_size) ** 2 + 1),
1324
- self.model_dim,
1325
- ),
1326
- )
1327
- else:
1328
- vision_tokens = self.vision_model(stacked_images, aspect_ratios).to(device=device)
1329
-
1330
- bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape)
1331
- xattn_caches = torch.stack(
1332
- [
1333
- layer.compute_xattn_kv_cache(vision_tokens.view(bsz, -1, image_token_dim))
1334
- for layer in self.text_model.cross_attention_layers
1335
- ]
1336
- )
1337
- padded_masks = _pad_masks(
1338
- batch_masks,
1339
- num_chunks,
1340
- total_len,
1341
- self.max_num_chunks,
1342
- )
1343
-
1344
- cross_attention_masks, full_text_row_masked_out_mask = self.text_model._get_xattn_mask(
1345
- num_tokens=total_len,
1346
- text_device=vision_tokens.device.type,
1347
- text_dtype=next(self.text_model.parameters()).dtype,
1348
- vision_tokens=vision_tokens,
1349
- cross_attention_masks=padded_masks,
1350
- )
1351
-
1352
- return (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask)
1353
-
1354
- def forward(
1355
- self,
1356
- position_ids: torch.Tensor,
1357
- tokens: torch.Tensor,
1358
- cross_attention_masks: torch.Tensor,
1359
- full_text_row_masked_out_mask: torch.Tensor,
1360
- xattn_caches: torch.Tensor,
1361
- text_only_inference: bool = False,
1362
- ) -> torch.Tensor:
1363
- h = self.text_model.get_partially_trainable_embedding(tokens[:, position_ids])
1364
- logits = self.text_model.forward(
1365
- position_ids=position_ids,
1366
- h=h,
1367
- xattn_mask=cross_attention_masks[:, :, position_ids],
1368
- full_text_row_masked_out_mask=full_text_row_masked_out_mask[:, :, position_ids],
1369
- xattn_caches=xattn_caches,
1370
- text_only_inference=text_only_inference,
1371
- )
1372
- return logits
1373
-
1374
-
1375
- def _stack_images(
1376
- images: list[list[PIL_Image.Image]],
1377
- max_num_chunks: int,
1378
- image_res: int,
1379
- max_num_images: int,
1380
- ) -> tuple[torch.Tensor, list[int]]:
1381
- """
1382
- Takes a list of list of images and stacks them into a tensor.
1383
- This function is needed since images can be of completely
1384
- different resolutions and aspect ratios.
1385
- """
1386
- out_images, out_num_chunks = [], []
1387
- for imgs_sample in images:
1388
- out_images_i = torch.zeros(
1389
- max_num_images,
1390
- max_num_chunks,
1391
- 3,
1392
- image_res,
1393
- image_res,
1394
- )
1395
- _num_chunks = []
1396
- for j, chunks_image in enumerate(imgs_sample):
1397
- out_images_i[j, : chunks_image.shape[0]] = chunks_image
1398
- _num_chunks.append(chunks_image.shape[0])
1399
- out_images.append(out_images_i)
1400
- out_num_chunks.append(_num_chunks)
1401
- return torch.stack(out_images), out_num_chunks
1402
-
1403
-
1404
- def _pad_masks(
1405
- all_masks: list[list[list[int]]],
1406
- all_num_chunks: list[list[int]],
1407
- total_len: int,
1408
- max_num_chunks: int,
1409
- ) -> torch.Tensor:
1410
- dtype = torch.get_default_dtype()
1411
- inf_value = get_negative_inf_value(dtype)
1412
-
1413
- bsz = len(all_masks)
1414
- max_num_media = max([len(m) for m in all_masks])
1415
-
1416
- out_masks = torch.full(
1417
- (bsz, total_len, max_num_media, max_num_chunks),
1418
- inf_value,
1419
- dtype=dtype,
1420
- )
1421
-
1422
- for idx, (mask, num_chunks) in enumerate(zip(all_masks, all_num_chunks, strict=False)):
1423
- for mask_idx, (mask_elem, mask_num_chunks) in enumerate(zip(mask, num_chunks, strict=False)):
1424
- if len(mask_elem) == 2:
1425
- mask_elem[1] = min(mask_elem[1], total_len)
1426
- if mask_elem[1] == -1:
1427
- mask_elem[1] = total_len
1428
- out_masks[idx, mask_elem[0] : mask_elem[1], mask_idx, :mask_num_chunks].fill_(0.0)
1429
-
1430
- return out_masks