llama-stack 0.4.4__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  30. llama_stack/distributions/nvidia/config.yaml +4 -1
  31. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  32. llama_stack/distributions/oci/config.yaml +4 -1
  33. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  34. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  35. llama_stack/distributions/starter/build.yaml +62 -0
  36. llama_stack/distributions/starter/config.yaml +22 -3
  37. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  38. llama_stack/distributions/starter/starter.py +13 -1
  39. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  40. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  41. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  42. llama_stack/distributions/template.py +10 -2
  43. llama_stack/distributions/watsonx/config.yaml +4 -1
  44. llama_stack/log.py +1 -0
  45. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  46. llama_stack/providers/inline/agents/meta_reference/agents.py +57 -61
  47. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +49 -51
  48. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +94 -22
  49. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  50. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  51. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  52. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  53. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  54. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  55. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  56. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  57. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  58. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  59. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +15 -18
  60. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  61. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  62. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  63. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  64. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  65. llama_stack/providers/registry/agents.py +1 -0
  66. llama_stack/providers/registry/inference.py +1 -9
  67. llama_stack/providers/registry/vector_io.py +136 -16
  68. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  69. llama_stack/providers/remote/files/s3/config.py +5 -3
  70. llama_stack/providers/remote/files/s3/files.py +2 -2
  71. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  72. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  73. llama_stack/providers/remote/inference/together/together.py +4 -0
  74. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  75. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  76. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  77. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  78. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  79. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  80. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  81. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  82. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  83. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  84. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  85. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  86. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  87. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  88. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  89. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  90. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  91. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  92. llama_stack/providers/utils/bedrock/client.py +3 -3
  93. llama_stack/providers/utils/bedrock/config.py +7 -7
  94. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  95. llama_stack/providers/utils/inference/http_client.py +239 -0
  96. llama_stack/providers/utils/inference/litellm_openai_mixin.py +5 -0
  97. llama_stack/providers/utils/inference/model_registry.py +148 -2
  98. llama_stack/providers/utils/inference/openai_compat.py +2 -1
  99. llama_stack/providers/utils/inference/openai_mixin.py +41 -2
  100. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  101. llama_stack/providers/utils/memory/vector_store.py +46 -19
  102. llama_stack/providers/utils/responses/responses_store.py +7 -7
  103. llama_stack/providers/utils/safety.py +114 -0
  104. llama_stack/providers/utils/tools/mcp.py +44 -3
  105. llama_stack/testing/api_recorder.py +9 -3
  106. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/METADATA +14 -2
  107. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/RECORD +111 -144
  108. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  109. llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
  110. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  111. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  112. llama_stack/models/llama/hadamard_utils.py +0 -88
  113. llama_stack/models/llama/llama3/args.py +0 -74
  114. llama_stack/models/llama/llama3/dog.jpg +0 -0
  115. llama_stack/models/llama/llama3/generation.py +0 -378
  116. llama_stack/models/llama/llama3/model.py +0 -304
  117. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  118. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  119. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  120. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  121. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  122. llama_stack/models/llama/llama3/pasta.jpeg +0 -0
  123. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  124. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  125. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  126. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  127. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  128. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  129. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  130. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  131. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  132. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  133. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  134. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  135. llama_stack/models/llama/llama4/args.py +0 -107
  136. llama_stack/models/llama/llama4/ffn.py +0 -58
  137. llama_stack/models/llama/llama4/moe.py +0 -214
  138. llama_stack/models/llama/llama4/preprocess.py +0 -435
  139. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  140. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  141. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  142. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  143. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  144. llama_stack/models/llama/quantize_impls.py +0 -316
  145. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  146. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  147. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  148. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  149. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  150. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  151. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  152. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/WHEEL +0 -0
  153. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/entry_points.txt +0 -0
  154. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  155. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,304 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- import math
8
-
9
- import fairscale.nn.model_parallel.initialize as fs_init
10
- import torch
11
- import torch.nn.functional as F
12
- from fairscale.nn.model_parallel.layers import (
13
- ColumnParallelLinear,
14
- RowParallelLinear,
15
- VocabParallelEmbedding,
16
- )
17
- from torch import nn
18
-
19
- from .args import ModelArgs
20
-
21
- # **NOTE**: This code is not runnable without installing `torch` and `fairscale`
22
- # dependencies. These dependencies are not part of the default dependencies
23
- # (requirements.txt) of the `llama-models` package.
24
-
25
-
26
- class RMSNorm(torch.nn.Module):
27
- def __init__(self, dim: int, eps: float = 1e-6):
28
- super().__init__()
29
- self.eps = eps
30
- self.weight = nn.Parameter(torch.ones(dim))
31
-
32
- def _norm(self, x):
33
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
34
-
35
- def forward(self, x):
36
- output = self._norm(x.float()).type_as(x)
37
- return output * self.weight
38
-
39
-
40
- def apply_scaling(freqs: torch.Tensor) -> torch.Tensor:
41
- # Values obtained from grid search
42
- scale_factor = 8
43
- low_freq_factor = 1
44
- high_freq_factor = 4
45
- old_context_len = 8192 # original llama3 length
46
-
47
- low_freq_wavelen = old_context_len / low_freq_factor
48
- high_freq_wavelen = old_context_len / high_freq_factor
49
-
50
- wavelen = 2 * torch.pi / freqs
51
- new_freqs = torch.where(wavelen > low_freq_wavelen, freqs / scale_factor, freqs)
52
- smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
53
- return torch.where(
54
- (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
55
- (1 - smooth) * new_freqs / scale_factor + smooth * new_freqs,
56
- new_freqs,
57
- )
58
-
59
-
60
- def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
61
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
62
- t = torch.arange(end, device=freqs.device, dtype=torch.float32)
63
- if use_scaled:
64
- freqs = apply_scaling(freqs)
65
- freqs = torch.outer(t, freqs)
66
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
67
- return freqs_cis
68
-
69
-
70
- def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
71
- ndim = x.ndim
72
- assert 0 <= 1 < ndim
73
- assert freqs_cis.shape == (x.shape[1], x.shape[-1])
74
- shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
75
- return freqs_cis.view(*shape)
76
-
77
-
78
- def apply_rotary_emb(
79
- xq: torch.Tensor,
80
- xk: torch.Tensor,
81
- freqs_cis: torch.Tensor,
82
- ) -> tuple[torch.Tensor, torch.Tensor]:
83
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
84
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
85
- freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
86
- xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
87
- xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
88
- return xq_out.type_as(xq), xk_out.type_as(xk)
89
-
90
-
91
- def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
92
- """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
93
- bs, slen, n_kv_heads, head_dim = x.shape
94
- if n_rep == 1:
95
- return x
96
- return (
97
- x[:, :, :, None, :]
98
- .expand(bs, slen, n_kv_heads, n_rep, head_dim)
99
- .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
100
- )
101
-
102
-
103
- class Attention(nn.Module):
104
- def __init__(self, args: ModelArgs):
105
- super().__init__()
106
- self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
107
- world_size = fs_init.get_model_parallel_world_size()
108
- self.n_local_heads = args.n_heads // world_size
109
- self.n_local_kv_heads = self.n_kv_heads // world_size
110
- self.n_rep = self.n_local_heads // self.n_local_kv_heads
111
- self.head_dim = args.dim // args.n_heads
112
-
113
- self.wq = ColumnParallelLinear(
114
- args.dim,
115
- args.n_heads * self.head_dim,
116
- bias=False,
117
- gather_output=False,
118
- init_method=lambda x: x,
119
- )
120
- self.wk = ColumnParallelLinear(
121
- args.dim,
122
- self.n_kv_heads * self.head_dim,
123
- bias=False,
124
- gather_output=False,
125
- init_method=lambda x: x,
126
- )
127
- self.wv = ColumnParallelLinear(
128
- args.dim,
129
- self.n_kv_heads * self.head_dim,
130
- bias=False,
131
- gather_output=False,
132
- init_method=lambda x: x,
133
- )
134
- self.wo = RowParallelLinear(
135
- args.n_heads * self.head_dim,
136
- args.dim,
137
- bias=False,
138
- input_is_parallel=True,
139
- init_method=lambda x: x,
140
- )
141
-
142
- self.cache_k = torch.zeros(
143
- (
144
- args.max_batch_size,
145
- args.max_seq_len,
146
- self.n_local_kv_heads,
147
- self.head_dim,
148
- )
149
- )
150
- self.cache_v = torch.zeros(
151
- (
152
- args.max_batch_size,
153
- args.max_seq_len,
154
- self.n_local_kv_heads,
155
- self.head_dim,
156
- )
157
- )
158
-
159
- def forward(
160
- self,
161
- x: torch.Tensor,
162
- start_pos: int,
163
- freqs_cis: torch.Tensor,
164
- mask: torch.Tensor | None,
165
- ):
166
- bsz, seqlen, _ = x.shape
167
- xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
168
-
169
- xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
170
- xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
171
- xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
172
-
173
- xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
174
-
175
- self.cache_k = self.cache_k.to(xq)
176
- self.cache_v = self.cache_v.to(xq)
177
-
178
- self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
179
- self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
180
-
181
- keys = self.cache_k[:bsz, : start_pos + seqlen]
182
- values = self.cache_v[:bsz, : start_pos + seqlen]
183
-
184
- # repeat k/v heads if n_kv_heads < n_heads
185
- keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
186
- values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
187
-
188
- xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
189
- keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
190
- values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
191
- scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
192
- if mask is not None:
193
- scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
194
- scores = F.softmax(scores.float(), dim=-1).type_as(xq)
195
- output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
196
- output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
197
- return self.wo(output)
198
-
199
-
200
- class FeedForward(nn.Module):
201
- def __init__(
202
- self,
203
- dim: int,
204
- hidden_dim: int,
205
- multiple_of: int,
206
- ffn_dim_multiplier: float | None,
207
- ):
208
- super().__init__()
209
- hidden_dim = int(2 * hidden_dim / 3)
210
- # custom dim factor multiplier
211
- if ffn_dim_multiplier is not None:
212
- hidden_dim = int(ffn_dim_multiplier * hidden_dim)
213
- hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
214
-
215
- self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
216
- self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
217
- self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
218
-
219
- def forward(self, x):
220
- return self.w2(F.silu(self.w1(x)) * self.w3(x))
221
-
222
-
223
- class TransformerBlock(nn.Module):
224
- def __init__(self, layer_id: int, args: ModelArgs):
225
- super().__init__()
226
- self.n_heads = args.n_heads
227
- self.dim = args.dim
228
- self.head_dim = args.dim // args.n_heads
229
- self.attention = Attention(args)
230
- self.feed_forward = FeedForward(
231
- dim=args.dim,
232
- hidden_dim=4 * args.dim,
233
- multiple_of=args.multiple_of,
234
- ffn_dim_multiplier=args.ffn_dim_multiplier,
235
- )
236
- self.layer_id = layer_id
237
- self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
238
- self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
239
-
240
- def forward(
241
- self,
242
- x: torch.Tensor,
243
- start_pos: int,
244
- freqs_cis: torch.Tensor,
245
- mask: torch.Tensor | None,
246
- ):
247
- h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
248
- out = h + self.feed_forward(self.ffn_norm(h))
249
- return out
250
-
251
-
252
- class Transformer(nn.Module):
253
- def __init__(self, params: ModelArgs):
254
- super().__init__()
255
- self.params = params
256
- self.vocab_size = params.vocab_size
257
- self.n_layers = params.n_layers
258
-
259
- self.tok_embeddings = VocabParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x)
260
-
261
- self.layers = torch.nn.ModuleList()
262
- for layer_id in range(params.n_layers):
263
- self.layers.append(TransformerBlock(layer_id, params))
264
-
265
- self.norm = RMSNorm(params.dim, eps=params.norm_eps)
266
- self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)
267
-
268
- self.freqs_cis = precompute_freqs_cis(
269
- params.dim // params.n_heads,
270
- params.max_seq_len * 2,
271
- params.rope_theta,
272
- params.use_scaled_rope,
273
- )
274
-
275
- @torch.inference_mode()
276
- def forward(self, tokens: torch.Tensor, start_pos: int):
277
- _bsz, seqlen = tokens.shape
278
- h = self.tok_embeddings(tokens)
279
- self.freqs_cis = self.freqs_cis.to(h.device)
280
- freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
281
-
282
- mask = None
283
- if seqlen > 1:
284
- mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
285
-
286
- mask = torch.triu(mask, diagonal=1)
287
-
288
- # https://github.com/pytorch/pytorch/issues/100005
289
- # torch.triu is buggy when the device is mps: filled values are
290
- # nan instead of 0.
291
- if mask.device.type == torch.device("mps").type:
292
- mask = torch.nan_to_num(mask, nan=0.0)
293
-
294
- # When performing key-value caching, we compute the attention scores
295
- # only for the new sequence. Thus, the matrix of scores is of size
296
- # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
297
- # j > cache_len + i, since row i corresponds to token cache_len + i.
298
- mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
299
-
300
- for layer in self.layers:
301
- h = layer(h, start_pos, freqs_cis, mask)
302
- h = self.norm(h)
303
- output = self.output(h).float()
304
- return output
@@ -1,12 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # Copyright (c) Meta Platforms, Inc. and affiliates.
8
- # All rights reserved.
9
- #
10
- # This source code is licensed under the terms described in the LICENSE file in
11
- # top-level folder for each specific model found within the models/ directory at
12
- # the top-level of this source tree.
@@ -1,180 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # Copyright (c) Meta Platforms, Inc. and affiliates.
8
- # All rights reserved.
9
- #
10
- # This source code is licensed under the terms described in the LICENSE file in
11
- # top-level folder for each specific model found within the models/ directory at
12
- # the top-level of this source tree.
13
-
14
- # Copyright (c) Meta Platforms, Inc. and its affiliates.
15
- import math
16
-
17
- import torch
18
- import torch.nn.functional as F
19
-
20
- from llama_stack.log import get_logger
21
-
22
- from .utils import get_negative_inf_value, to_2tuple
23
-
24
- logger = get_logger(name=__name__, category="models::llama")
25
-
26
-
27
- def resize_local_position_embedding(orig_pos_embed, grid_size):
28
- """
29
- Resize position embedding for vision encoder.
30
- Original position embedding is [n_tiles * n_tiles + 1, dim]
31
- New position embedding will be [grid_size[0] * grid_size[1] + 1, dim]
32
- """
33
- new_grid_size = to_2tuple(grid_size)
34
- orig_grid_size = to_2tuple(int(math.sqrt(len(orig_pos_embed) - 1)))
35
-
36
- new_pos_emb_tok, new_pos_emb_img = (
37
- orig_pos_embed[:1],
38
- orig_pos_embed[1:],
39
- )
40
- logger.info(f"resizing position embedding grid-size from {orig_grid_size} to {new_grid_size}")
41
-
42
- new_pos_emb_img = new_pos_emb_img.reshape(1, orig_grid_size[0], orig_grid_size[1], -1).permute(0, 3, 1, 2)
43
-
44
- new_pos_emb_img = F.interpolate(
45
- new_pos_emb_img,
46
- size=new_grid_size,
47
- mode="bilinear",
48
- align_corners=True,
49
- )
50
- new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1).reshape(1, new_grid_size[0] * new_grid_size[1], -1)[0]
51
- new_pos_embed = torch.cat([new_pos_emb_tok, new_pos_emb_img], dim=0)
52
- return new_pos_embed
53
-
54
-
55
- def initialize_global_position_embedding_from_local(pos_and_cls_embed, grid_size, x_scale, y_scale):
56
- """
57
- Takes a local position embedding for vision encoder and uses it
58
- to initialize the global position embedding.
59
- Input: local position embedding of shape [grid_size[0] * grid_size[1] + 1, dim]
60
- Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
61
- Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
62
- """
63
- pos_embed = pos_and_cls_embed[1:]
64
- cls_embed = pos_and_cls_embed[0].view(1, 1, 1, -1)
65
- grid_size = to_2tuple(grid_size)
66
- new_pos_emb_img = pos_embed.reshape(1, grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2)
67
- new_grid_size = (x_scale * grid_size[0], y_scale * grid_size[1])
68
- new_pos_emb_img = F.interpolate(
69
- new_pos_emb_img,
70
- size=new_grid_size,
71
- mode="bilinear",
72
- align_corners=True,
73
- )
74
- new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1)
75
- new_pos_emb_img = new_pos_emb_img.view(x_scale, grid_size[0], y_scale, grid_size[1], -1)
76
- new_pos_emb_img = new_pos_emb_img.permute(0, 2, 1, 3, 4).contiguous()
77
- new_pos_emb_img = new_pos_emb_img.reshape(x_scale, y_scale, grid_size[0] * grid_size[1], -1)
78
- cls_embed = cls_embed.expand(x_scale, y_scale, -1, -1)
79
- pos_and_cls_embed = torch.cat([cls_embed, new_pos_emb_img], dim=2)
80
- return pos_and_cls_embed
81
-
82
-
83
- def resize_global_position_embedding(pos_and_cls_embed, grid_size, x_scale, y_scale):
84
- """
85
- Takes a global position embedding for vision encoder and resizes it to new size.
86
- Input: global position embedding of shape [x_old, y_old, old_grid_size[0] * old_grid_size[1] + 1, dim]
87
- Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
88
- Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
89
- """
90
- # first remove cls token
91
- pos_embed = pos_and_cls_embed[:, :, 1:]
92
- cls_embed = pos_and_cls_embed[:, :, 0].unsqueeze(2)
93
-
94
- xs_old, ys_old, ntok, dim = pos_embed.shape
95
- old_grid_size = int(math.sqrt(ntok))
96
-
97
- # move to correct form for interpolation
98
- pos_embed = pos_embed.view(xs_old, ys_old, old_grid_size, old_grid_size, dim)
99
- pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
100
- pos_embed = pos_embed.view(xs_old * old_grid_size, ys_old * old_grid_size, dim)
101
- pos_embed = pos_embed.unsqueeze(0)
102
-
103
- # interpolate
104
- new_size = (grid_size[0] * x_scale, grid_size[1] * y_scale)
105
- pos_embed = pos_embed.permute(0, 3, 1, 2)
106
- pos_embed_resized = F.interpolate(
107
- pos_embed,
108
- size=new_size,
109
- mode="bilinear",
110
- align_corners=True,
111
- )
112
- pos_embed = pos_embed_resized.permute(0, 2, 3, 1)[0]
113
-
114
- # move it back in place
115
- pos_embed = pos_embed.view(x_scale, grid_size[0], y_scale, grid_size[1], dim)
116
- pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
117
- pos_embed = pos_embed.view(x_scale, y_scale, grid_size[0] * grid_size[1], dim)
118
-
119
- # interpolate cls token
120
- cls_embed = cls_embed.permute(2, 3, 0, 1)
121
- cls_embed_resized = F.interpolate(
122
- cls_embed,
123
- size=(x_scale, y_scale),
124
- mode="bilinear",
125
- align_corners=True,
126
- )
127
- cls_embed = cls_embed_resized.permute(2, 3, 0, 1)
128
- # add cls token back in
129
- pos_and_cls_embed = torch.cat([cls_embed, pos_embed], dim=2)
130
-
131
- return pos_and_cls_embed
132
-
133
-
134
- def build_encoder_attention_mask(
135
- x: torch.Tensor,
136
- ar: torch.Tensor,
137
- ntok: int,
138
- num_chunks: int,
139
- n_heads: int,
140
- ):
141
- """
142
- Build vision encoder attention mask that omits padding tokens.
143
- """
144
- masks_list: list[torch.Tensor] = []
145
- for arx in ar:
146
- mask_i = torch.ones((num_chunks, x.shape[2], 1), dtype=x.dtype)
147
- mask_i[: arx[0] * arx[1], :ntok] = 0
148
- mask_i = mask_i.view(num_chunks * x.shape[2], -1)
149
- mask_i = mask_i @ mask_i.T * get_negative_inf_value(x.dtype)
150
- mask_i = mask_i.unsqueeze(0)
151
- masks_list.append(mask_i)
152
- masks = torch.stack(masks_list).to(x.device).expand(-1, n_heads, -1, -1)
153
- return masks
154
-
155
-
156
- def expand_num_tokens_to_mult8(x):
157
- num_pad_tokens = 8 - (x.shape[-2] % 8)
158
- if num_pad_tokens == 0:
159
- return x, 0
160
- else:
161
- return (
162
- torch.cat(
163
- [
164
- x,
165
- torch.zeros(
166
- (x.shape[0], x.shape[1], num_pad_tokens, x.shape[-1]),
167
- dtype=x.dtype,
168
- device=x.device,
169
- ),
170
- ],
171
- dim=-2,
172
- ),
173
- num_pad_tokens,
174
- )
175
-
176
-
177
- def contract_num_tokens_from_mult8(x, num_pad_tokens):
178
- if num_pad_tokens == 0:
179
- return x
180
- return x[:, :, :-num_pad_tokens]