llama-stack 0.3.4__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/__init__.py +0 -5
- llama_stack/cli/llama.py +3 -3
- llama_stack/cli/stack/_list_deps.py +12 -23
- llama_stack/cli/stack/list_stacks.py +37 -18
- llama_stack/cli/stack/run.py +121 -11
- llama_stack/cli/stack/utils.py +0 -127
- llama_stack/core/access_control/access_control.py +69 -28
- llama_stack/core/access_control/conditions.py +15 -5
- llama_stack/core/admin.py +267 -0
- llama_stack/core/build.py +6 -74
- llama_stack/core/client.py +1 -1
- llama_stack/core/configure.py +6 -6
- llama_stack/core/conversations/conversations.py +28 -25
- llama_stack/core/datatypes.py +271 -79
- llama_stack/core/distribution.py +15 -16
- llama_stack/core/external.py +3 -3
- llama_stack/core/inspect.py +98 -15
- llama_stack/core/library_client.py +73 -61
- llama_stack/core/prompts/prompts.py +12 -11
- llama_stack/core/providers.py +17 -11
- llama_stack/core/resolver.py +65 -56
- llama_stack/core/routers/__init__.py +8 -12
- llama_stack/core/routers/datasets.py +1 -4
- llama_stack/core/routers/eval_scoring.py +7 -4
- llama_stack/core/routers/inference.py +55 -271
- llama_stack/core/routers/safety.py +52 -24
- llama_stack/core/routers/tool_runtime.py +6 -48
- llama_stack/core/routers/vector_io.py +130 -51
- llama_stack/core/routing_tables/benchmarks.py +24 -20
- llama_stack/core/routing_tables/common.py +1 -4
- llama_stack/core/routing_tables/datasets.py +22 -22
- llama_stack/core/routing_tables/models.py +119 -6
- llama_stack/core/routing_tables/scoring_functions.py +7 -7
- llama_stack/core/routing_tables/shields.py +1 -2
- llama_stack/core/routing_tables/toolgroups.py +17 -7
- llama_stack/core/routing_tables/vector_stores.py +51 -16
- llama_stack/core/server/auth.py +5 -3
- llama_stack/core/server/auth_providers.py +36 -20
- llama_stack/core/server/fastapi_router_registry.py +84 -0
- llama_stack/core/server/quota.py +2 -2
- llama_stack/core/server/routes.py +79 -27
- llama_stack/core/server/server.py +102 -87
- llama_stack/core/stack.py +201 -58
- llama_stack/core/storage/datatypes.py +26 -3
- llama_stack/{providers/utils → core/storage}/kvstore/__init__.py +2 -0
- llama_stack/{providers/utils → core/storage}/kvstore/kvstore.py +55 -24
- llama_stack/{providers/utils → core/storage}/kvstore/mongodb/mongodb.py +13 -10
- llama_stack/{providers/utils → core/storage}/kvstore/postgres/postgres.py +28 -17
- llama_stack/{providers/utils → core/storage}/kvstore/redis/redis.py +41 -16
- llama_stack/{providers/utils → core/storage}/kvstore/sqlite/sqlite.py +1 -1
- llama_stack/core/storage/sqlstore/__init__.py +17 -0
- llama_stack/{providers/utils → core/storage}/sqlstore/authorized_sqlstore.py +69 -49
- llama_stack/{providers/utils → core/storage}/sqlstore/sqlalchemy_sqlstore.py +47 -17
- llama_stack/{providers/utils → core/storage}/sqlstore/sqlstore.py +25 -8
- llama_stack/core/store/registry.py +1 -1
- llama_stack/core/utils/config.py +8 -2
- llama_stack/core/utils/config_resolution.py +32 -29
- llama_stack/core/utils/context.py +4 -10
- llama_stack/core/utils/exec.py +9 -0
- llama_stack/core/utils/type_inspection.py +45 -0
- llama_stack/distributions/dell/{run.yaml → config.yaml} +3 -2
- llama_stack/distributions/dell/dell.py +2 -2
- llama_stack/distributions/dell/run-with-safety.yaml +3 -2
- llama_stack/distributions/meta-reference-gpu/{run.yaml → config.yaml} +3 -2
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +2 -2
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +3 -2
- llama_stack/distributions/nvidia/{run.yaml → config.yaml} +4 -4
- llama_stack/distributions/nvidia/nvidia.py +1 -1
- llama_stack/distributions/nvidia/run-with-safety.yaml +4 -4
- llama_stack/{apis/datasetio → distributions/oci}/__init__.py +1 -1
- llama_stack/distributions/oci/config.yaml +134 -0
- llama_stack/distributions/oci/oci.py +108 -0
- llama_stack/distributions/open-benchmark/{run.yaml → config.yaml} +5 -4
- llama_stack/distributions/open-benchmark/open_benchmark.py +2 -3
- llama_stack/distributions/postgres-demo/{run.yaml → config.yaml} +4 -3
- llama_stack/distributions/starter/{run.yaml → config.yaml} +64 -13
- llama_stack/distributions/starter/run-with-postgres-store.yaml +64 -13
- llama_stack/distributions/starter/starter.py +8 -5
- llama_stack/distributions/starter-gpu/{run.yaml → config.yaml} +64 -13
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +64 -13
- llama_stack/distributions/template.py +13 -69
- llama_stack/distributions/watsonx/{run.yaml → config.yaml} +4 -3
- llama_stack/distributions/watsonx/watsonx.py +1 -1
- llama_stack/log.py +28 -11
- llama_stack/models/llama/checkpoint.py +6 -6
- llama_stack/models/llama/hadamard_utils.py +2 -0
- llama_stack/models/llama/llama3/generation.py +3 -1
- llama_stack/models/llama/llama3/interface.py +2 -5
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +3 -3
- llama_stack/models/llama/llama3/multimodal/image_transform.py +6 -6
- llama_stack/models/llama/llama3/prompt_templates/system_prompts.py +1 -1
- llama_stack/models/llama/llama3/tool_utils.py +2 -1
- llama_stack/models/llama/llama4/prompt_templates/system_prompts.py +1 -1
- llama_stack/providers/inline/agents/meta_reference/__init__.py +3 -3
- llama_stack/providers/inline/agents/meta_reference/agents.py +44 -261
- llama_stack/providers/inline/agents/meta_reference/config.py +6 -1
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +207 -57
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +308 -47
- llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +162 -96
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +23 -8
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +201 -33
- llama_stack/providers/inline/agents/meta_reference/safety.py +8 -13
- llama_stack/providers/inline/batches/reference/__init__.py +2 -4
- llama_stack/providers/inline/batches/reference/batches.py +78 -60
- llama_stack/providers/inline/datasetio/localfs/datasetio.py +2 -5
- llama_stack/providers/inline/eval/meta_reference/eval.py +16 -61
- llama_stack/providers/inline/files/localfs/files.py +37 -28
- llama_stack/providers/inline/inference/meta_reference/config.py +2 -2
- llama_stack/providers/inline/inference/meta_reference/generators.py +50 -60
- llama_stack/providers/inline/inference/meta_reference/inference.py +403 -19
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +7 -26
- llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +2 -12
- llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +10 -15
- llama_stack/providers/inline/post_training/common/validator.py +1 -5
- llama_stack/providers/inline/post_training/huggingface/post_training.py +8 -8
- llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py +18 -10
- llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py +12 -9
- llama_stack/providers/inline/post_training/huggingface/utils.py +27 -6
- llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py +1 -1
- llama_stack/providers/inline/post_training/torchtune/common/utils.py +1 -1
- llama_stack/providers/inline/post_training/torchtune/datasets/format_adapter.py +1 -1
- llama_stack/providers/inline/post_training/torchtune/post_training.py +8 -8
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +16 -16
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +13 -9
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +18 -15
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +9 -9
- llama_stack/providers/inline/scoring/basic/scoring.py +6 -13
- llama_stack/providers/inline/scoring/basic/scoring_fn/docvqa_scoring_fn.py +1 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py +1 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/docvqa.py +2 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py +2 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/ifeval.py +2 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_math_response.py +2 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py +2 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py +2 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/ifeval_scoring_fn.py +1 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_math_response_scoring_fn.py +1 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py +1 -2
- llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py +1 -2
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +12 -15
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_relevancy.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_similarity.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_entity_recall.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_precision.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_recall.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_relevancy.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py +2 -2
- llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/faithfulness.py +2 -2
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +7 -14
- llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_405b_simpleqa.py +2 -2
- llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py +1 -2
- llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py +1 -3
- llama_stack/providers/inline/tool_runtime/rag/__init__.py +1 -1
- llama_stack/providers/inline/tool_runtime/rag/config.py +8 -1
- llama_stack/providers/inline/tool_runtime/rag/context_retriever.py +7 -6
- llama_stack/providers/inline/tool_runtime/rag/memory.py +64 -48
- llama_stack/providers/inline/vector_io/chroma/__init__.py +1 -1
- llama_stack/providers/inline/vector_io/chroma/config.py +1 -1
- llama_stack/providers/inline/vector_io/faiss/__init__.py +1 -1
- llama_stack/providers/inline/vector_io/faiss/config.py +1 -1
- llama_stack/providers/inline/vector_io/faiss/faiss.py +43 -28
- llama_stack/providers/inline/vector_io/milvus/__init__.py +1 -1
- llama_stack/providers/inline/vector_io/milvus/config.py +1 -1
- llama_stack/providers/inline/vector_io/qdrant/__init__.py +1 -1
- llama_stack/providers/inline/vector_io/qdrant/config.py +1 -1
- llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py +1 -1
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +40 -33
- llama_stack/providers/registry/agents.py +7 -3
- llama_stack/providers/registry/batches.py +1 -1
- llama_stack/providers/registry/datasetio.py +1 -1
- llama_stack/providers/registry/eval.py +1 -1
- llama_stack/{apis/datasets/__init__.py → providers/registry/file_processors.py} +5 -1
- llama_stack/providers/registry/files.py +11 -2
- llama_stack/providers/registry/inference.py +22 -3
- llama_stack/providers/registry/post_training.py +1 -1
- llama_stack/providers/registry/safety.py +1 -1
- llama_stack/providers/registry/scoring.py +1 -1
- llama_stack/providers/registry/tool_runtime.py +2 -2
- llama_stack/providers/registry/vector_io.py +7 -7
- llama_stack/providers/remote/datasetio/huggingface/huggingface.py +2 -5
- llama_stack/providers/remote/datasetio/nvidia/datasetio.py +1 -4
- llama_stack/providers/remote/eval/nvidia/eval.py +15 -9
- llama_stack/providers/remote/files/openai/__init__.py +19 -0
- llama_stack/providers/remote/files/openai/config.py +28 -0
- llama_stack/providers/remote/files/openai/files.py +253 -0
- llama_stack/providers/remote/files/s3/files.py +52 -30
- llama_stack/providers/remote/inference/anthropic/anthropic.py +2 -1
- llama_stack/providers/remote/inference/anthropic/config.py +1 -1
- llama_stack/providers/remote/inference/azure/azure.py +1 -3
- llama_stack/providers/remote/inference/azure/config.py +8 -7
- llama_stack/providers/remote/inference/bedrock/__init__.py +1 -1
- llama_stack/providers/remote/inference/bedrock/bedrock.py +82 -105
- llama_stack/providers/remote/inference/bedrock/config.py +24 -3
- llama_stack/providers/remote/inference/cerebras/cerebras.py +5 -5
- llama_stack/providers/remote/inference/cerebras/config.py +12 -5
- llama_stack/providers/remote/inference/databricks/config.py +13 -6
- llama_stack/providers/remote/inference/databricks/databricks.py +16 -6
- llama_stack/providers/remote/inference/fireworks/config.py +5 -5
- llama_stack/providers/remote/inference/fireworks/fireworks.py +1 -1
- llama_stack/providers/remote/inference/gemini/config.py +1 -1
- llama_stack/providers/remote/inference/gemini/gemini.py +13 -14
- llama_stack/providers/remote/inference/groq/config.py +5 -5
- llama_stack/providers/remote/inference/groq/groq.py +1 -1
- llama_stack/providers/remote/inference/llama_openai_compat/config.py +5 -5
- llama_stack/providers/remote/inference/llama_openai_compat/llama.py +8 -6
- llama_stack/providers/remote/inference/nvidia/__init__.py +1 -1
- llama_stack/providers/remote/inference/nvidia/config.py +21 -11
- llama_stack/providers/remote/inference/nvidia/nvidia.py +115 -3
- llama_stack/providers/remote/inference/nvidia/utils.py +1 -1
- llama_stack/providers/remote/inference/oci/__init__.py +17 -0
- llama_stack/providers/remote/inference/oci/auth.py +79 -0
- llama_stack/providers/remote/inference/oci/config.py +75 -0
- llama_stack/providers/remote/inference/oci/oci.py +162 -0
- llama_stack/providers/remote/inference/ollama/config.py +7 -5
- llama_stack/providers/remote/inference/ollama/ollama.py +17 -8
- llama_stack/providers/remote/inference/openai/config.py +4 -4
- llama_stack/providers/remote/inference/openai/openai.py +1 -1
- llama_stack/providers/remote/inference/passthrough/__init__.py +2 -2
- llama_stack/providers/remote/inference/passthrough/config.py +5 -10
- llama_stack/providers/remote/inference/passthrough/passthrough.py +97 -75
- llama_stack/providers/remote/inference/runpod/config.py +12 -5
- llama_stack/providers/remote/inference/runpod/runpod.py +2 -20
- llama_stack/providers/remote/inference/sambanova/config.py +5 -5
- llama_stack/providers/remote/inference/sambanova/sambanova.py +1 -1
- llama_stack/providers/remote/inference/tgi/config.py +7 -6
- llama_stack/providers/remote/inference/tgi/tgi.py +19 -11
- llama_stack/providers/remote/inference/together/config.py +5 -5
- llama_stack/providers/remote/inference/together/together.py +15 -12
- llama_stack/providers/remote/inference/vertexai/config.py +1 -1
- llama_stack/providers/remote/inference/vllm/config.py +5 -5
- llama_stack/providers/remote/inference/vllm/vllm.py +13 -14
- llama_stack/providers/remote/inference/watsonx/config.py +4 -4
- llama_stack/providers/remote/inference/watsonx/watsonx.py +21 -94
- llama_stack/providers/remote/post_training/nvidia/post_training.py +4 -4
- llama_stack/providers/remote/post_training/nvidia/utils.py +1 -1
- llama_stack/providers/remote/safety/bedrock/bedrock.py +6 -6
- llama_stack/providers/remote/safety/bedrock/config.py +1 -1
- llama_stack/providers/remote/safety/nvidia/config.py +1 -1
- llama_stack/providers/remote/safety/nvidia/nvidia.py +11 -5
- llama_stack/providers/remote/safety/sambanova/config.py +1 -1
- llama_stack/providers/remote/safety/sambanova/sambanova.py +6 -6
- llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py +11 -6
- llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +12 -7
- llama_stack/providers/remote/tool_runtime/model_context_protocol/config.py +8 -2
- llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +57 -15
- llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +11 -6
- llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py +11 -6
- llama_stack/providers/remote/vector_io/chroma/__init__.py +1 -1
- llama_stack/providers/remote/vector_io/chroma/chroma.py +125 -20
- llama_stack/providers/remote/vector_io/chroma/config.py +1 -1
- llama_stack/providers/remote/vector_io/milvus/__init__.py +1 -1
- llama_stack/providers/remote/vector_io/milvus/config.py +1 -1
- llama_stack/providers/remote/vector_io/milvus/milvus.py +27 -21
- llama_stack/providers/remote/vector_io/pgvector/__init__.py +1 -1
- llama_stack/providers/remote/vector_io/pgvector/config.py +1 -1
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +26 -18
- llama_stack/providers/remote/vector_io/qdrant/__init__.py +1 -1
- llama_stack/providers/remote/vector_io/qdrant/config.py +1 -1
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +141 -24
- llama_stack/providers/remote/vector_io/weaviate/__init__.py +1 -1
- llama_stack/providers/remote/vector_io/weaviate/config.py +1 -1
- llama_stack/providers/remote/vector_io/weaviate/weaviate.py +26 -21
- llama_stack/providers/utils/common/data_schema_validator.py +1 -5
- llama_stack/providers/utils/files/form_data.py +1 -1
- llama_stack/providers/utils/inference/embedding_mixin.py +1 -1
- llama_stack/providers/utils/inference/inference_store.py +12 -21
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +79 -79
- llama_stack/providers/utils/inference/model_registry.py +1 -3
- llama_stack/providers/utils/inference/openai_compat.py +44 -1171
- llama_stack/providers/utils/inference/openai_mixin.py +68 -42
- llama_stack/providers/utils/inference/prompt_adapter.py +50 -265
- llama_stack/providers/utils/inference/stream_utils.py +23 -0
- llama_stack/providers/utils/memory/__init__.py +2 -0
- llama_stack/providers/utils/memory/file_utils.py +1 -1
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +181 -84
- llama_stack/providers/utils/memory/vector_store.py +39 -38
- llama_stack/providers/utils/pagination.py +1 -1
- llama_stack/providers/utils/responses/responses_store.py +15 -25
- llama_stack/providers/utils/scoring/aggregation_utils.py +1 -2
- llama_stack/providers/utils/scoring/base_scoring_fn.py +1 -2
- llama_stack/providers/utils/tools/mcp.py +93 -11
- llama_stack/telemetry/constants.py +27 -0
- llama_stack/telemetry/helpers.py +43 -0
- llama_stack/testing/api_recorder.py +25 -16
- {llama_stack-0.3.4.dist-info → llama_stack-0.4.0.dist-info}/METADATA +56 -131
- llama_stack-0.4.0.dist-info/RECORD +588 -0
- llama_stack-0.4.0.dist-info/top_level.txt +2 -0
- llama_stack_api/__init__.py +945 -0
- llama_stack_api/admin/__init__.py +45 -0
- llama_stack_api/admin/api.py +72 -0
- llama_stack_api/admin/fastapi_routes.py +117 -0
- llama_stack_api/admin/models.py +113 -0
- llama_stack_api/agents.py +173 -0
- llama_stack_api/batches/__init__.py +40 -0
- llama_stack_api/batches/api.py +53 -0
- llama_stack_api/batches/fastapi_routes.py +113 -0
- llama_stack_api/batches/models.py +78 -0
- llama_stack_api/benchmarks/__init__.py +43 -0
- llama_stack_api/benchmarks/api.py +39 -0
- llama_stack_api/benchmarks/fastapi_routes.py +109 -0
- llama_stack_api/benchmarks/models.py +109 -0
- {llama_stack/apis → llama_stack_api}/common/content_types.py +1 -43
- {llama_stack/apis → llama_stack_api}/common/errors.py +0 -8
- {llama_stack/apis → llama_stack_api}/common/job_types.py +1 -1
- llama_stack_api/common/responses.py +77 -0
- {llama_stack/apis → llama_stack_api}/common/training_types.py +1 -1
- {llama_stack/apis → llama_stack_api}/common/type_system.py +2 -14
- llama_stack_api/connectors.py +146 -0
- {llama_stack/apis/conversations → llama_stack_api}/conversations.py +23 -39
- {llama_stack/apis/datasetio → llama_stack_api}/datasetio.py +4 -8
- llama_stack_api/datasets/__init__.py +61 -0
- llama_stack_api/datasets/api.py +35 -0
- llama_stack_api/datasets/fastapi_routes.py +104 -0
- llama_stack_api/datasets/models.py +152 -0
- {llama_stack/providers → llama_stack_api}/datatypes.py +166 -10
- {llama_stack/apis/eval → llama_stack_api}/eval.py +8 -40
- llama_stack_api/file_processors/__init__.py +27 -0
- llama_stack_api/file_processors/api.py +64 -0
- llama_stack_api/file_processors/fastapi_routes.py +78 -0
- llama_stack_api/file_processors/models.py +42 -0
- llama_stack_api/files/__init__.py +35 -0
- llama_stack_api/files/api.py +51 -0
- llama_stack_api/files/fastapi_routes.py +124 -0
- llama_stack_api/files/models.py +107 -0
- {llama_stack/apis/inference → llama_stack_api}/inference.py +90 -194
- llama_stack_api/inspect_api/__init__.py +37 -0
- llama_stack_api/inspect_api/api.py +25 -0
- llama_stack_api/inspect_api/fastapi_routes.py +76 -0
- llama_stack_api/inspect_api/models.py +28 -0
- {llama_stack/apis/agents → llama_stack_api/internal}/__init__.py +3 -1
- llama_stack/providers/utils/kvstore/api.py → llama_stack_api/internal/kvstore.py +5 -0
- llama_stack_api/internal/sqlstore.py +79 -0
- {llama_stack/apis/models → llama_stack_api}/models.py +11 -9
- {llama_stack/apis/agents → llama_stack_api}/openai_responses.py +184 -27
- {llama_stack/apis/post_training → llama_stack_api}/post_training.py +7 -11
- {llama_stack/apis/prompts → llama_stack_api}/prompts.py +3 -4
- llama_stack_api/providers/__init__.py +33 -0
- llama_stack_api/providers/api.py +16 -0
- llama_stack_api/providers/fastapi_routes.py +57 -0
- llama_stack_api/providers/models.py +24 -0
- {llama_stack/apis/tools → llama_stack_api}/rag_tool.py +2 -52
- {llama_stack/apis → llama_stack_api}/resource.py +1 -1
- llama_stack_api/router_utils.py +160 -0
- {llama_stack/apis/safety → llama_stack_api}/safety.py +6 -9
- {llama_stack → llama_stack_api}/schema_utils.py +94 -4
- {llama_stack/apis/scoring → llama_stack_api}/scoring.py +3 -3
- {llama_stack/apis/scoring_functions → llama_stack_api}/scoring_functions.py +9 -6
- {llama_stack/apis/shields → llama_stack_api}/shields.py +6 -7
- {llama_stack/apis/tools → llama_stack_api}/tools.py +26 -21
- {llama_stack/apis/vector_io → llama_stack_api}/vector_io.py +133 -152
- {llama_stack/apis/vector_stores → llama_stack_api}/vector_stores.py +1 -1
- llama_stack/apis/agents/agents.py +0 -894
- llama_stack/apis/batches/__init__.py +0 -9
- llama_stack/apis/batches/batches.py +0 -100
- llama_stack/apis/benchmarks/__init__.py +0 -7
- llama_stack/apis/benchmarks/benchmarks.py +0 -108
- llama_stack/apis/common/responses.py +0 -36
- llama_stack/apis/conversations/__init__.py +0 -31
- llama_stack/apis/datasets/datasets.py +0 -251
- llama_stack/apis/datatypes.py +0 -160
- llama_stack/apis/eval/__init__.py +0 -7
- llama_stack/apis/files/__init__.py +0 -7
- llama_stack/apis/files/files.py +0 -199
- llama_stack/apis/inference/__init__.py +0 -7
- llama_stack/apis/inference/event_logger.py +0 -43
- llama_stack/apis/inspect/__init__.py +0 -7
- llama_stack/apis/inspect/inspect.py +0 -94
- llama_stack/apis/models/__init__.py +0 -7
- llama_stack/apis/post_training/__init__.py +0 -7
- llama_stack/apis/prompts/__init__.py +0 -9
- llama_stack/apis/providers/__init__.py +0 -7
- llama_stack/apis/providers/providers.py +0 -69
- llama_stack/apis/safety/__init__.py +0 -7
- llama_stack/apis/scoring/__init__.py +0 -7
- llama_stack/apis/scoring_functions/__init__.py +0 -7
- llama_stack/apis/shields/__init__.py +0 -7
- llama_stack/apis/synthetic_data_generation/__init__.py +0 -7
- llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py +0 -77
- llama_stack/apis/telemetry/__init__.py +0 -7
- llama_stack/apis/telemetry/telemetry.py +0 -423
- llama_stack/apis/tools/__init__.py +0 -8
- llama_stack/apis/vector_io/__init__.py +0 -7
- llama_stack/apis/vector_stores/__init__.py +0 -7
- llama_stack/core/server/tracing.py +0 -80
- llama_stack/core/ui/app.py +0 -55
- llama_stack/core/ui/modules/__init__.py +0 -5
- llama_stack/core/ui/modules/api.py +0 -32
- llama_stack/core/ui/modules/utils.py +0 -42
- llama_stack/core/ui/page/__init__.py +0 -5
- llama_stack/core/ui/page/distribution/__init__.py +0 -5
- llama_stack/core/ui/page/distribution/datasets.py +0 -18
- llama_stack/core/ui/page/distribution/eval_tasks.py +0 -20
- llama_stack/core/ui/page/distribution/models.py +0 -18
- llama_stack/core/ui/page/distribution/providers.py +0 -27
- llama_stack/core/ui/page/distribution/resources.py +0 -48
- llama_stack/core/ui/page/distribution/scoring_functions.py +0 -18
- llama_stack/core/ui/page/distribution/shields.py +0 -19
- llama_stack/core/ui/page/evaluations/__init__.py +0 -5
- llama_stack/core/ui/page/evaluations/app_eval.py +0 -143
- llama_stack/core/ui/page/evaluations/native_eval.py +0 -253
- llama_stack/core/ui/page/playground/__init__.py +0 -5
- llama_stack/core/ui/page/playground/chat.py +0 -130
- llama_stack/core/ui/page/playground/tools.py +0 -352
- llama_stack/distributions/dell/build.yaml +0 -33
- llama_stack/distributions/meta-reference-gpu/build.yaml +0 -32
- llama_stack/distributions/nvidia/build.yaml +0 -29
- llama_stack/distributions/open-benchmark/build.yaml +0 -36
- llama_stack/distributions/postgres-demo/__init__.py +0 -7
- llama_stack/distributions/postgres-demo/build.yaml +0 -23
- llama_stack/distributions/postgres-demo/postgres_demo.py +0 -125
- llama_stack/distributions/starter/build.yaml +0 -61
- llama_stack/distributions/starter-gpu/build.yaml +0 -61
- llama_stack/distributions/watsonx/build.yaml +0 -33
- llama_stack/providers/inline/agents/meta_reference/agent_instance.py +0 -1024
- llama_stack/providers/inline/agents/meta_reference/persistence.py +0 -228
- llama_stack/providers/inline/telemetry/__init__.py +0 -5
- llama_stack/providers/inline/telemetry/meta_reference/__init__.py +0 -21
- llama_stack/providers/inline/telemetry/meta_reference/config.py +0 -47
- llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +0 -252
- llama_stack/providers/remote/inference/bedrock/models.py +0 -29
- llama_stack/providers/utils/kvstore/sqlite/config.py +0 -20
- llama_stack/providers/utils/sqlstore/__init__.py +0 -5
- llama_stack/providers/utils/sqlstore/api.py +0 -128
- llama_stack/providers/utils/telemetry/__init__.py +0 -5
- llama_stack/providers/utils/telemetry/trace_protocol.py +0 -142
- llama_stack/providers/utils/telemetry/tracing.py +0 -384
- llama_stack/strong_typing/__init__.py +0 -19
- llama_stack/strong_typing/auxiliary.py +0 -228
- llama_stack/strong_typing/classdef.py +0 -440
- llama_stack/strong_typing/core.py +0 -46
- llama_stack/strong_typing/deserializer.py +0 -877
- llama_stack/strong_typing/docstring.py +0 -409
- llama_stack/strong_typing/exception.py +0 -23
- llama_stack/strong_typing/inspection.py +0 -1085
- llama_stack/strong_typing/mapping.py +0 -40
- llama_stack/strong_typing/name.py +0 -182
- llama_stack/strong_typing/schema.py +0 -792
- llama_stack/strong_typing/serialization.py +0 -97
- llama_stack/strong_typing/serializer.py +0 -500
- llama_stack/strong_typing/slots.py +0 -27
- llama_stack/strong_typing/topological.py +0 -89
- llama_stack/ui/node_modules/flatted/python/flatted.py +0 -149
- llama_stack-0.3.4.dist-info/RECORD +0 -625
- llama_stack-0.3.4.dist-info/top_level.txt +0 -1
- /llama_stack/{providers/utils → core/storage}/kvstore/config.py +0 -0
- /llama_stack/{providers/utils → core/storage}/kvstore/mongodb/__init__.py +0 -0
- /llama_stack/{providers/utils → core/storage}/kvstore/postgres/__init__.py +0 -0
- /llama_stack/{providers/utils → core/storage}/kvstore/redis/__init__.py +0 -0
- /llama_stack/{providers/utils → core/storage}/kvstore/sqlite/__init__.py +0 -0
- /llama_stack/{apis → providers/inline/file_processor}/__init__.py +0 -0
- /llama_stack/{apis/common → telemetry}/__init__.py +0 -0
- {llama_stack-0.3.4.dist-info → llama_stack-0.4.0.dist-info}/WHEEL +0 -0
- {llama_stack-0.3.4.dist-info → llama_stack-0.4.0.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.3.4.dist-info → llama_stack-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {llama_stack/core/ui → llama_stack_api/common}/__init__.py +0 -0
- {llama_stack/strong_typing → llama_stack_api}/py.typed +0 -0
- {llama_stack/apis → llama_stack_api}/version.py +0 -0
|
@@ -33,10 +33,6 @@ from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
|
|
33
33
|
|
|
34
34
|
from llama_stack.log import get_logger
|
|
35
35
|
from llama_stack.models.llama.datatypes import GenerationResult
|
|
36
|
-
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
37
|
-
ChatCompletionRequestWithRawContent,
|
|
38
|
-
CompletionRequestWithRawContent,
|
|
39
|
-
)
|
|
40
36
|
|
|
41
37
|
log = get_logger(name=__name__, category="inference")
|
|
42
38
|
|
|
@@ -69,10 +65,7 @@ class CancelSentinel(BaseModel):
|
|
|
69
65
|
|
|
70
66
|
class TaskRequest(BaseModel):
|
|
71
67
|
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
|
72
|
-
task: tuple[
|
|
73
|
-
str,
|
|
74
|
-
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
|
|
75
|
-
]
|
|
68
|
+
task: tuple[str, list]
|
|
76
69
|
|
|
77
70
|
|
|
78
71
|
class TaskResponse(BaseModel):
|
|
@@ -328,10 +321,7 @@ class ModelParallelProcessGroup:
|
|
|
328
321
|
|
|
329
322
|
def run_inference(
|
|
330
323
|
self,
|
|
331
|
-
req: tuple[
|
|
332
|
-
str,
|
|
333
|
-
list[CompletionRequestWithRawContent] | list[ChatCompletionRequestWithRawContent],
|
|
334
|
-
],
|
|
324
|
+
req: tuple[str, list],
|
|
335
325
|
) -> Generator:
|
|
336
326
|
assert not self.running, "inference already running"
|
|
337
327
|
|
|
@@ -6,24 +6,20 @@
|
|
|
6
6
|
|
|
7
7
|
from collections.abc import AsyncIterator
|
|
8
8
|
|
|
9
|
-
from llama_stack.apis.inference import (
|
|
10
|
-
InferenceProvider,
|
|
11
|
-
OpenAIChatCompletionRequestWithExtraBody,
|
|
12
|
-
OpenAICompletionRequestWithExtraBody,
|
|
13
|
-
)
|
|
14
|
-
from llama_stack.apis.inference.inference import (
|
|
15
|
-
OpenAIChatCompletion,
|
|
16
|
-
OpenAIChatCompletionChunk,
|
|
17
|
-
OpenAICompletion,
|
|
18
|
-
)
|
|
19
|
-
from llama_stack.apis.models import ModelType
|
|
20
9
|
from llama_stack.log import get_logger
|
|
21
|
-
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
|
22
10
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
|
23
11
|
SentenceTransformerEmbeddingMixin,
|
|
24
12
|
)
|
|
25
|
-
from
|
|
26
|
-
|
|
13
|
+
from llama_stack_api import (
|
|
14
|
+
InferenceProvider,
|
|
15
|
+
Model,
|
|
16
|
+
ModelsProtocolPrivate,
|
|
17
|
+
ModelType,
|
|
18
|
+
OpenAIChatCompletion,
|
|
19
|
+
OpenAIChatCompletionChunk,
|
|
20
|
+
OpenAIChatCompletionRequestWithExtraBody,
|
|
21
|
+
OpenAICompletion,
|
|
22
|
+
OpenAICompletionRequestWithExtraBody,
|
|
27
23
|
)
|
|
28
24
|
|
|
29
25
|
from .config import SentenceTransformersInferenceConfig
|
|
@@ -32,7 +28,6 @@ log = get_logger(name=__name__, category="inference")
|
|
|
32
28
|
|
|
33
29
|
|
|
34
30
|
class SentenceTransformersInferenceImpl(
|
|
35
|
-
OpenAIChatCompletionToLlamaStackMixin,
|
|
36
31
|
SentenceTransformerEmbeddingMixin,
|
|
37
32
|
InferenceProvider,
|
|
38
33
|
ModelsProtocolPrivate,
|
|
@@ -12,14 +12,10 @@
|
|
|
12
12
|
|
|
13
13
|
from typing import Any
|
|
14
14
|
|
|
15
|
-
from llama_stack.apis.common.type_system import (
|
|
16
|
-
ChatCompletionInputType,
|
|
17
|
-
DialogType,
|
|
18
|
-
StringType,
|
|
19
|
-
)
|
|
20
15
|
from llama_stack.providers.utils.common.data_schema_validator import (
|
|
21
16
|
ColumnName,
|
|
22
17
|
)
|
|
18
|
+
from llama_stack_api import ChatCompletionInputType, DialogType, StringType
|
|
23
19
|
|
|
24
20
|
EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
|
|
25
21
|
"instruct": [
|
|
@@ -6,11 +6,16 @@
|
|
|
6
6
|
from enum import Enum
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
9
|
-
from llama_stack.
|
|
10
|
-
|
|
11
|
-
|
|
9
|
+
from llama_stack.providers.inline.post_training.huggingface.config import (
|
|
10
|
+
HuggingFacePostTrainingConfig,
|
|
11
|
+
)
|
|
12
|
+
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
|
13
|
+
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
|
14
|
+
from llama_stack_api import (
|
|
12
15
|
AlgorithmConfig,
|
|
13
16
|
Checkpoint,
|
|
17
|
+
DatasetIO,
|
|
18
|
+
Datasets,
|
|
14
19
|
DPOAlignmentConfig,
|
|
15
20
|
JobStatus,
|
|
16
21
|
ListPostTrainingJobsResponse,
|
|
@@ -19,11 +24,6 @@ from llama_stack.apis.post_training import (
|
|
|
19
24
|
PostTrainingJobStatusResponse,
|
|
20
25
|
TrainingConfig,
|
|
21
26
|
)
|
|
22
|
-
from llama_stack.providers.inline.post_training.huggingface.config import (
|
|
23
|
-
HuggingFacePostTrainingConfig,
|
|
24
|
-
)
|
|
25
|
-
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
|
26
|
-
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
class TrainingArtifactType(Enum):
|
|
@@ -14,24 +14,24 @@ import torch
|
|
|
14
14
|
from datasets import Dataset
|
|
15
15
|
from peft import LoraConfig
|
|
16
16
|
from transformers import (
|
|
17
|
-
AutoModelForCausalLM,
|
|
18
17
|
AutoTokenizer,
|
|
19
18
|
)
|
|
20
19
|
from trl import SFTConfig, SFTTrainer
|
|
21
20
|
|
|
22
|
-
from llama_stack.
|
|
23
|
-
from llama_stack.
|
|
24
|
-
from
|
|
21
|
+
from llama_stack.log import get_logger
|
|
22
|
+
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
|
23
|
+
from llama_stack_api import (
|
|
25
24
|
Checkpoint,
|
|
26
25
|
DataConfig,
|
|
26
|
+
DatasetIO,
|
|
27
|
+
Datasets,
|
|
27
28
|
LoraFinetuningConfig,
|
|
28
29
|
TrainingConfig,
|
|
29
30
|
)
|
|
30
|
-
from llama_stack.log import get_logger
|
|
31
|
-
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
|
32
31
|
|
|
33
32
|
from ..config import HuggingFacePostTrainingConfig
|
|
34
33
|
from ..utils import (
|
|
34
|
+
HFAutoModel,
|
|
35
35
|
calculate_training_steps,
|
|
36
36
|
create_checkpoints,
|
|
37
37
|
get_memory_stats,
|
|
@@ -338,7 +338,7 @@ class HFFinetuningSingleDevice:
|
|
|
338
338
|
|
|
339
339
|
def save_model(
|
|
340
340
|
self,
|
|
341
|
-
model_obj:
|
|
341
|
+
model_obj: HFAutoModel,
|
|
342
342
|
trainer: SFTTrainer,
|
|
343
343
|
peft_config: LoraConfig | None,
|
|
344
344
|
output_dir_path: Path,
|
|
@@ -350,14 +350,22 @@ class HFFinetuningSingleDevice:
|
|
|
350
350
|
peft_config: Optional LoRA configuration
|
|
351
351
|
output_dir_path: Path to save the model
|
|
352
352
|
"""
|
|
353
|
+
from typing import cast
|
|
354
|
+
|
|
353
355
|
logger.info("Saving final model")
|
|
354
356
|
model_obj.config.use_cache = True
|
|
355
357
|
|
|
356
358
|
if peft_config:
|
|
357
359
|
logger.info("Merging LoRA weights with base model")
|
|
358
|
-
|
|
360
|
+
# TRL's merge_and_unload returns a HuggingFace model
|
|
361
|
+
# Both cast() and type: ignore are needed here:
|
|
362
|
+
# - cast() tells mypy the return type is HFAutoModel for downstream code
|
|
363
|
+
# - type: ignore suppresses errors on the merge_and_unload() call itself,
|
|
364
|
+
# which mypy can't type-check due to TRL library's incomplete type stubs
|
|
365
|
+
model_obj = cast(HFAutoModel, trainer.model.merge_and_unload()) # type: ignore[union-attr,operator]
|
|
359
366
|
else:
|
|
360
|
-
|
|
367
|
+
# trainer.model is the trained HuggingFace model
|
|
368
|
+
model_obj = cast(HFAutoModel, trainer.model)
|
|
361
369
|
|
|
362
370
|
save_path = output_dir_path / "merged_model"
|
|
363
371
|
logger.info(f"Saving model to {save_path}")
|
|
@@ -411,7 +419,7 @@ class HFFinetuningSingleDevice:
|
|
|
411
419
|
# Initialize trainer
|
|
412
420
|
logger.info("Initializing SFTTrainer")
|
|
413
421
|
trainer = SFTTrainer(
|
|
414
|
-
model=model_obj,
|
|
422
|
+
model=model_obj, # type: ignore[arg-type]
|
|
415
423
|
train_dataset=train_dataset,
|
|
416
424
|
eval_dataset=eval_dataset,
|
|
417
425
|
peft_config=peft_config,
|
llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py
CHANGED
|
@@ -16,15 +16,15 @@ from transformers import (
|
|
|
16
16
|
)
|
|
17
17
|
from trl import DPOConfig, DPOTrainer
|
|
18
18
|
|
|
19
|
-
from llama_stack.
|
|
20
|
-
from llama_stack.
|
|
21
|
-
from
|
|
19
|
+
from llama_stack.log import get_logger
|
|
20
|
+
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
|
21
|
+
from llama_stack_api import (
|
|
22
22
|
Checkpoint,
|
|
23
|
+
DatasetIO,
|
|
24
|
+
Datasets,
|
|
23
25
|
DPOAlignmentConfig,
|
|
24
26
|
TrainingConfig,
|
|
25
27
|
)
|
|
26
|
-
from llama_stack.log import get_logger
|
|
27
|
-
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
|
28
28
|
|
|
29
29
|
from ..config import HuggingFacePostTrainingConfig
|
|
30
30
|
from ..utils import (
|
|
@@ -309,7 +309,7 @@ class HFDPOAlignmentSingleDevice:
|
|
|
309
309
|
save_total_limit=provider_config.save_total_limit,
|
|
310
310
|
# DPO specific parameters
|
|
311
311
|
beta=dpo_config.beta,
|
|
312
|
-
loss_type=provider_config.dpo_loss_type,
|
|
312
|
+
loss_type=provider_config.dpo_loss_type, # type: ignore[arg-type]
|
|
313
313
|
)
|
|
314
314
|
|
|
315
315
|
def save_model(
|
|
@@ -381,13 +381,16 @@ class HFDPOAlignmentSingleDevice:
|
|
|
381
381
|
|
|
382
382
|
# Initialize DPO trainer
|
|
383
383
|
logger.info("Initializing DPOTrainer")
|
|
384
|
+
# TRL library has incomplete type stubs - use Any to bypass
|
|
385
|
+
from typing import Any, cast
|
|
386
|
+
|
|
384
387
|
trainer = DPOTrainer(
|
|
385
|
-
model=model_obj,
|
|
386
|
-
ref_model=ref_model,
|
|
388
|
+
model=cast(Any, model_obj), # HFAutoModel satisfies PreTrainedModel protocol
|
|
389
|
+
ref_model=cast(Any, ref_model),
|
|
387
390
|
args=training_args,
|
|
388
391
|
train_dataset=train_dataset,
|
|
389
392
|
eval_dataset=eval_dataset,
|
|
390
|
-
processing_class=tokenizer,
|
|
393
|
+
processing_class=cast(Any, tokenizer), # AutoTokenizer satisfies interface
|
|
391
394
|
)
|
|
392
395
|
|
|
393
396
|
try:
|
|
@@ -9,15 +9,33 @@ import signal
|
|
|
9
9
|
import sys
|
|
10
10
|
from datetime import UTC, datetime
|
|
11
11
|
from pathlib import Path
|
|
12
|
-
from typing import Any
|
|
12
|
+
from typing import TYPE_CHECKING, Any, Protocol
|
|
13
13
|
|
|
14
14
|
import psutil
|
|
15
15
|
import torch
|
|
16
16
|
from datasets import Dataset
|
|
17
17
|
from transformers import AutoConfig, AutoModelForCausalLM
|
|
18
18
|
|
|
19
|
-
from
|
|
20
|
-
|
|
19
|
+
from llama_stack_api import Checkpoint, DatasetIO, TrainingConfig
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from transformers import PretrainedConfig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class HFAutoModel(Protocol):
|
|
26
|
+
"""Protocol describing HuggingFace AutoModel interface.
|
|
27
|
+
|
|
28
|
+
This protocol defines the common interface for HuggingFace AutoModelForCausalLM
|
|
29
|
+
and similar models, providing type safety without requiring type stubs.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
config: PretrainedConfig
|
|
33
|
+
device: torch.device
|
|
34
|
+
|
|
35
|
+
def to(self, device: torch.device) -> "HFAutoModel": ...
|
|
36
|
+
def save_pretrained(self, save_directory: str | Path) -> None: ...
|
|
37
|
+
|
|
38
|
+
|
|
21
39
|
from llama_stack.log import get_logger
|
|
22
40
|
|
|
23
41
|
from .config import HuggingFacePostTrainingConfig
|
|
@@ -132,7 +150,7 @@ def load_model(
|
|
|
132
150
|
model: str,
|
|
133
151
|
device: torch.device,
|
|
134
152
|
provider_config: HuggingFacePostTrainingConfig,
|
|
135
|
-
) ->
|
|
153
|
+
) -> HFAutoModel:
|
|
136
154
|
"""Load and initialize the model for training.
|
|
137
155
|
Args:
|
|
138
156
|
model: The model identifier to load
|
|
@@ -143,6 +161,8 @@ def load_model(
|
|
|
143
161
|
Raises:
|
|
144
162
|
RuntimeError: If model loading fails
|
|
145
163
|
"""
|
|
164
|
+
from typing import cast
|
|
165
|
+
|
|
146
166
|
logger.info("Loading the base model")
|
|
147
167
|
try:
|
|
148
168
|
model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config)
|
|
@@ -154,9 +174,10 @@ def load_model(
|
|
|
154
174
|
**provider_config.model_specific_config,
|
|
155
175
|
)
|
|
156
176
|
# Always move model to specified device
|
|
157
|
-
model_obj = model_obj.to(device)
|
|
177
|
+
model_obj = model_obj.to(device) # type: ignore[arg-type]
|
|
158
178
|
logger.info(f"Model loaded and moved to device: {model_obj.device}")
|
|
159
|
-
|
|
179
|
+
# Cast to HFAutoModel protocol - transformers models satisfy this interface
|
|
180
|
+
return cast(HFAutoModel, model_obj)
|
|
160
181
|
except Exception as e:
|
|
161
182
|
raise RuntimeError(f"Failed to load model: {str(e)}") from e
|
|
162
183
|
|
|
@@ -91,7 +91,7 @@ class TorchtuneCheckpointer:
|
|
|
91
91
|
if checkpoint_format == "meta" or checkpoint_format is None:
|
|
92
92
|
self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only)
|
|
93
93
|
elif checkpoint_format == "huggingface":
|
|
94
|
-
# Note: for saving hugging face format checkpoints, we only
|
|
94
|
+
# Note: for saving hugging face format checkpoints, we only support saving adapter weights now
|
|
95
95
|
self._save_hf_format_checkpoint(model_file_path, state_dict)
|
|
96
96
|
else:
|
|
97
97
|
raise ValueError(f"Unsupported checkpoint format: {format}")
|
|
@@ -21,9 +21,9 @@ from torchtune.models.llama3_1 import lora_llama3_1_8b
|
|
|
21
21
|
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
|
22
22
|
from torchtune.modules.transforms import Transform
|
|
23
23
|
|
|
24
|
-
from llama_stack.apis.post_training import DatasetFormat
|
|
25
24
|
from llama_stack.models.llama.sku_list import resolve_model
|
|
26
25
|
from llama_stack.models.llama.sku_types import Model
|
|
26
|
+
from llama_stack_api import DatasetFormat
|
|
27
27
|
|
|
28
28
|
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
|
29
29
|
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
|
@@ -25,7 +25,7 @@ def llama_stack_instruct_to_torchtune_instruct(
|
|
|
25
25
|
)
|
|
26
26
|
input_messages = json.loads(sample[ColumnName.chat_completion_input.value])
|
|
27
27
|
|
|
28
|
-
assert len(input_messages) == 1, "llama stack
|
|
28
|
+
assert len(input_messages) == 1, "llama stack instruct dataset format only supports 1 user message"
|
|
29
29
|
input_message = input_messages[0]
|
|
30
30
|
|
|
31
31
|
assert "content" in input_message, "content not found in input message"
|
|
@@ -6,11 +6,16 @@
|
|
|
6
6
|
from enum import Enum
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
9
|
-
from llama_stack.
|
|
10
|
-
|
|
11
|
-
|
|
9
|
+
from llama_stack.providers.inline.post_training.torchtune.config import (
|
|
10
|
+
TorchtunePostTrainingConfig,
|
|
11
|
+
)
|
|
12
|
+
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
|
13
|
+
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
|
14
|
+
from llama_stack_api import (
|
|
12
15
|
AlgorithmConfig,
|
|
13
16
|
Checkpoint,
|
|
17
|
+
DatasetIO,
|
|
18
|
+
Datasets,
|
|
14
19
|
DPOAlignmentConfig,
|
|
15
20
|
JobStatus,
|
|
16
21
|
ListPostTrainingJobsResponse,
|
|
@@ -20,11 +25,6 @@ from llama_stack.apis.post_training import (
|
|
|
20
25
|
PostTrainingJobStatusResponse,
|
|
21
26
|
TrainingConfig,
|
|
22
27
|
)
|
|
23
|
-
from llama_stack.providers.inline.post_training.torchtune.config import (
|
|
24
|
-
TorchtunePostTrainingConfig,
|
|
25
|
-
)
|
|
26
|
-
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
|
27
|
-
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class TrainingArtifactType(Enum):
|
llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py
CHANGED
|
@@ -32,17 +32,6 @@ from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
|
|
|
32
32
|
from torchtune.training.metric_logging import DiskLogger
|
|
33
33
|
from tqdm import tqdm
|
|
34
34
|
|
|
35
|
-
from llama_stack.apis.common.training_types import PostTrainingMetric
|
|
36
|
-
from llama_stack.apis.datasetio import DatasetIO
|
|
37
|
-
from llama_stack.apis.datasets import Datasets
|
|
38
|
-
from llama_stack.apis.post_training import (
|
|
39
|
-
Checkpoint,
|
|
40
|
-
DataConfig,
|
|
41
|
-
LoraFinetuningConfig,
|
|
42
|
-
OptimizerConfig,
|
|
43
|
-
QATFinetuningConfig,
|
|
44
|
-
TrainingConfig,
|
|
45
|
-
)
|
|
46
35
|
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
|
47
36
|
from llama_stack.core.utils.model_utils import model_local_dir
|
|
48
37
|
from llama_stack.log import get_logger
|
|
@@ -56,6 +45,17 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
|
|
56
45
|
TorchtunePostTrainingConfig,
|
|
57
46
|
)
|
|
58
47
|
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
|
48
|
+
from llama_stack_api import (
|
|
49
|
+
Checkpoint,
|
|
50
|
+
DataConfig,
|
|
51
|
+
DatasetIO,
|
|
52
|
+
Datasets,
|
|
53
|
+
LoraFinetuningConfig,
|
|
54
|
+
OptimizerConfig,
|
|
55
|
+
PostTrainingMetric,
|
|
56
|
+
QATFinetuningConfig,
|
|
57
|
+
TrainingConfig,
|
|
58
|
+
)
|
|
59
59
|
|
|
60
60
|
log = get_logger(name=__name__, category="post_training")
|
|
61
61
|
|
|
@@ -193,7 +193,7 @@ class LoraFinetuningSingleDevice:
|
|
|
193
193
|
log.info("Optimizer is initialized.")
|
|
194
194
|
|
|
195
195
|
self._loss_fn = CEWithChunkedOutputLoss()
|
|
196
|
-
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
|
|
196
|
+
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) # type: ignore[operator]
|
|
197
197
|
log.info("Loss is initialized.")
|
|
198
198
|
|
|
199
199
|
assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized"
|
|
@@ -284,7 +284,7 @@ class LoraFinetuningSingleDevice:
|
|
|
284
284
|
if self._is_dora:
|
|
285
285
|
for m in model.modules():
|
|
286
286
|
if hasattr(m, "initialize_dora_magnitude"):
|
|
287
|
-
m.initialize_dora_magnitude()
|
|
287
|
+
m.initialize_dora_magnitude() # type: ignore[operator]
|
|
288
288
|
if lora_weights_state_dict:
|
|
289
289
|
lora_missing, lora_unexpected = model.load_state_dict(lora_weights_state_dict, strict=False)
|
|
290
290
|
else:
|
|
@@ -353,7 +353,7 @@ class LoraFinetuningSingleDevice:
|
|
|
353
353
|
dataset_type=self._data_format.value,
|
|
354
354
|
)
|
|
355
355
|
|
|
356
|
-
sampler = DistributedSampler(
|
|
356
|
+
sampler: DistributedSampler = DistributedSampler(
|
|
357
357
|
ds,
|
|
358
358
|
num_replicas=1,
|
|
359
359
|
rank=0,
|
|
@@ -389,7 +389,7 @@ class LoraFinetuningSingleDevice:
|
|
|
389
389
|
num_training_steps=num_training_steps,
|
|
390
390
|
last_epoch=last_epoch,
|
|
391
391
|
)
|
|
392
|
-
return lr_scheduler
|
|
392
|
+
return lr_scheduler # type: ignore[no-any-return]
|
|
393
393
|
|
|
394
394
|
async def save_checkpoint(self, epoch: int) -> str:
|
|
395
395
|
ckpt_dict = {}
|
|
@@ -447,7 +447,7 @@ class LoraFinetuningSingleDevice:
|
|
|
447
447
|
# free logits otherwise it peaks backward memory
|
|
448
448
|
del logits
|
|
449
449
|
|
|
450
|
-
return loss
|
|
450
|
+
return loss # type: ignore[no-any-return]
|
|
451
451
|
|
|
452
452
|
async def train(self) -> tuple[dict[str, Any], list[Checkpoint]]:
|
|
453
453
|
"""
|
|
@@ -10,19 +10,20 @@ from typing import TYPE_CHECKING, Any
|
|
|
10
10
|
if TYPE_CHECKING:
|
|
11
11
|
from codeshield.cs import CodeShieldScanResult
|
|
12
12
|
|
|
13
|
-
from llama_stack.
|
|
14
|
-
from llama_stack.
|
|
13
|
+
from llama_stack.log import get_logger
|
|
14
|
+
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
15
|
+
interleaved_content_as_str,
|
|
16
|
+
)
|
|
17
|
+
from llama_stack_api import (
|
|
18
|
+
ModerationObject,
|
|
19
|
+
ModerationObjectResults,
|
|
20
|
+
OpenAIMessageParam,
|
|
15
21
|
RunShieldResponse,
|
|
16
22
|
Safety,
|
|
17
23
|
SafetyViolation,
|
|
24
|
+
Shield,
|
|
18
25
|
ViolationLevel,
|
|
19
26
|
)
|
|
20
|
-
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
|
21
|
-
from llama_stack.apis.shields import Shield
|
|
22
|
-
from llama_stack.log import get_logger
|
|
23
|
-
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
24
|
-
interleaved_content_as_str,
|
|
25
|
-
)
|
|
26
27
|
|
|
27
28
|
from .config import CodeScannerConfig
|
|
28
29
|
|
|
@@ -101,7 +102,10 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|
|
101
102
|
metadata=metadata,
|
|
102
103
|
)
|
|
103
104
|
|
|
104
|
-
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
|
105
|
+
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
|
106
|
+
if model is None:
|
|
107
|
+
raise ValueError("Code scanner moderation requires a model identifier.")
|
|
108
|
+
|
|
105
109
|
inputs = input if isinstance(input, list) else [input]
|
|
106
110
|
results = []
|
|
107
111
|
|
|
@@ -9,29 +9,29 @@ import uuid
|
|
|
9
9
|
from string import Template
|
|
10
10
|
from typing import Any
|
|
11
11
|
|
|
12
|
-
from llama_stack.
|
|
13
|
-
from llama_stack.
|
|
12
|
+
from llama_stack.core.datatypes import Api
|
|
13
|
+
from llama_stack.log import get_logger
|
|
14
|
+
from llama_stack.models.llama.datatypes import Role
|
|
15
|
+
from llama_stack.models.llama.sku_types import CoreModelId
|
|
16
|
+
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
17
|
+
interleaved_content_as_str,
|
|
18
|
+
)
|
|
19
|
+
from llama_stack_api import (
|
|
20
|
+
ImageContentItem,
|
|
14
21
|
Inference,
|
|
22
|
+
ModerationObject,
|
|
23
|
+
ModerationObjectResults,
|
|
15
24
|
OpenAIChatCompletionRequestWithExtraBody,
|
|
16
25
|
OpenAIMessageParam,
|
|
17
26
|
OpenAIUserMessageParam,
|
|
18
|
-
)
|
|
19
|
-
from llama_stack.apis.safety import (
|
|
20
27
|
RunShieldResponse,
|
|
21
28
|
Safety,
|
|
22
29
|
SafetyViolation,
|
|
30
|
+
Shield,
|
|
31
|
+
ShieldsProtocolPrivate,
|
|
32
|
+
TextContentItem,
|
|
23
33
|
ViolationLevel,
|
|
24
34
|
)
|
|
25
|
-
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
|
26
|
-
from llama_stack.apis.shields import Shield
|
|
27
|
-
from llama_stack.core.datatypes import Api
|
|
28
|
-
from llama_stack.log import get_logger
|
|
29
|
-
from llama_stack.models.llama.datatypes import Role
|
|
30
|
-
from llama_stack.models.llama.sku_types import CoreModelId
|
|
31
|
-
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
|
32
|
-
from llama_stack.providers.utils.inference.prompt_adapter import (
|
|
33
|
-
interleaved_content_as_str,
|
|
34
|
-
)
|
|
35
35
|
|
|
36
36
|
from .config import LlamaGuardConfig
|
|
37
37
|
|
|
@@ -200,7 +200,10 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|
|
200
200
|
|
|
201
201
|
return await impl.run(messages)
|
|
202
202
|
|
|
203
|
-
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
|
203
|
+
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
|
204
|
+
if model is None:
|
|
205
|
+
raise ValueError("Llama Guard moderation requires a model identifier.")
|
|
206
|
+
|
|
204
207
|
if isinstance(input, list):
|
|
205
208
|
messages = input.copy()
|
|
206
209
|
else:
|
|
@@ -9,20 +9,20 @@ from typing import Any
|
|
|
9
9
|
import torch
|
|
10
10
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
11
11
|
|
|
12
|
-
from llama_stack.
|
|
13
|
-
from llama_stack.
|
|
12
|
+
from llama_stack.core.utils.model_utils import model_local_dir
|
|
13
|
+
from llama_stack.log import get_logger
|
|
14
|
+
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
|
15
|
+
from llama_stack_api import (
|
|
16
|
+
ModerationObject,
|
|
17
|
+
OpenAIMessageParam,
|
|
14
18
|
RunShieldResponse,
|
|
15
19
|
Safety,
|
|
16
20
|
SafetyViolation,
|
|
21
|
+
Shield,
|
|
22
|
+
ShieldsProtocolPrivate,
|
|
17
23
|
ShieldStore,
|
|
18
24
|
ViolationLevel,
|
|
19
25
|
)
|
|
20
|
-
from llama_stack.apis.safety.safety import ModerationObject
|
|
21
|
-
from llama_stack.apis.shields import Shield
|
|
22
|
-
from llama_stack.core.utils.model_utils import model_local_dir
|
|
23
|
-
from llama_stack.log import get_logger
|
|
24
|
-
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
|
25
|
-
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
|
26
26
|
|
|
27
27
|
from .config import PromptGuardConfig, PromptGuardType
|
|
28
28
|
|
|
@@ -63,7 +63,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|
|
63
63
|
|
|
64
64
|
return await self.shield.run(messages)
|
|
65
65
|
|
|
66
|
-
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
|
66
|
+
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
|
|
67
67
|
raise NotImplementedError("run_moderation is not implemented for Prompt Guard")
|
|
68
68
|
|
|
69
69
|
|
|
@@ -5,21 +5,17 @@
|
|
|
5
5
|
# the root directory of this source tree.
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
|
-
from
|
|
9
|
-
|
|
10
|
-
|
|
8
|
+
from llama_stack_api import (
|
|
9
|
+
DatasetIO,
|
|
10
|
+
Datasets,
|
|
11
11
|
ScoreBatchResponse,
|
|
12
12
|
ScoreResponse,
|
|
13
13
|
Scoring,
|
|
14
|
+
ScoringFn,
|
|
15
|
+
ScoringFnParams,
|
|
16
|
+
ScoringFunctionsProtocolPrivate,
|
|
14
17
|
ScoringResult,
|
|
15
18
|
)
|
|
16
|
-
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
|
17
|
-
from llama_stack.core.datatypes import Api
|
|
18
|
-
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
|
19
|
-
from llama_stack.providers.utils.common.data_schema_validator import (
|
|
20
|
-
get_valid_schemas,
|
|
21
|
-
validate_dataset_schema,
|
|
22
|
-
)
|
|
23
19
|
|
|
24
20
|
from .config import BasicScoringConfig
|
|
25
21
|
from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn
|
|
@@ -83,9 +79,6 @@ class BasicScoringImpl(
|
|
|
83
79
|
scoring_functions: dict[str, ScoringFnParams | None] = None,
|
|
84
80
|
save_results_dataset: bool = False,
|
|
85
81
|
) -> ScoreBatchResponse:
|
|
86
|
-
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
|
87
|
-
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
|
|
88
|
-
|
|
89
82
|
all_rows = await self.datasetio_api.iterrows(
|
|
90
83
|
dataset_id=dataset_id,
|
|
91
84
|
limit=-1,
|
|
@@ -8,9 +8,8 @@ import json
|
|
|
8
8
|
import re
|
|
9
9
|
from typing import Any
|
|
10
10
|
|
|
11
|
-
from llama_stack.apis.scoring import ScoringResultRow
|
|
12
|
-
from llama_stack.apis.scoring_functions import ScoringFnParams
|
|
13
11
|
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
|
12
|
+
from llama_stack_api import ScoringFnParams, ScoringResultRow
|
|
14
13
|
|
|
15
14
|
from .fn_defs.docvqa import docvqa
|
|
16
15
|
|
|
@@ -6,9 +6,8 @@
|
|
|
6
6
|
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
9
|
-
from llama_stack.apis.scoring import ScoringResultRow
|
|
10
|
-
from llama_stack.apis.scoring_functions import ScoringFnParams
|
|
11
9
|
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
|
10
|
+
from llama_stack_api import ScoringFnParams, ScoringResultRow
|
|
12
11
|
|
|
13
12
|
from .fn_defs.equality import equality
|
|
14
13
|
|
|
@@ -4,10 +4,10 @@
|
|
|
4
4
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
5
|
# the root directory of this source tree.
|
|
6
6
|
|
|
7
|
-
from
|
|
8
|
-
from llama_stack.apis.scoring_functions import (
|
|
7
|
+
from llama_stack_api import (
|
|
9
8
|
AggregationFunctionType,
|
|
10
9
|
BasicScoringFnParams,
|
|
10
|
+
NumberType,
|
|
11
11
|
ScoringFn,
|
|
12
12
|
)
|
|
13
13
|
|