llama-stack 0.4.3__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/distributions/dell/doc_template.md +209 -0
- llama_stack/distributions/meta-reference-gpu/doc_template.md +119 -0
- llama_stack/distributions/nvidia/doc_template.md +170 -0
- llama_stack/distributions/oci/doc_template.md +140 -0
- llama_stack/models/llama/llama3/dog.jpg +0 -0
- llama_stack/models/llama/llama3/pasta.jpeg +0 -0
- llama_stack/models/llama/resources/dog.jpg +0 -0
- llama_stack/models/llama/resources/pasta.jpeg +0 -0
- llama_stack/models/llama/resources/small_dog.jpg +0 -0
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +136 -11
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.h +9 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/LocalInference.swift +189 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/Parsing.swift +238 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/PromptTemplate.swift +12 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl/SystemPrompts.swift +89 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.pbxproj +550 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/contents.xcworkspacedata +7 -0
- llama_stack/providers/inline/ios/inference/LocalInferenceImpl.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
- llama_stack/providers/remote/datasetio/nvidia/README.md +74 -0
- llama_stack/providers/remote/eval/nvidia/README.md +134 -0
- llama_stack/providers/remote/files/s3/README.md +266 -0
- llama_stack/providers/remote/inference/nvidia/NVIDIA.md +203 -0
- llama_stack/providers/remote/post_training/nvidia/README.md +151 -0
- llama_stack/providers/remote/safety/nvidia/README.md +78 -0
- llama_stack/providers/utils/responses/responses_store.py +34 -0
- {llama_stack-0.4.3.dist-info → llama_stack-0.4.4.dist-info}/METADATA +2 -2
- {llama_stack-0.4.3.dist-info → llama_stack-0.4.4.dist-info}/RECORD +31 -142
- llama_stack-0.4.4.dist-info/top_level.txt +1 -0
- llama_stack-0.4.3.dist-info/top_level.txt +0 -2
- llama_stack_api/__init__.py +0 -945
- llama_stack_api/admin/__init__.py +0 -45
- llama_stack_api/admin/api.py +0 -72
- llama_stack_api/admin/fastapi_routes.py +0 -117
- llama_stack_api/admin/models.py +0 -113
- llama_stack_api/agents.py +0 -173
- llama_stack_api/batches/__init__.py +0 -40
- llama_stack_api/batches/api.py +0 -53
- llama_stack_api/batches/fastapi_routes.py +0 -113
- llama_stack_api/batches/models.py +0 -78
- llama_stack_api/benchmarks/__init__.py +0 -43
- llama_stack_api/benchmarks/api.py +0 -39
- llama_stack_api/benchmarks/fastapi_routes.py +0 -109
- llama_stack_api/benchmarks/models.py +0 -109
- llama_stack_api/common/__init__.py +0 -5
- llama_stack_api/common/content_types.py +0 -101
- llama_stack_api/common/errors.py +0 -95
- llama_stack_api/common/job_types.py +0 -38
- llama_stack_api/common/responses.py +0 -77
- llama_stack_api/common/training_types.py +0 -47
- llama_stack_api/common/type_system.py +0 -146
- llama_stack_api/connectors.py +0 -146
- llama_stack_api/conversations.py +0 -270
- llama_stack_api/datasetio.py +0 -55
- llama_stack_api/datasets/__init__.py +0 -61
- llama_stack_api/datasets/api.py +0 -35
- llama_stack_api/datasets/fastapi_routes.py +0 -104
- llama_stack_api/datasets/models.py +0 -152
- llama_stack_api/datatypes.py +0 -373
- llama_stack_api/eval.py +0 -137
- llama_stack_api/file_processors/__init__.py +0 -27
- llama_stack_api/file_processors/api.py +0 -64
- llama_stack_api/file_processors/fastapi_routes.py +0 -78
- llama_stack_api/file_processors/models.py +0 -42
- llama_stack_api/files/__init__.py +0 -35
- llama_stack_api/files/api.py +0 -51
- llama_stack_api/files/fastapi_routes.py +0 -124
- llama_stack_api/files/models.py +0 -107
- llama_stack_api/inference.py +0 -1169
- llama_stack_api/inspect_api/__init__.py +0 -37
- llama_stack_api/inspect_api/api.py +0 -25
- llama_stack_api/inspect_api/fastapi_routes.py +0 -76
- llama_stack_api/inspect_api/models.py +0 -28
- llama_stack_api/internal/__init__.py +0 -9
- llama_stack_api/internal/kvstore.py +0 -28
- llama_stack_api/internal/sqlstore.py +0 -81
- llama_stack_api/llama_stack_api/__init__.py +0 -945
- llama_stack_api/llama_stack_api/admin/__init__.py +0 -45
- llama_stack_api/llama_stack_api/admin/api.py +0 -72
- llama_stack_api/llama_stack_api/admin/fastapi_routes.py +0 -117
- llama_stack_api/llama_stack_api/admin/models.py +0 -113
- llama_stack_api/llama_stack_api/agents.py +0 -173
- llama_stack_api/llama_stack_api/batches/__init__.py +0 -40
- llama_stack_api/llama_stack_api/batches/api.py +0 -53
- llama_stack_api/llama_stack_api/batches/fastapi_routes.py +0 -113
- llama_stack_api/llama_stack_api/batches/models.py +0 -78
- llama_stack_api/llama_stack_api/benchmarks/__init__.py +0 -43
- llama_stack_api/llama_stack_api/benchmarks/api.py +0 -39
- llama_stack_api/llama_stack_api/benchmarks/fastapi_routes.py +0 -109
- llama_stack_api/llama_stack_api/benchmarks/models.py +0 -109
- llama_stack_api/llama_stack_api/common/__init__.py +0 -5
- llama_stack_api/llama_stack_api/common/content_types.py +0 -101
- llama_stack_api/llama_stack_api/common/errors.py +0 -95
- llama_stack_api/llama_stack_api/common/job_types.py +0 -38
- llama_stack_api/llama_stack_api/common/responses.py +0 -77
- llama_stack_api/llama_stack_api/common/training_types.py +0 -47
- llama_stack_api/llama_stack_api/common/type_system.py +0 -146
- llama_stack_api/llama_stack_api/connectors.py +0 -146
- llama_stack_api/llama_stack_api/conversations.py +0 -270
- llama_stack_api/llama_stack_api/datasetio.py +0 -55
- llama_stack_api/llama_stack_api/datasets/__init__.py +0 -61
- llama_stack_api/llama_stack_api/datasets/api.py +0 -35
- llama_stack_api/llama_stack_api/datasets/fastapi_routes.py +0 -104
- llama_stack_api/llama_stack_api/datasets/models.py +0 -152
- llama_stack_api/llama_stack_api/datatypes.py +0 -373
- llama_stack_api/llama_stack_api/eval.py +0 -137
- llama_stack_api/llama_stack_api/file_processors/__init__.py +0 -27
- llama_stack_api/llama_stack_api/file_processors/api.py +0 -64
- llama_stack_api/llama_stack_api/file_processors/fastapi_routes.py +0 -78
- llama_stack_api/llama_stack_api/file_processors/models.py +0 -42
- llama_stack_api/llama_stack_api/files/__init__.py +0 -35
- llama_stack_api/llama_stack_api/files/api.py +0 -51
- llama_stack_api/llama_stack_api/files/fastapi_routes.py +0 -124
- llama_stack_api/llama_stack_api/files/models.py +0 -107
- llama_stack_api/llama_stack_api/inference.py +0 -1169
- llama_stack_api/llama_stack_api/inspect_api/__init__.py +0 -37
- llama_stack_api/llama_stack_api/inspect_api/api.py +0 -25
- llama_stack_api/llama_stack_api/inspect_api/fastapi_routes.py +0 -76
- llama_stack_api/llama_stack_api/inspect_api/models.py +0 -28
- llama_stack_api/llama_stack_api/internal/__init__.py +0 -9
- llama_stack_api/llama_stack_api/internal/kvstore.py +0 -28
- llama_stack_api/llama_stack_api/internal/sqlstore.py +0 -81
- llama_stack_api/llama_stack_api/models.py +0 -171
- llama_stack_api/llama_stack_api/openai_responses.py +0 -1468
- llama_stack_api/llama_stack_api/post_training.py +0 -370
- llama_stack_api/llama_stack_api/prompts.py +0 -203
- llama_stack_api/llama_stack_api/providers/__init__.py +0 -33
- llama_stack_api/llama_stack_api/providers/api.py +0 -16
- llama_stack_api/llama_stack_api/providers/fastapi_routes.py +0 -57
- llama_stack_api/llama_stack_api/providers/models.py +0 -24
- llama_stack_api/llama_stack_api/py.typed +0 -0
- llama_stack_api/llama_stack_api/rag_tool.py +0 -168
- llama_stack_api/llama_stack_api/resource.py +0 -37
- llama_stack_api/llama_stack_api/router_utils.py +0 -160
- llama_stack_api/llama_stack_api/safety.py +0 -132
- llama_stack_api/llama_stack_api/schema_utils.py +0 -208
- llama_stack_api/llama_stack_api/scoring.py +0 -93
- llama_stack_api/llama_stack_api/scoring_functions.py +0 -211
- llama_stack_api/llama_stack_api/shields.py +0 -93
- llama_stack_api/llama_stack_api/tools.py +0 -226
- llama_stack_api/llama_stack_api/vector_io.py +0 -941
- llama_stack_api/llama_stack_api/vector_stores.py +0 -53
- llama_stack_api/llama_stack_api/version.py +0 -9
- llama_stack_api/models.py +0 -171
- llama_stack_api/openai_responses.py +0 -1468
- llama_stack_api/post_training.py +0 -370
- llama_stack_api/prompts.py +0 -203
- llama_stack_api/providers/__init__.py +0 -33
- llama_stack_api/providers/api.py +0 -16
- llama_stack_api/providers/fastapi_routes.py +0 -57
- llama_stack_api/providers/models.py +0 -24
- llama_stack_api/py.typed +0 -0
- llama_stack_api/rag_tool.py +0 -168
- llama_stack_api/resource.py +0 -37
- llama_stack_api/router_utils.py +0 -160
- llama_stack_api/safety.py +0 -132
- llama_stack_api/schema_utils.py +0 -208
- llama_stack_api/scoring.py +0 -93
- llama_stack_api/scoring_functions.py +0 -211
- llama_stack_api/shields.py +0 -93
- llama_stack_api/tools.py +0 -226
- llama_stack_api/vector_io.py +0 -941
- llama_stack_api/vector_stores.py +0 -53
- llama_stack_api/version.py +0 -9
- {llama_stack-0.4.3.dist-info → llama_stack-0.4.4.dist-info}/WHEEL +0 -0
- {llama_stack-0.4.3.dist-info → llama_stack-0.4.4.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.4.3.dist-info → llama_stack-0.4.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -324,6 +324,125 @@ class OpenAIResponsesImpl:
|
|
|
324
324
|
messages=messages,
|
|
325
325
|
)
|
|
326
326
|
|
|
327
|
+
def _prepare_input_items_for_storage(
|
|
328
|
+
self,
|
|
329
|
+
input: str | list[OpenAIResponseInput],
|
|
330
|
+
) -> list[OpenAIResponseInput]:
|
|
331
|
+
"""Prepare input items for storage, adding IDs where needed.
|
|
332
|
+
|
|
333
|
+
This method is called once at the start of streaming to prepare input items
|
|
334
|
+
that will be reused across multiple persistence calls during streaming.
|
|
335
|
+
"""
|
|
336
|
+
new_input_id = f"msg_{uuid.uuid4()}"
|
|
337
|
+
input_items_data: list[OpenAIResponseInput] = []
|
|
338
|
+
|
|
339
|
+
if isinstance(input, str):
|
|
340
|
+
input_content = OpenAIResponseInputMessageContentText(text=input)
|
|
341
|
+
input_content_item = OpenAIResponseMessage(
|
|
342
|
+
role="user",
|
|
343
|
+
content=[input_content],
|
|
344
|
+
id=new_input_id,
|
|
345
|
+
)
|
|
346
|
+
input_items_data = [input_content_item]
|
|
347
|
+
else:
|
|
348
|
+
for input_item in input:
|
|
349
|
+
if isinstance(input_item, OpenAIResponseMessage):
|
|
350
|
+
input_item_dict = input_item.model_dump()
|
|
351
|
+
if "id" not in input_item_dict:
|
|
352
|
+
input_item_dict["id"] = new_input_id
|
|
353
|
+
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
|
|
354
|
+
else:
|
|
355
|
+
input_items_data.append(input_item)
|
|
356
|
+
|
|
357
|
+
return input_items_data
|
|
358
|
+
|
|
359
|
+
async def _persist_streaming_state(
|
|
360
|
+
self,
|
|
361
|
+
stream_chunk: OpenAIResponseObjectStream,
|
|
362
|
+
orchestrator,
|
|
363
|
+
input_items: list[OpenAIResponseInput],
|
|
364
|
+
output_items: list,
|
|
365
|
+
) -> None:
|
|
366
|
+
"""Persist response state at significant streaming events.
|
|
367
|
+
|
|
368
|
+
This enables clients to poll GET /v1/responses/{response_id} during streaming
|
|
369
|
+
to see in-progress turn state instead of empty results.
|
|
370
|
+
|
|
371
|
+
Persistence occurs at:
|
|
372
|
+
- response.in_progress: Initial INSERT with empty output
|
|
373
|
+
- response.output_item.done: UPDATE with accumulated output items
|
|
374
|
+
- response.completed/response.incomplete: Final UPDATE with complete state
|
|
375
|
+
- response.failed: UPDATE with error state
|
|
376
|
+
|
|
377
|
+
:param stream_chunk: The current streaming event.
|
|
378
|
+
:param orchestrator: The streaming orchestrator (for snapshotting response).
|
|
379
|
+
:param input_items: Pre-prepared input items for storage.
|
|
380
|
+
:param output_items: Accumulated output items so far.
|
|
381
|
+
"""
|
|
382
|
+
try:
|
|
383
|
+
match stream_chunk.type:
|
|
384
|
+
case "response.in_progress":
|
|
385
|
+
# Initial persistence when response starts
|
|
386
|
+
in_progress_response = stream_chunk.response
|
|
387
|
+
await self.responses_store.upsert_response_object(
|
|
388
|
+
response_object=in_progress_response,
|
|
389
|
+
input=input_items,
|
|
390
|
+
messages=[],
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
case "response.output_item.done":
|
|
394
|
+
# Incremental update when an output item completes (tool call, message)
|
|
395
|
+
current_snapshot = orchestrator._snapshot_response(
|
|
396
|
+
status="in_progress",
|
|
397
|
+
outputs=output_items,
|
|
398
|
+
)
|
|
399
|
+
# Get current messages (filter out system messages)
|
|
400
|
+
messages_to_store = list(
|
|
401
|
+
filter(
|
|
402
|
+
lambda x: not isinstance(x, OpenAISystemMessageParam),
|
|
403
|
+
orchestrator.final_messages or orchestrator.ctx.messages,
|
|
404
|
+
)
|
|
405
|
+
)
|
|
406
|
+
await self.responses_store.upsert_response_object(
|
|
407
|
+
response_object=current_snapshot,
|
|
408
|
+
input=input_items,
|
|
409
|
+
messages=messages_to_store,
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
case "response.completed" | "response.incomplete":
|
|
413
|
+
# Final persistence when response finishes
|
|
414
|
+
final_response = stream_chunk.response
|
|
415
|
+
messages_to_store = list(
|
|
416
|
+
filter(
|
|
417
|
+
lambda x: not isinstance(x, OpenAISystemMessageParam),
|
|
418
|
+
orchestrator.final_messages,
|
|
419
|
+
)
|
|
420
|
+
)
|
|
421
|
+
await self.responses_store.upsert_response_object(
|
|
422
|
+
response_object=final_response,
|
|
423
|
+
input=input_items,
|
|
424
|
+
messages=messages_to_store,
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
case "response.failed":
|
|
428
|
+
# Persist failed state so GET shows error
|
|
429
|
+
failed_response = stream_chunk.response
|
|
430
|
+
# Preserve any accumulated non-system messages for failed responses
|
|
431
|
+
messages_to_store = list(
|
|
432
|
+
filter(
|
|
433
|
+
lambda x: not isinstance(x, OpenAISystemMessageParam),
|
|
434
|
+
orchestrator.final_messages or orchestrator.ctx.messages,
|
|
435
|
+
)
|
|
436
|
+
)
|
|
437
|
+
await self.responses_store.upsert_response_object(
|
|
438
|
+
response_object=failed_response,
|
|
439
|
+
input=input_items,
|
|
440
|
+
messages=messages_to_store,
|
|
441
|
+
)
|
|
442
|
+
except Exception as e:
|
|
443
|
+
# Best-effort persistence: log error but don't fail the stream
|
|
444
|
+
logger.warning(f"Failed to persist streaming state for {stream_chunk.type}: {e}")
|
|
445
|
+
|
|
327
446
|
async def create_openai_response(
|
|
328
447
|
self,
|
|
329
448
|
input: str | list[OpenAIResponseInput],
|
|
@@ -528,6 +647,10 @@ class OpenAIResponsesImpl:
|
|
|
528
647
|
|
|
529
648
|
# Type as ConversationItem to avoid list invariance issues
|
|
530
649
|
output_items: list[ConversationItem] = []
|
|
650
|
+
|
|
651
|
+
# Prepare input items for storage once (used by all persistence calls)
|
|
652
|
+
input_items_for_storage = self._prepare_input_items_for_storage(all_input)
|
|
653
|
+
|
|
531
654
|
try:
|
|
532
655
|
async for stream_chunk in orchestrator.create_response():
|
|
533
656
|
match stream_chunk.type:
|
|
@@ -541,6 +664,16 @@ class OpenAIResponsesImpl:
|
|
|
541
664
|
case _:
|
|
542
665
|
pass # Other event types
|
|
543
666
|
|
|
667
|
+
# Incremental persistence: persist on significant state changes
|
|
668
|
+
# This enables clients to poll GET /v1/responses/{response_id} during streaming
|
|
669
|
+
if store:
|
|
670
|
+
await self._persist_streaming_state(
|
|
671
|
+
stream_chunk=stream_chunk,
|
|
672
|
+
orchestrator=orchestrator,
|
|
673
|
+
input_items=input_items_for_storage,
|
|
674
|
+
output_items=output_items,
|
|
675
|
+
)
|
|
676
|
+
|
|
544
677
|
# Store and sync before yielding terminal events
|
|
545
678
|
# This ensures the storage/syncing happens even if the consumer breaks after receiving the event
|
|
546
679
|
if (
|
|
@@ -548,18 +681,10 @@ class OpenAIResponsesImpl:
|
|
|
548
681
|
and final_response
|
|
549
682
|
and failed_response is None
|
|
550
683
|
):
|
|
551
|
-
messages_to_store = list(
|
|
552
|
-
filter(lambda x: not isinstance(x, OpenAISystemMessageParam), orchestrator.final_messages)
|
|
553
|
-
)
|
|
554
|
-
if store:
|
|
555
|
-
# TODO: we really should work off of output_items instead of "final_messages"
|
|
556
|
-
await self._store_response(
|
|
557
|
-
response=final_response,
|
|
558
|
-
input=all_input,
|
|
559
|
-
messages=messages_to_store,
|
|
560
|
-
)
|
|
561
|
-
|
|
562
684
|
if conversation:
|
|
685
|
+
messages_to_store = list(
|
|
686
|
+
filter(lambda x: not isinstance(x, OpenAISystemMessageParam), orchestrator.final_messages)
|
|
687
|
+
)
|
|
563
688
|
await self._sync_response_to_conversation(conversation, input, output_items)
|
|
564
689
|
await self.responses_store.store_conversation_messages(conversation, messages_to_store)
|
|
565
690
|
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
#import <Foundation/Foundation.h>
|
|
2
|
+
|
|
3
|
+
//! Project version number for LocalInference.
|
|
4
|
+
FOUNDATION_EXPORT double LocalInferenceVersionNumber;
|
|
5
|
+
|
|
6
|
+
//! Project version string for LocalInference.
|
|
7
|
+
FOUNDATION_EXPORT const unsigned char LocalInferenceVersionString[];
|
|
8
|
+
|
|
9
|
+
// In this header, you should import all the public headers of your framework using statements like #import <LocalInference/PublicHeader.h>
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
import Foundation
|
|
2
|
+
|
|
3
|
+
import LLaMARunner
|
|
4
|
+
import LlamaStackClient
|
|
5
|
+
|
|
6
|
+
class RunnerHolder: ObservableObject {
|
|
7
|
+
var runner: Runner?
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
public class LocalInference: Inference {
|
|
11
|
+
private var runnerHolder = RunnerHolder()
|
|
12
|
+
private let runnerQueue: DispatchQueue
|
|
13
|
+
|
|
14
|
+
public init (queue: DispatchQueue) {
|
|
15
|
+
runnerQueue = queue
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
public func loadModel(modelPath: String, tokenizerPath: String, completion: @escaping (Result<Void, Error>) -> Void) {
|
|
19
|
+
runnerHolder.runner = runnerHolder.runner ?? Runner(
|
|
20
|
+
modelPath: modelPath,
|
|
21
|
+
tokenizerPath: tokenizerPath
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
runnerQueue.async {
|
|
26
|
+
let runner = self.runnerHolder.runner
|
|
27
|
+
do {
|
|
28
|
+
try runner!.load()
|
|
29
|
+
completion(.success(()))
|
|
30
|
+
} catch let loadError {
|
|
31
|
+
print("error: " + loadError.localizedDescription)
|
|
32
|
+
completion(.failure(loadError))
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
public func stop() {
|
|
38
|
+
runnerHolder.runner?.stop()
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
public func chatCompletion(request: Components.Schemas.ChatCompletionRequest) -> AsyncStream<Components.Schemas.ChatCompletionResponseStreamChunk> {
|
|
42
|
+
return AsyncStream { continuation in
|
|
43
|
+
let workItem = DispatchWorkItem {
|
|
44
|
+
do {
|
|
45
|
+
var tokens: [String] = []
|
|
46
|
+
|
|
47
|
+
let prompt = try encodeDialogPrompt(messages: prepareMessages(request: request))
|
|
48
|
+
var stopReason: Components.Schemas.CompletionMessage.stop_reasonPayload? = nil
|
|
49
|
+
var buffer = ""
|
|
50
|
+
var ipython = false
|
|
51
|
+
var echoDropped = false
|
|
52
|
+
|
|
53
|
+
try self.runnerHolder.runner?.generate(prompt, sequenceLength: 4096) { token in
|
|
54
|
+
buffer += token
|
|
55
|
+
|
|
56
|
+
// HACK: Workaround until LlamaRunner exposes echo param
|
|
57
|
+
if (!echoDropped) {
|
|
58
|
+
if (buffer.hasPrefix(prompt)) {
|
|
59
|
+
buffer = String(buffer.dropFirst(prompt.count))
|
|
60
|
+
echoDropped = true
|
|
61
|
+
}
|
|
62
|
+
return
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
tokens.append(token)
|
|
66
|
+
|
|
67
|
+
if !ipython && (buffer.starts(with: "<|python_tag|>") || buffer.starts(with: "[") ) {
|
|
68
|
+
ipython = true
|
|
69
|
+
continuation.yield(
|
|
70
|
+
Components.Schemas.ChatCompletionResponseStreamChunk(
|
|
71
|
+
event: Components.Schemas.ChatCompletionResponseEvent(
|
|
72
|
+
event_type: .progress,
|
|
73
|
+
delta: .tool_call(Components.Schemas.ToolCallDelta(
|
|
74
|
+
_type: Components.Schemas.ToolCallDelta._typePayload.tool_call,
|
|
75
|
+
tool_call: .case1(""),
|
|
76
|
+
parse_status: Components.Schemas.ToolCallDelta.parse_statusPayload.started
|
|
77
|
+
)
|
|
78
|
+
)
|
|
79
|
+
)
|
|
80
|
+
)
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
if (buffer.starts(with: "<|python_tag|>")) {
|
|
84
|
+
buffer = String(buffer.dropFirst("<|python_tag|>".count))
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
// TODO: Non-streaming lobprobs
|
|
89
|
+
|
|
90
|
+
var text = ""
|
|
91
|
+
if token == "<|eot_id|>" {
|
|
92
|
+
stopReason = Components.Schemas.CompletionMessage.stop_reasonPayload.end_of_turn
|
|
93
|
+
} else if token == "<|eom_id|>" {
|
|
94
|
+
stopReason = Components.Schemas.CompletionMessage.stop_reasonPayload.end_of_message
|
|
95
|
+
} else {
|
|
96
|
+
text = token
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
var delta: Components.Schemas.ContentDelta
|
|
100
|
+
if ipython {
|
|
101
|
+
delta = .tool_call(Components.Schemas.ToolCallDelta(
|
|
102
|
+
_type: .tool_call,
|
|
103
|
+
tool_call: .case1(text),
|
|
104
|
+
parse_status: .in_progress
|
|
105
|
+
))
|
|
106
|
+
} else {
|
|
107
|
+
delta = .text(Components.Schemas.TextDelta(
|
|
108
|
+
_type: Components.Schemas.TextDelta._typePayload.text,
|
|
109
|
+
text: text
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
if stopReason == nil {
|
|
115
|
+
continuation.yield(
|
|
116
|
+
Components.Schemas.ChatCompletionResponseStreamChunk(
|
|
117
|
+
event: Components.Schemas.ChatCompletionResponseEvent(
|
|
118
|
+
event_type: .progress,
|
|
119
|
+
delta: delta
|
|
120
|
+
)
|
|
121
|
+
)
|
|
122
|
+
)
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
if stopReason == nil {
|
|
127
|
+
stopReason = Components.Schemas.CompletionMessage.stop_reasonPayload.out_of_tokens
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
let message = decodeAssistantMessage(tokens: tokens.joined(), stopReason: stopReason!)
|
|
131
|
+
// TODO: non-streaming support
|
|
132
|
+
|
|
133
|
+
let didParseToolCalls = message.tool_calls?.count ?? 0 > 0
|
|
134
|
+
if ipython && !didParseToolCalls {
|
|
135
|
+
continuation.yield(
|
|
136
|
+
Components.Schemas.ChatCompletionResponseStreamChunk(
|
|
137
|
+
event: Components.Schemas.ChatCompletionResponseEvent(
|
|
138
|
+
event_type: .progress,
|
|
139
|
+
delta: .tool_call(Components.Schemas.ToolCallDelta(
|
|
140
|
+
_type: Components.Schemas.ToolCallDelta._typePayload.tool_call,
|
|
141
|
+
tool_call: .case1(""),
|
|
142
|
+
parse_status: Components.Schemas.ToolCallDelta.parse_statusPayload.failed
|
|
143
|
+
)
|
|
144
|
+
)
|
|
145
|
+
)
|
|
146
|
+
// TODO: stopReason
|
|
147
|
+
)
|
|
148
|
+
)
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
for toolCall in message.tool_calls! {
|
|
152
|
+
continuation.yield(
|
|
153
|
+
Components.Schemas.ChatCompletionResponseStreamChunk(
|
|
154
|
+
event: Components.Schemas.ChatCompletionResponseEvent(
|
|
155
|
+
event_type: .progress,
|
|
156
|
+
delta: .tool_call(Components.Schemas.ToolCallDelta(
|
|
157
|
+
_type: Components.Schemas.ToolCallDelta._typePayload.tool_call,
|
|
158
|
+
tool_call: Components.Schemas.ToolCallDelta.tool_callPayload.ToolCall(toolCall),
|
|
159
|
+
parse_status: Components.Schemas.ToolCallDelta.parse_statusPayload.succeeded
|
|
160
|
+
)
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
// TODO: stopReason
|
|
164
|
+
)
|
|
165
|
+
)
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
continuation.yield(
|
|
169
|
+
Components.Schemas.ChatCompletionResponseStreamChunk(
|
|
170
|
+
event: Components.Schemas.ChatCompletionResponseEvent(
|
|
171
|
+
event_type: .complete,
|
|
172
|
+
delta: .text(Components.Schemas.TextDelta(
|
|
173
|
+
_type: Components.Schemas.TextDelta._typePayload.text,
|
|
174
|
+
text: ""
|
|
175
|
+
)
|
|
176
|
+
)
|
|
177
|
+
)
|
|
178
|
+
// TODO: stopReason
|
|
179
|
+
)
|
|
180
|
+
)
|
|
181
|
+
}
|
|
182
|
+
catch (let error) {
|
|
183
|
+
print("Inference error: " + error.localizedDescription)
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
runnerQueue.async(execute: workItem)
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
}
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
import Foundation
|
|
2
|
+
|
|
3
|
+
import LlamaStackClient
|
|
4
|
+
|
|
5
|
+
func encodeHeader(role: String) -> String {
|
|
6
|
+
return "<|start_header_id|>\(role)<|end_header_id|>\n\n"
|
|
7
|
+
}
|
|
8
|
+
|
|
9
|
+
func encodeDialogPrompt(messages: [Components.Schemas.Message]) -> String {
|
|
10
|
+
var prompt = ""
|
|
11
|
+
|
|
12
|
+
prompt.append("<|begin_of_text|>")
|
|
13
|
+
for message in messages {
|
|
14
|
+
let msg = encodeMessage(message: message)
|
|
15
|
+
prompt += msg
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
prompt.append(encodeHeader(role: "assistant"))
|
|
19
|
+
|
|
20
|
+
return prompt
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
func getRole(message: Components.Schemas.Message) -> String {
|
|
24
|
+
switch (message) {
|
|
25
|
+
case .user(let m):
|
|
26
|
+
return m.role.rawValue
|
|
27
|
+
case .system(let m):
|
|
28
|
+
return m.role.rawValue
|
|
29
|
+
case .tool(let m):
|
|
30
|
+
return m.role.rawValue
|
|
31
|
+
case .assistant(let m):
|
|
32
|
+
return m.role.rawValue
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
func encodeMessage(message: Components.Schemas.Message) -> String {
|
|
37
|
+
var prompt = encodeHeader(role: getRole(message: message))
|
|
38
|
+
|
|
39
|
+
switch (message) {
|
|
40
|
+
case .assistant(let m):
|
|
41
|
+
if (m.tool_calls?.count ?? 0 > 0) {
|
|
42
|
+
prompt += "<|python_tag|>"
|
|
43
|
+
}
|
|
44
|
+
default:0
|
|
45
|
+
break
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
func _processContent(_ content: Any) -> String {
|
|
49
|
+
func _process(_ c: Any) {
|
|
50
|
+
if let str = c as? String {
|
|
51
|
+
prompt += str
|
|
52
|
+
}
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
if let str = content as? String {
|
|
56
|
+
_process(str)
|
|
57
|
+
} else if let list = content as? [Any] {
|
|
58
|
+
for c in list {
|
|
59
|
+
_process(c)
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
return ""
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
switch (message) {
|
|
67
|
+
case .user(let m):
|
|
68
|
+
prompt += _processContent(m.content)
|
|
69
|
+
case .system(let m):
|
|
70
|
+
prompt += _processContent(m.content)
|
|
71
|
+
case .tool(let m):
|
|
72
|
+
prompt += _processContent(m.content)
|
|
73
|
+
case .assistant(let m):
|
|
74
|
+
prompt += _processContent(m.content)
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
var eom = false
|
|
78
|
+
|
|
79
|
+
switch (message) {
|
|
80
|
+
case .user(let m):
|
|
81
|
+
switch (m.content) {
|
|
82
|
+
case .case1(let c):
|
|
83
|
+
prompt += _processContent(c)
|
|
84
|
+
case .InterleavedContentItem(let c):
|
|
85
|
+
prompt += _processContent(c)
|
|
86
|
+
case .case3(let c):
|
|
87
|
+
prompt += _processContent(c)
|
|
88
|
+
}
|
|
89
|
+
case .assistant(let m):
|
|
90
|
+
// TODO: Support encoding past tool call history
|
|
91
|
+
// for t in m.tool_calls {
|
|
92
|
+
// _processContent(t.)
|
|
93
|
+
//}
|
|
94
|
+
eom = m.stop_reason == Components.Schemas.CompletionMessage.stop_reasonPayload.end_of_message
|
|
95
|
+
case .system(_):
|
|
96
|
+
break
|
|
97
|
+
case .tool(_):
|
|
98
|
+
break
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
if (eom) {
|
|
102
|
+
prompt += "<|eom_id|>"
|
|
103
|
+
} else {
|
|
104
|
+
prompt += "<|eot_id|>"
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
return prompt
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
func prepareMessages(request: Components.Schemas.ChatCompletionRequest) throws -> [Components.Schemas.Message] {
|
|
111
|
+
var existingMessages = request.messages
|
|
112
|
+
var existingSystemMessage: Components.Schemas.Message?
|
|
113
|
+
// TODO: Existing system message
|
|
114
|
+
|
|
115
|
+
var messages: [Components.Schemas.Message] = []
|
|
116
|
+
|
|
117
|
+
let defaultGen = SystemDefaultGenerator()
|
|
118
|
+
let defaultTemplate = defaultGen.gen()
|
|
119
|
+
|
|
120
|
+
var sysContent = ""
|
|
121
|
+
|
|
122
|
+
// TODO: Built-in tools
|
|
123
|
+
|
|
124
|
+
sysContent += try defaultTemplate.render()
|
|
125
|
+
|
|
126
|
+
messages.append(.system(Components.Schemas.SystemMessage(
|
|
127
|
+
role: .system,
|
|
128
|
+
content: .case1(sysContent)
|
|
129
|
+
))
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
if request.tools?.isEmpty == false {
|
|
133
|
+
// TODO: Separate built-ins and custom tools (right now everything treated as custom)
|
|
134
|
+
let toolGen = FunctionTagCustomToolGenerator()
|
|
135
|
+
let toolTemplate = try toolGen.gen(customTools: request.tools!)
|
|
136
|
+
let tools = try toolTemplate.render()
|
|
137
|
+
messages.append(.user(Components.Schemas.UserMessage(
|
|
138
|
+
role: .user,
|
|
139
|
+
content: .case1(tools))
|
|
140
|
+
))
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
messages.append(contentsOf: existingMessages)
|
|
144
|
+
|
|
145
|
+
return messages
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
struct FunctionCall {
|
|
149
|
+
let name: String
|
|
150
|
+
let params: [String: Any]
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
public func maybeExtractCustomToolCalls(input: String) -> [Components.Schemas.ToolCall] {
|
|
154
|
+
guard input.hasPrefix("[") && input.hasSuffix("]") else {
|
|
155
|
+
return []
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
do {
|
|
159
|
+
let trimmed = input.trimmingCharacters(in: CharacterSet(charactersIn: "[]"))
|
|
160
|
+
let calls = trimmed.components(separatedBy: "),").map { $0.hasSuffix(")") ? $0 : $0 + ")" }
|
|
161
|
+
|
|
162
|
+
var result: [Components.Schemas.ToolCall] = []
|
|
163
|
+
|
|
164
|
+
for call in calls {
|
|
165
|
+
guard let nameEndIndex = call.firstIndex(of: "("),
|
|
166
|
+
let paramsStartIndex = call.firstIndex(of: "{"),
|
|
167
|
+
let paramsEndIndex = call.lastIndex(of: "}") else {
|
|
168
|
+
return []
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
let name = String(call[..<nameEndIndex]).trimmingCharacters(in: .whitespacesAndNewlines)
|
|
172
|
+
let paramsString = String(call[paramsStartIndex...paramsEndIndex])
|
|
173
|
+
|
|
174
|
+
guard let data = paramsString.data(using: .utf8),
|
|
175
|
+
let params = try? JSONSerialization.jsonObject(with: data, options: []) as? [String: Any] else {
|
|
176
|
+
return []
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
var props: [String : Components.Schemas.ToolCall.argumentsPayload.additionalPropertiesPayload] = [:]
|
|
180
|
+
for (param_name, param) in params {
|
|
181
|
+
switch (param) {
|
|
182
|
+
case let value as String:
|
|
183
|
+
props[param_name] = .case1(value)
|
|
184
|
+
case let value as Int:
|
|
185
|
+
props[param_name] = .case2(value)
|
|
186
|
+
case let value as Double:
|
|
187
|
+
props[param_name] = .case3(value)
|
|
188
|
+
case let value as Bool:
|
|
189
|
+
props[param_name] = .case4(value)
|
|
190
|
+
default:
|
|
191
|
+
return []
|
|
192
|
+
}
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
result.append(
|
|
196
|
+
Components.Schemas.ToolCall(
|
|
197
|
+
call_id: UUID().uuidString,
|
|
198
|
+
tool_name: .case2(name), // custom_tool
|
|
199
|
+
arguments: .init(additionalProperties: props)
|
|
200
|
+
)
|
|
201
|
+
)
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
return result.isEmpty ? [] : result
|
|
205
|
+
} catch {
|
|
206
|
+
return []
|
|
207
|
+
}
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
func decodeAssistantMessage(tokens: String, stopReason: Components.Schemas.CompletionMessage.stop_reasonPayload) -> Components.Schemas.CompletionMessage {
|
|
211
|
+
var content = tokens
|
|
212
|
+
|
|
213
|
+
let roles = ["user", "system", "assistant"]
|
|
214
|
+
for role in roles {
|
|
215
|
+
let headerStr = encodeHeader(role: role)
|
|
216
|
+
if content.hasPrefix(headerStr) {
|
|
217
|
+
content = String(content.dropFirst(encodeHeader(role: role).count))
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
if content.hasPrefix("<|python_tag|>") {
|
|
222
|
+
content = String(content.dropFirst("<|python_tag|>".count))
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
if content.hasSuffix("<|eot_id|>") {
|
|
227
|
+
content = String(content.dropLast("<|eot_id|>".count))
|
|
228
|
+
} else {
|
|
229
|
+
content = String(content.dropLast("<|eom_id|>".count))
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
return Components.Schemas.CompletionMessage(
|
|
233
|
+
role: .assistant,
|
|
234
|
+
content: .case1(content),
|
|
235
|
+
stop_reason: stopReason,
|
|
236
|
+
tool_calls: maybeExtractCustomToolCalls(input: content)
|
|
237
|
+
)
|
|
238
|
+
}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import Foundation
|
|
2
|
+
import Stencil
|
|
3
|
+
|
|
4
|
+
public struct PromptTemplate {
|
|
5
|
+
let template: String
|
|
6
|
+
let data: [String: Any]
|
|
7
|
+
|
|
8
|
+
public func render() throws -> String {
|
|
9
|
+
let template = Template(templateString: self.template)
|
|
10
|
+
return try template.render(self.data)
|
|
11
|
+
}
|
|
12
|
+
}
|