llama-stack 0.4.3__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/cli/stack/_list_deps.py +11 -7
- llama_stack/cli/stack/run.py +3 -25
- llama_stack/core/access_control/datatypes.py +78 -0
- llama_stack/core/configure.py +2 -2
- {llama_stack_api/internal → llama_stack/core/connectors}/__init__.py +2 -2
- llama_stack/core/connectors/connectors.py +162 -0
- llama_stack/core/conversations/conversations.py +61 -58
- llama_stack/core/datatypes.py +54 -8
- llama_stack/core/library_client.py +60 -13
- llama_stack/core/prompts/prompts.py +43 -42
- llama_stack/core/routers/datasets.py +20 -17
- llama_stack/core/routers/eval_scoring.py +143 -53
- llama_stack/core/routers/inference.py +20 -9
- llama_stack/core/routers/safety.py +30 -42
- llama_stack/core/routers/vector_io.py +15 -7
- llama_stack/core/routing_tables/models.py +42 -3
- llama_stack/core/routing_tables/scoring_functions.py +19 -19
- llama_stack/core/routing_tables/shields.py +20 -17
- llama_stack/core/routing_tables/vector_stores.py +8 -5
- llama_stack/core/server/auth.py +192 -17
- llama_stack/core/server/fastapi_router_registry.py +40 -5
- llama_stack/core/server/server.py +24 -5
- llama_stack/core/stack.py +54 -10
- llama_stack/core/storage/datatypes.py +9 -0
- llama_stack/core/store/registry.py +1 -1
- llama_stack/core/utils/exec.py +2 -2
- llama_stack/core/utils/type_inspection.py +16 -2
- llama_stack/distributions/dell/config.yaml +4 -1
- llama_stack/distributions/dell/doc_template.md +209 -0
- llama_stack/distributions/dell/run-with-safety.yaml +4 -1
- llama_stack/distributions/nvidia/config.yaml +4 -1
- llama_stack/distributions/nvidia/doc_template.md +170 -0
- llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
- llama_stack/distributions/oci/config.yaml +4 -1
- llama_stack/distributions/oci/doc_template.md +140 -0
- llama_stack/distributions/open-benchmark/config.yaml +9 -1
- llama_stack/distributions/postgres-demo/config.yaml +1 -1
- llama_stack/distributions/starter/build.yaml +62 -0
- llama_stack/distributions/starter/config.yaml +22 -3
- llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/starter/starter.py +13 -1
- llama_stack/distributions/starter-gpu/build.yaml +62 -0
- llama_stack/distributions/starter-gpu/config.yaml +22 -3
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/template.py +10 -2
- llama_stack/distributions/watsonx/config.yaml +4 -1
- llama_stack/log.py +1 -0
- llama_stack/models/llama/resources/dog.jpg +0 -0
- llama_stack/models/llama/resources/pasta.jpeg +0 -0
- llama_stack/models/llama/resources/small_dog.jpg +0 -0
- llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/agents.py +57 -61
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +183 -60
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +94 -22
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
- llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
- llama_stack/providers/inline/batches/reference/batches.py +2 -1
- llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
- llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
- llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
- llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +15 -18
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
- llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
- llama_stack/providers/registry/agents.py +1 -0
- llama_stack/providers/registry/inference.py +1 -9
- llama_stack/providers/registry/vector_io.py +136 -16
- llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
- llama_stack/providers/remote/eval/nvidia/README.md +134 -0
- llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
- llama_stack/providers/remote/files/s3/README.md +266 -0
- llama_stack/providers/remote/files/s3/config.py +5 -3
- llama_stack/providers/remote/files/s3/files.py +2 -2
- llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
- llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
- llama_stack/providers/remote/inference/openai/openai.py +2 -0
- llama_stack/providers/remote/inference/together/together.py +4 -0
- llama_stack/providers/remote/inference/vertexai/config.py +3 -3
- llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
- llama_stack/providers/remote/inference/vllm/config.py +37 -18
- llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
- llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
- llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
- llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
- llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
- llama_stack/providers/remote/safety/nvidia/README.md +78 -0
- llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
- llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
- llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
- llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
- llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
- llama_stack/providers/remote/vector_io/oci/config.py +41 -0
- llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
- llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
- llama_stack/providers/utils/bedrock/client.py +3 -3
- llama_stack/providers/utils/bedrock/config.py +7 -7
- llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
- llama_stack/providers/utils/inference/http_client.py +239 -0
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +5 -0
- llama_stack/providers/utils/inference/model_registry.py +148 -2
- llama_stack/providers/utils/inference/openai_compat.py +2 -1
- llama_stack/providers/utils/inference/openai_mixin.py +41 -2
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
- llama_stack/providers/utils/memory/vector_store.py +46 -19
- llama_stack/providers/utils/responses/responses_store.py +40 -6
- llama_stack/providers/utils/safety.py +114 -0
- llama_stack/providers/utils/tools/mcp.py +44 -3
- llama_stack/testing/api_recorder.py +9 -3
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/METADATA +14 -2
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/RECORD +131 -275
- llama_stack-0.5.0rc1.dist-info/top_level.txt +1 -0
- llama_stack/distributions/meta-reference-gpu/__init__.py +0 -7
- llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
- llama_stack/models/llama/hadamard_utils.py +0 -88
- llama_stack/models/llama/llama3/args.py +0 -74
- llama_stack/models/llama/llama3/generation.py +0 -378
- llama_stack/models/llama/llama3/model.py +0 -304
- llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
- llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
- llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
- llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
- llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama3/quantization/loader.py +0 -316
- llama_stack/models/llama/llama3_1/__init__.py +0 -12
- llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
- llama_stack/models/llama/llama3_1/prompts.py +0 -258
- llama_stack/models/llama/llama3_2/__init__.py +0 -5
- llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
- llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
- llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
- llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
- llama_stack/models/llama/llama3_3/__init__.py +0 -5
- llama_stack/models/llama/llama3_3/prompts.py +0 -259
- llama_stack/models/llama/llama4/args.py +0 -107
- llama_stack/models/llama/llama4/ffn.py +0 -58
- llama_stack/models/llama/llama4/moe.py +0 -214
- llama_stack/models/llama/llama4/preprocess.py +0 -435
- llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama4/quantization/loader.py +0 -226
- llama_stack/models/llama/llama4/vision/__init__.py +0 -5
- llama_stack/models/llama/llama4/vision/embedding.py +0 -210
- llama_stack/models/llama/llama4/vision/encoder.py +0 -412
- llama_stack/models/llama/quantize_impls.py +0 -316
- llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
- llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
- llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
- llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
- llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
- llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
- llama_stack-0.4.3.dist-info/top_level.txt +0 -2
- llama_stack_api/__init__.py +0 -945
- llama_stack_api/admin/__init__.py +0 -45
- llama_stack_api/admin/api.py +0 -72
- llama_stack_api/admin/fastapi_routes.py +0 -117
- llama_stack_api/admin/models.py +0 -113
- llama_stack_api/agents.py +0 -173
- llama_stack_api/batches/__init__.py +0 -40
- llama_stack_api/batches/api.py +0 -53
- llama_stack_api/batches/fastapi_routes.py +0 -113
- llama_stack_api/batches/models.py +0 -78
- llama_stack_api/benchmarks/__init__.py +0 -43
- llama_stack_api/benchmarks/api.py +0 -39
- llama_stack_api/benchmarks/fastapi_routes.py +0 -109
- llama_stack_api/benchmarks/models.py +0 -109
- llama_stack_api/common/__init__.py +0 -5
- llama_stack_api/common/content_types.py +0 -101
- llama_stack_api/common/errors.py +0 -95
- llama_stack_api/common/job_types.py +0 -38
- llama_stack_api/common/responses.py +0 -77
- llama_stack_api/common/training_types.py +0 -47
- llama_stack_api/common/type_system.py +0 -146
- llama_stack_api/connectors.py +0 -146
- llama_stack_api/conversations.py +0 -270
- llama_stack_api/datasetio.py +0 -55
- llama_stack_api/datasets/__init__.py +0 -61
- llama_stack_api/datasets/api.py +0 -35
- llama_stack_api/datasets/fastapi_routes.py +0 -104
- llama_stack_api/datasets/models.py +0 -152
- llama_stack_api/datatypes.py +0 -373
- llama_stack_api/eval.py +0 -137
- llama_stack_api/file_processors/__init__.py +0 -27
- llama_stack_api/file_processors/api.py +0 -64
- llama_stack_api/file_processors/fastapi_routes.py +0 -78
- llama_stack_api/file_processors/models.py +0 -42
- llama_stack_api/files/__init__.py +0 -35
- llama_stack_api/files/api.py +0 -51
- llama_stack_api/files/fastapi_routes.py +0 -124
- llama_stack_api/files/models.py +0 -107
- llama_stack_api/inference.py +0 -1169
- llama_stack_api/inspect_api/__init__.py +0 -37
- llama_stack_api/inspect_api/api.py +0 -25
- llama_stack_api/inspect_api/fastapi_routes.py +0 -76
- llama_stack_api/inspect_api/models.py +0 -28
- llama_stack_api/internal/kvstore.py +0 -28
- llama_stack_api/internal/sqlstore.py +0 -81
- llama_stack_api/llama_stack_api/__init__.py +0 -945
- llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
- llama_stack_api/llama_stack_api/admin/api.py +0 -72
- llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
- llama_stack_api/llama_stack_api/admin/models.py +0 -113
- llama_stack_api/llama_stack_api/agents.py +0 -173
- llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
- llama_stack_api/llama_stack_api/batches/api.py +0 -53
- llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
- llama_stack_api/llama_stack_api/batches/models.py +0 -78
- llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
- llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
- llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
- llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
- llama_stack_api/llama_stack_api/common/__init__.py +0 -5
- llama_stack_api/llama_stack_api/common/content_types.py +0 -101
- llama_stack_api/llama_stack_api/common/errors.py +0 -95
- llama_stack_api/llama_stack_api/common/job_types.py +0 -38
- llama_stack_api/llama_stack_api/common/responses.py +0 -77
- llama_stack_api/llama_stack_api/common/training_types.py +0 -47
- llama_stack_api/llama_stack_api/common/type_system.py +0 -146
- llama_stack_api/llama_stack_api/connectors.py +0 -146
- llama_stack_api/llama_stack_api/conversations.py +0 -270
- llama_stack_api/llama_stack_api/datasetio.py +0 -55
- llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
- llama_stack_api/llama_stack_api/datasets/api.py +0 -35
- llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
- llama_stack_api/llama_stack_api/datasets/models.py +0 -152
- llama_stack_api/llama_stack_api/datatypes.py +0 -373
- llama_stack_api/llama_stack_api/eval.py +0 -137
- llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
- llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
- llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
- llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
- llama_stack_api/llama_stack_api/files/__init__.py +0 -35
- llama_stack_api/llama_stack_api/files/api.py +0 -51
- llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
- llama_stack_api/llama_stack_api/files/models.py +0 -107
- llama_stack_api/llama_stack_api/inference.py +0 -1169
- llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
- llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
- llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
- llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
- llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
- llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
- llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
- llama_stack_api/llama_stack_api/models.py +0 -171
- llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
- llama_stack_api/llama_stack_api/post_training.py +0 -370
- llama_stack_api/llama_stack_api/prompts.py +0 -203
- llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
- llama_stack_api/llama_stack_api/providers/api.py +0 -16
- llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
- llama_stack_api/llama_stack_api/providers/models.py +0 -24
- llama_stack_api/llama_stack_api/py.typed +0 -0
- llama_stack_api/llama_stack_api/rag_tool.py +0 -168
- llama_stack_api/llama_stack_api/resource.py +0 -37
- llama_stack_api/llama_stack_api/router_utils.py +0 -160
- llama_stack_api/llama_stack_api/safety.py +0 -132
- llama_stack_api/llama_stack_api/schema_utils.py +0 -208
- llama_stack_api/llama_stack_api/scoring.py +0 -93
- llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
- llama_stack_api/llama_stack_api/shields.py +0 -93
- llama_stack_api/llama_stack_api/tools.py +0 -226
- llama_stack_api/llama_stack_api/vector_io.py +0 -941
- llama_stack_api/llama_stack_api/vector_stores.py +0 -53
- llama_stack_api/llama_stack_api/version.py +0 -9
- llama_stack_api/models.py +0 -171
- llama_stack_api/openai_responses.py +0 -1468
- llama_stack_api/post_training.py +0 -370
- llama_stack_api/prompts.py +0 -203
- llama_stack_api/providers/__init__.py +0 -33
- llama_stack_api/providers/api.py +0 -16
- llama_stack_api/providers/fastapi_routes.py +0 -57
- llama_stack_api/providers/models.py +0 -24
- llama_stack_api/py.typed +0 -0
- llama_stack_api/rag_tool.py +0 -168
- llama_stack_api/resource.py +0 -37
- llama_stack_api/router_utils.py +0 -160
- llama_stack_api/safety.py +0 -132
- llama_stack_api/schema_utils.py +0 -208
- llama_stack_api/scoring.py +0 -93
- llama_stack_api/scoring_functions.py +0 -211
- llama_stack_api/shields.py +0 -93
- llama_stack_api/tools.py +0 -226
- llama_stack_api/vector_io.py +0 -941
- llama_stack_api/vector_stores.py +0 -53
- llama_stack_api/version.py +0 -9
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/WHEEL +0 -0
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.4.3.dist-info → llama_stack-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
import Foundation
|
|
2
|
+
|
|
3
|
+
import LLaMARunner
|
|
4
|
+
import LlamaStackClient
|
|
5
|
+
|
|
6
|
+
class RunnerHolder: ObservableObject {
|
|
7
|
+
var runner: Runner?
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
public class LocalInference: Inference {
|
|
11
|
+
private var runnerHolder = RunnerHolder()
|
|
12
|
+
private let runnerQueue: DispatchQueue
|
|
13
|
+
|
|
14
|
+
public init (queue: DispatchQueue) {
|
|
15
|
+
runnerQueue = queue
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
public func loadModel(modelPath: String, tokenizerPath: String, completion: @escaping (Result<Void, Error>) -> Void) {
|
|
19
|
+
runnerHolder.runner = runnerHolder.runner ?? Runner(
|
|
20
|
+
modelPath: modelPath,
|
|
21
|
+
tokenizerPath: tokenizerPath
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
runnerQueue.async {
|
|
26
|
+
let runner = self.runnerHolder.runner
|
|
27
|
+
do {
|
|
28
|
+
try runner!.load()
|
|
29
|
+
completion(.success(()))
|
|
30
|
+
} catch let loadError {
|
|
31
|
+
print("error: " + loadError.localizedDescription)
|
|
32
|
+
completion(.failure(loadError))
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
public func stop() {
|
|
38
|
+
runnerHolder.runner?.stop()
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
public func chatCompletion(request: Components.Schemas.ChatCompletionRequest) -> AsyncStream<Components.Schemas.ChatCompletionResponseStreamChunk> {
|
|
42
|
+
return AsyncStream { continuation in
|
|
43
|
+
let workItem = DispatchWorkItem {
|
|
44
|
+
do {
|
|
45
|
+
var tokens: [String] = []
|
|
46
|
+
|
|
47
|
+
let prompt = try encodeDialogPrompt(messages: prepareMessages(request: request))
|
|
48
|
+
var stopReason: Components.Schemas.CompletionMessage.stop_reasonPayload? = nil
|
|
49
|
+
var buffer = ""
|
|
50
|
+
var ipython = false
|
|
51
|
+
var echoDropped = false
|
|
52
|
+
|
|
53
|
+
try self.runnerHolder.runner?.generate(prompt, sequenceLength: 4096) { token in
|
|
54
|
+
buffer += token
|
|
55
|
+
|
|
56
|
+
// HACK: Workaround until LlamaRunner exposes echo param
|
|
57
|
+
if (!echoDropped) {
|
|
58
|
+
if (buffer.hasPrefix(prompt)) {
|
|
59
|
+
buffer = String(buffer.dropFirst(prompt.count))
|
|
60
|
+
echoDropped = true
|
|
61
|
+
}
|
|
62
|
+
return
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
tokens.append(token)
|
|
66
|
+
|
|
67
|
+
if !ipython && (buffer.starts(with: "<|python_tag|>") || buffer.starts(with: "[") ) {
|
|
68
|
+
ipython = true
|
|
69
|
+
continuation.yield(
|
|
70
|
+
Components.Schemas.ChatCompletionResponseStreamChunk(
|
|
71
|
+
event: Components.Schemas.ChatCompletionResponseEvent(
|
|
72
|
+
event_type: .progress,
|
|
73
|
+
delta: .tool_call(Components.Schemas.ToolCallDelta(
|
|
74
|
+
_type: Components.Schemas.ToolCallDelta._typePayload.tool_call,
|
|
75
|
+
tool_call: .case1(""),
|
|
76
|
+
parse_status: Components.Schemas.ToolCallDelta.parse_statusPayload.started
|
|
77
|
+
)
|
|
78
|
+
)
|
|
79
|
+
)
|
|
80
|
+
)
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
if (buffer.starts(with: "<|python_tag|>")) {
|
|
84
|
+
buffer = String(buffer.dropFirst("<|python_tag|>".count))
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
// TODO: Non-streaming lobprobs
|
|
89
|
+
|
|
90
|
+
var text = ""
|
|
91
|
+
if token == "<|eot_id|>" {
|
|
92
|
+
stopReason = Components.Schemas.CompletionMessage.stop_reasonPayload.end_of_turn
|
|
93
|
+
} else if token == "<|eom_id|>" {
|
|
94
|
+
stopReason = Components.Schemas.CompletionMessage.stop_reasonPayload.end_of_message
|
|
95
|
+
} else {
|
|
96
|
+
text = token
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
var delta: Components.Schemas.ContentDelta
|
|
100
|
+
if ipython {
|
|
101
|
+
delta = .tool_call(Components.Schemas.ToolCallDelta(
|
|
102
|
+
_type: .tool_call,
|
|
103
|
+
tool_call: .case1(text),
|
|
104
|
+
parse_status: .in_progress
|
|
105
|
+
))
|
|
106
|
+
} else {
|
|
107
|
+
delta = .text(Components.Schemas.TextDelta(
|
|
108
|
+
_type: Components.Schemas.TextDelta._typePayload.text,
|
|
109
|
+
text: text
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
if stopReason == nil {
|
|
115
|
+
continuation.yield(
|
|
116
|
+
Components.Schemas.ChatCompletionResponseStreamChunk(
|
|
117
|
+
event: Components.Schemas.ChatCompletionResponseEvent(
|
|
118
|
+
event_type: .progress,
|
|
119
|
+
delta: delta
|
|
120
|
+
)
|
|
121
|
+
)
|
|
122
|
+
)
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
if stopReason == nil {
|
|
127
|
+
stopReason = Components.Schemas.CompletionMessage.stop_reasonPayload.out_of_tokens
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
let message = decodeAssistantMessage(tokens: tokens.joined(), stopReason: stopReason!)
|
|
131
|
+
// TODO: non-streaming support
|
|
132
|
+
|
|
133
|
+
let didParseToolCalls = message.tool_calls?.count ?? 0 > 0
|
|
134
|
+
if ipython && !didParseToolCalls {
|
|
135
|
+
continuation.yield(
|
|
136
|
+
Components.Schemas.ChatCompletionResponseStreamChunk(
|
|
137
|
+
event: Components.Schemas.ChatCompletionResponseEvent(
|
|
138
|
+
event_type: .progress,
|
|
139
|
+
delta: .tool_call(Components.Schemas.ToolCallDelta(
|
|
140
|
+
_type: Components.Schemas.ToolCallDelta._typePayload.tool_call,
|
|
141
|
+
tool_call: .case1(""),
|
|
142
|
+
parse_status: Components.Schemas.ToolCallDelta.parse_statusPayload.failed
|
|
143
|
+
)
|
|
144
|
+
)
|
|
145
|
+
)
|
|
146
|
+
// TODO: stopReason
|
|
147
|
+
)
|
|
148
|
+
)
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
for toolCall in message.tool_calls! {
|
|
152
|
+
continuation.yield(
|
|
153
|
+
Components.Schemas.ChatCompletionResponseStreamChunk(
|
|
154
|
+
event: Components.Schemas.ChatCompletionResponseEvent(
|
|
155
|
+
event_type: .progress,
|
|
156
|
+
delta: .tool_call(Components.Schemas.ToolCallDelta(
|
|
157
|
+
_type: Components.Schemas.ToolCallDelta._typePayload.tool_call,
|
|
158
|
+
tool_call: Components.Schemas.ToolCallDelta.tool_callPayload.ToolCall(toolCall),
|
|
159
|
+
parse_status: Components.Schemas.ToolCallDelta.parse_statusPayload.succeeded
|
|
160
|
+
)
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
// TODO: stopReason
|
|
164
|
+
)
|
|
165
|
+
)
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
continuation.yield(
|
|
169
|
+
Components.Schemas.ChatCompletionResponseStreamChunk(
|
|
170
|
+
event: Components.Schemas.ChatCompletionResponseEvent(
|
|
171
|
+
event_type: .complete,
|
|
172
|
+
delta: .text(Components.Schemas.TextDelta(
|
|
173
|
+
_type: Components.Schemas.TextDelta._typePayload.text,
|
|
174
|
+
text: ""
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
)
|
|
178
|
+
// TODO: stopReason
|
|
179
|
+
)
|
|
180
|
+
)
|
|
181
|
+
}
|
|
182
|
+
catch (let error) {
|
|
183
|
+
print("Inference error: " + error.localizedDescription)
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
runnerQueue.async(execute: workItem)
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
}
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
import Foundation
|
|
2
|
+
|
|
3
|
+
import LlamaStackClient
|
|
4
|
+
|
|
5
|
+
func encodeHeader(role: String) -> String {
|
|
6
|
+
return "<|start_header_id|>\(role)<|end_header_id|>\n\n"
|
|
7
|
+
}
|
|
8
|
+
|
|
9
|
+
func encodeDialogPrompt(messages: [Components.Schemas.Message]) -> String {
|
|
10
|
+
var prompt = ""
|
|
11
|
+
|
|
12
|
+
prompt.append("<|begin_of_text|>")
|
|
13
|
+
for message in messages {
|
|
14
|
+
let msg = encodeMessage(message: message)
|
|
15
|
+
prompt += msg
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
prompt.append(encodeHeader(role: "assistant"))
|
|
19
|
+
|
|
20
|
+
return prompt
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
func getRole(message: Components.Schemas.Message) -> String {
|
|
24
|
+
switch (message) {
|
|
25
|
+
case .user(let m):
|
|
26
|
+
return m.role.rawValue
|
|
27
|
+
case .system(let m):
|
|
28
|
+
return m.role.rawValue
|
|
29
|
+
case .tool(let m):
|
|
30
|
+
return m.role.rawValue
|
|
31
|
+
case .assistant(let m):
|
|
32
|
+
return m.role.rawValue
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
func encodeMessage(message: Components.Schemas.Message) -> String {
|
|
37
|
+
var prompt = encodeHeader(role: getRole(message: message))
|
|
38
|
+
|
|
39
|
+
switch (message) {
|
|
40
|
+
case .assistant(let m):
|
|
41
|
+
if (m.tool_calls?.count ?? 0 > 0) {
|
|
42
|
+
prompt += "<|python_tag|>"
|
|
43
|
+
}
|
|
44
|
+
default:0
|
|
45
|
+
break
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
func _processContent(_ content: Any) -> String {
|
|
49
|
+
func _process(_ c: Any) {
|
|
50
|
+
if let str = c as? String {
|
|
51
|
+
prompt += str
|
|
52
|
+
}
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
if let str = content as? String {
|
|
56
|
+
_process(str)
|
|
57
|
+
} else if let list = content as? [Any] {
|
|
58
|
+
for c in list {
|
|
59
|
+
_process(c)
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
return ""
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
switch (message) {
|
|
67
|
+
case .user(let m):
|
|
68
|
+
prompt += _processContent(m.content)
|
|
69
|
+
case .system(let m):
|
|
70
|
+
prompt += _processContent(m.content)
|
|
71
|
+
case .tool(let m):
|
|
72
|
+
prompt += _processContent(m.content)
|
|
73
|
+
case .assistant(let m):
|
|
74
|
+
prompt += _processContent(m.content)
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
var eom = false
|
|
78
|
+
|
|
79
|
+
switch (message) {
|
|
80
|
+
case .user(let m):
|
|
81
|
+
switch (m.content) {
|
|
82
|
+
case .case1(let c):
|
|
83
|
+
prompt += _processContent(c)
|
|
84
|
+
case .InterleavedContentItem(let c):
|
|
85
|
+
prompt += _processContent(c)
|
|
86
|
+
case .case3(let c):
|
|
87
|
+
prompt += _processContent(c)
|
|
88
|
+
}
|
|
89
|
+
case .assistant(let m):
|
|
90
|
+
// TODO: Support encoding past tool call history
|
|
91
|
+
// for t in m.tool_calls {
|
|
92
|
+
// _processContent(t.)
|
|
93
|
+
//}
|
|
94
|
+
eom = m.stop_reason == Components.Schemas.CompletionMessage.stop_reasonPayload.end_of_message
|
|
95
|
+
case .system(_):
|
|
96
|
+
break
|
|
97
|
+
case .tool(_):
|
|
98
|
+
break
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
if (eom) {
|
|
102
|
+
prompt += "<|eom_id|>"
|
|
103
|
+
} else {
|
|
104
|
+
prompt += "<|eot_id|>"
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
return prompt
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
func prepareMessages(request: Components.Schemas.ChatCompletionRequest) throws -> [Components.Schemas.Message] {
|
|
111
|
+
var existingMessages = request.messages
|
|
112
|
+
var existingSystemMessage: Components.Schemas.Message?
|
|
113
|
+
// TODO: Existing system message
|
|
114
|
+
|
|
115
|
+
var messages: [Components.Schemas.Message] = []
|
|
116
|
+
|
|
117
|
+
let defaultGen = SystemDefaultGenerator()
|
|
118
|
+
let defaultTemplate = defaultGen.gen()
|
|
119
|
+
|
|
120
|
+
var sysContent = ""
|
|
121
|
+
|
|
122
|
+
// TODO: Built-in tools
|
|
123
|
+
|
|
124
|
+
sysContent += try defaultTemplate.render()
|
|
125
|
+
|
|
126
|
+
messages.append(.system(Components.Schemas.SystemMessage(
|
|
127
|
+
role: .system,
|
|
128
|
+
content: .case1(sysContent)
|
|
129
|
+
))
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
if request.tools?.isEmpty == false {
|
|
133
|
+
// TODO: Separate built-ins and custom tools (right now everything treated as custom)
|
|
134
|
+
let toolGen = FunctionTagCustomToolGenerator()
|
|
135
|
+
let toolTemplate = try toolGen.gen(customTools: request.tools!)
|
|
136
|
+
let tools = try toolTemplate.render()
|
|
137
|
+
messages.append(.user(Components.Schemas.UserMessage(
|
|
138
|
+
role: .user,
|
|
139
|
+
content: .case1(tools))
|
|
140
|
+
))
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
messages.append(contentsOf: existingMessages)
|
|
144
|
+
|
|
145
|
+
return messages
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
struct FunctionCall {
|
|
149
|
+
let name: String
|
|
150
|
+
let params: [String: Any]
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
public func maybeExtractCustomToolCalls(input: String) -> [Components.Schemas.ToolCall] {
|
|
154
|
+
guard input.hasPrefix("[") && input.hasSuffix("]") else {
|
|
155
|
+
return []
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
do {
|
|
159
|
+
let trimmed = input.trimmingCharacters(in: CharacterSet(charactersIn: "[]"))
|
|
160
|
+
let calls = trimmed.components(separatedBy: "),").map { $0.hasSuffix(")") ? $0 : $0 + ")" }
|
|
161
|
+
|
|
162
|
+
var result: [Components.Schemas.ToolCall] = []
|
|
163
|
+
|
|
164
|
+
for call in calls {
|
|
165
|
+
guard let nameEndIndex = call.firstIndex(of: "("),
|
|
166
|
+
let paramsStartIndex = call.firstIndex(of: "{"),
|
|
167
|
+
let paramsEndIndex = call.lastIndex(of: "}") else {
|
|
168
|
+
return []
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
let name = String(call[..<nameEndIndex]).trimmingCharacters(in: .whitespacesAndNewlines)
|
|
172
|
+
let paramsString = String(call[paramsStartIndex...paramsEndIndex])
|
|
173
|
+
|
|
174
|
+
guard let data = paramsString.data(using: .utf8),
|
|
175
|
+
let params = try? JSONSerialization.jsonObject(with: data, options: []) as? [String: Any] else {
|
|
176
|
+
return []
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
var props: [String : Components.Schemas.ToolCall.argumentsPayload.additionalPropertiesPayload] = [:]
|
|
180
|
+
for (param_name, param) in params {
|
|
181
|
+
switch (param) {
|
|
182
|
+
case let value as String:
|
|
183
|
+
props[param_name] = .case1(value)
|
|
184
|
+
case let value as Int:
|
|
185
|
+
props[param_name] = .case2(value)
|
|
186
|
+
case let value as Double:
|
|
187
|
+
props[param_name] = .case3(value)
|
|
188
|
+
case let value as Bool:
|
|
189
|
+
props[param_name] = .case4(value)
|
|
190
|
+
default:
|
|
191
|
+
return []
|
|
192
|
+
}
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
result.append(
|
|
196
|
+
Components.Schemas.ToolCall(
|
|
197
|
+
call_id: UUID().uuidString,
|
|
198
|
+
tool_name: .case2(name), // custom_tool
|
|
199
|
+
arguments: .init(additionalProperties: props)
|
|
200
|
+
)
|
|
201
|
+
)
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
return result.isEmpty ? [] : result
|
|
205
|
+
} catch {
|
|
206
|
+
return []
|
|
207
|
+
}
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
func decodeAssistantMessage(tokens: String, stopReason: Components.Schemas.CompletionMessage.stop_reasonPayload) -> Components.Schemas.CompletionMessage {
|
|
211
|
+
var content = tokens
|
|
212
|
+
|
|
213
|
+
let roles = ["user", "system", "assistant"]
|
|
214
|
+
for role in roles {
|
|
215
|
+
let headerStr = encodeHeader(role: role)
|
|
216
|
+
if content.hasPrefix(headerStr) {
|
|
217
|
+
content = String(content.dropFirst(encodeHeader(role: role).count))
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
if content.hasPrefix("<|python_tag|>") {
|
|
222
|
+
content = String(content.dropFirst("<|python_tag|>".count))
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
if content.hasSuffix("<|eot_id|>") {
|
|
227
|
+
content = String(content.dropLast("<|eot_id|>".count))
|
|
228
|
+
} else {
|
|
229
|
+
content = String(content.dropLast("<|eom_id|>".count))
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
return Components.Schemas.CompletionMessage(
|
|
233
|
+
role: .assistant,
|
|
234
|
+
content: .case1(content),
|
|
235
|
+
stop_reason: stopReason,
|
|
236
|
+
tool_calls: maybeExtractCustomToolCalls(input: content)
|
|
237
|
+
)
|
|
238
|
+
}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import Foundation
|
|
2
|
+
import Stencil
|
|
3
|
+
|
|
4
|
+
public struct PromptTemplate {
|
|
5
|
+
let template: String
|
|
6
|
+
let data: [String: Any]
|
|
7
|
+
|
|
8
|
+
public func render() throws -> String {
|
|
9
|
+
let template = Template(templateString: self.template)
|
|
10
|
+
return try template.render(self.data)
|
|
11
|
+
}
|
|
12
|
+
}
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
import Foundation
|
|
2
|
+
|
|
3
|
+
import LlamaStackClient
|
|
4
|
+
|
|
5
|
+
func convertToNativeSwiftType(_ value: Any) -> Any {
|
|
6
|
+
switch value {
|
|
7
|
+
case let number as NSNumber:
|
|
8
|
+
if CFGetTypeID(number) == CFBooleanGetTypeID() {
|
|
9
|
+
return number.boolValue
|
|
10
|
+
}
|
|
11
|
+
if floor(number.doubleValue) == number.doubleValue {
|
|
12
|
+
return number.intValue
|
|
13
|
+
}
|
|
14
|
+
return number.doubleValue
|
|
15
|
+
case let string as String:
|
|
16
|
+
return string
|
|
17
|
+
case let array as [Any]:
|
|
18
|
+
return array.map(convertToNativeSwiftType)
|
|
19
|
+
case let dict as [String: Any]:
|
|
20
|
+
return dict.mapValues(convertToNativeSwiftType)
|
|
21
|
+
case is NSNull:
|
|
22
|
+
return NSNull()
|
|
23
|
+
default:
|
|
24
|
+
return value
|
|
25
|
+
}
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
public class SystemDefaultGenerator {
|
|
29
|
+
public init() {}
|
|
30
|
+
|
|
31
|
+
public func gen() -> PromptTemplate {
|
|
32
|
+
let templateStr = """
|
|
33
|
+
Cutting Knowledge Date: December 2023
|
|
34
|
+
Today Date: {{ today }}
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
let dateFormatter = DateFormatter()
|
|
38
|
+
dateFormatter.dateFormat = "dd MMMM yyyy"
|
|
39
|
+
|
|
40
|
+
return PromptTemplate(
|
|
41
|
+
template: templateStr,
|
|
42
|
+
data: ["today": dateFormatter.string(from: Date())]
|
|
43
|
+
)
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
public class FunctionTagCustomToolGenerator {
|
|
49
|
+
public init() {}
|
|
50
|
+
|
|
51
|
+
public func gen(customTools: [Components.Schemas.ToolDefinition]) throws -> PromptTemplate {
|
|
52
|
+
// TODO: required params
|
|
53
|
+
// TODO: {{#unless @last}},{{/unless}}
|
|
54
|
+
|
|
55
|
+
let templateStr = """
|
|
56
|
+
You are an expert in composing functions. You are given a question and a set of possible functions.
|
|
57
|
+
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
|
58
|
+
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
|
59
|
+
also point it out. You should only return the function call in tools call sections.
|
|
60
|
+
|
|
61
|
+
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
|
62
|
+
You SHOULD NOT include any other text in the response.
|
|
63
|
+
|
|
64
|
+
Here is a list of functions in JSON format that you can invoke.
|
|
65
|
+
|
|
66
|
+
[
|
|
67
|
+
{% for t in custom_tools %}
|
|
68
|
+
{
|
|
69
|
+
"name": "{{t.tool_name}}",
|
|
70
|
+
"description": "{{t.description}}",
|
|
71
|
+
"input_schema": { {{t.input_schema}} }
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
{{/let}}
|
|
75
|
+
{% endfor -%}
|
|
76
|
+
]
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
let encoder = JSONEncoder()
|
|
80
|
+
return PromptTemplate(
|
|
81
|
+
template: templateStr,
|
|
82
|
+
data: ["custom_tools": try customTools.map {
|
|
83
|
+
let data = try encoder.encode($0)
|
|
84
|
+
let obj = try JSONSerialization.jsonObject(with: data)
|
|
85
|
+
return convertToNativeSwiftType(obj)
|
|
86
|
+
}]
|
|
87
|
+
)
|
|
88
|
+
}
|
|
89
|
+
}
|