llama-stack 0.4.3__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/cli/stack/_list_deps.py +11 -7
- llama_stack/cli/stack/run.py +3 -25
- llama_stack/core/access_control/datatypes.py +78 -0
- llama_stack/core/configure.py +2 -2
- {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
- llama_stack/core/connectors/connectors.py +162 -0
- llama_stack/core/conversations/conversations.py +61 -58
- llama_stack/core/datatypes.py +54 -8
- llama_stack/core/library_client.py +60 -13
- llama_stack/core/prompts/prompts.py +43 -42
- llama_stack/core/routers/datasets.py +20 -17
- llama_stack/core/routers/eval_scoring.py +143 -53
- llama_stack/core/routers/inference.py +20 -9
- llama_stack/core/routers/safety.py +30 -42
- llama_stack/core/routers/vector_io.py +15 -7
- llama_stack/core/routing_tables/models.py +42 -3
- llama_stack/core/routing_tables/scoring_functions.py +19 -19
- llama_stack/core/routing_tables/shields.py +20 -17
- llama_stack/core/routing_tables/vector_stores.py +8 -5
- llama_stack/core/server/auth.py +192 -17
- llama_stack/core/server/fastapi_router_registry.py +40 -5
- llama_stack/core/server/server.py +24 -5
- llama_stack/core/stack.py +54 -10
- llama_stack/core/storage/datatypes.py +9 -0
- llama_stack/core/store/registry.py +1 -1
- llama_stack/core/utils/exec.py +2 -2
- llama_stack/core/utils/type_inspection.py +16 -2
- llama_stack/distributions/dell/config.yaml +4 -1
- llama_stack/distributions/dell/doc_template.md +209 -0
- llama_stack/distributions/dell/run-with-safety.yaml +4 -1
- llama_stack/distributions/nvidia/config.yaml +4 -1
- llama_stack/distributions/nvidia/doc_template.md +170 -0
- llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
- llama_stack/distributions/oci/config.yaml +4 -1
- llama_stack/distributions/oci/doc_template.md +140 -0
- llama_stack/distributions/open-benchmark/config.yaml +9 -1
- llama_stack/distributions/postgres-demo/config.yaml +1 -1
- llama_stack/distributions/starter/build.yaml +62 -0
- llama_stack/distributions/starter/config.yaml +22 -3
- llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/starter/starter.py +13 -1
- llama_stack/distributions/starter-gpu/build.yaml +62 -0
- llama_stack/distributions/starter-gpu/config.yaml +22 -3
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/template.py +10 -2
- llama_stack/distributions/watsonx/config.yaml +4 -1
- llama_stack/log.py +1 -0
- llama_stack/models/llama/resources/dog.jpg +0 -0
- llama_stack/models/llama/resources/pasta.jpeg +0 -0
- llama_stack/models/llama/resources/small_dog.jpg +0 -0
- llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/agents.py +57 -61
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +183 -60
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +94 -22
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
- llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
- llama_stack/providers/inline/batches/reference/batches.py +2 -1
- llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
- llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
- llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
- llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +15 -18
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
- llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
- llama_stack/providers/registry/agents.py +1 -0
- llama_stack/providers/registry/inference.py +1 -9
- llama_stack/providers/registry/vector_io.py +136 -16
- llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
- llama_stack/providers/remote/eval/nvidia/README.md +134 -0
- llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
- llama_stack/providers/remote/files/s3/README.md +266 -0
- llama_stack/providers/remote/files/s3/config.py +5 -3
- llama_stack/providers/remote/files/s3/files.py +2 -2
- llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
- llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
- llama_stack/providers/remote/inference/openai/openai.py +2 -0
- llama_stack/providers/remote/inference/together/together.py +4 -0
- llama_stack/providers/remote/inference/vertexai/config.py +3 -3
- llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
- llama_stack/providers/remote/inference/vllm/config.py +37 -18
- llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
- llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
- llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
- llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
- llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
- llama_stack/providers/remote/safety/nvidia/README.md +78 -0
- llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
- llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
- llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
- llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
- llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
- llama_stack/providers/remote/vector_io/oci/config.py +41 -0
- llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
- llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
- llama_stack/providers/utils/bedrock/client.py +3 -3
- llama_stack/providers/utils/bedrock/config.py +7 -7
- llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
- llama_stack/providers/utils/inference/http_client.py +239 -0
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +5 -0
- llama_stack/providers/utils/inference/model_registry.py +148 -2
- llama_stack/providers/utils/inference/openai_compat.py +2 -1
- llama_stack/providers/utils/inference/openai_mixin.py +41 -2
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
- llama_stack/providers/utils/memory/vector_store.py +46 -19
- llama_stack/providers/utils/responses/responses_store.py +40 -6
- llama_stack/providers/utils/safety.py +114 -0
- llama_stack/providers/utils/tools/mcp.py +44 -3
- llama_stack/testing/api_recorder.py +9 -3
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/METADATA +14 -2
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/RECORD +131 -275
- llama_stack-0.5.0rc1.dist-info/top_level.txt +1 -0
- llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
- llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
- llama_stack/models/llama/hadamard_utils.py +0 -88
- llama_stack/models/llama/llama3/args.py +0 -74
- llama_stack/models/llama/llama3/generation.py +0 -378
- llama_stack/models/llama/llama3/model.py +0 -304
- llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
- llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
- llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
- llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
- llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama3/quantization/loader.py +0 -316
- llama_stack/models/llama/llama3_1/__init__.py +0 -12
- llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
- llama_stack/models/llama/llama3_1/prompts.py +0 -258
- llama_stack/models/llama/llama3_2/__init__.py +0 -5
- llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
- llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
- llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
- llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
- llama_stack/models/llama/llama3_3/__init__.py +0 -5
- llama_stack/models/llama/llama3_3/prompts.py +0 -259
- llama_stack/models/llama/llama4/args.py +0 -107
- llama_stack/models/llama/llama4/ffn.py +0 -58
- llama_stack/models/llama/llama4/moe.py +0 -214
- llama_stack/models/llama/llama4/preprocess.py +0 -435
- llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama4/quantization/loader.py +0 -226
- llama_stack/models/llama/llama4/vision/__init__.py +0 -5
- llama_stack/models/llama/llama4/vision/embedding.py +0 -210
- llama_stack/models/llama/llama4/vision/encoder.py +0 -412
- llama_stack/models/llama/quantize_impls.py +0 -316
- llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
- llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
- llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
- llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
- llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
- llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
- llama_stack-0.4.3.dist-info/top_level.txt +0 -2
- llama_stack_api/__init__.py +0 -945
- llama_stack_api/admin/__init__.py +0 -45
- llama_stack_api/admin/api.py +0 -72
- llama_stack_api/admin/fastapi_routes.py +0 -117
- llama_stack_api/admin/models.py +0 -113
- llama_stack_api/agents.py +0 -173
- llama_stack_api/batches/__init__.py +0 -40
- llama_stack_api/batches/api.py +0 -53
- llama_stack_api/batches/fastapi_routes.py +0 -113
- llama_stack_api/batches/models.py +0 -78
- llama_stack_api/benchmarks/__init__.py +0 -43
- llama_stack_api/benchmarks/api.py +0 -39
- llama_stack_api/benchmarks/fastapi_routes.py +0 -109
- llama_stack_api/benchmarks/models.py +0 -109
- llama_stack_api/common/__init__.py +0 -5
- llama_stack_api/common/content_types.py +0 -101
- llama_stack_api/common/errors.py +0 -95
- llama_stack_api/common/job_types.py +0 -38
- llama_stack_api/common/responses.py +0 -77
- llama_stack_api/common/training_types.py +0 -47
- llama_stack_api/common/type_system.py +0 -146
- llama_stack_api/connectors.py +0 -146
- llama_stack_api/conversations.py +0 -270
- llama_stack_api/datasetio.py +0 -55
- llama_stack_api/datasets/__init__.py +0 -61
- llama_stack_api/datasets/api.py +0 -35
- llama_stack_api/datasets/fastapi_routes.py +0 -104
- llama_stack_api/datasets/models.py +0 -152
- llama_stack_api/datatypes.py +0 -373
- llama_stack_api/eval.py +0 -137
- llama_stack_api/file_processors/__init__.py +0 -27
- llama_stack_api/file_processors/api.py +0 -64
- llama_stack_api/file_processors/fastapi_routes.py +0 -78
- llama_stack_api/file_processors/models.py +0 -42
- llama_stack_api/files/__init__.py +0 -35
- llama_stack_api/files/api.py +0 -51
- llama_stack_api/files/fastapi_routes.py +0 -124
- llama_stack_api/files/models.py +0 -107
- llama_stack_api/inference.py +0 -1169
- llama_stack_api/inspect_api/__init__.py +0 -37
- llama_stack_api/inspect_api/api.py +0 -25
- llama_stack_api/inspect_api/fastapi_routes.py +0 -76
- llama_stack_api/inspect_api/models.py +0 -28
- llama_stack_api/internal/kvstore.py +0 -28
- llama_stack_api/internal/sqlstore.py +0 -81
- llama_stack_api/llama_stack_api/__init__.py +0 -945
- llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
- llama_stack_api/llama_stack_api/admin/api.py +0 -72
- llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
- llama_stack_api/llama_stack_api/admin/models.py +0 -113
- llama_stack_api/llama_stack_api/agents.py +0 -173
- llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
- llama_stack_api/llama_stack_api/batches/api.py +0 -53
- llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
- llama_stack_api/llama_stack_api/batches/models.py +0 -78
- llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
- llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
- llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
- llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
- llama_stack_api/llama_stack_api/common/__init__.py +0 -5
- llama_stack_api/llama_stack_api/common/content_types.py +0 -101
- llama_stack_api/llama_stack_api/common/errors.py +0 -95
- llama_stack_api/llama_stack_api/common/job_types.py +0 -38
- llama_stack_api/llama_stack_api/common/responses.py +0 -77
- llama_stack_api/llama_stack_api/common/training_types.py +0 -47
- llama_stack_api/llama_stack_api/common/type_system.py +0 -146
- llama_stack_api/llama_stack_api/connectors.py +0 -146
- llama_stack_api/llama_stack_api/conversations.py +0 -270
- llama_stack_api/llama_stack_api/datasetio.py +0 -55
- llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
- llama_stack_api/llama_stack_api/datasets/api.py +0 -35
- llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
- llama_stack_api/llama_stack_api/datasets/models.py +0 -152
- llama_stack_api/llama_stack_api/datatypes.py +0 -373
- llama_stack_api/llama_stack_api/eval.py +0 -137
- llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
- llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
- llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
- llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
- llama_stack_api/llama_stack_api/files/__init__.py +0 -35
- llama_stack_api/llama_stack_api/files/api.py +0 -51
- llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
- llama_stack_api/llama_stack_api/files/models.py +0 -107
- llama_stack_api/llama_stack_api/inference.py +0 -1169
- llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
- llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
- llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
- llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
- llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
- llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
- llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
- llama_stack_api/llama_stack_api/models.py +0 -171
- llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
- llama_stack_api/llama_stack_api/post_training.py +0 -370
- llama_stack_api/llama_stack_api/prompts.py +0 -203
- llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
- llama_stack_api/llama_stack_api/providers/api.py +0 -16
- llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
- llama_stack_api/llama_stack_api/providers/models.py +0 -24
- llama_stack_api/llama_stack_api/py.typed +0 -0
- llama_stack_api/llama_stack_api/rag_tool.py +0 -168
- llama_stack_api/llama_stack_api/resource.py +0 -37
- llama_stack_api/llama_stack_api/router_utils.py +0 -160
- llama_stack_api/llama_stack_api/safety.py +0 -132
- llama_stack_api/llama_stack_api/schema_utils.py +0 -208
- llama_stack_api/llama_stack_api/scoring.py +0 -93
- llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
- llama_stack_api/llama_stack_api/shields.py +0 -93
- llama_stack_api/llama_stack_api/tools.py +0 -226
- llama_stack_api/llama_stack_api/vector_io.py +0 -941
- llama_stack_api/llama_stack_api/vector_stores.py +0 -53
- llama_stack_api/llama_stack_api/version.py +0 -9
- llama_stack_api/models.py +0 -171
- llama_stack_api/openai_responses.py +0 -1468
- llama_stack_api/post_training.py +0 -370
- llama_stack_api/prompts.py +0 -203
- llama_stack_api/providers/__init__.py +0 -33
- llama_stack_api/providers/api.py +0 -16
- llama_stack_api/providers/fastapi_routes.py +0 -57
- llama_stack_api/providers/models.py +0 -24
- llama_stack_api/py.typed +0 -0
- llama_stack_api/rag_tool.py +0 -168
- llama_stack_api/resource.py +0 -37
- llama_stack_api/router_utils.py +0 -160
- llama_stack_api/safety.py +0 -132
- llama_stack_api/schema_utils.py +0 -208
- llama_stack_api/scoring.py +0 -93
- llama_stack_api/scoring_functions.py +0 -211
- llama_stack_api/shields.py +0 -93
- llama_stack_api/tools.py +0 -226
- llama_stack_api/vector_io.py +0 -941
- llama_stack_api/vector_stores.py +0 -53
- llama_stack_api/version.py +0 -9
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/WHEEL +0 -0
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
|
@@ -12,17 +12,19 @@ from llama_stack.providers.inline.post_training.huggingface.config import (
|
|
|
12
12
|
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
|
13
13
|
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
|
14
14
|
from llama_stack_api import (
|
|
15
|
-
|
|
15
|
+
CancelTrainingJobRequest,
|
|
16
16
|
Checkpoint,
|
|
17
17
|
DatasetIO,
|
|
18
18
|
Datasets,
|
|
19
|
-
|
|
19
|
+
GetTrainingJobArtifactsRequest,
|
|
20
|
+
GetTrainingJobStatusRequest,
|
|
20
21
|
JobStatus,
|
|
21
22
|
ListPostTrainingJobsResponse,
|
|
22
23
|
PostTrainingJob,
|
|
23
24
|
PostTrainingJobArtifactsResponse,
|
|
24
25
|
PostTrainingJobStatusResponse,
|
|
25
|
-
|
|
26
|
+
PreferenceOptimizeRequest,
|
|
27
|
+
SupervisedFineTuneRequest,
|
|
26
28
|
)
|
|
27
29
|
|
|
28
30
|
|
|
@@ -69,13 +71,7 @@ class HuggingFacePostTrainingImpl:
|
|
|
69
71
|
|
|
70
72
|
async def supervised_fine_tune(
|
|
71
73
|
self,
|
|
72
|
-
|
|
73
|
-
training_config: TrainingConfig,
|
|
74
|
-
hyperparam_search_config: dict[str, Any],
|
|
75
|
-
logger_config: dict[str, Any],
|
|
76
|
-
model: str,
|
|
77
|
-
checkpoint_dir: str | None = None,
|
|
78
|
-
algorithm_config: AlgorithmConfig | None = None,
|
|
74
|
+
request: SupervisedFineTuneRequest,
|
|
79
75
|
) -> PostTrainingJob:
|
|
80
76
|
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
|
81
77
|
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
|
|
@@ -85,17 +81,17 @@ class HuggingFacePostTrainingImpl:
|
|
|
85
81
|
on_log_message_cb("Starting HF finetuning")
|
|
86
82
|
|
|
87
83
|
recipe = HFFinetuningSingleDevice(
|
|
88
|
-
job_uuid=job_uuid,
|
|
84
|
+
job_uuid=request.job_uuid,
|
|
89
85
|
datasetio_api=self.datasetio_api,
|
|
90
86
|
datasets_api=self.datasets_api,
|
|
91
87
|
)
|
|
92
88
|
|
|
93
89
|
resources_allocated, checkpoints = await recipe.train(
|
|
94
|
-
model=model,
|
|
95
|
-
output_dir=checkpoint_dir,
|
|
96
|
-
job_uuid=job_uuid,
|
|
97
|
-
lora_config=algorithm_config,
|
|
98
|
-
config=training_config,
|
|
90
|
+
model=request.model,
|
|
91
|
+
output_dir=request.checkpoint_dir,
|
|
92
|
+
job_uuid=request.job_uuid,
|
|
93
|
+
lora_config=request.algorithm_config,
|
|
94
|
+
config=request.training_config,
|
|
99
95
|
provider_config=self.config,
|
|
100
96
|
)
|
|
101
97
|
|
|
@@ -108,17 +104,12 @@ class HuggingFacePostTrainingImpl:
|
|
|
108
104
|
on_status_change_cb(SchedulerJobStatus.completed)
|
|
109
105
|
on_log_message_cb("HF finetuning completed")
|
|
110
106
|
|
|
111
|
-
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
|
107
|
+
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, request.job_uuid, handler)
|
|
112
108
|
return PostTrainingJob(job_uuid=job_uuid)
|
|
113
109
|
|
|
114
110
|
async def preference_optimize(
|
|
115
111
|
self,
|
|
116
|
-
|
|
117
|
-
finetuned_model: str,
|
|
118
|
-
algorithm_config: DPOAlignmentConfig,
|
|
119
|
-
training_config: TrainingConfig,
|
|
120
|
-
hyperparam_search_config: dict[str, Any],
|
|
121
|
-
logger_config: dict[str, Any],
|
|
112
|
+
request: PreferenceOptimizeRequest,
|
|
122
113
|
) -> PostTrainingJob:
|
|
123
114
|
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
|
124
115
|
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import (
|
|
@@ -128,17 +119,17 @@ class HuggingFacePostTrainingImpl:
|
|
|
128
119
|
on_log_message_cb("Starting HF DPO alignment")
|
|
129
120
|
|
|
130
121
|
recipe = HFDPOAlignmentSingleDevice(
|
|
131
|
-
job_uuid=job_uuid,
|
|
122
|
+
job_uuid=request.job_uuid,
|
|
132
123
|
datasetio_api=self.datasetio_api,
|
|
133
124
|
datasets_api=self.datasets_api,
|
|
134
125
|
)
|
|
135
126
|
|
|
136
127
|
resources_allocated, checkpoints = await recipe.train(
|
|
137
|
-
model=finetuned_model,
|
|
138
|
-
output_dir=f"{self.config.dpo_output_dir}/{job_uuid}",
|
|
139
|
-
job_uuid=job_uuid,
|
|
140
|
-
dpo_config=algorithm_config,
|
|
141
|
-
config=training_config,
|
|
128
|
+
model=request.finetuned_model,
|
|
129
|
+
output_dir=f"{self.config.dpo_output_dir}/{request.job_uuid}",
|
|
130
|
+
job_uuid=request.job_uuid,
|
|
131
|
+
dpo_config=request.algorithm_config,
|
|
132
|
+
config=request.training_config,
|
|
142
133
|
provider_config=self.config,
|
|
143
134
|
)
|
|
144
135
|
|
|
@@ -153,7 +144,7 @@ class HuggingFacePostTrainingImpl:
|
|
|
153
144
|
on_status_change_cb(SchedulerJobStatus.completed)
|
|
154
145
|
on_log_message_cb("HF DPO alignment completed")
|
|
155
146
|
|
|
156
|
-
job_uuid = self._scheduler.schedule(_JOB_TYPE_DPO_TRAINING, job_uuid, handler)
|
|
147
|
+
job_uuid = self._scheduler.schedule(_JOB_TYPE_DPO_TRAINING, request.job_uuid, handler)
|
|
157
148
|
return PostTrainingJob(job_uuid=job_uuid)
|
|
158
149
|
|
|
159
150
|
@staticmethod
|
|
@@ -169,8 +160,10 @@ class HuggingFacePostTrainingImpl:
|
|
|
169
160
|
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
|
170
161
|
return data[0] if data else None
|
|
171
162
|
|
|
172
|
-
async def get_training_job_status(
|
|
173
|
-
|
|
163
|
+
async def get_training_job_status(
|
|
164
|
+
self, request: GetTrainingJobStatusRequest
|
|
165
|
+
) -> PostTrainingJobStatusResponse | None:
|
|
166
|
+
job = self._scheduler.get_job(request.job_uuid)
|
|
174
167
|
|
|
175
168
|
match job.status:
|
|
176
169
|
# TODO: Add support for other statuses to API
|
|
@@ -186,7 +179,7 @@ class HuggingFacePostTrainingImpl:
|
|
|
186
179
|
raise NotImplementedError()
|
|
187
180
|
|
|
188
181
|
return PostTrainingJobStatusResponse(
|
|
189
|
-
job_uuid=job_uuid,
|
|
182
|
+
job_uuid=request.job_uuid,
|
|
190
183
|
status=status,
|
|
191
184
|
scheduled_at=job.scheduled_at,
|
|
192
185
|
started_at=job.started_at,
|
|
@@ -195,12 +188,14 @@ class HuggingFacePostTrainingImpl:
|
|
|
195
188
|
resources_allocated=self._get_resources_allocated(job),
|
|
196
189
|
)
|
|
197
190
|
|
|
198
|
-
async def cancel_training_job(self,
|
|
199
|
-
self._scheduler.cancel(job_uuid)
|
|
191
|
+
async def cancel_training_job(self, request: CancelTrainingJobRequest) -> None:
|
|
192
|
+
self._scheduler.cancel(request.job_uuid)
|
|
200
193
|
|
|
201
|
-
async def get_training_job_artifacts(
|
|
202
|
-
|
|
203
|
-
|
|
194
|
+
async def get_training_job_artifacts(
|
|
195
|
+
self, request: GetTrainingJobArtifactsRequest
|
|
196
|
+
) -> PostTrainingJobArtifactsResponse | None:
|
|
197
|
+
job = self._scheduler.get_job(request.job_uuid)
|
|
198
|
+
return PostTrainingJobArtifactsResponse(job_uuid=request.job_uuid, checkpoints=self._get_checkpoints(job))
|
|
204
199
|
|
|
205
200
|
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
|
206
201
|
return ListPostTrainingJobsResponse(
|
|
@@ -16,7 +16,7 @@ import torch
|
|
|
16
16
|
from datasets import Dataset
|
|
17
17
|
from transformers import AutoConfig, AutoModelForCausalLM
|
|
18
18
|
|
|
19
|
-
from llama_stack_api import Checkpoint, DatasetIO, TrainingConfig
|
|
19
|
+
from llama_stack_api import Checkpoint, DatasetIO, IterRowsRequest, TrainingConfig
|
|
20
20
|
|
|
21
21
|
if TYPE_CHECKING:
|
|
22
22
|
from transformers import PretrainedConfig
|
|
@@ -135,10 +135,7 @@ def setup_torch_device(device_str: str) -> torch.device:
|
|
|
135
135
|
async def load_rows_from_dataset(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]:
|
|
136
136
|
"""Load dataset from llama stack dataset provider"""
|
|
137
137
|
try:
|
|
138
|
-
all_rows = await datasetio_api.iterrows(
|
|
139
|
-
dataset_id=dataset_id,
|
|
140
|
-
limit=-1,
|
|
141
|
-
)
|
|
138
|
+
all_rows = await datasetio_api.iterrows(IterRowsRequest(dataset_id=dataset_id, limit=-1))
|
|
142
139
|
if not isinstance(all_rows.data, list):
|
|
143
140
|
raise RuntimeError("Expected dataset data to be a list")
|
|
144
141
|
return all_rows.data
|
|
@@ -12,18 +12,20 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
|
|
12
12
|
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
|
13
13
|
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
|
14
14
|
from llama_stack_api import (
|
|
15
|
-
|
|
15
|
+
CancelTrainingJobRequest,
|
|
16
16
|
Checkpoint,
|
|
17
17
|
DatasetIO,
|
|
18
18
|
Datasets,
|
|
19
|
-
|
|
19
|
+
GetTrainingJobArtifactsRequest,
|
|
20
|
+
GetTrainingJobStatusRequest,
|
|
20
21
|
JobStatus,
|
|
21
22
|
ListPostTrainingJobsResponse,
|
|
22
23
|
LoraFinetuningConfig,
|
|
23
24
|
PostTrainingJob,
|
|
24
25
|
PostTrainingJobArtifactsResponse,
|
|
25
26
|
PostTrainingJobStatusResponse,
|
|
26
|
-
|
|
27
|
+
PreferenceOptimizeRequest,
|
|
28
|
+
SupervisedFineTuneRequest,
|
|
27
29
|
)
|
|
28
30
|
|
|
29
31
|
|
|
@@ -69,15 +71,9 @@ class TorchtunePostTrainingImpl:
|
|
|
69
71
|
|
|
70
72
|
async def supervised_fine_tune(
|
|
71
73
|
self,
|
|
72
|
-
|
|
73
|
-
training_config: TrainingConfig,
|
|
74
|
-
hyperparam_search_config: dict[str, Any],
|
|
75
|
-
logger_config: dict[str, Any],
|
|
76
|
-
model: str,
|
|
77
|
-
checkpoint_dir: str | None,
|
|
78
|
-
algorithm_config: AlgorithmConfig | None,
|
|
74
|
+
request: SupervisedFineTuneRequest,
|
|
79
75
|
) -> PostTrainingJob:
|
|
80
|
-
if isinstance(algorithm_config, LoraFinetuningConfig):
|
|
76
|
+
if isinstance(request.algorithm_config, LoraFinetuningConfig):
|
|
81
77
|
|
|
82
78
|
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
|
|
83
79
|
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
|
@@ -88,13 +84,13 @@ class TorchtunePostTrainingImpl:
|
|
|
88
84
|
|
|
89
85
|
recipe = LoraFinetuningSingleDevice(
|
|
90
86
|
self.config,
|
|
91
|
-
job_uuid,
|
|
92
|
-
training_config,
|
|
93
|
-
hyperparam_search_config,
|
|
94
|
-
logger_config,
|
|
95
|
-
model,
|
|
96
|
-
checkpoint_dir,
|
|
97
|
-
algorithm_config,
|
|
87
|
+
request.job_uuid,
|
|
88
|
+
request.training_config,
|
|
89
|
+
request.hyperparam_search_config,
|
|
90
|
+
request.logger_config,
|
|
91
|
+
request.model,
|
|
92
|
+
request.checkpoint_dir,
|
|
93
|
+
request.algorithm_config,
|
|
98
94
|
self.datasetio_api,
|
|
99
95
|
self.datasets_api,
|
|
100
96
|
)
|
|
@@ -112,17 +108,12 @@ class TorchtunePostTrainingImpl:
|
|
|
112
108
|
else:
|
|
113
109
|
raise NotImplementedError()
|
|
114
110
|
|
|
115
|
-
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
|
|
111
|
+
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, request.job_uuid, handler)
|
|
116
112
|
return PostTrainingJob(job_uuid=job_uuid)
|
|
117
113
|
|
|
118
114
|
async def preference_optimize(
|
|
119
115
|
self,
|
|
120
|
-
|
|
121
|
-
finetuned_model: str,
|
|
122
|
-
algorithm_config: DPOAlignmentConfig,
|
|
123
|
-
training_config: TrainingConfig,
|
|
124
|
-
hyperparam_search_config: dict[str, Any],
|
|
125
|
-
logger_config: dict[str, Any],
|
|
116
|
+
request: PreferenceOptimizeRequest,
|
|
126
117
|
) -> PostTrainingJob:
|
|
127
118
|
raise NotImplementedError()
|
|
128
119
|
|
|
@@ -144,8 +135,10 @@ class TorchtunePostTrainingImpl:
|
|
|
144
135
|
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
|
|
145
136
|
return data[0] if data else None
|
|
146
137
|
|
|
147
|
-
async def get_training_job_status(
|
|
148
|
-
|
|
138
|
+
async def get_training_job_status(
|
|
139
|
+
self, request: GetTrainingJobStatusRequest
|
|
140
|
+
) -> PostTrainingJobStatusResponse | None:
|
|
141
|
+
job = self._scheduler.get_job(request.job_uuid)
|
|
149
142
|
|
|
150
143
|
match job.status:
|
|
151
144
|
# TODO: Add support for other statuses to API
|
|
@@ -161,7 +154,7 @@ class TorchtunePostTrainingImpl:
|
|
|
161
154
|
raise NotImplementedError()
|
|
162
155
|
|
|
163
156
|
return PostTrainingJobStatusResponse(
|
|
164
|
-
job_uuid=job_uuid,
|
|
157
|
+
job_uuid=request.job_uuid,
|
|
165
158
|
status=status,
|
|
166
159
|
scheduled_at=job.scheduled_at,
|
|
167
160
|
started_at=job.started_at,
|
|
@@ -170,9 +163,11 @@ class TorchtunePostTrainingImpl:
|
|
|
170
163
|
resources_allocated=self._get_resources_allocated(job),
|
|
171
164
|
)
|
|
172
165
|
|
|
173
|
-
async def cancel_training_job(self,
|
|
174
|
-
self._scheduler.cancel(job_uuid)
|
|
166
|
+
async def cancel_training_job(self, request: CancelTrainingJobRequest) -> None:
|
|
167
|
+
self._scheduler.cancel(request.job_uuid)
|
|
175
168
|
|
|
176
|
-
async def get_training_job_artifacts(
|
|
177
|
-
|
|
178
|
-
|
|
169
|
+
async def get_training_job_artifacts(
|
|
170
|
+
self, request: GetTrainingJobArtifactsRequest
|
|
171
|
+
) -> PostTrainingJobArtifactsResponse | None:
|
|
172
|
+
job = self._scheduler.get_job(request.job_uuid)
|
|
173
|
+
return PostTrainingJobArtifactsResponse(job_uuid=request.job_uuid, checkpoints=self._get_checkpoints(job))
|
llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py
CHANGED
|
@@ -50,6 +50,7 @@ from llama_stack_api import (
|
|
|
50
50
|
DataConfig,
|
|
51
51
|
DatasetIO,
|
|
52
52
|
Datasets,
|
|
53
|
+
IterRowsRequest,
|
|
53
54
|
LoraFinetuningConfig,
|
|
54
55
|
OptimizerConfig,
|
|
55
56
|
PostTrainingMetric,
|
|
@@ -334,10 +335,7 @@ class LoraFinetuningSingleDevice:
|
|
|
334
335
|
batch_size: int,
|
|
335
336
|
) -> tuple[DistributedSampler, DataLoader]:
|
|
336
337
|
async def fetch_rows(dataset_id: str):
|
|
337
|
-
return await self.datasetio_api.iterrows(
|
|
338
|
-
dataset_id=dataset_id,
|
|
339
|
-
limit=-1,
|
|
340
|
-
)
|
|
338
|
+
return await self.datasetio_api.iterrows(IterRowsRequest(dataset_id=dataset_id, limit=-1))
|
|
341
339
|
|
|
342
340
|
all_rows = await fetch_rows(dataset_id)
|
|
343
341
|
rows = all_rows.data
|
|
@@ -5,7 +5,7 @@
|
|
|
5
5
|
# the root directory of this source tree.
|
|
6
6
|
|
|
7
7
|
import uuid
|
|
8
|
-
from typing import TYPE_CHECKING
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
9
|
|
|
10
10
|
if TYPE_CHECKING:
|
|
11
11
|
from codeshield.cs import CodeShieldScanResult
|
|
@@ -15,9 +15,11 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
|
15
15
|
interleaved_content_as_str,
|
|
16
16
|
)
|
|
17
17
|
from llama_stack_api import (
|
|
18
|
+
GetShieldRequest,
|
|
18
19
|
ModerationObject,
|
|
19
20
|
ModerationObjectResults,
|
|
20
|
-
|
|
21
|
+
RunModerationRequest,
|
|
22
|
+
RunShieldRequest,
|
|
21
23
|
RunShieldResponse,
|
|
22
24
|
Safety,
|
|
23
25
|
SafetyViolation,
|
|
@@ -51,19 +53,14 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|
|
51
53
|
f"Unsupported Code Scanner ID: {shield.provider_resource_id}. Allowed IDs: {ALLOWED_CODE_SCANNER_MODEL_IDS}"
|
|
52
54
|
)
|
|
53
55
|
|
|
54
|
-
async def run_shield(
|
|
55
|
-
self
|
|
56
|
-
shield_id: str,
|
|
57
|
-
messages: list[OpenAIMessageParam],
|
|
58
|
-
params: dict[str, Any] = None,
|
|
59
|
-
) -> RunShieldResponse:
|
|
60
|
-
shield = await self.shield_store.get_shield(shield_id)
|
|
56
|
+
async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
|
|
57
|
+
shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
|
|
61
58
|
if not shield:
|
|
62
|
-
raise ValueError(f"Shield {shield_id} not found")
|
|
59
|
+
raise ValueError(f"Shield {request.shield_id} not found")
|
|
63
60
|
|
|
64
61
|
from codeshield.cs import CodeShield
|
|
65
62
|
|
|
66
|
-
text = "\n".join([interleaved_content_as_str(m.content) for m in messages])
|
|
63
|
+
text = "\n".join([interleaved_content_as_str(m.content) for m in request.messages])
|
|
67
64
|
log.info(f"Running CodeScannerShield on {text[50:]}")
|
|
68
65
|
result = await CodeShield.scan_code(text)
|
|
69
66
|
|
|
@@ -102,11 +99,11 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|
|
102
99
|
metadata=metadata,
|
|
103
100
|
)
|
|
104
101
|
|
|
105
|
-
async def run_moderation(self,
|
|
106
|
-
if model is None:
|
|
102
|
+
async def run_moderation(self, request: RunModerationRequest) -> ModerationObject:
|
|
103
|
+
if request.model is None:
|
|
107
104
|
raise ValueError("Code scanner moderation requires a model identifier.")
|
|
108
105
|
|
|
109
|
-
inputs = input if isinstance(input, list) else [input]
|
|
106
|
+
inputs = request.input if isinstance(request.input, list) else [request.input]
|
|
110
107
|
results = []
|
|
111
108
|
|
|
112
109
|
from codeshield.cs import CodeShield
|
|
@@ -129,4 +126,4 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|
|
129
126
|
)
|
|
130
127
|
results.append(moderation_result)
|
|
131
128
|
|
|
132
|
-
return ModerationObject(id=str(uuid.uuid4()), model=model, results=results)
|
|
129
|
+
return ModerationObject(id=str(uuid.uuid4()), model=request.model, results=results)
|
|
@@ -7,7 +7,6 @@
|
|
|
7
7
|
import re
|
|
8
8
|
import uuid
|
|
9
9
|
from string import Template
|
|
10
|
-
from typing import Any
|
|
11
10
|
|
|
12
11
|
from llama_stack.core.datatypes import Api
|
|
13
12
|
from llama_stack.log import get_logger
|
|
@@ -17,6 +16,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
|
17
16
|
interleaved_content_as_str,
|
|
18
17
|
)
|
|
19
18
|
from llama_stack_api import (
|
|
19
|
+
GetShieldRequest,
|
|
20
20
|
ImageContentItem,
|
|
21
21
|
Inference,
|
|
22
22
|
ModerationObject,
|
|
@@ -24,6 +24,8 @@ from llama_stack_api import (
|
|
|
24
24
|
OpenAIChatCompletionRequestWithExtraBody,
|
|
25
25
|
OpenAIMessageParam,
|
|
26
26
|
OpenAIUserMessageParam,
|
|
27
|
+
RunModerationRequest,
|
|
28
|
+
RunShieldRequest,
|
|
27
29
|
RunShieldResponse,
|
|
28
30
|
Safety,
|
|
29
31
|
SafetyViolation,
|
|
@@ -161,17 +163,12 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|
|
161
163
|
# The routing table handles the removal from the registry
|
|
162
164
|
pass
|
|
163
165
|
|
|
164
|
-
async def run_shield(
|
|
165
|
-
self
|
|
166
|
-
shield_id: str,
|
|
167
|
-
messages: list[OpenAIMessageParam],
|
|
168
|
-
params: dict[str, Any] = None,
|
|
169
|
-
) -> RunShieldResponse:
|
|
170
|
-
shield = await self.shield_store.get_shield(shield_id)
|
|
166
|
+
async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
|
|
167
|
+
shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
|
|
171
168
|
if not shield:
|
|
172
|
-
raise ValueError(f"Unknown shield {shield_id}")
|
|
169
|
+
raise ValueError(f"Unknown shield {request.shield_id}")
|
|
173
170
|
|
|
174
|
-
messages = messages.copy()
|
|
171
|
+
messages = request.messages.copy()
|
|
175
172
|
# some shields like llama-guard require the first message to be a user message
|
|
176
173
|
# since this might be a tool call, first role might not be user
|
|
177
174
|
if len(messages) > 0 and messages[0].role != "user":
|
|
@@ -200,30 +197,30 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|
|
200
197
|
|
|
201
198
|
return await impl.run(messages)
|
|
202
199
|
|
|
203
|
-
async def run_moderation(self,
|
|
204
|
-
if model is None:
|
|
200
|
+
async def run_moderation(self, request: RunModerationRequest) -> ModerationObject:
|
|
201
|
+
if request.model is None:
|
|
205
202
|
raise ValueError("Llama Guard moderation requires a model identifier.")
|
|
206
203
|
|
|
207
|
-
if isinstance(input, list):
|
|
208
|
-
messages = input.copy()
|
|
204
|
+
if isinstance(request.input, list):
|
|
205
|
+
messages = request.input.copy()
|
|
209
206
|
else:
|
|
210
|
-
messages = [input]
|
|
207
|
+
messages = [request.input]
|
|
211
208
|
|
|
212
209
|
# convert to user messages format with role
|
|
213
210
|
messages = [OpenAIUserMessageParam(content=m) for m in messages]
|
|
214
211
|
|
|
215
212
|
# Determine safety categories based on the model type
|
|
216
213
|
# For known Llama Guard models, use specific categories
|
|
217
|
-
if model in LLAMA_GUARD_MODEL_IDS:
|
|
214
|
+
if request.model in LLAMA_GUARD_MODEL_IDS:
|
|
218
215
|
# Use the mapped model for categories but the original model_id for inference
|
|
219
|
-
mapped_model = LLAMA_GUARD_MODEL_IDS[model]
|
|
216
|
+
mapped_model = LLAMA_GUARD_MODEL_IDS[request.model]
|
|
220
217
|
safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
|
|
221
218
|
else:
|
|
222
219
|
# For unknown models, use default Llama Guard 3 8B categories
|
|
223
220
|
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
|
|
224
221
|
|
|
225
222
|
impl = LlamaGuardShield(
|
|
226
|
-
model=model,
|
|
223
|
+
model=request.model,
|
|
227
224
|
inference_api=self.inference_api,
|
|
228
225
|
excluded_categories=self.config.excluded_categories,
|
|
229
226
|
safety_categories=safety_categories,
|
|
@@ -4,17 +4,19 @@
|
|
|
4
4
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
5
|
# the root directory of this source tree.
|
|
6
6
|
|
|
7
|
-
from typing import Any
|
|
8
|
-
|
|
9
7
|
import torch
|
|
10
8
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
11
9
|
|
|
12
10
|
from llama_stack.core.utils.model_utils import model_local_dir
|
|
13
11
|
from llama_stack.log import get_logger
|
|
14
|
-
from llama_stack.providers.utils.inference.prompt_adapter import
|
|
12
|
+
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
13
|
+
interleaved_content_as_str,
|
|
14
|
+
)
|
|
15
|
+
from llama_stack.providers.utils.safety import ShieldToModerationMixin
|
|
15
16
|
from llama_stack_api import (
|
|
16
|
-
|
|
17
|
+
GetShieldRequest,
|
|
17
18
|
OpenAIMessageParam,
|
|
19
|
+
RunShieldRequest,
|
|
18
20
|
RunShieldResponse,
|
|
19
21
|
Safety,
|
|
20
22
|
SafetyViolation,
|
|
@@ -31,7 +33,7 @@ log = get_logger(name=__name__, category="safety")
|
|
|
31
33
|
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
|
32
34
|
|
|
33
35
|
|
|
34
|
-
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|
36
|
+
class PromptGuardSafetyImpl(ShieldToModerationMixin, Safety, ShieldsProtocolPrivate):
|
|
35
37
|
shield_store: ShieldStore
|
|
36
38
|
|
|
37
39
|
def __init__(self, config: PromptGuardConfig, _deps) -> None:
|
|
@@ -51,20 +53,12 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|
|
51
53
|
async def unregister_shield(self, identifier: str) -> None:
|
|
52
54
|
pass
|
|
53
55
|
|
|
54
|
-
async def run_shield(
|
|
55
|
-
self
|
|
56
|
-
shield_id: str,
|
|
57
|
-
messages: list[OpenAIMessageParam],
|
|
58
|
-
params: dict[str, Any],
|
|
59
|
-
) -> RunShieldResponse:
|
|
60
|
-
shield = await self.shield_store.get_shield(shield_id)
|
|
56
|
+
async def run_shield(self, request: RunShieldRequest) -> RunShieldResponse:
|
|
57
|
+
shield = await self.shield_store.get_shield(GetShieldRequest(identifier=request.shield_id))
|
|
61
58
|
if not shield:
|
|
62
|
-
raise ValueError(f"Unknown shield {shield_id}")
|
|
63
|
-
|
|
64
|
-
return await self.shield.run(messages)
|
|
59
|
+
raise ValueError(f"Unknown shield {request.shield_id}")
|
|
65
60
|
|
|
66
|
-
|
|
67
|
-
raise NotImplementedError("run_moderation is not implemented for Prompt Guard")
|
|
61
|
+
return await self.shield.run(request.messages)
|
|
68
62
|
|
|
69
63
|
|
|
70
64
|
class PromptGuardShield:
|
|
@@ -3,16 +3,17 @@
|
|
|
3
3
|
#
|
|
4
4
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
5
|
# the root directory of this source tree.
|
|
6
|
-
from typing import Any
|
|
7
6
|
|
|
8
7
|
from llama_stack_api import (
|
|
9
8
|
DatasetIO,
|
|
10
9
|
Datasets,
|
|
10
|
+
IterRowsRequest,
|
|
11
|
+
ScoreBatchRequest,
|
|
11
12
|
ScoreBatchResponse,
|
|
13
|
+
ScoreRequest,
|
|
12
14
|
ScoreResponse,
|
|
13
15
|
Scoring,
|
|
14
16
|
ScoringFn,
|
|
15
|
-
ScoringFnParams,
|
|
16
17
|
ScoringFunctionsProtocolPrivate,
|
|
17
18
|
ScoringResult,
|
|
18
19
|
)
|
|
@@ -75,19 +76,15 @@ class BasicScoringImpl(
|
|
|
75
76
|
|
|
76
77
|
async def score_batch(
|
|
77
78
|
self,
|
|
78
|
-
|
|
79
|
-
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
|
80
|
-
save_results_dataset: bool = False,
|
|
79
|
+
request: ScoreBatchRequest,
|
|
81
80
|
) -> ScoreBatchResponse:
|
|
82
|
-
all_rows = await self.datasetio_api.iterrows(
|
|
83
|
-
|
|
84
|
-
limit=-1,
|
|
85
|
-
)
|
|
86
|
-
res = await self.score(
|
|
81
|
+
all_rows = await self.datasetio_api.iterrows(IterRowsRequest(dataset_id=request.dataset_id, limit=-1))
|
|
82
|
+
score_request = ScoreRequest(
|
|
87
83
|
input_rows=all_rows.data,
|
|
88
|
-
scoring_functions=scoring_functions,
|
|
84
|
+
scoring_functions=request.scoring_functions,
|
|
89
85
|
)
|
|
90
|
-
|
|
86
|
+
res = await self.score(score_request)
|
|
87
|
+
if request.save_results_dataset:
|
|
91
88
|
# TODO: persist and register dataset on to server for reading
|
|
92
89
|
# self.datasets_api.register_dataset()
|
|
93
90
|
raise NotImplementedError("Save results dataset not implemented yet")
|
|
@@ -98,16 +95,15 @@ class BasicScoringImpl(
|
|
|
98
95
|
|
|
99
96
|
async def score(
|
|
100
97
|
self,
|
|
101
|
-
|
|
102
|
-
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
|
98
|
+
request: ScoreRequest,
|
|
103
99
|
) -> ScoreResponse:
|
|
104
100
|
res = {}
|
|
105
|
-
for scoring_fn_id in scoring_functions.keys():
|
|
101
|
+
for scoring_fn_id in request.scoring_functions.keys():
|
|
106
102
|
if scoring_fn_id not in self.scoring_fn_id_impls:
|
|
107
103
|
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
|
108
104
|
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
|
109
|
-
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
|
110
|
-
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
|
|
105
|
+
scoring_fn_params = request.scoring_functions.get(scoring_fn_id, None)
|
|
106
|
+
score_results = await scoring_fn.score(request.input_rows, scoring_fn_id, scoring_fn_params)
|
|
111
107
|
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)
|
|
112
108
|
res[scoring_fn_id] = ScoringResult(
|
|
113
109
|
score_rows=score_results,
|
|
@@ -29,11 +29,13 @@ from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metr
|
|
|
29
29
|
from llama_stack_api import (
|
|
30
30
|
DatasetIO,
|
|
31
31
|
Datasets,
|
|
32
|
+
IterRowsRequest,
|
|
33
|
+
ScoreBatchRequest,
|
|
32
34
|
ScoreBatchResponse,
|
|
35
|
+
ScoreRequest,
|
|
33
36
|
ScoreResponse,
|
|
34
37
|
Scoring,
|
|
35
38
|
ScoringFn,
|
|
36
|
-
ScoringFnParams,
|
|
37
39
|
ScoringFunctionsProtocolPrivate,
|
|
38
40
|
ScoringResult,
|
|
39
41
|
ScoringResultRow,
|
|
@@ -158,18 +160,17 @@ class BraintrustScoringImpl(
|
|
|
158
160
|
|
|
159
161
|
async def score_batch(
|
|
160
162
|
self,
|
|
161
|
-
|
|
162
|
-
scoring_functions: dict[str, ScoringFnParams | None],
|
|
163
|
-
save_results_dataset: bool = False,
|
|
163
|
+
request: ScoreBatchRequest,
|
|
164
164
|
) -> ScoreBatchResponse:
|
|
165
165
|
await self.set_api_key()
|
|
166
166
|
|
|
167
|
-
all_rows = await self.datasetio_api.iterrows(
|
|
168
|
-
|
|
169
|
-
|
|
167
|
+
all_rows = await self.datasetio_api.iterrows(IterRowsRequest(dataset_id=request.dataset_id, limit=-1))
|
|
168
|
+
score_request = ScoreRequest(
|
|
169
|
+
input_rows=all_rows.data,
|
|
170
|
+
scoring_functions=request.scoring_functions,
|
|
170
171
|
)
|
|
171
|
-
res = await self.score(
|
|
172
|
-
if save_results_dataset:
|
|
172
|
+
res = await self.score(score_request)
|
|
173
|
+
if request.save_results_dataset:
|
|
173
174
|
# TODO: persist and register dataset on to server for reading
|
|
174
175
|
# self.datasets_api.register_dataset()
|
|
175
176
|
raise NotImplementedError("Save results dataset not implemented yet")
|
|
@@ -198,21 +199,20 @@ class BraintrustScoringImpl(
|
|
|
198
199
|
|
|
199
200
|
async def score(
|
|
200
201
|
self,
|
|
201
|
-
|
|
202
|
-
scoring_functions: dict[str, ScoringFnParams | None],
|
|
202
|
+
request: ScoreRequest,
|
|
203
203
|
) -> ScoreResponse:
|
|
204
204
|
await self.set_api_key()
|
|
205
205
|
res = {}
|
|
206
|
-
for scoring_fn_id in scoring_functions:
|
|
206
|
+
for scoring_fn_id in request.scoring_functions:
|
|
207
207
|
if scoring_fn_id not in self.supported_fn_defs_registry:
|
|
208
208
|
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
|
209
209
|
|
|
210
|
-
score_results = [await self.score_row(input_row, scoring_fn_id) for input_row in input_rows]
|
|
210
|
+
score_results = [await self.score_row(input_row, scoring_fn_id) for input_row in request.input_rows]
|
|
211
211
|
aggregation_functions = self.supported_fn_defs_registry[scoring_fn_id].params.aggregation_functions
|
|
212
212
|
|
|
213
213
|
# override scoring_fn params if provided
|
|
214
|
-
if scoring_functions[scoring_fn_id] is not None:
|
|
215
|
-
override_params = scoring_functions[scoring_fn_id]
|
|
214
|
+
if request.scoring_functions[scoring_fn_id] is not None:
|
|
215
|
+
override_params = request.scoring_functions[scoring_fn_id]
|
|
216
216
|
if override_params.aggregation_functions:
|
|
217
217
|
aggregation_functions = override_params.aggregation_functions
|
|
218
218
|
|