llama-stack 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/cli/stack/_list_deps.py +11 -7
- llama_stack/cli/stack/run.py +3 -25
- llama_stack/core/access_control/datatypes.py +78 -0
- llama_stack/core/configure.py +2 -2
- llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
- llama_stack/core/connectors/connectors.py +162 -0
- llama_stack/core/conversations/conversations.py +61 -58
- llama_stack/core/datatypes.py +54 -8
- llama_stack/core/library_client.py +60 -13
- llama_stack/core/prompts/prompts.py +43 -42
- llama_stack/core/routers/datasets.py +20 -17
- llama_stack/core/routers/eval_scoring.py +143 -53
- llama_stack/core/routers/inference.py +20 -9
- llama_stack/core/routers/safety.py +30 -42
- llama_stack/core/routers/vector_io.py +15 -7
- llama_stack/core/routing_tables/models.py +42 -3
- llama_stack/core/routing_tables/scoring_functions.py +19 -19
- llama_stack/core/routing_tables/shields.py +20 -17
- llama_stack/core/routing_tables/vector_stores.py +8 -5
- llama_stack/core/server/auth.py +192 -17
- llama_stack/core/server/fastapi_router_registry.py +40 -5
- llama_stack/core/server/server.py +24 -5
- llama_stack/core/stack.py +54 -10
- llama_stack/core/storage/datatypes.py +9 -0
- llama_stack/core/store/registry.py +1 -1
- llama_stack/core/utils/exec.py +2 -2
- llama_stack/core/utils/type_inspection.py +16 -2
- llama_stack/distributions/dell/config.yaml +4 -1
- llama_stack/distributions/dell/run-with-safety.yaml +4 -1
- llama_stack/distributions/nvidia/config.yaml +4 -1
- llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
- llama_stack/distributions/oci/config.yaml +4 -1
- llama_stack/distributions/open-benchmark/config.yaml +9 -1
- llama_stack/distributions/postgres-demo/config.yaml +1 -1
- llama_stack/distributions/starter/build.yaml +62 -0
- llama_stack/distributions/starter/config.yaml +22 -3
- llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/starter/starter.py +13 -1
- llama_stack/distributions/starter-gpu/build.yaml +62 -0
- llama_stack/distributions/starter-gpu/config.yaml +22 -3
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/template.py +10 -2
- llama_stack/distributions/watsonx/config.yaml +4 -1
- llama_stack/log.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +53 -51
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
- llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
- llama_stack/providers/inline/batches/reference/batches.py +2 -1
- llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
- llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
- llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
- llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
- llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
- llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
- llama_stack/providers/registry/agents.py +1 -0
- llama_stack/providers/registry/inference.py +1 -9
- llama_stack/providers/registry/vector_io.py +136 -16
- llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
- llama_stack/providers/remote/files/s3/config.py +5 -3
- llama_stack/providers/remote/files/s3/files.py +2 -2
- llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
- llama_stack/providers/remote/inference/openai/openai.py +2 -0
- llama_stack/providers/remote/inference/together/together.py +4 -0
- llama_stack/providers/remote/inference/vertexai/config.py +3 -3
- llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
- llama_stack/providers/remote/inference/vllm/config.py +37 -18
- llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
- llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
- llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
- llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
- llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
- llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
- llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
- llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
- llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
- llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
- llama_stack/providers/remote/vector_io/oci/config.py +41 -0
- llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
- llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
- llama_stack/providers/utils/bedrock/client.py +3 -3
- llama_stack/providers/utils/bedrock/config.py +7 -7
- llama_stack/providers/utils/inference/__init__.py +0 -25
- llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
- llama_stack/providers/utils/inference/http_client.py +239 -0
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
- llama_stack/providers/utils/inference/model_registry.py +148 -2
- llama_stack/providers/utils/inference/openai_compat.py +1 -158
- llama_stack/providers/utils/inference/openai_mixin.py +42 -2
- llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
- llama_stack/providers/utils/memory/vector_store.py +46 -19
- llama_stack/providers/utils/responses/responses_store.py +7 -7
- llama_stack/providers/utils/safety.py +114 -0
- llama_stack/providers/utils/tools/mcp.py +44 -3
- llama_stack/testing/api_recorder.py +9 -3
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/RECORD +115 -148
- llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
- llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
- llama_stack/models/llama/hadamard_utils.py +0 -88
- llama_stack/models/llama/llama3/args.py +0 -74
- llama_stack/models/llama/llama3/dog.jpg +0 -0
- llama_stack/models/llama/llama3/generation.py +0 -378
- llama_stack/models/llama/llama3/model.py +0 -304
- llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
- llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
- llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
- llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
- llama_stack/models/llama/llama3/pasta.jpeg +0 -0
- llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama3/quantization/loader.py +0 -316
- llama_stack/models/llama/llama3_1/__init__.py +0 -12
- llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
- llama_stack/models/llama/llama3_1/prompts.py +0 -258
- llama_stack/models/llama/llama3_2/__init__.py +0 -5
- llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
- llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
- llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
- llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
- llama_stack/models/llama/llama3_3/__init__.py +0 -5
- llama_stack/models/llama/llama3_3/prompts.py +0 -259
- llama_stack/models/llama/llama4/args.py +0 -107
- llama_stack/models/llama/llama4/ffn.py +0 -58
- llama_stack/models/llama/llama4/moe.py +0 -214
- llama_stack/models/llama/llama4/preprocess.py +0 -435
- llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama4/quantization/loader.py +0 -226
- llama_stack/models/llama/llama4/vision/__init__.py +0 -5
- llama_stack/models/llama/llama4/vision/embedding.py +0 -210
- llama_stack/models/llama/llama4/vision/encoder.py +0 -412
- llama_stack/models/llama/quantize_impls.py +0 -316
- llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
- llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
- llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
- llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
- llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
- llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -12,6 +12,7 @@ from pydantic import BaseModel
|
|
|
12
12
|
|
|
13
13
|
from llama_stack_api import (
|
|
14
14
|
OpenAIChatCompletionToolCall,
|
|
15
|
+
OpenAIFinishReason,
|
|
15
16
|
OpenAIMessageParam,
|
|
16
17
|
OpenAIResponseFormatParam,
|
|
17
18
|
OpenAIResponseInput,
|
|
@@ -52,7 +53,7 @@ class ChatCompletionResult:
|
|
|
52
53
|
tool_calls: dict[int, OpenAIChatCompletionToolCall]
|
|
53
54
|
created: int
|
|
54
55
|
model: str
|
|
55
|
-
finish_reason:
|
|
56
|
+
finish_reason: OpenAIFinishReason
|
|
56
57
|
message_item_id: str # For streaming events
|
|
57
58
|
tool_call_item_ids: dict[int, str] # For streaming events
|
|
58
59
|
content_part_emitted: bool # Tracking state
|
|
@@ -53,6 +53,7 @@ from llama_stack_api import (
|
|
|
53
53
|
OpenAIToolMessageParam,
|
|
54
54
|
OpenAIUserMessageParam,
|
|
55
55
|
ResponseGuardrailSpec,
|
|
56
|
+
RunModerationRequest,
|
|
56
57
|
Safety,
|
|
57
58
|
)
|
|
58
59
|
|
|
@@ -468,7 +469,9 @@ async def run_guardrails(safety_api: Safety | None, messages: str, guardrail_ids
|
|
|
468
469
|
else:
|
|
469
470
|
raise ValueError(f"No shield found with identifier '{guardrail_id}'")
|
|
470
471
|
|
|
471
|
-
guardrail_tasks = [
|
|
472
|
+
guardrail_tasks = [
|
|
473
|
+
safety_api.run_moderation(RunModerationRequest(input=messages, model=model_id)) for model_id in model_ids
|
|
474
|
+
]
|
|
472
475
|
responses = await asyncio.gather(*guardrail_tasks)
|
|
473
476
|
|
|
474
477
|
for response in responses:
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
import asyncio
|
|
8
8
|
|
|
9
9
|
from llama_stack.log import get_logger
|
|
10
|
-
from llama_stack_api import OpenAIMessageParam, Safety, SafetyViolation, ViolationLevel
|
|
10
|
+
from llama_stack_api import OpenAIMessageParam, RunShieldRequest, Safety, SafetyViolation, ViolationLevel
|
|
11
11
|
|
|
12
12
|
log = get_logger(name=__name__, category="agents::meta_reference")
|
|
13
13
|
|
|
@@ -32,7 +32,7 @@ class ShieldRunnerMixin:
|
|
|
32
32
|
async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None:
|
|
33
33
|
responses = await asyncio.gather(
|
|
34
34
|
*[
|
|
35
|
-
self.safety_api.run_shield(shield_id=identifier, messages=messages
|
|
35
|
+
self.safety_api.run_shield(RunShieldRequest(shield_id=identifier, messages=messages))
|
|
36
36
|
for identifier in identifiers
|
|
37
37
|
]
|
|
38
38
|
)
|
|
@@ -23,6 +23,7 @@ from llama_stack_api import (
|
|
|
23
23
|
BatchObject,
|
|
24
24
|
ConflictError,
|
|
25
25
|
Files,
|
|
26
|
+
GetModelRequest,
|
|
26
27
|
Inference,
|
|
27
28
|
ListBatchesResponse,
|
|
28
29
|
Models,
|
|
@@ -485,7 +486,7 @@ class ReferenceBatchesImpl(Batches):
|
|
|
485
486
|
|
|
486
487
|
if "model" in request_body and isinstance(request_body["model"], str):
|
|
487
488
|
try:
|
|
488
|
-
await self.models_api.get_model(request_body["model"])
|
|
489
|
+
await self.models_api.get_model(GetModelRequest(model_id=request_body["model"]))
|
|
489
490
|
except Exception:
|
|
490
491
|
errors.append(
|
|
491
492
|
BatchError(
|
|
@@ -13,19 +13,25 @@ from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
|
|
13
13
|
from llama_stack_api import (
|
|
14
14
|
Agents,
|
|
15
15
|
Benchmark,
|
|
16
|
-
BenchmarkConfig,
|
|
17
16
|
BenchmarksProtocolPrivate,
|
|
18
17
|
DatasetIO,
|
|
19
18
|
Datasets,
|
|
20
19
|
Eval,
|
|
21
20
|
EvaluateResponse,
|
|
21
|
+
EvaluateRowsRequest,
|
|
22
22
|
Inference,
|
|
23
|
+
IterRowsRequest,
|
|
23
24
|
Job,
|
|
25
|
+
JobCancelRequest,
|
|
26
|
+
JobResultRequest,
|
|
24
27
|
JobStatus,
|
|
28
|
+
JobStatusRequest,
|
|
25
29
|
OpenAIChatCompletionRequestWithExtraBody,
|
|
26
30
|
OpenAICompletionRequestWithExtraBody,
|
|
27
31
|
OpenAISystemMessageParam,
|
|
28
32
|
OpenAIUserMessageParam,
|
|
33
|
+
RunEvalRequest,
|
|
34
|
+
ScoreRequest,
|
|
29
35
|
Scoring,
|
|
30
36
|
)
|
|
31
37
|
|
|
@@ -90,10 +96,9 @@ class MetaReferenceEvalImpl(
|
|
|
90
96
|
|
|
91
97
|
async def run_eval(
|
|
92
98
|
self,
|
|
93
|
-
|
|
94
|
-
benchmark_config: BenchmarkConfig,
|
|
99
|
+
request: RunEvalRequest,
|
|
95
100
|
) -> Job:
|
|
96
|
-
task_def = self.benchmarks[benchmark_id]
|
|
101
|
+
task_def = self.benchmarks[request.benchmark_id]
|
|
97
102
|
dataset_id = task_def.dataset_id
|
|
98
103
|
scoring_functions = task_def.scoring_functions
|
|
99
104
|
|
|
@@ -101,15 +106,18 @@ class MetaReferenceEvalImpl(
|
|
|
101
106
|
# dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
|
102
107
|
|
|
103
108
|
all_rows = await self.datasetio_api.iterrows(
|
|
104
|
-
|
|
105
|
-
|
|
109
|
+
IterRowsRequest(
|
|
110
|
+
dataset_id=dataset_id,
|
|
111
|
+
limit=(-1 if request.benchmark_config.num_examples is None else request.benchmark_config.num_examples),
|
|
112
|
+
)
|
|
106
113
|
)
|
|
107
|
-
|
|
108
|
-
benchmark_id=benchmark_id,
|
|
114
|
+
eval_rows_request = EvaluateRowsRequest(
|
|
115
|
+
benchmark_id=request.benchmark_id,
|
|
109
116
|
input_rows=all_rows.data,
|
|
110
117
|
scoring_functions=scoring_functions,
|
|
111
|
-
benchmark_config=benchmark_config,
|
|
118
|
+
benchmark_config=request.benchmark_config,
|
|
112
119
|
)
|
|
120
|
+
res = await self.evaluate_rows(eval_rows_request)
|
|
113
121
|
|
|
114
122
|
# TODO: currently needs to wait for generation before returning
|
|
115
123
|
# need job scheduler queue (ray/celery) w/ jobs api
|
|
@@ -118,9 +126,9 @@ class MetaReferenceEvalImpl(
|
|
|
118
126
|
return Job(job_id=job_id, status=JobStatus.completed)
|
|
119
127
|
|
|
120
128
|
async def _run_model_generation(
|
|
121
|
-
self, input_rows: list[dict[str, Any]],
|
|
129
|
+
self, input_rows: list[dict[str, Any]], request: EvaluateRowsRequest
|
|
122
130
|
) -> list[dict[str, Any]]:
|
|
123
|
-
candidate = benchmark_config.eval_candidate
|
|
131
|
+
candidate = request.benchmark_config.eval_candidate
|
|
124
132
|
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
|
|
125
133
|
sampling_params = {"max_tokens": candidate.sampling_params.max_tokens}
|
|
126
134
|
|
|
@@ -165,50 +173,50 @@ class MetaReferenceEvalImpl(
|
|
|
165
173
|
|
|
166
174
|
async def evaluate_rows(
|
|
167
175
|
self,
|
|
168
|
-
|
|
169
|
-
input_rows: list[dict[str, Any]],
|
|
170
|
-
scoring_functions: list[str],
|
|
171
|
-
benchmark_config: BenchmarkConfig,
|
|
176
|
+
request: EvaluateRowsRequest,
|
|
172
177
|
) -> EvaluateResponse:
|
|
173
|
-
candidate = benchmark_config.eval_candidate
|
|
178
|
+
candidate = request.benchmark_config.eval_candidate
|
|
174
179
|
# Agent evaluation removed
|
|
175
180
|
if candidate.type == "model":
|
|
176
|
-
generations = await self._run_model_generation(input_rows,
|
|
181
|
+
generations = await self._run_model_generation(request.input_rows, request)
|
|
177
182
|
else:
|
|
178
183
|
raise ValueError(f"Invalid candidate type: {candidate.type}")
|
|
179
184
|
|
|
180
185
|
# scoring with generated_answer
|
|
181
186
|
score_input_rows = [
|
|
182
|
-
input_r | generated_r for input_r, generated_r in zip(input_rows, generations, strict=False)
|
|
187
|
+
input_r | generated_r for input_r, generated_r in zip(request.input_rows, generations, strict=False)
|
|
183
188
|
]
|
|
184
189
|
|
|
185
|
-
if benchmark_config.scoring_params is not None:
|
|
190
|
+
if request.benchmark_config.scoring_params is not None:
|
|
186
191
|
scoring_functions_dict = {
|
|
187
|
-
scoring_fn_id: benchmark_config.scoring_params.get(scoring_fn_id, None)
|
|
188
|
-
for scoring_fn_id in scoring_functions
|
|
192
|
+
scoring_fn_id: request.benchmark_config.scoring_params.get(scoring_fn_id, None)
|
|
193
|
+
for scoring_fn_id in request.scoring_functions
|
|
189
194
|
}
|
|
190
195
|
else:
|
|
191
|
-
scoring_functions_dict = dict.fromkeys(scoring_functions)
|
|
196
|
+
scoring_functions_dict = dict.fromkeys(request.scoring_functions)
|
|
192
197
|
|
|
193
|
-
|
|
194
|
-
input_rows=score_input_rows,
|
|
198
|
+
score_request = ScoreRequest(
|
|
199
|
+
input_rows=score_input_rows,
|
|
200
|
+
scoring_functions=scoring_functions_dict,
|
|
195
201
|
)
|
|
202
|
+
score_response = await self.scoring_api.score(score_request)
|
|
196
203
|
|
|
197
204
|
return EvaluateResponse(generations=generations, scores=score_response.results)
|
|
198
205
|
|
|
199
|
-
async def job_status(self,
|
|
200
|
-
if job_id in self.jobs:
|
|
201
|
-
return Job(job_id=job_id, status=JobStatus.completed)
|
|
206
|
+
async def job_status(self, request: JobStatusRequest) -> Job:
|
|
207
|
+
if request.job_id in self.jobs:
|
|
208
|
+
return Job(job_id=request.job_id, status=JobStatus.completed)
|
|
202
209
|
|
|
203
|
-
raise ValueError(f"Job {job_id} not found")
|
|
210
|
+
raise ValueError(f"Job {request.job_id} not found")
|
|
204
211
|
|
|
205
|
-
async def job_cancel(self,
|
|
212
|
+
async def job_cancel(self, request: JobCancelRequest) -> None:
|
|
206
213
|
raise NotImplementedError("Job cancel is not implemented yet")
|
|
207
214
|
|
|
208
|
-
async def job_result(self,
|
|
209
|
-
|
|
215
|
+
async def job_result(self, request: JobResultRequest) -> EvaluateResponse:
|
|
216
|
+
job_status_request = JobStatusRequest(benchmark_id=request.benchmark_id, job_id=request.job_id)
|
|
217
|
+
job = await self.job_status(job_status_request)
|
|
210
218
|
status = job.status
|
|
211
219
|
if not status or status != JobStatus.completed:
|
|
212
220
|
raise ValueError(f"Job is not completed, Status: {status.value}")
|
|
213
221
|
|
|
214
|
-
return self.jobs[job_id]
|
|
222
|
+
return self.jobs[request.job_id]
|
|
@@ -12,17 +12,19 @@ from llama_stack.providers.inline.post_training.huggingface.config import (
|
|
|
12
12
|
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
|
13
13
|
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
|
14
14
|
from llama_stack_api import (
|
|
15
|
-
|
|
15
|
+
CancelTrainingJobRequest,
|
|
16
16
|
Checkpoint,
|
|
17
17
|
DatasetIO,
|
|
18
18
|
Datasets,
|
|
19
|
-
|
|
19
|
+
GetTrainingJobArtifactsRequest,
|
|
20
|
+
GetTrainingJobStatusRequest,
|
|
20
21
|
JobStatus,
|
|
21
22
|
ListPostTrainingJobsResponse,
|
|
22
23
|
PostTrainingJob,
|
|
23
24
|
PostTrainingJobArtifactsResponse,
|
|
24
25
|
PostTrainingJobStatusResponse,
|
|
25
|
-
|
|
26
|
+
PreferenceOptimizeRequest,
|
|
27
|
+
SupervisedFineTuneRequest,
|
|
26
28
|
)
|
|
27
29
|
|
|
28
30
|
|
|
@@ -69,13 +71,7 @@ class HuggingFacePostTrainingImpl:
|
|
|
69
71
|
|
|
70
72
|
async def supervised_fine_tune(
|
|
71
73
|
self,
|
|
72
|
-
|
|
73
|
-
training_config: TrainingConfig,
|
|
74
|
-
hyperparam_search_config: dict[str, Any],
|
|
75
|
-
logger_config: dict[str, Any],
|
|
76
|
-
model: str,
|
|
77
|
-
checkpoint_dir: str | None = None,
|
|
78
|
-
algorithm_config: AlgorithmConfig | None = None,
|
|
74
|
+
request: SupervisedFineTuneRequest,
|
|
79
75
|
) -> PostTrainingJob:
|
|
80
76
|
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
|
81
77
|
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
|
|
@@ -85,17 +81,17 @@ class HuggingFacePostTrainingImpl:
|
|
|
85
81
|
on_log_message_cb("Starting HF finetuning")
|
|
86
82
|
|
|
87
83
|
recipe = HFFinetuningSingleDevice(
|
|
88
|
-
job_uuid=job_uuid,
|
|
84
|
+
job_uuid=request.job_uuid,
|
|
89
85
|
datasetio_api=self.datasetio_api,
|
|
90
86
|
datasets_api=self.datasets_api,
|
|
91
87
|
)
|
|
92
88
|
|
|
93
89
|
resources_allocated, checkpoints = await recipe.train(
|
|
94
|
-
model=model,
|
|
95
|
-
output_dir=checkpoint_dir,
|
|
96
|
-
job_uuid=job_uuid,
|
|
97
|
-
lora_config=algorithm_config,
|
|
98
|
-
config=training_config,
|
|
90
|
+
model=request.model,
|
|
91
|
+
output_dir=request.checkpoint_dir,
|
|
92
|
+
job_uuid=request.job_uuid,
|
|
93
|
+
lora_config=request.algorithm_config,
|
|
94
|
+
config=request.training_config,
|
|
99
95
|
provider_config=self.config,
|
|
100
96
|
)
|
|
101
97
|
|
|
@@ -108,17 +104,12 @@ class HuggingFacePostTrainingImpl:
|
|
|
108
104
|
on_status_change_cb(SchedulerJobStatus.completed)
|
|
109
105
|
on_log_message_cb("HF finetuning completed")
|
|
110
106
|
|
|
111
|
-
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
|
107
|
+
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, request.job_uuid, handler)
|
|
112
108
|
return PostTrainingJob(job_uuid=job_uuid)
|
|
113
109
|
|
|
114
110
|
async def preference_optimize(
|
|
115
111
|
self,
|
|
116
|
-
|
|
117
|
-
finetuned_model: str,
|
|
118
|
-
algorithm_config: DPOAlignmentConfig,
|
|
119
|
-
training_config: TrainingConfig,
|
|
120
|
-
hyperparam_search_config: dict[str, Any],
|
|
121
|
-
logger_config: dict[str, Any],
|
|
112
|
+
request: PreferenceOptimizeRequest,
|
|
122
113
|
) -> PostTrainingJob:
|
|
123
114
|
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
|
124
115
|
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import (
|
|
@@ -128,17 +119,17 @@ class HuggingFacePostTrainingImpl:
|
|
|
128
119
|
on_log_message_cb("Starting HF DPO alignment")
|
|
129
120
|
|
|
130
121
|
recipe = HFDPOAlignmentSingleDevice(
|
|
131
|
-
job_uuid=job_uuid,
|
|
122
|
+
job_uuid=request.job_uuid,
|
|
132
123
|
datasetio_api=self.datasetio_api,
|
|
133
124
|
datasets_api=self.datasets_api,
|
|
134
125
|
)
|
|
135
126
|
|
|
136
127
|
resources_allocated, checkpoints = await recipe.train(
|
|
137
|
-
model=finetuned_model,
|
|
138
|
-
output_dir=f"{self.config.dpo_output_dir}/{job_uuid}",
|
|
139
|
-
job_uuid=job_uuid,
|
|
140
|
-
dpo_config=algorithm_config,
|
|
141
|
-
config=training_config,
|
|
128
|
+
model=request.finetuned_model,
|
|
129
|
+
output_dir=f"{self.config.dpo_output_dir}/{request.job_uuid}",
|
|
130
|
+
job_uuid=request.job_uuid,
|
|
131
|
+
dpo_config=request.algorithm_config,
|
|
132
|
+
config=request.training_config,
|
|
142
133
|
provider_config=self.config,
|
|
143
134
|
)
|
|
144
135
|
|
|
@@ -153,7 +144,7 @@ class HuggingFacePostTrainingImpl:
|
|
|
153
144
|
on_status_change_cb(SchedulerJobStatus.completed)
|
|
154
145
|
on_log_message_cb("HF DPO alignment completed")
|
|
155
146
|
|
|
156
|
-
job_uuid = self._scheduler.schedule(_JOB_TYPE_DPO_TRAINING, job_uuid, handler)
|
|
147
|
+
job_uuid = self._scheduler.schedule(_JOB_TYPE_DPO_TRAINING, request.job_uuid, handler)
|
|
157
148
|
return PostTrainingJob(job_uuid=job_uuid)
|
|
158
149
|
|
|
159
150
|
@staticmethod
|
|
@@ -169,8 +160,10 @@ class HuggingFacePostTrainingImpl:
|
|
|
169
160
|
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
|
170
161
|
return data[0] if data else None
|
|
171
162
|
|
|
172
|
-
async def get_training_job_status(
|
|
173
|
-
|
|
163
|
+
async def get_training_job_status(
|
|
164
|
+
self, request: GetTrainingJobStatusRequest
|
|
165
|
+
) -> PostTrainingJobStatusResponse | None:
|
|
166
|
+
job = self._scheduler.get_job(request.job_uuid)
|
|
174
167
|
|
|
175
168
|
match job.status:
|
|
176
169
|
# TODO: Add support for other statuses to API
|
|
@@ -186,7 +179,7 @@ class HuggingFacePostTrainingImpl:
|
|
|
186
179
|
raise NotImplementedError()
|
|
187
180
|
|
|
188
181
|
return PostTrainingJobStatusResponse(
|
|
189
|
-
job_uuid=job_uuid,
|
|
182
|
+
job_uuid=request.job_uuid,
|
|
190
183
|
status=status,
|
|
191
184
|
scheduled_at=job.scheduled_at,
|
|
192
185
|
started_at=job.started_at,
|
|
@@ -195,12 +188,14 @@ class HuggingFacePostTrainingImpl:
|
|
|
195
188
|
resources_allocated=self._get_resources_allocated(job),
|
|
196
189
|
)
|
|
197
190
|
|
|
198
|
-
async def cancel_training_job(self,
|
|
199
|
-
self._scheduler.cancel(job_uuid)
|
|
191
|
+
async def cancel_training_job(self, request: CancelTrainingJobRequest) -> None:
|
|
192
|
+
self._scheduler.cancel(request.job_uuid)
|
|
200
193
|
|
|
201
|
-
async def get_training_job_artifacts(
|
|
202
|
-
|
|
203
|
-
|
|
194
|
+
async def get_training_job_artifacts(
|
|
195
|
+
self, request: GetTrainingJobArtifactsRequest
|
|
196
|
+
) -> PostTrainingJobArtifactsResponse | None:
|
|
197
|
+
job = self._scheduler.get_job(request.job_uuid)
|
|
198
|
+
return PostTrainingJobArtifactsResponse(job_uuid=request.job_uuid, checkpoints=self._get_checkpoints(job))
|
|
204
199
|
|
|
205
200
|
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
|
206
201
|
return ListPostTrainingJobsResponse(
|
|
@@ -16,7 +16,7 @@ import torch
|
|
|
16
16
|
from datasets import Dataset
|
|
17
17
|
from transformers import AutoConfig, AutoModelForCausalLM
|
|
18
18
|
|
|
19
|
-
from llama_stack_api import Checkpoint, DatasetIO, TrainingConfig
|
|
19
|
+
from llama_stack_api import Checkpoint, DatasetIO, IterRowsRequest, TrainingConfig
|
|
20
20
|
|
|
21
21
|
if TYPE_CHECKING:
|
|
22
22
|
from transformers import PretrainedConfig
|
|
@@ -135,10 +135,7 @@ def setup_torch_device(device_str: str) -> torch.device:
|
|
|
135
135
|
async def load_rows_from_dataset(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
|
|
136
136
|
"""Load dataset from llama stack dataset provider"""
|
|
137
137
|
try:
|
|
138
|
-
all_rows = await datasetio_api.iterrows(
|
|
139
|
-
dataset_id=dataset_id,
|
|
140
|
-
limit=-1,
|
|
141
|
-
)
|
|
138
|
+
all_rows = await datasetio_api.iterrows(IterRowsRequest(dataset_id=dataset_id, limit=-1))
|
|
142
139
|
if not isinstance(all_rows.data, list):
|
|
143
140
|
raise RuntimeError("Expected dataset data to be a list")
|
|
144
141
|
return all_rows.data
|
|
@@ -22,7 +22,6 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b
|
|
|
22
22
|
from torchtune.modules.transforms import Transform
|
|
23
23
|
|
|
24
24
|
from llama_stack.models.llama.sku_list import resolve_model
|
|
25
|
-
from llama_stack.models.llama.sku_types import Model
|
|
26
25
|
from llama_stack_api import DatasetFormat
|
|
27
26
|
|
|
28
27
|
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
|
@@ -54,18 +53,17 @@ DATA_FORMATS: dict[str, Transform] = {
|
|
|
54
53
|
}
|
|
55
54
|
|
|
56
55
|
|
|
57
|
-
def _validate_model_id(model_id: str) ->
|
|
56
|
+
def _validate_model_id(model_id: str) -> str:
|
|
58
57
|
model = resolve_model(model_id)
|
|
59
58
|
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
|
|
60
59
|
raise ValueError(f"Model {model_id} is not supported.")
|
|
61
|
-
return model
|
|
60
|
+
return model.core_model_id.value
|
|
62
61
|
|
|
63
62
|
|
|
64
63
|
async def get_model_definition(
|
|
65
64
|
model_id: str,
|
|
66
65
|
) -> BuildLoraModelCallable:
|
|
67
|
-
|
|
68
|
-
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
|
66
|
+
model_config = MODEL_CONFIGS[_validate_model_id(model_id)]
|
|
69
67
|
if not hasattr(model_config, "model_definition"):
|
|
70
68
|
raise ValueError(f"Model {model_id} does not have model definition.")
|
|
71
69
|
return model_config.model_definition
|
|
@@ -74,8 +72,7 @@ async def get_model_definition(
|
|
|
74
72
|
async def get_tokenizer_type(
|
|
75
73
|
model_id: str,
|
|
76
74
|
) -> BuildTokenizerCallable:
|
|
77
|
-
|
|
78
|
-
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
|
75
|
+
model_config = MODEL_CONFIGS[_validate_model_id(model_id)]
|
|
79
76
|
if not hasattr(model_config, "tokenizer_type"):
|
|
80
77
|
raise ValueError(f"Model {model_id} does not have tokenizer_type.")
|
|
81
78
|
return model_config.tokenizer_type
|
|
@@ -88,8 +85,7 @@ async def get_checkpointer_model_type(
|
|
|
88
85
|
checkpointer model type is used in checkpointer for some special treatment on some specific model types
|
|
89
86
|
For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041)
|
|
90
87
|
"""
|
|
91
|
-
|
|
92
|
-
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
|
88
|
+
model_config = MODEL_CONFIGS[_validate_model_id(model_id)]
|
|
93
89
|
if not hasattr(model_config, "checkpoint_type"):
|
|
94
90
|
raise ValueError(f"Model {model_id} does not have checkpoint_type.")
|
|
95
91
|
return model_config.checkpoint_type
|
|
@@ -12,18 +12,20 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
|
|
12
12
|
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
|
13
13
|
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
|
14
14
|
from llama_stack_api import (
|
|
15
|
-
|
|
15
|
+
CancelTrainingJobRequest,
|
|
16
16
|
Checkpoint,
|
|
17
17
|
DatasetIO,
|
|
18
18
|
Datasets,
|
|
19
|
-
|
|
19
|
+
GetTrainingJobArtifactsRequest,
|
|
20
|
+
GetTrainingJobStatusRequest,
|
|
20
21
|
JobStatus,
|
|
21
22
|
ListPostTrainingJobsResponse,
|
|
22
23
|
LoraFinetuningConfig,
|
|
23
24
|
PostTrainingJob,
|
|
24
25
|
PostTrainingJobArtifactsResponse,
|
|
25
26
|
PostTrainingJobStatusResponse,
|
|
26
|
-
|
|
27
|
+
PreferenceOptimizeRequest,
|
|
28
|
+
SupervisedFineTuneRequest,
|
|
27
29
|
)
|
|
28
30
|
|
|
29
31
|
|
|
@@ -69,15 +71,9 @@ class TorchtunePostTrainingImpl:
|
|
|
69
71
|
|
|
70
72
|
async def supervised_fine_tune(
|
|
71
73
|
self,
|
|
72
|
-
|
|
73
|
-
training_config: TrainingConfig,
|
|
74
|
-
hyperparam_search_config: dict[str, Any],
|
|
75
|
-
logger_config: dict[str, Any],
|
|
76
|
-
model: str,
|
|
77
|
-
checkpoint_dir: str | None,
|
|
78
|
-
algorithm_config: AlgorithmConfig | None,
|
|
74
|
+
request: SupervisedFineTuneRequest,
|
|
79
75
|
) -> PostTrainingJob:
|
|
80
|
-
if isinstance(algorithm_config, LoraFinetuningConfig):
|
|
76
|
+
if isinstance(request.algorithm_config, LoraFinetuningConfig):
|
|
81
77
|
|
|
82
78
|
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
|
83
79
|
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
|
@@ -88,13 +84,13 @@ class TorchtunePostTrainingImpl:
|
|
|
88
84
|
|
|
89
85
|
recipe = LoraFinetuningSingleDevice(
|
|
90
86
|
self.config,
|
|
91
|
-
job_uuid,
|
|
92
|
-
training_config,
|
|
93
|
-
hyperparam_search_config,
|
|
94
|
-
logger_config,
|
|
95
|
-
model,
|
|
96
|
-
checkpoint_dir,
|
|
97
|
-
algorithm_config,
|
|
87
|
+
request.job_uuid,
|
|
88
|
+
request.training_config,
|
|
89
|
+
request.hyperparam_search_config,
|
|
90
|
+
request.logger_config,
|
|
91
|
+
request.model,
|
|
92
|
+
request.checkpoint_dir,
|
|
93
|
+
request.algorithm_config,
|
|
98
94
|
self.datasetio_api,
|
|
99
95
|
self.datasets_api,
|
|
100
96
|
)
|
|
@@ -112,17 +108,12 @@ class TorchtunePostTrainingImpl:
|
|
|
112
108
|
else:
|
|
113
109
|
raise NotImplementedError()
|
|
114
110
|
|
|
115
|
-
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
|
111
|
+
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, request.job_uuid, handler)
|
|
116
112
|
return PostTrainingJob(job_uuid=job_uuid)
|
|
117
113
|
|
|
118
114
|
async def preference_optimize(
|
|
119
115
|
self,
|
|
120
|
-
|
|
121
|
-
finetuned_model: str,
|
|
122
|
-
algorithm_config: DPOAlignmentConfig,
|
|
123
|
-
training_config: TrainingConfig,
|
|
124
|
-
hyperparam_search_config: dict[str, Any],
|
|
125
|
-
logger_config: dict[str, Any],
|
|
116
|
+
request: PreferenceOptimizeRequest,
|
|
126
117
|
) -> PostTrainingJob:
|
|
127
118
|
raise NotImplementedError()
|
|
128
119
|
|
|
@@ -144,8 +135,10 @@ class TorchtunePostTrainingImpl:
|
|
|
144
135
|
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
|
145
136
|
return data[0] if data else None
|
|
146
137
|
|
|
147
|
-
async def get_training_job_status(
|
|
148
|
-
|
|
138
|
+
async def get_training_job_status(
|
|
139
|
+
self, request: GetTrainingJobStatusRequest
|
|
140
|
+
) -> PostTrainingJobStatusResponse | None:
|
|
141
|
+
job = self._scheduler.get_job(request.job_uuid)
|
|
149
142
|
|
|
150
143
|
match job.status:
|
|
151
144
|
# TODO: Add support for other statuses to API
|
|
@@ -161,7 +154,7 @@ class TorchtunePostTrainingImpl:
|
|
|
161
154
|
raise NotImplementedError()
|
|
162
155
|
|
|
163
156
|
return PostTrainingJobStatusResponse(
|
|
164
|
-
job_uuid=job_uuid,
|
|
157
|
+
job_uuid=request.job_uuid,
|
|
165
158
|
status=status,
|
|
166
159
|
scheduled_at=job.scheduled_at,
|
|
167
160
|
started_at=job.started_at,
|
|
@@ -170,9 +163,11 @@ class TorchtunePostTrainingImpl:
|
|
|
170
163
|
resources_allocated=self._get_resources_allocated(job),
|
|
171
164
|
)
|
|
172
165
|
|
|
173
|
-
async def cancel_training_job(self,
|
|
174
|
-
self._scheduler.cancel(job_uuid)
|
|
166
|
+
async def cancel_training_job(self, request: CancelTrainingJobRequest) -> None:
|
|
167
|
+
self._scheduler.cancel(request.job_uuid)
|
|
175
168
|
|
|
176
|
-
async def get_training_job_artifacts(
|
|
177
|
-
|
|
178
|
-
|
|
169
|
+
async def get_training_job_artifacts(
|
|
170
|
+
self, request: GetTrainingJobArtifactsRequest
|
|
171
|
+
) -> PostTrainingJobArtifactsResponse | None:
|
|
172
|
+
job = self._scheduler.get_job(request.job_uuid)
|
|
173
|
+
return PostTrainingJobArtifactsResponse(job_uuid=request.job_uuid, checkpoints=self._get_checkpoints(job))
|
llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py
CHANGED
|
@@ -50,6 +50,7 @@ from llama_stack_api import (
|
|
|
50
50
|
DataConfig,
|
|
51
51
|
DatasetIO,
|
|
52
52
|
Datasets,
|
|
53
|
+
IterRowsRequest,
|
|
53
54
|
LoraFinetuningConfig,
|
|
54
55
|
OptimizerConfig,
|
|
55
56
|
PostTrainingMetric,
|
|
@@ -334,10 +335,7 @@ class LoraFinetuningSingleDevice:
|
|
|
334
335
|
batch_size: int,
|
|
335
336
|
) -> tuple[DistributedSampler, DataLoader]:
|
|
336
337
|
async def fetch_rows(dataset_id: str):
|
|
337
|
-
return await self.datasetio_api.iterrows(
|
|
338
|
-
dataset_id=dataset_id,
|
|
339
|
-
limit=-1,
|
|
340
|
-
)
|
|
338
|
+
return await self.datasetio_api.iterrows(IterRowsRequest(dataset_id=dataset_id, limit=-1))
|
|
341
339
|
|
|
342
340
|
all_rows = await fetch_rows(dataset_id)
|
|
343
341
|
rows = all_rows.data
|
|
@@ -5,7 +5,7 @@
|
|
|
5
5
|
# the root directory of this source tree.
|
|
6
6
|
|
|
7
7
|
import uuid
|
|
8
|
-
from typing import TYPE_CHECKING
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
9
|
|
|
10
10
|
if TYPE_CHECKING:
|
|
11
11
|
from codeshield.cs import CodeShieldScanResult
|
|
@@ -15,9 +15,11 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
|
15
15
|
interleaved_content_as_str,
|
|
16
16
|
)
|
|
17
17
|
from llama_stack_api import (
|
|
18
|
+
GetShieldRequest,
|
|
18
19
|
ModerationObject,
|
|
19
20
|
ModerationObjectResults,
|
|
20
|
-
|
|
21
|
+
RunModerationRequest,
|
|
22
|
+
RunShieldRequest,
|
|
21
23
|
RunShieldResponse,
|
|
22
24
|
Safety,
|
|
23
25
|
SafetyViolation,
|
|
@@ -51,19 +53,14 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|
|
51
53
|
f"Unsupported Code Scanner ID: {shield.provider_resource_id}. Allowed IDs: {ALLOWED_CODE_SCANNER_MODEL_IDS}"
|
|
52
54
|
)
|
|
53
55
|
|
|
54
|
-
async def run_shield(
|
|
55
|
-
self
|
|
56
|
-
shield_id: str,
|
|
57
|
-
messages: list[OpenAIMessageParam],
|
|
58
|
-
params: dict[str, Any] = None,
|
|
59
|
-
) -> RunShieldResponse:
|
|
60
|
-
shield = await self.shield_store.get_shield(shield_id)
|
|
56
|
+
async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
|
|
57
|
+
shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
|
|
61
58
|
if not shield:
|
|
62
|
-
raise ValueError(f"Shield {shield_id} not found")
|
|
59
|
+
raise ValueError(f"Shield {request.shield_id} not found")
|
|
63
60
|
|
|
64
61
|
from codeshield.cs import CodeShield
|
|
65
62
|
|
|
66
|
-
text = "\n".join([interleaved_content_as_str(m.content) for m in messages])
|
|
63
|
+
text = "\n".join([interleaved_content_as_str(m.content) for m in request.messages])
|
|
67
64
|
log.info(f"Running CodeScannerShield on {text[50:]}")
|
|
68
65
|
result = await CodeShield.scan_code(text)
|
|
69
66
|
|
|
@@ -102,11 +99,11 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|
|
102
99
|
metadata=metadata,
|
|
103
100
|
)
|
|
104
101
|
|
|
105
|
-
async def run_moderation(self,
|
|
106
|
-
if model is None:
|
|
102
|
+
async def run_moderation(self, request: RunModerationRequest) -> ModerationObject:
|
|
103
|
+
if request.model is None:
|
|
107
104
|
raise ValueError("Code scanner moderation requires a model identifier.")
|
|
108
105
|
|
|
109
|
-
inputs = input if isinstance(input, list) else [input]
|
|
106
|
+
inputs = request.input if isinstance(request.input, list) else [request.input]
|
|
110
107
|
results = []
|
|
111
108
|
|
|
112
109
|
from codeshield.cs import CodeShield
|
|
@@ -129,4 +126,4 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|
|
129
126
|
)
|
|
130
127
|
results.append(moderation_result)
|
|
131
128
|
|
|
132
|
-
return ModerationObject(id=str(uuid.uuid4()), model=model, results=results)
|
|
129
|
+
return ModerationObject(id=str(uuid.uuid4()), model=request.model, results=results)
|