llama-stack 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  30. llama_stack/distributions/nvidia/config.yaml +4 -1
  31. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  32. llama_stack/distributions/oci/config.yaml +4 -1
  33. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  34. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  35. llama_stack/distributions/starter/build.yaml +62 -0
  36. llama_stack/distributions/starter/config.yaml +22 -3
  37. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  38. llama_stack/distributions/starter/starter.py +13 -1
  39. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  40. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  41. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  42. llama_stack/distributions/template.py +10 -2
  43. llama_stack/distributions/watsonx/config.yaml +4 -1
  44. llama_stack/log.py +1 -0
  45. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  46. llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
  47. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +53 -51
  48. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
  49. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  50. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  51. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  52. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  53. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  54. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  55. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  56. llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
  57. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  58. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  59. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  60. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
  61. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  62. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  63. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  64. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  65. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  66. llama_stack/providers/registry/agents.py +1 -0
  67. llama_stack/providers/registry/inference.py +1 -9
  68. llama_stack/providers/registry/vector_io.py +136 -16
  69. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  70. llama_stack/providers/remote/files/s3/config.py +5 -3
  71. llama_stack/providers/remote/files/s3/files.py +2 -2
  72. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  73. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  74. llama_stack/providers/remote/inference/together/together.py +4 -0
  75. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  76. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  77. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  78. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  79. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  80. llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
  81. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  82. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  83. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  84. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  85. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  86. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  87. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  88. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  89. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  90. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  91. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  92. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  93. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  94. llama_stack/providers/utils/bedrock/client.py +3 -3
  95. llama_stack/providers/utils/bedrock/config.py +7 -7
  96. llama_stack/providers/utils/inference/__init__.py +0 -25
  97. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  98. llama_stack/providers/utils/inference/http_client.py +239 -0
  99. llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
  100. llama_stack/providers/utils/inference/model_registry.py +148 -2
  101. llama_stack/providers/utils/inference/openai_compat.py +1 -158
  102. llama_stack/providers/utils/inference/openai_mixin.py +42 -2
  103. llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
  104. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  105. llama_stack/providers/utils/memory/vector_store.py +46 -19
  106. llama_stack/providers/utils/responses/responses_store.py +7 -7
  107. llama_stack/providers/utils/safety.py +114 -0
  108. llama_stack/providers/utils/tools/mcp.py +44 -3
  109. llama_stack/testing/api_recorder.py +9 -3
  110. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
  111. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/RECORD +115 -148
  112. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  113. llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
  114. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  115. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  116. llama_stack/models/llama/hadamard_utils.py +0 -88
  117. llama_stack/models/llama/llama3/args.py +0 -74
  118. llama_stack/models/llama/llama3/dog.jpg +0 -0
  119. llama_stack/models/llama/llama3/generation.py +0 -378
  120. llama_stack/models/llama/llama3/model.py +0 -304
  121. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  122. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  123. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  124. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  125. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  126. llama_stack/models/llama/llama3/pasta.jpeg +0 -0
  127. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  128. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  129. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  130. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  131. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  132. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  133. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  134. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  135. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  136. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  137. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  138. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  139. llama_stack/models/llama/llama4/args.py +0 -107
  140. llama_stack/models/llama/llama4/ffn.py +0 -58
  141. llama_stack/models/llama/llama4/moe.py +0 -214
  142. llama_stack/models/llama/llama4/preprocess.py +0 -435
  143. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  144. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  145. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  146. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  147. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  148. llama_stack/models/llama/quantize_impls.py +0 -316
  149. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  150. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  151. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  152. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  153. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  154. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  155. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  156. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
  157. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
  158. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
  159. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@ from pydantic import BaseModel
12
12
 
13
13
  from llama_stack_api import (
14
14
  OpenAIChatCompletionToolCall,
15
+ OpenAIFinishReason,
15
16
  OpenAIMessageParam,
16
17
  OpenAIResponseFormatParam,
17
18
  OpenAIResponseInput,
@@ -52,7 +53,7 @@ class ChatCompletionResult:
52
53
  tool_calls: dict[int, OpenAIChatCompletionToolCall]
53
54
  created: int
54
55
  model: str
55
- finish_reason: str
56
+ finish_reason: OpenAIFinishReason
56
57
  message_item_id: str # For streaming events
57
58
  tool_call_item_ids: dict[int, str] # For streaming events
58
59
  content_part_emitted: bool # Tracking state
@@ -53,6 +53,7 @@ from llama_stack_api import (
53
53
  OpenAIToolMessageParam,
54
54
  OpenAIUserMessageParam,
55
55
  ResponseGuardrailSpec,
56
+ RunModerationRequest,
56
57
  Safety,
57
58
  )
58
59
 
@@ -468,7 +469,9 @@ async def run_guardrails(safety_api: Safety | None, messages: str, guardrail_ids
468
469
  else:
469
470
  raise ValueError(f"No shield found with identifier '{guardrail_id}'")
470
471
 
471
- guardrail_tasks = [safety_api.run_moderation(messages, model=model_id) for model_id in model_ids]
472
+ guardrail_tasks = [
473
+ safety_api.run_moderation(RunModerationRequest(input=messages, model=model_id)) for model_id in model_ids
474
+ ]
472
475
  responses = await asyncio.gather(*guardrail_tasks)
473
476
 
474
477
  for response in responses:
@@ -7,7 +7,7 @@
7
7
  import asyncio
8
8
 
9
9
  from llama_stack.log import get_logger
10
- from llama_stack_api import OpenAIMessageParam, Safety, SafetyViolation, ViolationLevel
10
+ from llama_stack_api import OpenAIMessageParam, RunShieldRequest, Safety, SafetyViolation, ViolationLevel
11
11
 
12
12
  log = get_logger(name=__name__, category="agents::meta_reference")
13
13
 
@@ -32,7 +32,7 @@ class ShieldRunnerMixin:
32
32
  async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None:
33
33
  responses = await asyncio.gather(
34
34
  *[
35
- self.safety_api.run_shield(shield_id=identifier, messages=messages, params={})
35
+ self.safety_api.run_shield(RunShieldRequest(shield_id=identifier, messages=messages))
36
36
  for identifier in identifiers
37
37
  ]
38
38
  )
@@ -23,6 +23,7 @@ from llama_stack_api import (
23
23
  BatchObject,
24
24
  ConflictError,
25
25
  Files,
26
+ GetModelRequest,
26
27
  Inference,
27
28
  ListBatchesResponse,
28
29
  Models,
@@ -485,7 +486,7 @@ class ReferenceBatchesImpl(Batches):
485
486
 
486
487
  if "model" in request_body and isinstance(request_body["model"], str):
487
488
  try:
488
- await self.models_api.get_model(request_body["model"])
489
+ await self.models_api.get_model(GetModelRequest(model_id=request_body["model"]))
489
490
  except Exception:
490
491
  errors.append(
491
492
  BatchError(
@@ -13,19 +13,25 @@ from llama_stack.providers.utils.common.data_schema_validator import ColumnName
13
13
  from llama_stack_api import (
14
14
  Agents,
15
15
  Benchmark,
16
- BenchmarkConfig,
17
16
  BenchmarksProtocolPrivate,
18
17
  DatasetIO,
19
18
  Datasets,
20
19
  Eval,
21
20
  EvaluateResponse,
21
+ EvaluateRowsRequest,
22
22
  Inference,
23
+ IterRowsRequest,
23
24
  Job,
25
+ JobCancelRequest,
26
+ JobResultRequest,
24
27
  JobStatus,
28
+ JobStatusRequest,
25
29
  OpenAIChatCompletionRequestWithExtraBody,
26
30
  OpenAICompletionRequestWithExtraBody,
27
31
  OpenAISystemMessageParam,
28
32
  OpenAIUserMessageParam,
33
+ RunEvalRequest,
34
+ ScoreRequest,
29
35
  Scoring,
30
36
  )
31
37
 
@@ -90,10 +96,9 @@ class MetaReferenceEvalImpl(
90
96
 
91
97
  async def run_eval(
92
98
  self,
93
- benchmark_id: str,
94
- benchmark_config: BenchmarkConfig,
99
+ request: RunEvalRequest,
95
100
  ) -> Job:
96
- task_def = self.benchmarks[benchmark_id]
101
+ task_def = self.benchmarks[request.benchmark_id]
97
102
  dataset_id = task_def.dataset_id
98
103
  scoring_functions = task_def.scoring_functions
99
104
 
@@ -101,15 +106,18 @@ class MetaReferenceEvalImpl(
101
106
  # dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
102
107
 
103
108
  all_rows = await self.datasetio_api.iterrows(
104
- dataset_id=dataset_id,
105
- limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
109
+ IterRowsRequest(
110
+ dataset_id=dataset_id,
111
+ limit=(-1 if request.benchmark_config.num_examples is None else request.benchmark_config.num_examples),
112
+ )
106
113
  )
107
- res = await self.evaluate_rows(
108
- benchmark_id=benchmark_id,
114
+ eval_rows_request = EvaluateRowsRequest(
115
+ benchmark_id=request.benchmark_id,
109
116
  input_rows=all_rows.data,
110
117
  scoring_functions=scoring_functions,
111
- benchmark_config=benchmark_config,
118
+ benchmark_config=request.benchmark_config,
112
119
  )
120
+ res = await self.evaluate_rows(eval_rows_request)
113
121
 
114
122
  # TODO: currently needs to wait for generation before returning
115
123
  # need job scheduler queue (ray/celery) w/ jobs api
@@ -118,9 +126,9 @@ class MetaReferenceEvalImpl(
118
126
  return Job(job_id=job_id, status=JobStatus.completed)
119
127
 
120
128
  async def _run_model_generation(
121
- self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
129
+ self, input_rows: list[dict[str, Any]], request: EvaluateRowsRequest
122
130
  ) -> list[dict[str, Any]]:
123
- candidate = benchmark_config.eval_candidate
131
+ candidate = request.benchmark_config.eval_candidate
124
132
  assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
125
133
  sampling_params = {"max_tokens": candidate.sampling_params.max_tokens}
126
134
 
@@ -165,50 +173,50 @@ class MetaReferenceEvalImpl(
165
173
 
166
174
  async def evaluate_rows(
167
175
  self,
168
- benchmark_id: str,
169
- input_rows: list[dict[str, Any]],
170
- scoring_functions: list[str],
171
- benchmark_config: BenchmarkConfig,
176
+ request: EvaluateRowsRequest,
172
177
  ) -> EvaluateResponse:
173
- candidate = benchmark_config.eval_candidate
178
+ candidate = request.benchmark_config.eval_candidate
174
179
  # Agent evaluation removed
175
180
  if candidate.type == "model":
176
- generations = await self._run_model_generation(input_rows, benchmark_config)
181
+ generations = await self._run_model_generation(request.input_rows, request)
177
182
  else:
178
183
  raise ValueError(f"Invalid candidate type: {candidate.type}")
179
184
 
180
185
  # scoring with generated_answer
181
186
  score_input_rows = [
182
- input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
187
+ input_r | generated_r for input_r, generated_r in zip(request.input_rows, generations, strict=False)
183
188
  ]
184
189
 
185
- if benchmark_config.scoring_params is not None:
190
+ if request.benchmark_config.scoring_params is not None:
186
191
  scoring_functions_dict = {
187
- scoring_fn_id: benchmark_config.scoring_params.get(scoring_fn_id, None)
188
- for scoring_fn_id in scoring_functions
192
+ scoring_fn_id: request.benchmark_config.scoring_params.get(scoring_fn_id, None)
193
+ for scoring_fn_id in request.scoring_functions
189
194
  }
190
195
  else:
191
- scoring_functions_dict = dict.fromkeys(scoring_functions)
196
+ scoring_functions_dict = dict.fromkeys(request.scoring_functions)
192
197
 
193
- score_response = await self.scoring_api.score(
194
- input_rows=score_input_rows, scoring_functions=scoring_functions_dict
198
+ score_request = ScoreRequest(
199
+ input_rows=score_input_rows,
200
+ scoring_functions=scoring_functions_dict,
195
201
  )
202
+ score_response = await self.scoring_api.score(score_request)
196
203
 
197
204
  return EvaluateResponse(generations=generations, scores=score_response.results)
198
205
 
199
- async def job_status(self, benchmark_id: str, job_id: str) -> Job:
200
- if job_id in self.jobs:
201
- return Job(job_id=job_id, status=JobStatus.completed)
206
+ async def job_status(self, request: JobStatusRequest) -> Job:
207
+ if request.job_id in self.jobs:
208
+ return Job(job_id=request.job_id, status=JobStatus.completed)
202
209
 
203
- raise ValueError(f"Job {job_id} not found")
210
+ raise ValueError(f"Job {request.job_id} not found")
204
211
 
205
- async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
212
+ async def job_cancel(self, request: JobCancelRequest) -> None:
206
213
  raise NotImplementedError("Job cancel is not implemented yet")
207
214
 
208
- async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
209
- job = await self.job_status(benchmark_id, job_id)
215
+ async def job_result(self, request: JobResultRequest) -> EvaluateResponse:
216
+ job_status_request = JobStatusRequest(benchmark_id=request.benchmark_id, job_id=request.job_id)
217
+ job = await self.job_status(job_status_request)
210
218
  status = job.status
211
219
  if not status or status != JobStatus.completed:
212
220
  raise ValueError(f"Job is not completed, Status: {status.value}")
213
221
 
214
- return self.jobs[job_id]
222
+ return self.jobs[request.job_id]
@@ -12,17 +12,19 @@ from llama_stack.providers.inline.post_training.huggingface.config import (
12
12
  from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
13
13
  from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
14
14
  from llama_stack_api import (
15
- AlgorithmConfig,
15
+ CancelTrainingJobRequest,
16
16
  Checkpoint,
17
17
  DatasetIO,
18
18
  Datasets,
19
- DPOAlignmentConfig,
19
+ GetTrainingJobArtifactsRequest,
20
+ GetTrainingJobStatusRequest,
20
21
  JobStatus,
21
22
  ListPostTrainingJobsResponse,
22
23
  PostTrainingJob,
23
24
  PostTrainingJobArtifactsResponse,
24
25
  PostTrainingJobStatusResponse,
25
- TrainingConfig,
26
+ PreferenceOptimizeRequest,
27
+ SupervisedFineTuneRequest,
26
28
  )
27
29
 
28
30
 
@@ -69,13 +71,7 @@ class HuggingFacePostTrainingImpl:
69
71
 
70
72
  async def supervised_fine_tune(
71
73
  self,
72
- job_uuid: str,
73
- training_config: TrainingConfig,
74
- hyperparam_search_config: dict[str, Any],
75
- logger_config: dict[str, Any],
76
- model: str,
77
- checkpoint_dir: str | None = None,
78
- algorithm_config: AlgorithmConfig | None = None,
74
+ request: SupervisedFineTuneRequest,
79
75
  ) -> PostTrainingJob:
80
76
  async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
81
77
  from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
@@ -85,17 +81,17 @@ class HuggingFacePostTrainingImpl:
85
81
  on_log_message_cb("Starting HF finetuning")
86
82
 
87
83
  recipe = HFFinetuningSingleDevice(
88
- job_uuid=job_uuid,
84
+ job_uuid=request.job_uuid,
89
85
  datasetio_api=self.datasetio_api,
90
86
  datasets_api=self.datasets_api,
91
87
  )
92
88
 
93
89
  resources_allocated, checkpoints = await recipe.train(
94
- model=model,
95
- output_dir=checkpoint_dir,
96
- job_uuid=job_uuid,
97
- lora_config=algorithm_config,
98
- config=training_config,
90
+ model=request.model,
91
+ output_dir=request.checkpoint_dir,
92
+ job_uuid=request.job_uuid,
93
+ lora_config=request.algorithm_config,
94
+ config=request.training_config,
99
95
  provider_config=self.config,
100
96
  )
101
97
 
@@ -108,17 +104,12 @@ class HuggingFacePostTrainingImpl:
108
104
  on_status_change_cb(SchedulerJobStatus.completed)
109
105
  on_log_message_cb("HF finetuning completed")
110
106
 
111
- job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
107
+ job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, request.job_uuid, handler)
112
108
  return PostTrainingJob(job_uuid=job_uuid)
113
109
 
114
110
  async def preference_optimize(
115
111
  self,
116
- job_uuid: str,
117
- finetuned_model: str,
118
- algorithm_config: DPOAlignmentConfig,
119
- training_config: TrainingConfig,
120
- hyperparam_search_config: dict[str, Any],
121
- logger_config: dict[str, Any],
112
+ request: PreferenceOptimizeRequest,
122
113
  ) -> PostTrainingJob:
123
114
  async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
124
115
  from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import (
@@ -128,17 +119,17 @@ class HuggingFacePostTrainingImpl:
128
119
  on_log_message_cb("Starting HF DPO alignment")
129
120
 
130
121
  recipe = HFDPOAlignmentSingleDevice(
131
- job_uuid=job_uuid,
122
+ job_uuid=request.job_uuid,
132
123
  datasetio_api=self.datasetio_api,
133
124
  datasets_api=self.datasets_api,
134
125
  )
135
126
 
136
127
  resources_allocated, checkpoints = await recipe.train(
137
- model=finetuned_model,
138
- output_dir=f"{self.config.dpo_output_dir}/{job_uuid}",
139
- job_uuid=job_uuid,
140
- dpo_config=algorithm_config,
141
- config=training_config,
128
+ model=request.finetuned_model,
129
+ output_dir=f"{self.config.dpo_output_dir}/{request.job_uuid}",
130
+ job_uuid=request.job_uuid,
131
+ dpo_config=request.algorithm_config,
132
+ config=request.training_config,
142
133
  provider_config=self.config,
143
134
  )
144
135
 
@@ -153,7 +144,7 @@ class HuggingFacePostTrainingImpl:
153
144
  on_status_change_cb(SchedulerJobStatus.completed)
154
145
  on_log_message_cb("HF DPO alignment completed")
155
146
 
156
- job_uuid = self._scheduler.schedule(_JOB_TYPE_DPO_TRAINING, job_uuid, handler)
147
+ job_uuid = self._scheduler.schedule(_JOB_TYPE_DPO_TRAINING, request.job_uuid, handler)
157
148
  return PostTrainingJob(job_uuid=job_uuid)
158
149
 
159
150
  @staticmethod
@@ -169,8 +160,10 @@ class HuggingFacePostTrainingImpl:
169
160
  data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
170
161
  return data[0] if data else None
171
162
 
172
- async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
173
- job = self._scheduler.get_job(job_uuid)
163
+ async def get_training_job_status(
164
+ self, request: GetTrainingJobStatusRequest
165
+ ) -> PostTrainingJobStatusResponse | None:
166
+ job = self._scheduler.get_job(request.job_uuid)
174
167
 
175
168
  match job.status:
176
169
  # TODO: Add support for other statuses to API
@@ -186,7 +179,7 @@ class HuggingFacePostTrainingImpl:
186
179
  raise NotImplementedError()
187
180
 
188
181
  return PostTrainingJobStatusResponse(
189
- job_uuid=job_uuid,
182
+ job_uuid=request.job_uuid,
190
183
  status=status,
191
184
  scheduled_at=job.scheduled_at,
192
185
  started_at=job.started_at,
@@ -195,12 +188,14 @@ class HuggingFacePostTrainingImpl:
195
188
  resources_allocated=self._get_resources_allocated(job),
196
189
  )
197
190
 
198
- async def cancel_training_job(self, job_uuid: str) -> None:
199
- self._scheduler.cancel(job_uuid)
191
+ async def cancel_training_job(self, request: CancelTrainingJobRequest) -> None:
192
+ self._scheduler.cancel(request.job_uuid)
200
193
 
201
- async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
202
- job = self._scheduler.get_job(job_uuid)
203
- return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
194
+ async def get_training_job_artifacts(
195
+ self, request: GetTrainingJobArtifactsRequest
196
+ ) -> PostTrainingJobArtifactsResponse | None:
197
+ job = self._scheduler.get_job(request.job_uuid)
198
+ return PostTrainingJobArtifactsResponse(job_uuid=request.job_uuid, checkpoints=self._get_checkpoints(job))
204
199
 
205
200
  async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
206
201
  return ListPostTrainingJobsResponse(
@@ -16,7 +16,7 @@ import torch
16
16
  from datasets import Dataset
17
17
  from transformers import AutoConfig, AutoModelForCausalLM
18
18
 
19
- from llama_stack_api import Checkpoint, DatasetIO, TrainingConfig
19
+ from llama_stack_api import Checkpoint, DatasetIO, IterRowsRequest, TrainingConfig
20
20
 
21
21
  if TYPE_CHECKING:
22
22
  from transformers import PretrainedConfig
@@ -135,10 +135,7 @@ def setup_torch_device(device_str: str) -> torch.device:
135
135
  async def load_rows_from_dataset(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
136
136
  """Load dataset from llama stack dataset provider"""
137
137
  try:
138
- all_rows = await datasetio_api.iterrows(
139
- dataset_id=dataset_id,
140
- limit=-1,
141
- )
138
+ all_rows = await datasetio_api.iterrows(IterRowsRequest(dataset_id=dataset_id, limit=-1))
142
139
  if not isinstance(all_rows.data, list):
143
140
  raise RuntimeError("Expected dataset data to be a list")
144
141
  return all_rows.data
@@ -22,7 +22,6 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b
22
22
  from torchtune.modules.transforms import Transform
23
23
 
24
24
  from llama_stack.models.llama.sku_list import resolve_model
25
- from llama_stack.models.llama.sku_types import Model
26
25
  from llama_stack_api import DatasetFormat
27
26
 
28
27
  BuildLoraModelCallable = Callable[..., torch.nn.Module]
@@ -54,18 +53,17 @@ DATA_FORMATS: dict[str, Transform] = {
54
53
  }
55
54
 
56
55
 
57
- def _validate_model_id(model_id: str) -> Model:
56
+ def _validate_model_id(model_id: str) -> str:
58
57
  model = resolve_model(model_id)
59
58
  if model is None or model.core_model_id.value not in MODEL_CONFIGS:
60
59
  raise ValueError(f"Model {model_id} is not supported.")
61
- return model
60
+ return model.core_model_id.value
62
61
 
63
62
 
64
63
  async def get_model_definition(
65
64
  model_id: str,
66
65
  ) -> BuildLoraModelCallable:
67
- model = _validate_model_id(model_id)
68
- model_config = MODEL_CONFIGS[model.core_model_id.value]
66
+ model_config = MODEL_CONFIGS[_validate_model_id(model_id)]
69
67
  if not hasattr(model_config, "model_definition"):
70
68
  raise ValueError(f"Model {model_id} does not have model definition.")
71
69
  return model_config.model_definition
@@ -74,8 +72,7 @@ async def get_model_definition(
74
72
  async def get_tokenizer_type(
75
73
  model_id: str,
76
74
  ) -> BuildTokenizerCallable:
77
- model = _validate_model_id(model_id)
78
- model_config = MODEL_CONFIGS[model.core_model_id.value]
75
+ model_config = MODEL_CONFIGS[_validate_model_id(model_id)]
79
76
  if not hasattr(model_config, "tokenizer_type"):
80
77
  raise ValueError(f"Model {model_id} does not have tokenizer_type.")
81
78
  return model_config.tokenizer_type
@@ -88,8 +85,7 @@ async def get_checkpointer_model_type(
88
85
  checkpointer model type is used in checkpointer for some special treatment on some specific model types
89
86
  For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041)
90
87
  """
91
- model = _validate_model_id(model_id)
92
- model_config = MODEL_CONFIGS[model.core_model_id.value]
88
+ model_config = MODEL_CONFIGS[_validate_model_id(model_id)]
93
89
  if not hasattr(model_config, "checkpoint_type"):
94
90
  raise ValueError(f"Model {model_id} does not have checkpoint_type.")
95
91
  return model_config.checkpoint_type
@@ -12,18 +12,20 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
12
12
  from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
13
13
  from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
14
14
  from llama_stack_api import (
15
- AlgorithmConfig,
15
+ CancelTrainingJobRequest,
16
16
  Checkpoint,
17
17
  DatasetIO,
18
18
  Datasets,
19
- DPOAlignmentConfig,
19
+ GetTrainingJobArtifactsRequest,
20
+ GetTrainingJobStatusRequest,
20
21
  JobStatus,
21
22
  ListPostTrainingJobsResponse,
22
23
  LoraFinetuningConfig,
23
24
  PostTrainingJob,
24
25
  PostTrainingJobArtifactsResponse,
25
26
  PostTrainingJobStatusResponse,
26
- TrainingConfig,
27
+ PreferenceOptimizeRequest,
28
+ SupervisedFineTuneRequest,
27
29
  )
28
30
 
29
31
 
@@ -69,15 +71,9 @@ class TorchtunePostTrainingImpl:
69
71
 
70
72
  async def supervised_fine_tune(
71
73
  self,
72
- job_uuid: str,
73
- training_config: TrainingConfig,
74
- hyperparam_search_config: dict[str, Any],
75
- logger_config: dict[str, Any],
76
- model: str,
77
- checkpoint_dir: str | None,
78
- algorithm_config: AlgorithmConfig | None,
74
+ request: SupervisedFineTuneRequest,
79
75
  ) -> PostTrainingJob:
80
- if isinstance(algorithm_config, LoraFinetuningConfig):
76
+ if isinstance(request.algorithm_config, LoraFinetuningConfig):
81
77
 
82
78
  async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
83
79
  from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
@@ -88,13 +84,13 @@ class TorchtunePostTrainingImpl:
88
84
 
89
85
  recipe = LoraFinetuningSingleDevice(
90
86
  self.config,
91
- job_uuid,
92
- training_config,
93
- hyperparam_search_config,
94
- logger_config,
95
- model,
96
- checkpoint_dir,
97
- algorithm_config,
87
+ request.job_uuid,
88
+ request.training_config,
89
+ request.hyperparam_search_config,
90
+ request.logger_config,
91
+ request.model,
92
+ request.checkpoint_dir,
93
+ request.algorithm_config,
98
94
  self.datasetio_api,
99
95
  self.datasets_api,
100
96
  )
@@ -112,17 +108,12 @@ class TorchtunePostTrainingImpl:
112
108
  else:
113
109
  raise NotImplementedError()
114
110
 
115
- job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
111
+ job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, request.job_uuid, handler)
116
112
  return PostTrainingJob(job_uuid=job_uuid)
117
113
 
118
114
  async def preference_optimize(
119
115
  self,
120
- job_uuid: str,
121
- finetuned_model: str,
122
- algorithm_config: DPOAlignmentConfig,
123
- training_config: TrainingConfig,
124
- hyperparam_search_config: dict[str, Any],
125
- logger_config: dict[str, Any],
116
+ request: PreferenceOptimizeRequest,
126
117
  ) -> PostTrainingJob:
127
118
  raise NotImplementedError()
128
119
 
@@ -144,8 +135,10 @@ class TorchtunePostTrainingImpl:
144
135
  data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
145
136
  return data[0] if data else None
146
137
 
147
- async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
148
- job = self._scheduler.get_job(job_uuid)
138
+ async def get_training_job_status(
139
+ self, request: GetTrainingJobStatusRequest
140
+ ) -> PostTrainingJobStatusResponse | None:
141
+ job = self._scheduler.get_job(request.job_uuid)
149
142
 
150
143
  match job.status:
151
144
  # TODO: Add support for other statuses to API
@@ -161,7 +154,7 @@ class TorchtunePostTrainingImpl:
161
154
  raise NotImplementedError()
162
155
 
163
156
  return PostTrainingJobStatusResponse(
164
- job_uuid=job_uuid,
157
+ job_uuid=request.job_uuid,
165
158
  status=status,
166
159
  scheduled_at=job.scheduled_at,
167
160
  started_at=job.started_at,
@@ -170,9 +163,11 @@ class TorchtunePostTrainingImpl:
170
163
  resources_allocated=self._get_resources_allocated(job),
171
164
  )
172
165
 
173
- async def cancel_training_job(self, job_uuid: str) -> None:
174
- self._scheduler.cancel(job_uuid)
166
+ async def cancel_training_job(self, request: CancelTrainingJobRequest) -> None:
167
+ self._scheduler.cancel(request.job_uuid)
175
168
 
176
- async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
177
- job = self._scheduler.get_job(job_uuid)
178
- return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))
169
+ async def get_training_job_artifacts(
170
+ self, request: GetTrainingJobArtifactsRequest
171
+ ) -> PostTrainingJobArtifactsResponse | None:
172
+ job = self._scheduler.get_job(request.job_uuid)
173
+ return PostTrainingJobArtifactsResponse(job_uuid=request.job_uuid, checkpoints=self._get_checkpoints(job))
@@ -50,6 +50,7 @@ from llama_stack_api import (
50
50
  DataConfig,
51
51
  DatasetIO,
52
52
  Datasets,
53
+ IterRowsRequest,
53
54
  LoraFinetuningConfig,
54
55
  OptimizerConfig,
55
56
  PostTrainingMetric,
@@ -334,10 +335,7 @@ class LoraFinetuningSingleDevice:
334
335
  batch_size: int,
335
336
  ) -> tuple[DistributedSampler, DataLoader]:
336
337
  async def fetch_rows(dataset_id: str):
337
- return await self.datasetio_api.iterrows(
338
- dataset_id=dataset_id,
339
- limit=-1,
340
- )
338
+ return await self.datasetio_api.iterrows(IterRowsRequest(dataset_id=dataset_id, limit=-1))
341
339
 
342
340
  all_rows = await fetch_rows(dataset_id)
343
341
  rows = all_rows.data
@@ -5,7 +5,7 @@
5
5
  # the root directory of this source tree.
6
6
 
7
7
  import uuid
8
- from typing import TYPE_CHECKING, Any
8
+ from typing import TYPE_CHECKING
9
9
 
10
10
  if TYPE_CHECKING:
11
11
  from codeshield.cs import CodeShieldScanResult
@@ -15,9 +15,11 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
15
15
  interleaved_content_as_str,
16
16
  )
17
17
  from llama_stack_api import (
18
+ GetShieldRequest,
18
19
  ModerationObject,
19
20
  ModerationObjectResults,
20
- OpenAIMessageParam,
21
+ RunModerationRequest,
22
+ RunShieldRequest,
21
23
  RunShieldResponse,
22
24
  Safety,
23
25
  SafetyViolation,
@@ -51,19 +53,14 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
51
53
  f"Unsupported Code Scanner ID: {shield.provider_resource_id}. Allowed IDs: {ALLOWED_CODE_SCANNER_MODEL_IDS}"
52
54
  )
53
55
 
54
- async def run_shield(
55
- self,
56
- shield_id: str,
57
- messages: list[OpenAIMessageParam],
58
- params: dict[str, Any] = None,
59
- ) -> RunShieldResponse:
60
- shield = await self.shield_store.get_shield(shield_id)
56
+ async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
57
+ shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
61
58
  if not shield:
62
- raise ValueError(f"Shield {shield_id} not found")
59
+ raise ValueError(f"Shield {request.shield_id} not found")
63
60
 
64
61
  from codeshield.cs import CodeShield
65
62
 
66
- text = "\n".join([interleaved_content_as_str(m.content) for m in messages])
63
+ text = "\n".join([interleaved_content_as_str(m.content) for m in request.messages])
67
64
  log.info(f"Running CodeScannerShield on {text[50:]}")
68
65
  result = await CodeShield.scan_code(text)
69
66
 
@@ -102,11 +99,11 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
102
99
  metadata=metadata,
103
100
  )
104
101
 
105
- async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
106
- if model is None:
102
+ async def run_moderation(self, request: RunModerationRequest) -> ModerationObject:
103
+ if request.model is None:
107
104
  raise ValueError("Code scanner moderation requires a model identifier.")
108
105
 
109
- inputs = input if isinstance(input, list) else [input]
106
+ inputs = request.input if isinstance(request.input, list) else [request.input]
110
107
  results = []
111
108
 
112
109
  from codeshield.cs import CodeShield
@@ -129,4 +126,4 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
129
126
  )
130
127
  results.append(moderation_result)
131
128
 
132
- return ModerationObject(id=str(uuid.uuid4()), model=model, results=results)
129
+ return ModerationObject(id=str(uuid.uuid4()), model=request.model, results=results)