llama-stack 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  30. llama_stack/distributions/nvidia/config.yaml +4 -1
  31. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  32. llama_stack/distributions/oci/config.yaml +4 -1
  33. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  34. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  35. llama_stack/distributions/starter/build.yaml +62 -0
  36. llama_stack/distributions/starter/config.yaml +22 -3
  37. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  38. llama_stack/distributions/starter/starter.py +13 -1
  39. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  40. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  41. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  42. llama_stack/distributions/template.py +10 -2
  43. llama_stack/distributions/watsonx/config.yaml +4 -1
  44. llama_stack/log.py +1 -0
  45. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  46. llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
  47. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +53 -51
  48. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
  49. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  50. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  51. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  52. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  53. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  54. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  55. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  56. llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
  57. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  58. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  59. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  60. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
  61. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  62. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  63. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  64. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  65. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  66. llama_stack/providers/registry/agents.py +1 -0
  67. llama_stack/providers/registry/inference.py +1 -9
  68. llama_stack/providers/registry/vector_io.py +136 -16
  69. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  70. llama_stack/providers/remote/files/s3/config.py +5 -3
  71. llama_stack/providers/remote/files/s3/files.py +2 -2
  72. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  73. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  74. llama_stack/providers/remote/inference/together/together.py +4 -0
  75. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  76. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  77. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  78. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  79. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  80. llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
  81. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  82. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  83. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  84. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  85. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  86. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  87. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  88. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  89. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  90. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  91. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  92. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  93. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  94. llama_stack/providers/utils/bedrock/client.py +3 -3
  95. llama_stack/providers/utils/bedrock/config.py +7 -7
  96. llama_stack/providers/utils/inference/__init__.py +0 -25
  97. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  98. llama_stack/providers/utils/inference/http_client.py +239 -0
  99. llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
  100. llama_stack/providers/utils/inference/model_registry.py +148 -2
  101. llama_stack/providers/utils/inference/openai_compat.py +1 -158
  102. llama_stack/providers/utils/inference/openai_mixin.py +42 -2
  103. llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
  104. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  105. llama_stack/providers/utils/memory/vector_store.py +46 -19
  106. llama_stack/providers/utils/responses/responses_store.py +7 -7
  107. llama_stack/providers/utils/safety.py +114 -0
  108. llama_stack/providers/utils/tools/mcp.py +44 -3
  109. llama_stack/testing/api_recorder.py +9 -3
  110. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
  111. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/RECORD +115 -148
  112. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  113. llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
  114. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  115. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  116. llama_stack/models/llama/hadamard_utils.py +0 -88
  117. llama_stack/models/llama/llama3/args.py +0 -74
  118. llama_stack/models/llama/llama3/dog.jpg +0 -0
  119. llama_stack/models/llama/llama3/generation.py +0 -378
  120. llama_stack/models/llama/llama3/model.py +0 -304
  121. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  122. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  123. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  124. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  125. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  126. llama_stack/models/llama/llama3/pasta.jpeg +0 -0
  127. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  128. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  129. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  130. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  131. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  132. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  133. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  134. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  135. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  136. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  137. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  138. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  139. llama_stack/models/llama/llama4/args.py +0 -107
  140. llama_stack/models/llama/llama4/ffn.py +0 -58
  141. llama_stack/models/llama/llama4/moe.py +0 -214
  142. llama_stack/models/llama/llama4/preprocess.py +0 -435
  143. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  144. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  145. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  146. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  147. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  148. llama_stack/models/llama/quantize_impls.py +0 -316
  149. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  150. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  151. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  152. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  153. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  154. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  155. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  156. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
  157. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
  158. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
  159. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/top_level.txt +0 -0
@@ -1,542 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- import asyncio
8
- import time
9
- import uuid
10
- from collections.abc import AsyncIterator
11
-
12
- from llama_stack.log import get_logger
13
- from llama_stack.models.llama.datatypes import RawMessage, RawTextItem, ToolDefinition
14
- from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
15
- from llama_stack.models.llama.llama3.prompt_templates import (
16
- JsonCustomToolGenerator,
17
- SystemDefaultGenerator,
18
- )
19
- from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
20
- from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
21
- from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
22
- PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
23
- )
24
- from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
25
- from llama_stack.models.llama.sku_list import resolve_model
26
- from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
27
- from llama_stack.providers.utils.inference.embedding_mixin import (
28
- SentenceTransformerEmbeddingMixin,
29
- )
30
- from llama_stack.providers.utils.inference.model_registry import (
31
- ModelRegistryHelper,
32
- build_hf_repo_model_entry,
33
- )
34
- from llama_stack_api import (
35
- InferenceProvider,
36
- Model,
37
- ModelsProtocolPrivate,
38
- ModelType,
39
- OpenAIAssistantMessageParam,
40
- OpenAIChatCompletion,
41
- OpenAIChatCompletionChunk,
42
- OpenAIChatCompletionRequestWithExtraBody,
43
- OpenAIChatCompletionUsage,
44
- OpenAIChoice,
45
- OpenAICompletion,
46
- OpenAICompletionRequestWithExtraBody,
47
- OpenAIUserMessageParam,
48
- ToolChoice,
49
- )
50
-
51
- from .config import MetaReferenceInferenceConfig
52
- from .generators import LlamaGenerator
53
- from .model_parallel import LlamaModelParallelGenerator
54
-
55
- log = get_logger(__name__, category="inference")
56
- # there's a single model parallel process running serving the model. for now,
57
- # we don't support multiple concurrent requests to this process.
58
- SEMAPHORE = asyncio.Semaphore(1)
59
-
60
-
61
- def _convert_openai_tool_to_tool_definition(tool) -> ToolDefinition:
62
- """Convert OpenAI tool format to ToolDefinition format."""
63
- # OpenAI tools have function.name and function.parameters
64
- return ToolDefinition(
65
- tool_name=tool.function.name,
66
- description=tool.function.description or "",
67
- parameters=tool.function.parameters or {},
68
- )
69
-
70
-
71
- def _get_tool_choice_prompt(tool_choice, tools) -> str:
72
- """Generate prompt text for tool_choice behavior."""
73
- if not tool_choice or tool_choice == ToolChoice.auto or tool_choice == "auto":
74
- return ""
75
- elif tool_choice == ToolChoice.required or tool_choice == "required":
76
- return "You MUST use one of the provided functions/tools to answer the user query."
77
- elif tool_choice == ToolChoice.none or tool_choice == "none":
78
- return ""
79
- else:
80
- # Specific tool specified
81
- return f"You MUST use the tool `{tool_choice}` to answer the user query."
82
-
83
-
84
- def _raw_content_as_str(content) -> str:
85
- """Convert RawContent to string for system messages."""
86
- if isinstance(content, str):
87
- return content
88
- elif isinstance(content, RawTextItem):
89
- return content.text
90
- elif isinstance(content, list):
91
- return "\n".join(_raw_content_as_str(c) for c in content)
92
- else:
93
- return "<media>"
94
-
95
-
96
- def _augment_raw_messages_for_tools_llama_3_1(
97
- raw_messages: list[RawMessage],
98
- tools: list,
99
- tool_choice,
100
- ) -> list[RawMessage]:
101
- """Augment raw messages with tool definitions for Llama 3.1 style models."""
102
- messages = raw_messages.copy()
103
- existing_system_message = None
104
- if messages and messages[0].role == "system":
105
- existing_system_message = messages.pop(0)
106
-
107
- sys_content = ""
108
-
109
- # Add tool definitions first (if present)
110
- if tools:
111
- # Convert OpenAI tools to ToolDefinitions
112
- tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
113
-
114
- # For OpenAI format, all tools are custom (have string names)
115
- tool_gen = JsonCustomToolGenerator()
116
- tool_template = tool_gen.gen(tool_definitions)
117
- sys_content += tool_template.render()
118
- sys_content += "\n"
119
-
120
- # Add default system prompt
121
- default_gen = SystemDefaultGenerator()
122
- default_template = default_gen.gen()
123
- sys_content += default_template.render()
124
-
125
- # Add existing system message if present
126
- if existing_system_message:
127
- sys_content += "\n" + _raw_content_as_str(existing_system_message.content)
128
-
129
- # Add tool choice prompt if needed
130
- if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
131
- sys_content += "\n" + tool_choice_prompt
132
-
133
- # Create new system message
134
- new_system_message = RawMessage(
135
- role="system",
136
- content=[RawTextItem(text=sys_content.strip())],
137
- )
138
-
139
- return [new_system_message] + messages
140
-
141
-
142
- def _augment_raw_messages_for_tools_llama_4(
143
- raw_messages: list[RawMessage],
144
- tools: list,
145
- tool_choice,
146
- ) -> list[RawMessage]:
147
- """Augment raw messages with tool definitions for Llama 4/3.2/3.3 style models."""
148
- messages = raw_messages.copy()
149
- existing_system_message = None
150
- if messages and messages[0].role == "system":
151
- existing_system_message = messages.pop(0)
152
-
153
- sys_content = ""
154
-
155
- # Add tool definitions if present
156
- if tools:
157
- # Convert OpenAI tools to ToolDefinitions
158
- tool_definitions = [_convert_openai_tool_to_tool_definition(t) for t in tools]
159
-
160
- # Use python_list format for Llama 4
161
- tool_gen = PythonListCustomToolGeneratorLlama4()
162
- system_prompt = None
163
- if existing_system_message:
164
- system_prompt = _raw_content_as_str(existing_system_message.content)
165
-
166
- tool_template = tool_gen.gen(tool_definitions, system_prompt)
167
- sys_content = tool_template.render()
168
- elif existing_system_message:
169
- # No tools, just use existing system message
170
- sys_content = _raw_content_as_str(existing_system_message.content)
171
-
172
- # Add tool choice prompt if needed
173
- if tool_choice_prompt := _get_tool_choice_prompt(tool_choice, tools):
174
- sys_content += "\n" + tool_choice_prompt
175
-
176
- if sys_content:
177
- new_system_message = RawMessage(
178
- role="system",
179
- content=[RawTextItem(text=sys_content.strip())],
180
- )
181
- return [new_system_message] + messages
182
-
183
- return messages
184
-
185
-
186
- def augment_raw_messages_for_tools(
187
- raw_messages: list[RawMessage],
188
- params: OpenAIChatCompletionRequestWithExtraBody,
189
- llama_model,
190
- ) -> list[RawMessage]:
191
- """Augment raw messages with tool definitions based on model family."""
192
- if not params.tools:
193
- return raw_messages
194
-
195
- # Determine augmentation strategy based on model family
196
- if llama_model.model_family == ModelFamily.llama3_1 or (
197
- llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
198
- ):
199
- # Llama 3.1 and Llama 3.2 multimodal use JSON format
200
- return _augment_raw_messages_for_tools_llama_3_1(
201
- raw_messages,
202
- params.tools,
203
- params.tool_choice,
204
- )
205
- elif llama_model.model_family in (
206
- ModelFamily.llama3_2,
207
- ModelFamily.llama3_3,
208
- ModelFamily.llama4,
209
- ):
210
- # Llama 3.2/3.3/4 use python_list format
211
- return _augment_raw_messages_for_tools_llama_4(
212
- raw_messages,
213
- params.tools,
214
- params.tool_choice,
215
- )
216
- else:
217
- # Default to Llama 3.1 style
218
- return _augment_raw_messages_for_tools_llama_3_1(
219
- raw_messages,
220
- params.tools,
221
- params.tool_choice,
222
- )
223
-
224
-
225
- def llama_builder_fn(config: MetaReferenceInferenceConfig, model_id: str, llama_model: Model) -> LlamaGenerator:
226
- return LlamaGenerator(config, model_id, llama_model)
227
-
228
-
229
- class MetaReferenceInferenceImpl(
230
- SentenceTransformerEmbeddingMixin,
231
- InferenceProvider,
232
- ModelsProtocolPrivate,
233
- ):
234
- def __init__(self, config: MetaReferenceInferenceConfig) -> None:
235
- self.config = config
236
- self.model_id = None
237
- self.llama_model = None
238
-
239
- async def initialize(self) -> None:
240
- pass
241
-
242
- async def shutdown(self) -> None:
243
- if self.config.create_distributed_process_group:
244
- self.generator.stop()
245
-
246
- async def openai_completion(
247
- self,
248
- params: OpenAICompletionRequestWithExtraBody,
249
- ) -> OpenAICompletion | AsyncIterator[OpenAICompletion]:
250
- raise NotImplementedError("OpenAI completion not supported by meta reference provider")
251
-
252
- async def should_refresh_models(self) -> bool:
253
- return False
254
-
255
- async def list_models(self) -> list[Model] | None:
256
- return None
257
-
258
- async def unregister_model(self, model_id: str) -> None:
259
- pass
260
-
261
- async def register_model(self, model: Model) -> Model:
262
- llama_model = (
263
- resolve_model(model.metadata["llama_model"])
264
- if "llama_model" in model.metadata
265
- else resolve_model(model.identifier)
266
- )
267
- if llama_model is None:
268
- raise ValueError(
269
- "Please make sure your llama_model in model metadata or model identifier is in Llama SKU list"
270
- )
271
-
272
- self.model_registry_helper = ModelRegistryHelper(
273
- [
274
- build_hf_repo_model_entry(
275
- llama_model.descriptor(),
276
- llama_model.core_model_id.value,
277
- )
278
- ],
279
- )
280
- model = await self.model_registry_helper.register_model(model)
281
-
282
- if model.model_type == ModelType.embedding:
283
- self._load_sentence_transformer_model(model.provider_resource_id)
284
-
285
- # TODO: what is this?! you can't really specify skipping via model metadata
286
- # kill this madness
287
- if "skip_load" in model.metadata and model.metadata["skip_load"]:
288
- return model
289
-
290
- await self.load_model(model.identifier, llama_model)
291
- return model
292
-
293
- async def load_model(self, model_id, llama_model) -> None:
294
- log.info(f"Loading model `{model_id}`")
295
-
296
- builder_params = [self.config, model_id, llama_model]
297
-
298
- if self.config.create_distributed_process_group:
299
- self.generator = LlamaModelParallelGenerator(
300
- model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
301
- builder_fn=llama_builder_fn,
302
- builder_params=builder_params,
303
- formatter=(
304
- Llama4ChatFormat(Llama4Tokenizer.get_instance())
305
- if llama_model.model_family == ModelFamily.llama4
306
- else Llama3ChatFormat(Llama3Tokenizer.get_instance())
307
- ),
308
- )
309
- self.generator.start()
310
- else:
311
- self.generator = llama_builder_fn(*builder_params)
312
-
313
- self.model_id = model_id
314
- self.llama_model = llama_model
315
-
316
- log.info("Warming up...")
317
-
318
- await self.openai_chat_completion(
319
- params=OpenAIChatCompletionRequestWithExtraBody(
320
- model=model_id,
321
- messages=[OpenAIUserMessageParam(role="user", content="Hi how are you?")],
322
- max_tokens=20,
323
- )
324
- )
325
- log.info("Warmed up!")
326
-
327
- def check_model(self, request) -> None:
328
- if self.model_id is None or self.llama_model is None:
329
- raise RuntimeError(
330
- "No available model yet, please register your requested model or add your model in the resources first"
331
- )
332
- elif request.model != self.model_id:
333
- raise RuntimeError(f"Model mismatch: request model: {request.model} != loaded model: {self.model_id}")
334
-
335
- async def openai_chat_completion(
336
- self,
337
- params: OpenAIChatCompletionRequestWithExtraBody,
338
- ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
339
- self.check_model(params)
340
-
341
- # Convert OpenAI messages to RawMessages
342
- from llama_stack.models.llama.datatypes import StopReason
343
- from llama_stack.providers.utils.inference.prompt_adapter import (
344
- convert_openai_message_to_raw_message,
345
- decode_assistant_message,
346
- )
347
-
348
- raw_messages = [await convert_openai_message_to_raw_message(msg) for msg in params.messages]
349
-
350
- # Augment messages with tool definitions if tools are present
351
- raw_messages = augment_raw_messages_for_tools(raw_messages, params, self.llama_model)
352
-
353
- # Call generator's chat_completion method (works for both single-GPU and model-parallel)
354
- if isinstance(self.generator, LlamaGenerator):
355
- generator = self.generator.chat_completion(params, raw_messages)
356
- else:
357
- # Model parallel: submit task to process group
358
- generator = self.generator.group.run_inference(("chat_completion", [params, raw_messages]))
359
-
360
- # Check if streaming is requested
361
- if params.stream:
362
- return self._stream_chat_completion(generator, params)
363
-
364
- # Non-streaming: collect all generated text
365
- generated_text = ""
366
- for result_batch in generator:
367
- for result in result_batch:
368
- if not result.ignore_token and result.source == "output":
369
- generated_text += result.text
370
-
371
- # Decode assistant message to extract tool calls and determine stop_reason
372
- # Default to end_of_turn if generation completed normally
373
- decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
374
-
375
- # Convert tool calls to OpenAI format
376
- openai_tool_calls = None
377
- if decoded_message.tool_calls:
378
- from llama_stack_api import (
379
- OpenAIChatCompletionToolCall,
380
- OpenAIChatCompletionToolCallFunction,
381
- )
382
-
383
- openai_tool_calls = [
384
- OpenAIChatCompletionToolCall(
385
- # generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
386
- id=f"call_{uuid.uuid4().hex[:24]}",
387
- type="function",
388
- function=OpenAIChatCompletionToolCallFunction(
389
- name=str(tc.tool_name),
390
- arguments=tc.arguments,
391
- ),
392
- )
393
- for tc in decoded_message.tool_calls
394
- ]
395
-
396
- # Determine finish_reason based on whether tool calls are present
397
- finish_reason = "tool_calls" if openai_tool_calls else "stop"
398
-
399
- # Extract content from decoded message
400
- content = ""
401
- if isinstance(decoded_message.content, str):
402
- content = decoded_message.content
403
- elif isinstance(decoded_message.content, list):
404
- for item in decoded_message.content:
405
- if isinstance(item, RawTextItem):
406
- content += item.text
407
-
408
- # Create OpenAI response
409
- # generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
410
- response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
411
- created = int(time.time())
412
-
413
- return OpenAIChatCompletion(
414
- id=response_id,
415
- object="chat.completion",
416
- created=created,
417
- model=params.model,
418
- choices=[
419
- OpenAIChoice(
420
- index=0,
421
- message=OpenAIAssistantMessageParam(
422
- role="assistant",
423
- content=content,
424
- tool_calls=openai_tool_calls,
425
- ),
426
- finish_reason=finish_reason,
427
- logprobs=None,
428
- )
429
- ],
430
- usage=OpenAIChatCompletionUsage(
431
- prompt_tokens=0, # TODO: calculate properly
432
- completion_tokens=0, # TODO: calculate properly
433
- total_tokens=0, # TODO: calculate properly
434
- ),
435
- )
436
-
437
- async def _stream_chat_completion(
438
- self,
439
- generator,
440
- params: OpenAIChatCompletionRequestWithExtraBody,
441
- ) -> AsyncIterator[OpenAIChatCompletionChunk]:
442
- """Stream chat completion chunks as they're generated."""
443
- from llama_stack.models.llama.datatypes import StopReason
444
- from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message
445
- from llama_stack_api import (
446
- OpenAIChatCompletionChunk,
447
- OpenAIChatCompletionToolCall,
448
- OpenAIChatCompletionToolCallFunction,
449
- OpenAIChoiceDelta,
450
- OpenAIChunkChoice,
451
- )
452
-
453
- response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
454
- created = int(time.time())
455
- generated_text = ""
456
-
457
- # Yield chunks as tokens are generated
458
- for result_batch in generator:
459
- for result in result_batch:
460
- if result.ignore_token or result.source != "output":
461
- continue
462
-
463
- generated_text += result.text
464
-
465
- # Yield delta chunk with the new text
466
- chunk = OpenAIChatCompletionChunk(
467
- id=response_id,
468
- object="chat.completion.chunk",
469
- created=created,
470
- model=params.model,
471
- choices=[
472
- OpenAIChunkChoice(
473
- index=0,
474
- delta=OpenAIChoiceDelta(
475
- role="assistant",
476
- content=result.text,
477
- ),
478
- finish_reason="",
479
- logprobs=None,
480
- )
481
- ],
482
- )
483
- yield chunk
484
-
485
- # After generation completes, decode the full message to extract tool calls
486
- decoded_message = decode_assistant_message(generated_text, StopReason.end_of_turn)
487
-
488
- # If tool calls are present, yield a final chunk with tool_calls
489
- if decoded_message.tool_calls:
490
- openai_tool_calls = [
491
- OpenAIChatCompletionToolCall(
492
- # generate a uuid for the call id. This is the only inline provider that does this, so need to get creative.
493
- id=f"call_{uuid.uuid4().hex[:24]}",
494
- type="function",
495
- function=OpenAIChatCompletionToolCallFunction(
496
- name=str(tc.tool_name),
497
- arguments=tc.arguments,
498
- ),
499
- )
500
- for tc in decoded_message.tool_calls
501
- ]
502
-
503
- # Yield chunk with tool_calls
504
- chunk = OpenAIChatCompletionChunk(
505
- id=response_id,
506
- object="chat.completion.chunk",
507
- created=created,
508
- model=params.model,
509
- choices=[
510
- OpenAIChunkChoice(
511
- index=0,
512
- delta=OpenAIChoiceDelta(
513
- role="assistant",
514
- tool_calls=openai_tool_calls,
515
- ),
516
- finish_reason="",
517
- logprobs=None,
518
- )
519
- ],
520
- )
521
- yield chunk
522
-
523
- finish_reason = "tool_calls"
524
- else:
525
- finish_reason = "stop"
526
-
527
- # Yield final chunk with finish_reason
528
- final_chunk = OpenAIChatCompletionChunk(
529
- id=response_id,
530
- object="chat.completion.chunk",
531
- created=created,
532
- model=params.model,
533
- choices=[
534
- OpenAIChunkChoice(
535
- index=0,
536
- delta=OpenAIChoiceDelta(),
537
- finish_reason=finish_reason,
538
- logprobs=None,
539
- )
540
- ],
541
- )
542
- yield final_chunk
@@ -1,77 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- from collections.abc import Callable
8
- from functools import partial
9
- from typing import Any
10
-
11
- from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
12
- from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
13
-
14
- from .parallel_utils import ModelParallelProcessGroup
15
-
16
-
17
- class ModelRunner:
18
- def __init__(self, llama):
19
- self.llama = llama
20
-
21
- def __call__(self, task: Any):
22
- task_type = task[0]
23
- if task_type == "chat_completion":
24
- # task[1] is [params, raw_messages]
25
- params, raw_messages = task[1]
26
- return self.llama.chat_completion(params, raw_messages)
27
- else:
28
- raise ValueError(f"Unexpected task type {task_type}")
29
-
30
-
31
- def init_model_cb(
32
- builder_fn: Callable,
33
- params: list[Any],
34
- ):
35
- llama = builder_fn(*params)
36
- return ModelRunner(llama)
37
-
38
-
39
- class LlamaModelParallelGenerator:
40
- """
41
- This abstraction exists so
42
- - we can run model parallel code without needing to run the CLIs via torchrun
43
- - this also enables use model parallel code within a notebook context.
44
-
45
- A Context Manager is used to ensure that the model parallel process is started and stopped
46
- correctly. This does make the ergonomics a little awkward, because it isn't immediately
47
- clear at the callsite why we need to use a context manager.
48
- """
49
-
50
- def __init__(
51
- self,
52
- model_parallel_size: int,
53
- builder_fn: Callable,
54
- builder_params: list[Any],
55
- formatter: Llama3ChatFormat | Llama4ChatFormat,
56
- ):
57
- self.model_parallel_size = model_parallel_size
58
- self.builder_fn = builder_fn
59
- self.builder_params = builder_params
60
- self.formatter = formatter
61
-
62
- def start(self):
63
- self.__enter__()
64
-
65
- def stop(self):
66
- self.__exit__(None, None, None)
67
-
68
- def __enter__(self):
69
- self.group = ModelParallelProcessGroup(
70
- self.model_parallel_size,
71
- init_model_cb=partial(init_model_cb, self.builder_fn, self.builder_params),
72
- )
73
- self.group.start()
74
- return self
75
-
76
- def __exit__(self, exc_type, exc_value, exc_traceback):
77
- self.group.stop()