llama-stack 0.4.4__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  30. llama_stack/distributions/nvidia/config.yaml +4 -1
  31. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  32. llama_stack/distributions/oci/config.yaml +4 -1
  33. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  34. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  35. llama_stack/distributions/starter/build.yaml +62 -0
  36. llama_stack/distributions/starter/config.yaml +22 -3
  37. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  38. llama_stack/distributions/starter/starter.py +13 -1
  39. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  40. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  41. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  42. llama_stack/distributions/template.py +10 -2
  43. llama_stack/distributions/watsonx/config.yaml +4 -1
  44. llama_stack/log.py +1 -0
  45. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  46. llama_stack/providers/inline/agents/meta_reference/agents.py +57 -61
  47. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +49 -51
  48. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +94 -22
  49. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  50. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  51. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  52. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  53. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  54. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  55. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  56. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  57. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  58. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  59. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +15 -18
  60. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  61. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  62. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  63. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  64. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  65. llama_stack/providers/registry/agents.py +1 -0
  66. llama_stack/providers/registry/inference.py +1 -9
  67. llama_stack/providers/registry/vector_io.py +136 -16
  68. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  69. llama_stack/providers/remote/files/s3/config.py +5 -3
  70. llama_stack/providers/remote/files/s3/files.py +2 -2
  71. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  72. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  73. llama_stack/providers/remote/inference/together/together.py +4 -0
  74. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  75. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  76. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  77. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  78. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  79. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  80. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  81. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  82. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  83. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  84. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  85. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  86. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  87. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  88. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  89. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  90. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  91. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  92. llama_stack/providers/utils/bedrock/client.py +3 -3
  93. llama_stack/providers/utils/bedrock/config.py +7 -7
  94. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  95. llama_stack/providers/utils/inference/http_client.py +239 -0
  96. llama_stack/providers/utils/inference/litellm_openai_mixin.py +5 -0
  97. llama_stack/providers/utils/inference/model_registry.py +148 -2
  98. llama_stack/providers/utils/inference/openai_compat.py +2 -1
  99. llama_stack/providers/utils/inference/openai_mixin.py +41 -2
  100. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  101. llama_stack/providers/utils/memory/vector_store.py +46 -19
  102. llama_stack/providers/utils/responses/responses_store.py +7 -7
  103. llama_stack/providers/utils/safety.py +114 -0
  104. llama_stack/providers/utils/tools/mcp.py +44 -3
  105. llama_stack/testing/api_recorder.py +9 -3
  106. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/METADATA +14 -2
  107. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/RECORD +111 -144
  108. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  109. llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
  110. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  111. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  112. llama_stack/models/llama/hadamard_utils.py +0 -88
  113. llama_stack/models/llama/llama3/args.py +0 -74
  114. llama_stack/models/llama/llama3/dog.jpg +0 -0
  115. llama_stack/models/llama/llama3/generation.py +0 -378
  116. llama_stack/models/llama/llama3/model.py +0 -304
  117. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  118. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  119. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  120. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  121. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  122. llama_stack/models/llama/llama3/pasta.jpeg +0 -0
  123. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  124. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  125. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  126. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  127. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  128. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  129. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  130. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  131. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  132. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  133. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  134. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  135. llama_stack/models/llama/llama4/args.py +0 -107
  136. llama_stack/models/llama/llama4/ffn.py +0 -58
  137. llama_stack/models/llama/llama4/moe.py +0 -214
  138. llama_stack/models/llama/llama4/preprocess.py +0 -435
  139. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  140. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  141. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  142. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  143. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  144. llama_stack/models/llama/quantize_impls.py +0 -316
  145. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  146. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  147. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  148. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  149. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  150. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  151. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  152. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/WHEEL +0 -0
  153. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/entry_points.txt +0 -0
  154. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  155. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,6 @@
3
3
  #
4
4
  # This source code is licensed under the terms described in the LICENSE file in
5
5
  # the root directory of this source tree.
6
-
7
6
  from typing import Any
8
7
 
9
8
  from llama_stack.log import get_logger
@@ -11,12 +10,23 @@ from llama_stack_api import (
11
10
  BenchmarkConfig,
12
11
  Eval,
13
12
  EvaluateResponse,
13
+ EvaluateRowsRequest,
14
14
  Job,
15
+ JobCancelRequest,
16
+ JobResultRequest,
17
+ JobStatusRequest,
15
18
  RoutingTable,
19
+ RunEvalRequest,
20
+ ScoreBatchRequest,
16
21
  ScoreBatchResponse,
22
+ ScoreRequest,
17
23
  ScoreResponse,
18
24
  Scoring,
19
- ScoringFnParams,
25
+ resolve_evaluate_rows_request,
26
+ resolve_job_cancel_request,
27
+ resolve_job_result_request,
28
+ resolve_job_status_request,
29
+ resolve_run_eval_request,
20
30
  )
21
31
 
22
32
  logger = get_logger(name=__name__, category="core::routers")
@@ -40,21 +50,22 @@ class ScoringRouter(Scoring):
40
50
 
41
51
  async def score_batch(
42
52
  self,
43
- dataset_id: str,
44
- scoring_functions: dict[str, ScoringFnParams | None] = None,
45
- save_results_dataset: bool = False,
53
+ request: ScoreBatchRequest,
46
54
  ) -> ScoreBatchResponse:
47
- logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
55
+ logger.debug(f"ScoringRouter.score_batch: {request.dataset_id}")
48
56
  res = {}
49
- for fn_identifier in scoring_functions.keys():
57
+ for fn_identifier in request.scoring_functions.keys():
50
58
  provider = await self.routing_table.get_provider_impl(fn_identifier)
51
- score_response = await provider.score_batch(
52
- dataset_id=dataset_id,
53
- scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
59
+ # Create a request for this specific scoring function
60
+ single_fn_request = ScoreBatchRequest(
61
+ dataset_id=request.dataset_id,
62
+ scoring_functions={fn_identifier: request.scoring_functions[fn_identifier]},
63
+ save_results_dataset=request.save_results_dataset,
54
64
  )
65
+ score_response = await provider.score_batch(single_fn_request)
55
66
  res.update(score_response.results)
56
67
 
57
- if save_results_dataset:
68
+ if request.save_results_dataset:
58
69
  raise NotImplementedError("Save results dataset not implemented yet")
59
70
 
60
71
  return ScoreBatchResponse(
@@ -63,18 +74,19 @@ class ScoringRouter(Scoring):
63
74
 
64
75
  async def score(
65
76
  self,
66
- input_rows: list[dict[str, Any]],
67
- scoring_functions: dict[str, ScoringFnParams | None] = None,
77
+ request: ScoreRequest,
68
78
  ) -> ScoreResponse:
69
- logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
79
+ logger.debug(f"ScoringRouter.score: {len(request.input_rows)} rows, {len(request.scoring_functions)} functions")
70
80
  res = {}
71
81
  # look up and map each scoring function to its provider impl
72
- for fn_identifier in scoring_functions.keys():
82
+ for fn_identifier in request.scoring_functions.keys():
73
83
  provider = await self.routing_table.get_provider_impl(fn_identifier)
74
- score_response = await provider.score(
75
- input_rows=input_rows,
76
- scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
84
+ # Create a request for this specific scoring function
85
+ single_fn_request = ScoreRequest(
86
+ input_rows=request.input_rows,
87
+ scoring_functions={fn_identifier: request.scoring_functions[fn_identifier]},
77
88
  )
89
+ score_response = await provider.score(single_fn_request)
78
90
  res.update(score_response.results)
79
91
 
80
92
  return ScoreResponse(results=res)
@@ -98,61 +110,139 @@ class EvalRouter(Eval):
98
110
 
99
111
  async def run_eval(
100
112
  self,
101
- benchmark_id: str,
102
- benchmark_config: BenchmarkConfig,
113
+ request: RunEvalRequest | None = None,
114
+ *,
115
+ benchmark_id: str | None = None,
116
+ benchmark_config: BenchmarkConfig | None = None,
103
117
  ) -> Job:
104
- logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
105
- provider = await self.routing_table.get_provider_impl(benchmark_id)
106
- return await provider.run_eval(
107
- benchmark_id=benchmark_id,
108
- benchmark_config=benchmark_config,
118
+ """Run an evaluation on a benchmark.
119
+
120
+ Supports both new-style (request object) and old-style (individual parameters).
121
+ Old-style usage is deprecated and will emit a DeprecationWarning.
122
+
123
+ Args:
124
+ request: The new-style request object (preferred)
125
+ benchmark_id: (Deprecated) The benchmark ID
126
+ benchmark_config: (Deprecated) The benchmark configuration
127
+
128
+ Returns:
129
+ Job object representing the evaluation job
130
+ """
131
+ resolved_request = resolve_run_eval_request(
132
+ request, benchmark_id=benchmark_id, benchmark_config=benchmark_config
109
133
  )
134
+ logger.debug(f"EvalRouter.run_eval: {resolved_request.benchmark_id}")
135
+ provider = await self.routing_table.get_provider_impl(resolved_request.benchmark_id)
136
+ return await provider.run_eval(resolved_request)
110
137
 
111
138
  async def evaluate_rows(
112
139
  self,
113
- benchmark_id: str,
114
- input_rows: list[dict[str, Any]],
115
- scoring_functions: list[str],
116
- benchmark_config: BenchmarkConfig,
140
+ request: EvaluateRowsRequest | None = None,
141
+ *,
142
+ benchmark_id: str | None = None,
143
+ input_rows: list[dict[str, Any]] | None = None,
144
+ scoring_functions: list[str] | None = None,
145
+ benchmark_config: BenchmarkConfig | None = None,
117
146
  ) -> EvaluateResponse:
118
- logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
119
- provider = await self.routing_table.get_provider_impl(benchmark_id)
120
- return await provider.evaluate_rows(
147
+ """Evaluate a list of rows on a benchmark.
148
+
149
+ Supports both new-style (request object) and old-style (individual parameters).
150
+ Old-style usage is deprecated and will emit a DeprecationWarning.
151
+
152
+ Args:
153
+ request: The new-style request object (preferred)
154
+ benchmark_id: (Deprecated) The benchmark ID
155
+ input_rows: (Deprecated) The rows to evaluate
156
+ scoring_functions: (Deprecated) The scoring functions to use
157
+ benchmark_config: (Deprecated) The benchmark configuration
158
+
159
+ Returns:
160
+ EvaluateResponse object containing generations and scores
161
+ """
162
+ resolved_request = resolve_evaluate_rows_request(
163
+ request,
121
164
  benchmark_id=benchmark_id,
122
165
  input_rows=input_rows,
123
166
  scoring_functions=scoring_functions,
124
167
  benchmark_config=benchmark_config,
125
168
  )
169
+ logger.debug(
170
+ f"EvalRouter.evaluate_rows: {resolved_request.benchmark_id}, {len(resolved_request.input_rows)} rows"
171
+ )
172
+ provider = await self.routing_table.get_provider_impl(resolved_request.benchmark_id)
173
+ return await provider.evaluate_rows(resolved_request)
126
174
 
127
175
  async def job_status(
128
176
  self,
129
- benchmark_id: str,
130
- job_id: str,
177
+ request: JobStatusRequest | None = None,
178
+ *,
179
+ benchmark_id: str | None = None,
180
+ job_id: str | None = None,
131
181
  ) -> Job:
132
- logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
133
- provider = await self.routing_table.get_provider_impl(benchmark_id)
134
- return await provider.job_status(benchmark_id, job_id)
182
+ """Get the status of a job.
183
+
184
+ Supports both new-style (request object) and old-style (individual parameters).
185
+ Old-style usage is deprecated and will emit a DeprecationWarning.
186
+
187
+ Args:
188
+ request: The new-style request object (preferred)
189
+ benchmark_id: (Deprecated) The benchmark ID
190
+ job_id: (Deprecated) The job ID
191
+
192
+ Returns:
193
+ Job object with the current status
194
+ """
195
+ resolved_request = resolve_job_status_request(request, benchmark_id=benchmark_id, job_id=job_id)
196
+ logger.debug(f"EvalRouter.job_status: {resolved_request.benchmark_id}, {resolved_request.job_id}")
197
+ provider = await self.routing_table.get_provider_impl(resolved_request.benchmark_id)
198
+ return await provider.job_status(resolved_request)
135
199
 
136
200
  async def job_cancel(
137
201
  self,
138
- benchmark_id: str,
139
- job_id: str,
202
+ request: JobCancelRequest | None = None,
203
+ *,
204
+ benchmark_id: str | None = None,
205
+ job_id: str | None = None,
140
206
  ) -> None:
141
- logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
142
- provider = await self.routing_table.get_provider_impl(benchmark_id)
143
- await provider.job_cancel(
144
- benchmark_id,
145
- job_id,
146
- )
207
+ """Cancel a job.
208
+
209
+ Supports both new-style (request object) and old-style (individual parameters).
210
+ Old-style usage is deprecated and will emit a DeprecationWarning.
211
+
212
+ Args:
213
+ request: The new-style request object (preferred)
214
+ benchmark_id: (Deprecated) The benchmark ID
215
+ job_id: (Deprecated) The job ID
216
+
217
+ Returns:
218
+ None
219
+ """
220
+ resolved_request = resolve_job_cancel_request(request, benchmark_id=benchmark_id, job_id=job_id)
221
+ logger.debug(f"EvalRouter.job_cancel: {resolved_request.benchmark_id}, {resolved_request.job_id}")
222
+ provider = await self.routing_table.get_provider_impl(resolved_request.benchmark_id)
223
+ await provider.job_cancel(resolved_request)
147
224
 
148
225
  async def job_result(
149
226
  self,
150
- benchmark_id: str,
151
- job_id: str,
227
+ request: JobResultRequest | None = None,
228
+ *,
229
+ benchmark_id: str | None = None,
230
+ job_id: str | None = None,
152
231
  ) -> EvaluateResponse:
153
- logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
154
- provider = await self.routing_table.get_provider_impl(benchmark_id)
155
- return await provider.job_result(
156
- benchmark_id,
157
- job_id,
158
- )
232
+ """Get the result of a job.
233
+
234
+ Supports both new-style (request object) and old-style (individual parameters).
235
+ Old-style usage is deprecated and will emit a DeprecationWarning.
236
+
237
+ Args:
238
+ request: The new-style request object (preferred)
239
+ benchmark_id: (Deprecated) The benchmark ID
240
+ job_id: (Deprecated) The job ID
241
+
242
+ Returns:
243
+ EvaluateResponse object with the job results
244
+ """
245
+ resolved_request = resolve_job_result_request(request, benchmark_id=benchmark_id, job_id=job_id)
246
+ logger.debug(f"EvalRouter.job_result: {resolved_request.benchmark_id}, {resolved_request.job_id}")
247
+ provider = await self.routing_table.get_provider_impl(resolved_request.benchmark_id)
248
+ return await provider.job_result(resolved_request)
@@ -20,9 +20,11 @@ from llama_stack.core.request_headers import get_authenticated_user
20
20
  from llama_stack.log import get_logger
21
21
  from llama_stack.providers.utils.inference.inference_store import InferenceStore
22
22
  from llama_stack_api import (
23
+ GetChatCompletionRequest,
23
24
  HealthResponse,
24
25
  HealthStatus,
25
26
  Inference,
27
+ ListChatCompletionsRequest,
26
28
  ListOpenAIChatCompletionResponse,
27
29
  ModelNotFoundError,
28
30
  ModelType,
@@ -45,7 +47,7 @@ from llama_stack_api import (
45
47
  OpenAIMessageParam,
46
48
  OpenAITokenLogProb,
47
49
  OpenAITopLogProb,
48
- Order,
50
+ RegisterModelRequest,
49
51
  RerankResponse,
50
52
  RoutingTable,
51
53
  )
@@ -87,7 +89,14 @@ class InferenceRouter(Inference):
87
89
  logger.debug(
88
90
  f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
89
91
  )
90
- await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
92
+ request = RegisterModelRequest(
93
+ model_id=model_id,
94
+ provider_model_id=provider_model_id,
95
+ provider_id=provider_id,
96
+ metadata=metadata,
97
+ model_type=model_type,
98
+ )
99
+ await self.routing_table.register_model(request)
91
100
 
92
101
  async def _get_model_provider(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
93
102
  model = await self.routing_table.get_object_by_identifier("model", model_id)
@@ -229,18 +238,20 @@ class InferenceRouter(Inference):
229
238
 
230
239
  async def list_chat_completions(
231
240
  self,
232
- after: str | None = None,
233
- limit: int | None = 20,
234
- model: str | None = None,
235
- order: Order | None = Order.desc,
241
+ request: ListChatCompletionsRequest,
236
242
  ) -> ListOpenAIChatCompletionResponse:
237
243
  if self.store:
238
- return await self.store.list_chat_completions(after, limit, model, order)
244
+ return await self.store.list_chat_completions(
245
+ after=request.after,
246
+ limit=request.limit,
247
+ model=request.model,
248
+ order=request.order,
249
+ )
239
250
  raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
240
251
 
241
- async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
252
+ async def get_chat_completion(self, request: GetChatCompletionRequest) -> OpenAICompletionWithInputMessages:
242
253
  if self.store:
243
- return await self.store.get_chat_completion(completion_id)
254
+ return await self.store.get_chat_completion(request.completion_id)
244
255
  raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
245
256
 
246
257
  async def _nonstream_openai_chat_completion(
@@ -4,14 +4,22 @@
4
4
  # This source code is licensed under the terms described in the LICENSE file in
5
5
  # the root directory of this source tree.
6
6
 
7
- from typing import Any
8
-
9
7
  from opentelemetry import trace
10
8
 
11
9
  from llama_stack.core.datatypes import SafetyConfig
12
10
  from llama_stack.log import get_logger
13
11
  from llama_stack.telemetry.helpers import safety_request_span_attributes, safety_span_name
14
- from llama_stack_api import ModerationObject, OpenAIMessageParam, RoutingTable, RunShieldResponse, Safety, Shield
12
+ from llama_stack_api import (
13
+ ModerationObject,
14
+ RegisterShieldRequest,
15
+ RoutingTable,
16
+ RunModerationRequest,
17
+ RunShieldRequest,
18
+ RunShieldResponse,
19
+ Safety,
20
+ Shield,
21
+ UnregisterShieldRequest,
22
+ )
15
23
 
16
24
  logger = get_logger(name=__name__, category="core::routers")
17
25
  tracer = trace.get_tracer(__name__)
@@ -35,54 +43,38 @@ class SafetyRouter(Safety):
35
43
  logger.debug("SafetyRouter.shutdown")
36
44
  pass
37
45
 
38
- async def register_shield(
39
- self,
40
- shield_id: str,
41
- provider_shield_id: str | None = None,
42
- provider_id: str | None = None,
43
- params: dict[str, Any] | None = None,
44
- ) -> Shield:
45
- logger.debug(f"SafetyRouter.register_shield: {shield_id}")
46
- return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
46
+ async def register_shield(self, request: RegisterShieldRequest) -> Shield:
47
+ logger.debug(f"SafetyRouter.register_shield: {request.shield_id}")
48
+ return await self.routing_table.register_shield(request)
47
49
 
48
50
  async def unregister_shield(self, identifier: str) -> None:
49
51
  logger.debug(f"SafetyRouter.unregister_shield: {identifier}")
50
- return await self.routing_table.unregister_shield(identifier)
51
-
52
- async def run_shield(
53
- self,
54
- shield_id: str,
55
- messages: list[OpenAIMessageParam],
56
- params: dict[str, Any] = None,
57
- ) -> RunShieldResponse:
58
- with tracer.start_as_current_span(name=safety_span_name(shield_id)):
59
- logger.debug(f"SafetyRouter.run_shield: {shield_id}")
60
- provider = await self.routing_table.get_provider_impl(shield_id)
61
- response = await provider.run_shield(
62
- shield_id=shield_id,
63
- messages=messages,
64
- params=params,
65
- )
66
-
67
- safety_request_span_attributes(shield_id, messages, response)
52
+ return await self.routing_table.unregister_shield(UnregisterShieldRequest(identifier=identifier))
53
+
54
+ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
55
+ with tracer.start_as_current_span(name=safety_span_name(request.shield_id)):
56
+ logger.debug(f"SafetyRouter.run_shield: {request.shield_id}")
57
+ provider = await self.routing_table.get_provider_impl(request.shield_id)
58
+ response = await provider.run_shield(request)
59
+ safety_request_span_attributes(request.shield_id, request.messages, response)
68
60
  return response
69
61
 
70
- async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
62
+ async def run_moderation(self, request: RunModerationRequest) -> ModerationObject:
71
63
  list_shields_response = await self.routing_table.list_shields()
72
64
  shields = list_shields_response.data
73
65
 
74
66
  selected_shield: Shield | None = None
75
- provider_model: str | None = model
67
+ provider_model: str | None = request.model
76
68
 
77
- if model:
78
- matches: list[Shield] = [s for s in shields if model == s.provider_resource_id]
69
+ if request.model:
70
+ matches: list[Shield] = [s for s in shields if request.model == s.provider_resource_id]
79
71
  if not matches:
80
72
  raise ValueError(
81
- f"No shield associated with provider_resource id {model}: choose from {[s.provider_resource_id for s in shields]}"
73
+ f"No shield associated with provider_resource id {request.model}: choose from {[s.provider_resource_id for s in shields]}"
82
74
  )
83
75
  if len(matches) > 1:
84
76
  raise ValueError(
85
- f"Multiple shields associated with provider_resource id {model}: matched shields {[s.identifier for s in matches]}"
77
+ f"Multiple shields associated with provider_resource id {request.model}: matched shields {[s.identifier for s in matches]}"
86
78
  )
87
79
  selected_shield = matches[0]
88
80
  else:
@@ -105,9 +97,5 @@ class SafetyRouter(Safety):
105
97
  logger.debug(f"SafetyRouter.run_moderation: {shield_id}")
106
98
  provider = await self.routing_table.get_provider_impl(shield_id)
107
99
 
108
- response = await provider.run_moderation(
109
- input=input,
110
- model=provider_model,
111
- )
112
-
113
- return response
100
+ provider_request = RunModerationRequest(input=request.input, model=provider_model)
101
+ return await provider.run_moderation(provider_request)
@@ -39,6 +39,7 @@ from llama_stack_api import (
39
39
  VectorStoreFileObject,
40
40
  VectorStoreFilesListInBatchResponse,
41
41
  VectorStoreFileStatus,
42
+ VectorStoreListFilesResponse,
42
43
  VectorStoreListResponse,
43
44
  VectorStoreObject,
44
45
  VectorStoreSearchResponsePage,
@@ -148,11 +149,12 @@ class VectorIORouter(VectorIO):
148
149
  self,
149
150
  params: Annotated[OpenAICreateVectorStoreRequestWithExtraBody, Body(...)],
150
151
  ) -> VectorStoreObject:
151
- # Extract llama-stack-specific parameters from extra_body
152
+ # Extract llama-stack-specific parameters from extra_body or metadata
152
153
  extra = params.model_extra or {}
153
- embedding_model = extra.get("embedding_model")
154
- embedding_dimension = extra.get("embedding_dimension")
155
- provider_id = extra.get("provider_id")
154
+ metadata = params.metadata or {}
155
+ embedding_model = extra.get("embedding_model", metadata.get("embedding_model"))
156
+ embedding_dimension = extra.get("embedding_dimension", metadata.get("embedding_dimension"))
157
+ provider_id = extra.get("provider_id", metadata.get("provider_id"))
156
158
 
157
159
  # Use default embedding model if not specified
158
160
  if (
@@ -166,8 +168,14 @@ class VectorIORouter(VectorIO):
166
168
  embedding_model = f"{embedding_provider_id}/{model_id}"
167
169
 
168
170
  if embedding_model is not None and embedding_dimension is None:
169
- embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
170
-
171
+ if (
172
+ self.vector_stores_config
173
+ and self.vector_stores_config.default_embedding_model is not None
174
+ and self.vector_stores_config.default_embedding_model.embedding_dimensions
175
+ ):
176
+ embedding_dimension = self.vector_stores_config.default_embedding_model.embedding_dimensions
177
+ else:
178
+ embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
171
179
  # Validate that embedding model exists and is of the correct type
172
180
  if embedding_model is not None:
173
181
  model = await self.routing_table.get_object_by_identifier("model", embedding_model)
@@ -376,7 +384,7 @@ class VectorIORouter(VectorIO):
376
384
  after: str | None = None,
377
385
  before: str | None = None,
378
386
  filter: VectorStoreFileStatus | None = None,
379
- ) -> list[VectorStoreFileObject]:
387
+ ) -> VectorStoreListFilesResponse:
380
388
  logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
381
389
  return await self.routing_table.openai_list_files_in_vector_store(
382
390
  vector_store_id=vector_store_id,
@@ -16,6 +16,7 @@ from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProv
16
16
  from llama_stack.core.utils.dynamic import instantiate_class_type
17
17
  from llama_stack.log import get_logger
18
18
  from llama_stack_api import (
19
+ GetModelRequest,
19
20
  ListModelsResponse,
20
21
  Model,
21
22
  ModelNotFoundError,
@@ -23,6 +24,8 @@ from llama_stack_api import (
23
24
  ModelType,
24
25
  OpenAIListModelsResponse,
25
26
  OpenAIModel,
27
+ RegisterModelRequest,
28
+ UnregisterModelRequest,
26
29
  )
27
30
 
28
31
  from .common import CommonRoutingTableImpl, lookup_model
@@ -171,7 +174,12 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
171
174
  ]
172
175
  return OpenAIListModelsResponse(data=openai_models)
173
176
 
174
- async def get_model(self, model_id: str) -> Model:
177
+ async def get_model(self, request_or_model_id: GetModelRequest | str) -> Model:
178
+ # Support both the public Models API (GetModelRequest) and internal ModelStore interface (string)
179
+ if isinstance(request_or_model_id, GetModelRequest):
180
+ model_id = request_or_model_id.model_id
181
+ else:
182
+ model_id = request_or_model_id
175
183
  return await lookup_model(self, model_id)
176
184
 
177
185
  async def get_provider_impl(self, model_id: str) -> Any:
@@ -195,12 +203,28 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
195
203
 
196
204
  async def register_model(
197
205
  self,
198
- model_id: str,
206
+ request: RegisterModelRequest | str | None = None,
207
+ *,
208
+ model_id: str | None = None,
199
209
  provider_model_id: str | None = None,
200
210
  provider_id: str | None = None,
201
211
  metadata: dict[str, Any] | None = None,
202
212
  model_type: ModelType | None = None,
203
213
  ) -> Model:
214
+ # Support both the public Models API (RegisterModelRequest) and legacy parameter-based interface
215
+ if isinstance(request, RegisterModelRequest):
216
+ model_id = request.model_id
217
+ provider_model_id = request.provider_model_id
218
+ provider_id = request.provider_id
219
+ metadata = request.metadata
220
+ model_type = request.model_type
221
+ elif isinstance(request, str):
222
+ # Legacy positional argument: register_model("model-id", ...)
223
+ model_id = request
224
+
225
+ if model_id is None:
226
+ raise ValueError("Either request or model_id must be provided")
227
+
204
228
  if provider_id is None:
205
229
  # If provider_id not specified, use the only provider if it supports this model
206
230
  if len(self.impls_by_provider_id) == 1:
@@ -229,7 +253,22 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
229
253
  registered_model = await self.register_object(model)
230
254
  return registered_model
231
255
 
232
- async def unregister_model(self, model_id: str) -> None:
256
+ async def unregister_model(
257
+ self,
258
+ request: UnregisterModelRequest | str | None = None,
259
+ *,
260
+ model_id: str | None = None,
261
+ ) -> None:
262
+ # Support both the public Models API (UnregisterModelRequest) and legacy parameter-based interface
263
+ if isinstance(request, UnregisterModelRequest):
264
+ model_id = request.model_id
265
+ elif isinstance(request, str):
266
+ # Legacy positional argument: unregister_model("model-id")
267
+ model_id = request
268
+
269
+ if model_id is None:
270
+ raise ValueError("Either request or model_id must be provided")
271
+
233
272
  existing_model = await self.get_model(model_id)
234
273
  if existing_model is None:
235
274
  raise ModelNotFoundError(model_id)
@@ -9,12 +9,14 @@ from llama_stack.core.datatypes import (
9
9
  )
10
10
  from llama_stack.log import get_logger
11
11
  from llama_stack_api import (
12
+ GetScoringFunctionRequest,
13
+ ListScoringFunctionsRequest,
12
14
  ListScoringFunctionsResponse,
13
- ParamType,
15
+ RegisterScoringFunctionRequest,
14
16
  ResourceType,
15
17
  ScoringFn,
16
- ScoringFnParams,
17
18
  ScoringFunctions,
19
+ UnregisterScoringFunctionRequest,
18
20
  )
19
21
 
20
22
  from .common import CommonRoutingTableImpl
@@ -23,26 +25,23 @@ logger = get_logger(name=__name__, category="core::routing_tables")
23
25
 
24
26
 
25
27
  class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
26
- async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
28
+ async def list_scoring_functions(self, request: ListScoringFunctionsRequest) -> ListScoringFunctionsResponse:
27
29
  return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
28
30
 
29
- async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
30
- scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
31
+ async def get_scoring_function(self, request: GetScoringFunctionRequest) -> ScoringFn:
32
+ scoring_fn = await self.get_object_by_identifier("scoring_function", request.scoring_fn_id)
31
33
  if scoring_fn is None:
32
- raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
34
+ raise ValueError(f"Scoring function '{request.scoring_fn_id}' not found")
33
35
  return scoring_fn
34
36
 
35
37
  async def register_scoring_function(
36
38
  self,
37
- scoring_fn_id: str,
38
- description: str,
39
- return_type: ParamType,
40
- provider_scoring_fn_id: str | None = None,
41
- provider_id: str | None = None,
42
- params: ScoringFnParams | None = None,
39
+ request: RegisterScoringFunctionRequest,
43
40
  ) -> None:
41
+ provider_scoring_fn_id = request.provider_scoring_fn_id
44
42
  if provider_scoring_fn_id is None:
45
- provider_scoring_fn_id = scoring_fn_id
43
+ provider_scoring_fn_id = request.scoring_fn_id
44
+ provider_id = request.provider_id
46
45
  if provider_id is None:
47
46
  if len(self.impls_by_provider_id) == 1:
48
47
  provider_id = list(self.impls_by_provider_id.keys())[0]
@@ -51,16 +50,17 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
51
50
  "No provider specified and multiple providers available. Please specify a provider_id."
52
51
  )
53
52
  scoring_fn = ScoringFnWithOwner(
54
- identifier=scoring_fn_id,
55
- description=description,
56
- return_type=return_type,
53
+ identifier=request.scoring_fn_id,
54
+ description=request.description,
55
+ return_type=request.return_type,
57
56
  provider_resource_id=provider_scoring_fn_id,
58
57
  provider_id=provider_id,
59
- params=params,
58
+ params=request.params,
60
59
  )
61
60
  scoring_fn.provider_id = provider_id
62
61
  await self.register_object(scoring_fn)
63
62
 
64
- async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
65
- existing_scoring_fn = await self.get_scoring_function(scoring_fn_id)
63
+ async def unregister_scoring_function(self, request: UnregisterScoringFunctionRequest) -> None:
64
+ get_request = GetScoringFunctionRequest(scoring_fn_id=request.scoring_fn_id)
65
+ existing_scoring_fn = await self.get_scoring_function(get_request)
66
66
  await self.unregister_object(existing_scoring_fn)