llama-stack 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  30. llama_stack/distributions/nvidia/config.yaml +4 -1
  31. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  32. llama_stack/distributions/oci/config.yaml +4 -1
  33. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  34. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  35. llama_stack/distributions/starter/build.yaml +62 -0
  36. llama_stack/distributions/starter/config.yaml +22 -3
  37. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  38. llama_stack/distributions/starter/starter.py +13 -1
  39. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  40. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  41. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  42. llama_stack/distributions/template.py +10 -2
  43. llama_stack/distributions/watsonx/config.yaml +4 -1
  44. llama_stack/log.py +1 -0
  45. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  46. llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
  47. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +53 -51
  48. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
  49. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  50. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  51. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  52. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  53. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  54. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  55. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  56. llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
  57. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  58. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  59. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  60. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
  61. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  62. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  63. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  64. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  65. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  66. llama_stack/providers/registry/agents.py +1 -0
  67. llama_stack/providers/registry/inference.py +1 -9
  68. llama_stack/providers/registry/vector_io.py +136 -16
  69. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  70. llama_stack/providers/remote/files/s3/config.py +5 -3
  71. llama_stack/providers/remote/files/s3/files.py +2 -2
  72. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  73. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  74. llama_stack/providers/remote/inference/together/together.py +4 -0
  75. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  76. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  77. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  78. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  79. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  80. llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
  81. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  82. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  83. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  84. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  85. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  86. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  87. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  88. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  89. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  90. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  91. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  92. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  93. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  94. llama_stack/providers/utils/bedrock/client.py +3 -3
  95. llama_stack/providers/utils/bedrock/config.py +7 -7
  96. llama_stack/providers/utils/inference/__init__.py +0 -25
  97. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  98. llama_stack/providers/utils/inference/http_client.py +239 -0
  99. llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
  100. llama_stack/providers/utils/inference/model_registry.py +148 -2
  101. llama_stack/providers/utils/inference/openai_compat.py +1 -158
  102. llama_stack/providers/utils/inference/openai_mixin.py +42 -2
  103. llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
  104. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  105. llama_stack/providers/utils/memory/vector_store.py +46 -19
  106. llama_stack/providers/utils/responses/responses_store.py +7 -7
  107. llama_stack/providers/utils/safety.py +114 -0
  108. llama_stack/providers/utils/tools/mcp.py +44 -3
  109. llama_stack/testing/api_recorder.py +9 -3
  110. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
  111. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/RECORD +115 -148
  112. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  113. llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
  114. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  115. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  116. llama_stack/models/llama/hadamard_utils.py +0 -88
  117. llama_stack/models/llama/llama3/args.py +0 -74
  118. llama_stack/models/llama/llama3/dog.jpg +0 -0
  119. llama_stack/models/llama/llama3/generation.py +0 -378
  120. llama_stack/models/llama/llama3/model.py +0 -304
  121. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  122. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  123. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  124. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  125. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  126. llama_stack/models/llama/llama3/pasta.jpeg +0 -0
  127. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  128. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  129. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  130. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  131. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  132. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  133. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  134. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  135. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  136. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  137. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  138. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  139. llama_stack/models/llama/llama4/args.py +0 -107
  140. llama_stack/models/llama/llama4/ffn.py +0 -58
  141. llama_stack/models/llama/llama4/moe.py +0 -214
  142. llama_stack/models/llama/llama4/preprocess.py +0 -435
  143. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  144. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  145. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  146. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  147. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  148. llama_stack/models/llama/quantize_impls.py +0 -316
  149. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  150. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  151. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  152. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  153. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  154. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  155. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  156. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
  157. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
  158. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
  159. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/top_level.txt +0 -0
@@ -1,304 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- import math
8
-
9
- import fairscale.nn.model_parallel.initialize as fs_init
10
- import torch
11
- import torch.nn.functional as F
12
- from fairscale.nn.model_parallel.layers import (
13
- ColumnParallelLinear,
14
- RowParallelLinear,
15
- VocabParallelEmbedding,
16
- )
17
- from torch import nn
18
-
19
- from .args import ModelArgs
20
-
21
- # **NOTE**: This code is not runnable without installing `torch` and `fairscale`
22
- # dependencies. These dependencies are not part of the default dependencies
23
- # (requirements.txt) of the `llama-models` package.
24
-
25
-
26
- class RMSNorm(torch.nn.Module):
27
- def __init__(self, dim: int, eps: float = 1e-6):
28
- super().__init__()
29
- self.eps = eps
30
- self.weight = nn.Parameter(torch.ones(dim))
31
-
32
- def _norm(self, x):
33
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
34
-
35
- def forward(self, x):
36
- output = self._norm(x.float()).type_as(x)
37
- return output * self.weight
38
-
39
-
40
- def apply_scaling(freqs: torch.Tensor) -> torch.Tensor:
41
- # Values obtained from grid search
42
- scale_factor = 8
43
- low_freq_factor = 1
44
- high_freq_factor = 4
45
- old_context_len = 8192 # original llama3 length
46
-
47
- low_freq_wavelen = old_context_len / low_freq_factor
48
- high_freq_wavelen = old_context_len / high_freq_factor
49
-
50
- wavelen = 2 * torch.pi / freqs
51
- new_freqs = torch.where(wavelen > low_freq_wavelen, freqs / scale_factor, freqs)
52
- smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
53
- return torch.where(
54
- (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
55
- (1 - smooth) * new_freqs / scale_factor + smooth * new_freqs,
56
- new_freqs,
57
- )
58
-
59
-
60
- def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
61
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
62
- t = torch.arange(end, device=freqs.device, dtype=torch.float32)
63
- if use_scaled:
64
- freqs = apply_scaling(freqs)
65
- freqs = torch.outer(t, freqs)
66
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
67
- return freqs_cis
68
-
69
-
70
- def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
71
- ndim = x.ndim
72
- assert 0 <= 1 < ndim
73
- assert freqs_cis.shape == (x.shape[1], x.shape[-1])
74
- shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
75
- return freqs_cis.view(*shape)
76
-
77
-
78
- def apply_rotary_emb(
79
- xq: torch.Tensor,
80
- xk: torch.Tensor,
81
- freqs_cis: torch.Tensor,
82
- ) -> tuple[torch.Tensor, torch.Tensor]:
83
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
84
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
85
- freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
86
- xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
87
- xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
88
- return xq_out.type_as(xq), xk_out.type_as(xk)
89
-
90
-
91
- def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
92
- """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
93
- bs, slen, n_kv_heads, head_dim = x.shape
94
- if n_rep == 1:
95
- return x
96
- return (
97
- x[:, :, :, None, :]
98
- .expand(bs, slen, n_kv_heads, n_rep, head_dim)
99
- .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
100
- )
101
-
102
-
103
- class Attention(nn.Module):
104
- def __init__(self, args: ModelArgs):
105
- super().__init__()
106
- self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
107
- world_size = fs_init.get_model_parallel_world_size()
108
- self.n_local_heads = args.n_heads // world_size
109
- self.n_local_kv_heads = self.n_kv_heads // world_size
110
- self.n_rep = self.n_local_heads // self.n_local_kv_heads
111
- self.head_dim = args.dim // args.n_heads
112
-
113
- self.wq = ColumnParallelLinear(
114
- args.dim,
115
- args.n_heads * self.head_dim,
116
- bias=False,
117
- gather_output=False,
118
- init_method=lambda x: x,
119
- )
120
- self.wk = ColumnParallelLinear(
121
- args.dim,
122
- self.n_kv_heads * self.head_dim,
123
- bias=False,
124
- gather_output=False,
125
- init_method=lambda x: x,
126
- )
127
- self.wv = ColumnParallelLinear(
128
- args.dim,
129
- self.n_kv_heads * self.head_dim,
130
- bias=False,
131
- gather_output=False,
132
- init_method=lambda x: x,
133
- )
134
- self.wo = RowParallelLinear(
135
- args.n_heads * self.head_dim,
136
- args.dim,
137
- bias=False,
138
- input_is_parallel=True,
139
- init_method=lambda x: x,
140
- )
141
-
142
- self.cache_k = torch.zeros(
143
- (
144
- args.max_batch_size,
145
- args.max_seq_len,
146
- self.n_local_kv_heads,
147
- self.head_dim,
148
- )
149
- )
150
- self.cache_v = torch.zeros(
151
- (
152
- args.max_batch_size,
153
- args.max_seq_len,
154
- self.n_local_kv_heads,
155
- self.head_dim,
156
- )
157
- )
158
-
159
- def forward(
160
- self,
161
- x: torch.Tensor,
162
- start_pos: int,
163
- freqs_cis: torch.Tensor,
164
- mask: torch.Tensor | None,
165
- ):
166
- bsz, seqlen, _ = x.shape
167
- xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
168
-
169
- xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
170
- xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
171
- xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
172
-
173
- xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
174
-
175
- self.cache_k = self.cache_k.to(xq)
176
- self.cache_v = self.cache_v.to(xq)
177
-
178
- self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
179
- self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
180
-
181
- keys = self.cache_k[:bsz, : start_pos + seqlen]
182
- values = self.cache_v[:bsz, : start_pos + seqlen]
183
-
184
- # repeat k/v heads if n_kv_heads < n_heads
185
- keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
186
- values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
187
-
188
- xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
189
- keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
190
- values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
191
- scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
192
- if mask is not None:
193
- scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
194
- scores = F.softmax(scores.float(), dim=-1).type_as(xq)
195
- output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
196
- output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
197
- return self.wo(output)
198
-
199
-
200
- class FeedForward(nn.Module):
201
- def __init__(
202
- self,
203
- dim: int,
204
- hidden_dim: int,
205
- multiple_of: int,
206
- ffn_dim_multiplier: float | None,
207
- ):
208
- super().__init__()
209
- hidden_dim = int(2 * hidden_dim / 3)
210
- # custom dim factor multiplier
211
- if ffn_dim_multiplier is not None:
212
- hidden_dim = int(ffn_dim_multiplier * hidden_dim)
213
- hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
214
-
215
- self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
216
- self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
217
- self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
218
-
219
- def forward(self, x):
220
- return self.w2(F.silu(self.w1(x)) * self.w3(x))
221
-
222
-
223
- class TransformerBlock(nn.Module):
224
- def __init__(self, layer_id: int, args: ModelArgs):
225
- super().__init__()
226
- self.n_heads = args.n_heads
227
- self.dim = args.dim
228
- self.head_dim = args.dim // args.n_heads
229
- self.attention = Attention(args)
230
- self.feed_forward = FeedForward(
231
- dim=args.dim,
232
- hidden_dim=4 * args.dim,
233
- multiple_of=args.multiple_of,
234
- ffn_dim_multiplier=args.ffn_dim_multiplier,
235
- )
236
- self.layer_id = layer_id
237
- self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
238
- self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
239
-
240
- def forward(
241
- self,
242
- x: torch.Tensor,
243
- start_pos: int,
244
- freqs_cis: torch.Tensor,
245
- mask: torch.Tensor | None,
246
- ):
247
- h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
248
- out = h + self.feed_forward(self.ffn_norm(h))
249
- return out
250
-
251
-
252
- class Transformer(nn.Module):
253
- def __init__(self, params: ModelArgs):
254
- super().__init__()
255
- self.params = params
256
- self.vocab_size = params.vocab_size
257
- self.n_layers = params.n_layers
258
-
259
- self.tok_embeddings = VocabParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x)
260
-
261
- self.layers = torch.nn.ModuleList()
262
- for layer_id in range(params.n_layers):
263
- self.layers.append(TransformerBlock(layer_id, params))
264
-
265
- self.norm = RMSNorm(params.dim, eps=params.norm_eps)
266
- self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)
267
-
268
- self.freqs_cis = precompute_freqs_cis(
269
- params.dim // params.n_heads,
270
- params.max_seq_len * 2,
271
- params.rope_theta,
272
- params.use_scaled_rope,
273
- )
274
-
275
- @torch.inference_mode()
276
- def forward(self, tokens: torch.Tensor, start_pos: int):
277
- _bsz, seqlen = tokens.shape
278
- h = self.tok_embeddings(tokens)
279
- self.freqs_cis = self.freqs_cis.to(h.device)
280
- freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
281
-
282
- mask = None
283
- if seqlen > 1:
284
- mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
285
-
286
- mask = torch.triu(mask, diagonal=1)
287
-
288
- # https://github.com/pytorch/pytorch/issues/100005
289
- # torch.triu is buggy when the device is mps: filled values are
290
- # nan instead of 0.
291
- if mask.device.type == torch.device("mps").type:
292
- mask = torch.nan_to_num(mask, nan=0.0)
293
-
294
- # When performing key-value caching, we compute the attention scores
295
- # only for the new sequence. Thus, the matrix of scores is of size
296
- # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
297
- # j > cache_len + i, since row i corresponds to token cache_len + i.
298
- mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
299
-
300
- for layer in self.layers:
301
- h = layer(h, start_pos, freqs_cis, mask)
302
- h = self.norm(h)
303
- output = self.output(h).float()
304
- return output
@@ -1,12 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # Copyright (c) Meta Platforms, Inc. and affiliates.
8
- # All rights reserved.
9
- #
10
- # This source code is licensed under the terms described in the LICENSE file in
11
- # top-level folder for each specific model found within the models/ directory at
12
- # the top-level of this source tree.
@@ -1,180 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # Copyright (c) Meta Platforms, Inc. and affiliates.
8
- # All rights reserved.
9
- #
10
- # This source code is licensed under the terms described in the LICENSE file in
11
- # top-level folder for each specific model found within the models/ directory at
12
- # the top-level of this source tree.
13
-
14
- # Copyright (c) Meta Platforms, Inc. and its affiliates.
15
- import math
16
-
17
- import torch
18
- import torch.nn.functional as F
19
-
20
- from llama_stack.log import get_logger
21
-
22
- from .utils import get_negative_inf_value, to_2tuple
23
-
24
- logger = get_logger(name=__name__, category="models::llama")
25
-
26
-
27
- def resize_local_position_embedding(orig_pos_embed, grid_size):
28
- """
29
- Resize position embedding for vision encoder.
30
- Original position embedding is [n_tiles * n_tiles + 1, dim]
31
- New position embedding will be [grid_size[0] * grid_size[1] + 1, dim]
32
- """
33
- new_grid_size = to_2tuple(grid_size)
34
- orig_grid_size = to_2tuple(int(math.sqrt(len(orig_pos_embed) - 1)))
35
-
36
- new_pos_emb_tok, new_pos_emb_img = (
37
- orig_pos_embed[:1],
38
- orig_pos_embed[1:],
39
- )
40
- logger.info(f"resizing position embedding grid-size from {orig_grid_size} to {new_grid_size}")
41
-
42
- new_pos_emb_img = new_pos_emb_img.reshape(1, orig_grid_size[0], orig_grid_size[1], -1).permute(0, 3, 1, 2)
43
-
44
- new_pos_emb_img = F.interpolate(
45
- new_pos_emb_img,
46
- size=new_grid_size,
47
- mode="bilinear",
48
- align_corners=True,
49
- )
50
- new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1).reshape(1, new_grid_size[0] * new_grid_size[1], -1)[0]
51
- new_pos_embed = torch.cat([new_pos_emb_tok, new_pos_emb_img], dim=0)
52
- return new_pos_embed
53
-
54
-
55
- def initialize_global_position_embedding_from_local(pos_and_cls_embed, grid_size, x_scale, y_scale):
56
- """
57
- Takes a local position embedding for vision encoder and uses it
58
- to initialize the global position embedding.
59
- Input: local position embedding of shape [grid_size[0] * grid_size[1] + 1, dim]
60
- Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
61
- Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
62
- """
63
- pos_embed = pos_and_cls_embed[1:]
64
- cls_embed = pos_and_cls_embed[0].view(1, 1, 1, -1)
65
- grid_size = to_2tuple(grid_size)
66
- new_pos_emb_img = pos_embed.reshape(1, grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2)
67
- new_grid_size = (x_scale * grid_size[0], y_scale * grid_size[1])
68
- new_pos_emb_img = F.interpolate(
69
- new_pos_emb_img,
70
- size=new_grid_size,
71
- mode="bilinear",
72
- align_corners=True,
73
- )
74
- new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1)
75
- new_pos_emb_img = new_pos_emb_img.view(x_scale, grid_size[0], y_scale, grid_size[1], -1)
76
- new_pos_emb_img = new_pos_emb_img.permute(0, 2, 1, 3, 4).contiguous()
77
- new_pos_emb_img = new_pos_emb_img.reshape(x_scale, y_scale, grid_size[0] * grid_size[1], -1)
78
- cls_embed = cls_embed.expand(x_scale, y_scale, -1, -1)
79
- pos_and_cls_embed = torch.cat([cls_embed, new_pos_emb_img], dim=2)
80
- return pos_and_cls_embed
81
-
82
-
83
- def resize_global_position_embedding(pos_and_cls_embed, grid_size, x_scale, y_scale):
84
- """
85
- Takes a global position embedding for vision encoder and resizes it to new size.
86
- Input: global position embedding of shape [x_old, y_old, old_grid_size[0] * old_grid_size[1] + 1, dim]
87
- Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
88
- Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
89
- """
90
- # first remove cls token
91
- pos_embed = pos_and_cls_embed[:, :, 1:]
92
- cls_embed = pos_and_cls_embed[:, :, 0].unsqueeze(2)
93
-
94
- xs_old, ys_old, ntok, dim = pos_embed.shape
95
- old_grid_size = int(math.sqrt(ntok))
96
-
97
- # move to correct form for interpolation
98
- pos_embed = pos_embed.view(xs_old, ys_old, old_grid_size, old_grid_size, dim)
99
- pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
100
- pos_embed = pos_embed.view(xs_old * old_grid_size, ys_old * old_grid_size, dim)
101
- pos_embed = pos_embed.unsqueeze(0)
102
-
103
- # interpolate
104
- new_size = (grid_size[0] * x_scale, grid_size[1] * y_scale)
105
- pos_embed = pos_embed.permute(0, 3, 1, 2)
106
- pos_embed_resized = F.interpolate(
107
- pos_embed,
108
- size=new_size,
109
- mode="bilinear",
110
- align_corners=True,
111
- )
112
- pos_embed = pos_embed_resized.permute(0, 2, 3, 1)[0]
113
-
114
- # move it back in place
115
- pos_embed = pos_embed.view(x_scale, grid_size[0], y_scale, grid_size[1], dim)
116
- pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
117
- pos_embed = pos_embed.view(x_scale, y_scale, grid_size[0] * grid_size[1], dim)
118
-
119
- # interpolate cls token
120
- cls_embed = cls_embed.permute(2, 3, 0, 1)
121
- cls_embed_resized = F.interpolate(
122
- cls_embed,
123
- size=(x_scale, y_scale),
124
- mode="bilinear",
125
- align_corners=True,
126
- )
127
- cls_embed = cls_embed_resized.permute(2, 3, 0, 1)
128
- # add cls token back in
129
- pos_and_cls_embed = torch.cat([cls_embed, pos_embed], dim=2)
130
-
131
- return pos_and_cls_embed
132
-
133
-
134
- def build_encoder_attention_mask(
135
- x: torch.Tensor,
136
- ar: torch.Tensor,
137
- ntok: int,
138
- num_chunks: int,
139
- n_heads: int,
140
- ):
141
- """
142
- Build vision encoder attention mask that omits padding tokens.
143
- """
144
- masks_list: list[torch.Tensor] = []
145
- for arx in ar:
146
- mask_i = torch.ones((num_chunks, x.shape[2], 1), dtype=x.dtype)
147
- mask_i[: arx[0] * arx[1], :ntok] = 0
148
- mask_i = mask_i.view(num_chunks * x.shape[2], -1)
149
- mask_i = mask_i @ mask_i.T * get_negative_inf_value(x.dtype)
150
- mask_i = mask_i.unsqueeze(0)
151
- masks_list.append(mask_i)
152
- masks = torch.stack(masks_list).to(x.device).expand(-1, n_heads, -1, -1)
153
- return masks
154
-
155
-
156
- def expand_num_tokens_to_mult8(x):
157
- num_pad_tokens = 8 - (x.shape[-2] % 8)
158
- if num_pad_tokens == 0:
159
- return x, 0
160
- else:
161
- return (
162
- torch.cat(
163
- [
164
- x,
165
- torch.zeros(
166
- (x.shape[0], x.shape[1], num_pad_tokens, x.shape[-1]),
167
- dtype=x.dtype,
168
- device=x.device,
169
- ),
170
- ],
171
- dim=-2,
172
- ),
173
- num_pad_tokens,
174
- )
175
-
176
-
177
- def contract_num_tokens_from_mult8(x, num_pad_tokens):
178
- if num_pad_tokens == 0:
179
- return x
180
- return x[:, :, :-num_pad_tokens]