llama-stack 0.4.4__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  30. llama_stack/distributions/nvidia/config.yaml +4 -1
  31. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  32. llama_stack/distributions/oci/config.yaml +4 -1
  33. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  34. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  35. llama_stack/distributions/starter/build.yaml +62 -0
  36. llama_stack/distributions/starter/config.yaml +22 -3
  37. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  38. llama_stack/distributions/starter/starter.py +13 -1
  39. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  40. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  41. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  42. llama_stack/distributions/template.py +10 -2
  43. llama_stack/distributions/watsonx/config.yaml +4 -1
  44. llama_stack/log.py +1 -0
  45. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  46. llama_stack/providers/inline/agents/meta_reference/agents.py +57 -61
  47. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +49 -51
  48. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +94 -22
  49. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  50. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  51. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  52. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  53. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  54. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  55. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  56. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  57. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  58. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  59. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +15 -18
  60. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  61. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  62. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  63. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  64. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  65. llama_stack/providers/registry/agents.py +1 -0
  66. llama_stack/providers/registry/inference.py +1 -9
  67. llama_stack/providers/registry/vector_io.py +136 -16
  68. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  69. llama_stack/providers/remote/files/s3/config.py +5 -3
  70. llama_stack/providers/remote/files/s3/files.py +2 -2
  71. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  72. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  73. llama_stack/providers/remote/inference/together/together.py +4 -0
  74. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  75. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  76. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  77. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  78. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  79. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  80. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  81. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  82. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  83. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  84. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  85. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  86. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  87. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  88. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  89. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  90. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  91. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  92. llama_stack/providers/utils/bedrock/client.py +3 -3
  93. llama_stack/providers/utils/bedrock/config.py +7 -7
  94. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  95. llama_stack/providers/utils/inference/http_client.py +239 -0
  96. llama_stack/providers/utils/inference/litellm_openai_mixin.py +5 -0
  97. llama_stack/providers/utils/inference/model_registry.py +148 -2
  98. llama_stack/providers/utils/inference/openai_compat.py +2 -1
  99. llama_stack/providers/utils/inference/openai_mixin.py +41 -2
  100. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  101. llama_stack/providers/utils/memory/vector_store.py +46 -19
  102. llama_stack/providers/utils/responses/responses_store.py +7 -7
  103. llama_stack/providers/utils/safety.py +114 -0
  104. llama_stack/providers/utils/tools/mcp.py +44 -3
  105. llama_stack/testing/api_recorder.py +9 -3
  106. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/METADATA +14 -2
  107. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/RECORD +111 -144
  108. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  109. llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
  110. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  111. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  112. llama_stack/models/llama/hadamard_utils.py +0 -88
  113. llama_stack/models/llama/llama3/args.py +0 -74
  114. llama_stack/models/llama/llama3/dog.jpg +0 -0
  115. llama_stack/models/llama/llama3/generation.py +0 -378
  116. llama_stack/models/llama/llama3/model.py +0 -304
  117. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  118. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  119. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  120. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  121. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  122. llama_stack/models/llama/llama3/pasta.jpeg +0 -0
  123. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  124. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  125. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  126. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  127. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  128. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  129. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  130. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  131. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  132. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  133. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  134. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  135. llama_stack/models/llama/llama4/args.py +0 -107
  136. llama_stack/models/llama/llama4/ffn.py +0 -58
  137. llama_stack/models/llama/llama4/moe.py +0 -214
  138. llama_stack/models/llama/llama4/preprocess.py +0 -435
  139. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  140. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  141. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  142. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  143. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  144. llama_stack/models/llama/quantize_impls.py +0 -316
  145. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  146. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  147. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  148. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  149. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  150. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  151. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  152. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/WHEEL +0 -0
  153. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/entry_points.txt +0 -0
  154. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  155. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,316 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- # type: ignore
8
- import collections
9
-
10
- from llama_stack.log import get_logger
11
-
12
- log = get_logger(name=__name__, category="models::llama")
13
-
14
- try:
15
- import fbgemm_gpu.experimental.gen_ai # noqa: F401
16
-
17
- log.info("Using efficient FP8 or INT4 operators in FBGEMM.")
18
- except ImportError:
19
- log.error("No efficient FP8 or INT4 operators. Please install FBGEMM.")
20
- raise
21
-
22
- import torch
23
- from torch import Tensor, nn
24
-
25
-
26
- class Fp8ScaledWeights:
27
- # TODO: Ugly trick so torch allows us to replace parameters
28
- # with our custom Fp8Weights instance. Do this properly.
29
- @property
30
- def __class__(self) -> type[nn.parameter.Parameter]:
31
- return nn.Parameter
32
-
33
- @property
34
- def grad_fn(self) -> None:
35
- return None
36
-
37
-
38
- # pyre-fixme[4]: Attribute annotation cannot be `Any`.
39
- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
40
- class Fp8RowwiseWeights(
41
- Fp8ScaledWeights,
42
- collections.namedtuple(
43
- "Fp8RowwiseWeights",
44
- ["weight", "scale", "shape", "activation_scale_ub"],
45
- ),
46
- ):
47
- pass
48
-
49
-
50
- class Int4ScaledWeights:
51
- # TODO: Ugly trick so torch allows us to replace parameters
52
- # with our custom Int4Weights instance. Do this properly.
53
- @property
54
- def __class__(self) -> type[nn.parameter.Parameter]:
55
- return nn.Parameter
56
-
57
- @property
58
- def grad_fn(self) -> None:
59
- return None
60
-
61
-
62
- # pyre-fixme[4]: Attribute annotation cannot be `Any`.
63
- # pyre-fixme[2]: Parameter annotation cannot be `Any`.
64
- class Int4Weights(
65
- Int4ScaledWeights,
66
- collections.namedtuple(
67
- "Int4Weights",
68
- ["weight", "scale", "zero_point", "shape"],
69
- ),
70
- ):
71
- pass
72
-
73
-
74
- def int4_row_quantize(
75
- x: torch.Tensor,
76
- group_size: int = 128,
77
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
78
- n_bit = 4 # Number of target bits.
79
- to_quant = x.reshape(-1, group_size).to(torch.float)
80
-
81
- max_val = to_quant.amax(dim=1, keepdim=True)
82
- min_val = to_quant.amin(dim=1, keepdim=True)
83
- max_int = 2**n_bit - 1
84
- min_int = 0
85
- scales = (max_val - min_val).clamp(min=1e-6) / max_int
86
-
87
- zeros = min_val + scales * (2 ** (n_bit - 1))
88
-
89
- out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
90
-
91
- # Recenter output and move to int8.
92
- out = (out - 2 ** (n_bit - 1)).to(dtype=torch.int8).reshape(x.shape)
93
-
94
- # Cutlass expects column major layout for scale and zero point,
95
- # so we transpose here and make them contiguous.
96
- scales = scales.view(x.shape[0], -1).t().contiguous()
97
- zeros = zeros.view(x.shape[0], -1).t().contiguous()
98
-
99
- return out, scales, zeros
100
-
101
-
102
- def pack_int4(x: torch.Tensor) -> torch.Tensor:
103
- # Given int8 x, pack adjacent int4 values into a single int8.
104
- low_x = x[:, ::2]
105
- high_x = x[:, 1::2]
106
-
107
- # High bits need to left shift, this also masks off extra bits.
108
- high_x = torch.bitwise_left_shift(high_x, 4)
109
- # Low bits need to have sign bits removed.
110
- low_x = torch.bitwise_and(low_x, 0xF)
111
-
112
- # Recombine into a single value with bitwise or.
113
- return torch.bitwise_or(low_x, high_x).contiguous()
114
-
115
-
116
- def bmm_nt(
117
- x: Tensor,
118
- w: Fp8RowwiseWeights | Int4Weights,
119
- num_tokens: Tensor | None = None,
120
- ) -> Tensor:
121
- if isinstance(w, Fp8ScaledWeights):
122
- xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, w.activation_scale_ub)
123
- return torch.ops.fbgemm.f8f8bf16_rowwise_batched(xq, w.weight, x_scale, w.scale)
124
- elif isinstance(w, Int4ScaledWeights):
125
- return torch.ops.fbgemm.bf16i4bf16_rowwise_batched(x, w.weight, w.scale, w.zero_point)
126
- else:
127
- raise ValueError("Unsupported quantization type")
128
-
129
-
130
- def ffn_swiglu(
131
- x: Tensor,
132
- w1: Fp8RowwiseWeights | Int4Weights,
133
- w3: Fp8RowwiseWeights | Int4Weights,
134
- w2: Fp8RowwiseWeights | Int4Weights,
135
- num_tokens: Tensor | None = None,
136
- is_memory_bounded: bool = False,
137
- ) -> Tensor:
138
- if (isinstance(w1, Fp8ScaledWeights) and isinstance(w3, Fp8ScaledWeights) and isinstance(w2, Fp8ScaledWeights)) or (
139
- isinstance(w1, Int4ScaledWeights) and isinstance(w3, Int4ScaledWeights) and isinstance(w2, Int4ScaledWeights)
140
- ):
141
- return ffn_swiglu_dynamic(x, w1, w3, w2, w1.activation_scale_ub, num_tokens, is_memory_bounded)
142
-
143
- (B, T, D) = x.shape # noqa: N806
144
- (HD_L, D_) = w1.shape # noqa: N806
145
- assert D_ == D
146
-
147
- assert isinstance(w1, Tensor)
148
- assert isinstance(w3, Tensor)
149
- x1 = x.view(B * T, D) @ w1.T
150
- x2 = x.view(B * T, D) @ w3.T
151
- z = torch.nn.functional.silu(x1) * x2
152
- del x1, x2
153
- assert isinstance(w2, Tensor)
154
- return (z @ w2.T).view(B, T, D)
155
-
156
-
157
- @torch.inference_mode()
158
- def quantize_fp8(
159
- w: Tensor,
160
- fp8_activation_scale_ub: float,
161
- output_device: torch.device | None = None,
162
- ) -> Fp8RowwiseWeights:
163
- """Quantize [n, k] weight tensor.
164
-
165
- Args:
166
- w (Tensor): [n, k] input high precision tensor to quantize.
167
- fp8_activation_scale_ub (float): Upper bound for activation max.
168
- """
169
- activation_scale_ub = torch.tensor(
170
- [fp8_activation_scale_ub],
171
- dtype=torch.float,
172
- device=output_device,
173
- )
174
- wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
175
- del w
176
- return Fp8RowwiseWeights(
177
- weight=wq,
178
- scale=w_scale,
179
- shape=wq.shape,
180
- activation_scale_ub=activation_scale_ub,
181
- )
182
-
183
-
184
- @torch.inference_mode()
185
- def quantize_int4(
186
- w: Tensor,
187
- output_device: torch.device | None = None,
188
- ) -> Int4Weights:
189
- """Quantize [n, k/2] weight tensor.
190
-
191
- Args:
192
- w (Tensor): [n, k/2] input high precision tensor to quantize.
193
- """
194
- if w.ndim >= 3:
195
- wq, scale, zero_point = zip(*[int4_row_quantize(i) for i in w], strict=False)
196
- wq = torch.stack([pack_int4(i) for i in wq], dim=0)
197
- scale = torch.stack(scale, dim=0)
198
- zero_point = torch.stack(zero_point, dim=0)
199
- else:
200
- wq, scale, zero_point = int4_row_quantize(w)
201
- wq = pack_int4(wq)
202
- del w
203
- return Int4Weights(
204
- weight=wq.to(output_device),
205
- scale=scale.to(output_device),
206
- zero_point=zero_point.to(output_device),
207
- shape=wq.shape,
208
- )
209
-
210
-
211
- @torch.inference_mode()
212
- def load_fp8(
213
- w: Tensor,
214
- w_scale: Tensor,
215
- fp8_activation_scale_ub: float,
216
- output_device: torch.device | None = None,
217
- ) -> Fp8RowwiseWeights:
218
- """Load FP8 [n, k] weight tensor.
219
-
220
- Args:
221
- w (Tensor): [n, k] input FP8.
222
- fp8_activation_scale_ub (float): Upper bound for activation max.
223
- """
224
- activation_scale_ub = torch.tensor(
225
- [fp8_activation_scale_ub],
226
- dtype=torch.float,
227
- device=output_device,
228
- )
229
- return Fp8RowwiseWeights(
230
- weight=w.to(torch.float8_e4m3fn).to(device=output_device),
231
- scale=w_scale.to(device=output_device),
232
- shape=w.shape,
233
- activation_scale_ub=activation_scale_ub,
234
- )
235
-
236
-
237
- @torch.inference_mode()
238
- def load_int4(
239
- w: Tensor,
240
- scale: Tensor,
241
- zero_point: Tensor,
242
- output_device: torch.device | None = None,
243
- ) -> Int4Weights:
244
- """Load INT4 [n, k/2] weight tensor.
245
-
246
- Args:
247
- w (Tensor): [n, k/2] input INT4.
248
- """
249
- return Int4Weights(
250
- weight=w.to(torch.int8).to(device=output_device),
251
- scale=scale.to(device=output_device),
252
- zero_point=zero_point.to(device=output_device),
253
- shape=w.shape,
254
- )
255
-
256
-
257
- def fc_dynamic(
258
- x: Tensor,
259
- w: Fp8RowwiseWeights | Int4Weights,
260
- activation_scale_ub: Tensor | None = None,
261
- num_tokens: Tensor | None = None,
262
- is_memory_bounded: bool = False,
263
- ) -> Tensor:
264
- """
265
- Single w8a8 fc layer with dynamic row-wise scaling, or w4a16 fc layer with dyanmic row-wise scaling
266
- """
267
- if isinstance(w, Int4Weights):
268
- y = torch.ops.fbgemm.bf16i4bf16_rowwise(x, w.weight, w.scale, w.zero_point)
269
- else:
270
- xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(x, num_tokens, activation_scale_ub)
271
- y = torch.ops.fbgemm.f8f8bf16_rowwise(xq, w.weight, x_scale, w.scale, use_fast_accum=True)
272
- del xq
273
- return y
274
-
275
-
276
- def ffn_swiglu_dynamic(
277
- x: Tensor,
278
- w1: Fp8RowwiseWeights | Int4Weights,
279
- w3: Fp8RowwiseWeights | Int4Weights,
280
- w2: Fp8RowwiseWeights | Int4Weights,
281
- activation_scale_ub: Tensor | None = None,
282
- num_tokens: Tensor | None = None,
283
- is_memory_bounded: bool = False,
284
- ) -> Tensor:
285
- assert x.dim() == 3 or x.dim() == 2
286
- if x.dim() == 3:
287
- (B, T, D) = x.shape # noqa: N806
288
- else:
289
- (T, D) = x.shape # noqa: N806
290
- B = 1 # noqa: N806
291
-
292
- HD_L = w1.shape[0] # noqa: N806
293
- assert HD_L == w3.shape[0]
294
- x1 = fc_dynamic(
295
- x.view(B * T, D),
296
- w1,
297
- activation_scale_ub,
298
- num_tokens,
299
- is_memory_bounded,
300
- )
301
- x2 = fc_dynamic(
302
- x.view(B * T, D),
303
- w3,
304
- activation_scale_ub,
305
- num_tokens,
306
- is_memory_bounded,
307
- )
308
- z = torch.nn.functional.silu(x1) * x2
309
- del x1, x2
310
-
311
- z_ = fc_dynamic(z, w2, activation_scale_ub, num_tokens, is_memory_bounded)
312
-
313
- if x.dim() == 3:
314
- return z_.view(B, T, D)
315
- else:
316
- return z_
@@ -1,20 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- from typing import Any
8
-
9
- from .config import MetaReferenceInferenceConfig
10
-
11
-
12
- async def get_provider_impl(
13
- config: MetaReferenceInferenceConfig,
14
- _deps: dict[str, Any],
15
- ):
16
- from .inference import MetaReferenceInferenceImpl
17
-
18
- impl = MetaReferenceInferenceImpl(config)
19
- await impl.initialize()
20
- return impl
@@ -1,24 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- from pathlib import Path
8
-
9
- from llama_stack.core.utils.model_utils import model_local_dir
10
-
11
-
12
- def model_checkpoint_dir(model_id) -> str:
13
- checkpoint_dir = Path(model_local_dir(model_id))
14
-
15
- paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
16
- if not any(p.exists() for p in paths):
17
- checkpoint_dir = checkpoint_dir / "original"
18
-
19
- assert checkpoint_dir.exists(), (
20
- f"Could not find checkpoints in: {model_local_dir(model_id)}. "
21
- f"If you try to use the native llama model, please download the model using `llama-model download --source meta --model-id {model_id}` (see https://github.com/meta-llama/llama-models). "
22
- f"Otherwise, please save your model checkpoint under {model_local_dir(model_id)}"
23
- )
24
- return str(checkpoint_dir)
@@ -1,68 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- from typing import Any
8
-
9
- from pydantic import BaseModel, field_validator
10
-
11
- from llama_stack.providers.utils.inference import supported_inference_models
12
- from llama_stack_api import QuantizationConfig
13
-
14
-
15
- class MetaReferenceInferenceConfig(BaseModel):
16
- # this is a placeholder to indicate inference model id
17
- # the actual inference model id is dtermined by the moddel id in the request
18
- # Note: you need to register the model before using it for inference
19
- # models in the resouce list in the config.yaml config will be registered automatically
20
- model: str | None = None
21
- torch_seed: int | None = None
22
- max_seq_len: int = 4096
23
- max_batch_size: int = 1
24
- model_parallel_size: int | None = None
25
-
26
- # when this is False, we assume that the distributed process group is setup by someone
27
- # outside of this code (e.g., when run inside `torchrun`). that is useful for clients
28
- # (including our testing code) who might be using llama-stack as a library.
29
- create_distributed_process_group: bool = True
30
-
31
- # By default, the implementation will look at ~/.llama/checkpoints/<model> but you
32
- # can override by specifying the directory explicitly
33
- checkpoint_dir: str | None = None
34
-
35
- quantization: QuantizationConfig | None = None
36
-
37
- @field_validator("model")
38
- @classmethod
39
- def validate_model(cls, model: str) -> str:
40
- permitted_models = supported_inference_models()
41
- descriptors = [m.descriptor() for m in permitted_models]
42
- repos = [m.huggingface_repo for m in permitted_models if m.huggingface_repo is not None]
43
- if model not in (descriptors + repos):
44
- model_list = "\n\t".join(repos)
45
- raise ValueError(f"Unknown model: `{model}`. Choose from [\n\t{model_list}\n]")
46
- return model
47
-
48
- @classmethod
49
- def sample_run_config(
50
- cls,
51
- model: str = "Llama3.2-3B-Instruct",
52
- checkpoint_dir: str = "${env.CHECKPOINT_DIR:=null}",
53
- quantization_type: str = "${env.QUANTIZATION_TYPE:=bf16}",
54
- model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:=0}",
55
- max_batch_size: str = "${env.MAX_BATCH_SIZE:=1}",
56
- max_seq_len: str = "${env.MAX_SEQ_LEN:=4096}",
57
- **kwargs,
58
- ) -> dict[str, Any]:
59
- return {
60
- "model": model,
61
- "checkpoint_dir": checkpoint_dir,
62
- "quantization": {
63
- "type": quantization_type,
64
- },
65
- "model_parallel_size": model_parallel_size,
66
- "max_batch_size": max_batch_size,
67
- "max_seq_len": max_seq_len,
68
- }
@@ -1,201 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- import math
8
- from typing import Optional
9
-
10
- import torch
11
- from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
12
-
13
- from llama_stack.models.llama.datatypes import QuantizationMode, ToolPromptFormat
14
- from llama_stack.models.llama.llama3.generation import Llama3
15
- from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
16
- from llama_stack.models.llama.llama4.generation import Llama4
17
- from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
18
- from llama_stack.models.llama.sku_types import Model, ModelFamily
19
- from llama_stack_api import (
20
- GreedySamplingStrategy,
21
- JsonSchemaResponseFormat,
22
- OpenAIChatCompletionRequestWithExtraBody,
23
- OpenAIResponseFormatJSONSchema,
24
- ResponseFormat,
25
- ResponseFormatType,
26
- SamplingParams,
27
- TopPSamplingStrategy,
28
- )
29
-
30
- from .common import model_checkpoint_dir
31
- from .config import MetaReferenceInferenceConfig
32
- from .inference import resolve_model
33
-
34
- Tokenizer = Llama4Tokenizer | Llama3Tokenizer
35
-
36
-
37
- class LogitsProcessor:
38
- def __init__(self, token_enforcer: TokenEnforcer):
39
- self.token_enforcer = token_enforcer
40
- self.mask: torch.Tensor | None = None
41
-
42
- def __call__(self, tokens: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
43
- token_sequence = tokens[0, :].tolist()
44
- allowed_tokens = self.token_enforcer.get_allowed_tokens(token_sequence)
45
-
46
- if self.mask is not None:
47
- self.mask.fill_(-math.inf)
48
- else:
49
- self.mask = torch.full_like(scores, -math.inf)
50
-
51
- self.mask[:, :, allowed_tokens] = 0
52
- scores = scores + self.mask
53
- return scores
54
-
55
-
56
- def get_logits_processor(
57
- tokenizer: Tokenizer,
58
- vocab_size: int,
59
- response_format: ResponseFormat | None,
60
- ) -> Optional["LogitsProcessor"]:
61
- if response_format is None:
62
- return None
63
-
64
- if not isinstance(response_format, JsonSchemaResponseFormat):
65
- raise ValueError(f"Unsupported response format type {response_format.type}")
66
-
67
- parser = JsonSchemaParser(response_format.json_schema)
68
- data = TokenEnforcerTokenizerData(
69
- _build_regular_tokens_list(tokenizer, vocab_size),
70
- tokenizer.decode,
71
- tokenizer.stop_tokens,
72
- )
73
- token_enforcer = TokenEnforcer(data, parser)
74
- return LogitsProcessor(token_enforcer)
75
-
76
-
77
- def _build_regular_tokens_list(tokenizer: Tokenizer, vocab_size: int) -> list[tuple[int, str, bool]]:
78
- token_0 = tokenizer.encode("0", bos=False, eos=False)[-1]
79
- regular_tokens = []
80
-
81
- special_token_ids = set(tokenizer.special_tokens.values())
82
- for token_idx in range(vocab_size):
83
- if token_idx in special_token_ids:
84
- continue
85
-
86
- # We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
87
- decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
88
- decoded_regular = tokenizer.decode([token_idx])
89
- is_word_start_token = len(decoded_after_0) > len(decoded_regular)
90
- regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
91
- return regular_tokens
92
-
93
-
94
- def _infer_sampling_params(sampling_params: SamplingParams):
95
- if isinstance(sampling_params.strategy, GreedySamplingStrategy):
96
- temperature = 0.0
97
- top_p = 1.0
98
- elif isinstance(sampling_params.strategy, TopPSamplingStrategy):
99
- temperature = sampling_params.strategy.temperature or 1.0
100
- top_p = sampling_params.strategy.top_p or 1.0
101
- else:
102
- raise ValueError(f"Unsupported sampling strategy {sampling_params.strategy}")
103
- return temperature, top_p
104
-
105
-
106
- class LlamaGenerator:
107
- def __init__(
108
- self,
109
- config: MetaReferenceInferenceConfig,
110
- model_id: str,
111
- llama_model: Model,
112
- ):
113
- if config.checkpoint_dir and config.checkpoint_dir != "null":
114
- ckpt_dir = config.checkpoint_dir
115
- else:
116
- resolved_model = resolve_model(model_id)
117
- if resolved_model is None:
118
- # if the model is not a native llama model, get the default checkpoint_dir based on model id
119
- ckpt_dir = model_checkpoint_dir(model_id)
120
- else:
121
- # if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
122
- ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
123
-
124
- if config.quantization:
125
- if config.quantization.type == "fp8_mixed":
126
- quantization_mode = QuantizationMode.fp8_mixed
127
- elif config.quantization.type == "int4_mixed":
128
- quantization_mode = QuantizationMode.int4_mixed
129
- elif config.quantization.type == "bf16":
130
- quantization_mode = None
131
- else:
132
- raise ValueError(f"Unsupported quantization mode {config.quantization}")
133
- else:
134
- quantization_mode = None
135
-
136
- cls = Llama4 if llama_model.model_family == ModelFamily.llama4 else Llama3
137
- self.inner_generator = cls.build(
138
- ckpt_dir=ckpt_dir,
139
- max_seq_len=config.max_seq_len,
140
- max_batch_size=config.max_batch_size,
141
- world_size=config.model_parallel_size or llama_model.pth_file_count,
142
- quantization_mode=quantization_mode,
143
- )
144
-
145
- self.tokenizer = self.inner_generator.tokenizer
146
- self.args = self.inner_generator.args
147
- self.formatter = self.inner_generator.formatter
148
-
149
- def chat_completion(
150
- self,
151
- request: OpenAIChatCompletionRequestWithExtraBody,
152
- raw_messages: list,
153
- ):
154
- """Generate chat completion using OpenAI request format.
155
-
156
- Args:
157
- request: OpenAI chat completion request
158
- raw_messages: Pre-converted list of RawMessage objects
159
- """
160
-
161
- # Determine tool prompt format
162
- tool_prompt_format = ToolPromptFormat.json if request.tools else ToolPromptFormat.json
163
-
164
- # Prepare sampling params
165
- sampling_params = SamplingParams()
166
- if request.temperature is not None or request.top_p is not None:
167
- sampling_params.strategy = TopPSamplingStrategy(
168
- temperature=request.temperature if request.temperature is not None else 1.0,
169
- top_p=request.top_p if request.top_p is not None else 1.0,
170
- )
171
- if request.max_tokens:
172
- sampling_params.max_tokens = request.max_tokens
173
-
174
- max_gen_len = sampling_params.max_tokens
175
- if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
176
- max_gen_len = self.args.max_seq_len - 1
177
-
178
- temperature, top_p = _infer_sampling_params(sampling_params)
179
-
180
- # Get logits processor for response format
181
- logits_processor = None
182
- if request.response_format:
183
- if isinstance(request.response_format, OpenAIResponseFormatJSONSchema):
184
- # Extract the actual schema from OpenAIJSONSchema TypedDict
185
- schema_dict = request.response_format.json_schema.get("schema") or {}
186
- json_schema_format = JsonSchemaResponseFormat(
187
- type=ResponseFormatType.json_schema,
188
- json_schema=schema_dict,
189
- )
190
- logits_processor = get_logits_processor(self.tokenizer, self.args.vocab_size, json_schema_format)
191
-
192
- # Generate
193
- yield from self.inner_generator.generate(
194
- llm_inputs=[self.formatter.encode_dialog_prompt(raw_messages, tool_prompt_format)],
195
- max_gen_len=max_gen_len,
196
- temperature=temperature,
197
- top_p=top_p,
198
- logprobs=False,
199
- echo=False,
200
- logits_processor=logits_processor,
201
- )