spaik-sdk 0.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spaik_sdk/__init__.py +21 -0
- spaik_sdk/agent/__init__.py +0 -0
- spaik_sdk/agent/base_agent.py +249 -0
- spaik_sdk/attachments/__init__.py +22 -0
- spaik_sdk/attachments/builder.py +61 -0
- spaik_sdk/attachments/file_storage_provider.py +27 -0
- spaik_sdk/attachments/mime_types.py +118 -0
- spaik_sdk/attachments/models.py +63 -0
- spaik_sdk/attachments/provider_support.py +53 -0
- spaik_sdk/attachments/storage/__init__.py +0 -0
- spaik_sdk/attachments/storage/base_file_storage.py +32 -0
- spaik_sdk/attachments/storage/impl/__init__.py +0 -0
- spaik_sdk/attachments/storage/impl/local_file_storage.py +101 -0
- spaik_sdk/audio/__init__.py +12 -0
- spaik_sdk/audio/options.py +53 -0
- spaik_sdk/audio/providers/__init__.py +1 -0
- spaik_sdk/audio/providers/google_tts.py +77 -0
- spaik_sdk/audio/providers/openai_stt.py +71 -0
- spaik_sdk/audio/providers/openai_tts.py +111 -0
- spaik_sdk/audio/stt.py +61 -0
- spaik_sdk/audio/tts.py +124 -0
- spaik_sdk/config/credentials_provider.py +10 -0
- spaik_sdk/config/env.py +59 -0
- spaik_sdk/config/env_credentials_provider.py +7 -0
- spaik_sdk/config/get_credentials_provider.py +14 -0
- spaik_sdk/image_gen/__init__.py +9 -0
- spaik_sdk/image_gen/image_generator.py +83 -0
- spaik_sdk/image_gen/options.py +24 -0
- spaik_sdk/image_gen/providers/__init__.py +0 -0
- spaik_sdk/image_gen/providers/google.py +75 -0
- spaik_sdk/image_gen/providers/openai.py +60 -0
- spaik_sdk/llm/__init__.py +0 -0
- spaik_sdk/llm/cancellation_handle.py +10 -0
- spaik_sdk/llm/consumption/__init__.py +0 -0
- spaik_sdk/llm/consumption/consumption_estimate.py +26 -0
- spaik_sdk/llm/consumption/consumption_estimate_builder.py +113 -0
- spaik_sdk/llm/consumption/consumption_extractor.py +59 -0
- spaik_sdk/llm/consumption/token_usage.py +31 -0
- spaik_sdk/llm/converters.py +146 -0
- spaik_sdk/llm/cost/__init__.py +1 -0
- spaik_sdk/llm/cost/builtin_cost_provider.py +83 -0
- spaik_sdk/llm/cost/cost_estimate.py +8 -0
- spaik_sdk/llm/cost/cost_provider.py +28 -0
- spaik_sdk/llm/extract_error_message.py +37 -0
- spaik_sdk/llm/langchain_loop_manager.py +270 -0
- spaik_sdk/llm/langchain_service.py +196 -0
- spaik_sdk/llm/message_handler.py +188 -0
- spaik_sdk/llm/streaming/__init__.py +1 -0
- spaik_sdk/llm/streaming/block_manager.py +152 -0
- spaik_sdk/llm/streaming/models.py +42 -0
- spaik_sdk/llm/streaming/streaming_content_handler.py +157 -0
- spaik_sdk/llm/streaming/streaming_event_handler.py +215 -0
- spaik_sdk/llm/streaming/streaming_state_manager.py +58 -0
- spaik_sdk/models/__init__.py +0 -0
- spaik_sdk/models/factories/__init__.py +0 -0
- spaik_sdk/models/factories/anthropic_factory.py +33 -0
- spaik_sdk/models/factories/base_model_factory.py +71 -0
- spaik_sdk/models/factories/google_factory.py +30 -0
- spaik_sdk/models/factories/ollama_factory.py +41 -0
- spaik_sdk/models/factories/openai_factory.py +50 -0
- spaik_sdk/models/llm_config.py +46 -0
- spaik_sdk/models/llm_families.py +7 -0
- spaik_sdk/models/llm_model.py +17 -0
- spaik_sdk/models/llm_wrapper.py +25 -0
- spaik_sdk/models/model_registry.py +156 -0
- spaik_sdk/models/providers/__init__.py +0 -0
- spaik_sdk/models/providers/anthropic_provider.py +29 -0
- spaik_sdk/models/providers/azure_provider.py +31 -0
- spaik_sdk/models/providers/base_provider.py +62 -0
- spaik_sdk/models/providers/google_provider.py +26 -0
- spaik_sdk/models/providers/ollama_provider.py +26 -0
- spaik_sdk/models/providers/openai_provider.py +26 -0
- spaik_sdk/models/providers/provider_type.py +90 -0
- spaik_sdk/orchestration/__init__.py +24 -0
- spaik_sdk/orchestration/base_orchestrator.py +238 -0
- spaik_sdk/orchestration/checkpoint.py +80 -0
- spaik_sdk/orchestration/models.py +103 -0
- spaik_sdk/prompt/__init__.py +0 -0
- spaik_sdk/prompt/get_prompt_loader.py +13 -0
- spaik_sdk/prompt/local_prompt_loader.py +21 -0
- spaik_sdk/prompt/prompt_loader.py +48 -0
- spaik_sdk/prompt/prompt_loader_mode.py +14 -0
- spaik_sdk/py.typed +1 -0
- spaik_sdk/recording/__init__.py +1 -0
- spaik_sdk/recording/base_playback.py +90 -0
- spaik_sdk/recording/base_recorder.py +50 -0
- spaik_sdk/recording/conditional_recorder.py +38 -0
- spaik_sdk/recording/impl/__init__.py +1 -0
- spaik_sdk/recording/impl/local_playback.py +76 -0
- spaik_sdk/recording/impl/local_recorder.py +85 -0
- spaik_sdk/recording/langchain_serializer.py +88 -0
- spaik_sdk/server/__init__.py +1 -0
- spaik_sdk/server/api/routers/__init__.py +0 -0
- spaik_sdk/server/api/routers/api_builder.py +149 -0
- spaik_sdk/server/api/routers/audio_router_factory.py +201 -0
- spaik_sdk/server/api/routers/file_router_factory.py +111 -0
- spaik_sdk/server/api/routers/thread_router_factory.py +284 -0
- spaik_sdk/server/api/streaming/__init__.py +0 -0
- spaik_sdk/server/api/streaming/format_sse_event.py +41 -0
- spaik_sdk/server/api/streaming/negotiate_streaming_response.py +8 -0
- spaik_sdk/server/api/streaming/streaming_negotiator.py +10 -0
- spaik_sdk/server/authorization/__init__.py +0 -0
- spaik_sdk/server/authorization/base_authorizer.py +64 -0
- spaik_sdk/server/authorization/base_user.py +13 -0
- spaik_sdk/server/authorization/dummy_authorizer.py +17 -0
- spaik_sdk/server/job_processor/__init__.py +0 -0
- spaik_sdk/server/job_processor/base_job_processor.py +8 -0
- spaik_sdk/server/job_processor/thread_job_processor.py +32 -0
- spaik_sdk/server/pubsub/__init__.py +1 -0
- spaik_sdk/server/pubsub/cancellation_publisher.py +7 -0
- spaik_sdk/server/pubsub/cancellation_subscriber.py +38 -0
- spaik_sdk/server/pubsub/event_publisher.py +13 -0
- spaik_sdk/server/pubsub/impl/__init__.py +1 -0
- spaik_sdk/server/pubsub/impl/local_cancellation_pubsub.py +48 -0
- spaik_sdk/server/pubsub/impl/signalr_publisher.py +36 -0
- spaik_sdk/server/queue/__init__.py +1 -0
- spaik_sdk/server/queue/agent_job_queue.py +27 -0
- spaik_sdk/server/queue/impl/__init__.py +1 -0
- spaik_sdk/server/queue/impl/azure_queue.py +24 -0
- spaik_sdk/server/response/__init__.py +0 -0
- spaik_sdk/server/response/agent_response_generator.py +39 -0
- spaik_sdk/server/response/response_generator.py +13 -0
- spaik_sdk/server/response/simple_agent_response_generator.py +14 -0
- spaik_sdk/server/services/__init__.py +0 -0
- spaik_sdk/server/services/thread_converters.py +113 -0
- spaik_sdk/server/services/thread_models.py +90 -0
- spaik_sdk/server/services/thread_service.py +91 -0
- spaik_sdk/server/storage/__init__.py +1 -0
- spaik_sdk/server/storage/base_thread_repository.py +51 -0
- spaik_sdk/server/storage/impl/__init__.py +0 -0
- spaik_sdk/server/storage/impl/in_memory_thread_repository.py +100 -0
- spaik_sdk/server/storage/impl/local_file_thread_repository.py +217 -0
- spaik_sdk/server/storage/thread_filter.py +166 -0
- spaik_sdk/server/storage/thread_metadata.py +53 -0
- spaik_sdk/thread/__init__.py +0 -0
- spaik_sdk/thread/adapters/__init__.py +0 -0
- spaik_sdk/thread/adapters/cli/__init__.py +0 -0
- spaik_sdk/thread/adapters/cli/block_display.py +92 -0
- spaik_sdk/thread/adapters/cli/display_manager.py +84 -0
- spaik_sdk/thread/adapters/cli/live_cli.py +235 -0
- spaik_sdk/thread/adapters/event_adapter.py +28 -0
- spaik_sdk/thread/adapters/streaming_block_adapter.py +57 -0
- spaik_sdk/thread/adapters/sync_adapter.py +76 -0
- spaik_sdk/thread/models.py +224 -0
- spaik_sdk/thread/thread_container.py +468 -0
- spaik_sdk/tools/__init__.py +0 -0
- spaik_sdk/tools/impl/__init__.py +0 -0
- spaik_sdk/tools/impl/mcp_tool_provider.py +93 -0
- spaik_sdk/tools/impl/search_tool_provider.py +18 -0
- spaik_sdk/tools/tool_provider.py +131 -0
- spaik_sdk/tracing/__init__.py +13 -0
- spaik_sdk/tracing/agent_trace.py +72 -0
- spaik_sdk/tracing/get_trace_sink.py +15 -0
- spaik_sdk/tracing/local_trace_sink.py +23 -0
- spaik_sdk/tracing/trace_sink.py +19 -0
- spaik_sdk/tracing/trace_sink_mode.py +14 -0
- spaik_sdk/utils/__init__.py +0 -0
- spaik_sdk/utils/init_logger.py +24 -0
- spaik_sdk-0.6.2.dist-info/METADATA +379 -0
- spaik_sdk-0.6.2.dist-info/RECORD +161 -0
- spaik_sdk-0.6.2.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import (
|
|
4
|
+
Any,
|
|
5
|
+
AsyncIterator,
|
|
6
|
+
Awaitable,
|
|
7
|
+
Callable,
|
|
8
|
+
Generic,
|
|
9
|
+
Optional,
|
|
10
|
+
TypeVar,
|
|
11
|
+
Union,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from spaik_sdk.orchestration.checkpoint import CheckpointProvider
|
|
15
|
+
from spaik_sdk.orchestration.models import OrchestratorEvent
|
|
16
|
+
from spaik_sdk.utils.init_logger import init_logger
|
|
17
|
+
|
|
18
|
+
logger = init_logger(__name__)
|
|
19
|
+
|
|
20
|
+
T_State = TypeVar("T_State")
|
|
21
|
+
T_Result = TypeVar("T_Result")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BaseOrchestrator(ABC, Generic[T_State, T_Result]):
|
|
25
|
+
"""
|
|
26
|
+
Code-first orchestration without graph DSLs.
|
|
27
|
+
|
|
28
|
+
Subclass this and implement `run()` to define your orchestration logic.
|
|
29
|
+
Use `step()` to execute steps with automatic status emission and optional checkpointing.
|
|
30
|
+
|
|
31
|
+
Example:
|
|
32
|
+
class MyOrchestrator(BaseOrchestrator[MyState, MyResult]):
|
|
33
|
+
async def run(self) -> AsyncIterator[OrchestratorEvent[MyResult]]:
|
|
34
|
+
state = MyState(items=[])
|
|
35
|
+
|
|
36
|
+
# Run a step - yields status events automatically
|
|
37
|
+
async for event in self.step("fetch", "Fetching data", self.fetch_data, state):
|
|
38
|
+
yield event
|
|
39
|
+
if event.result:
|
|
40
|
+
state = event.result
|
|
41
|
+
|
|
42
|
+
# Emit progress during processing
|
|
43
|
+
for i, item in enumerate(state.items):
|
|
44
|
+
yield self.progress("process", i + 1, len(state.items))
|
|
45
|
+
await self.process_item(item)
|
|
46
|
+
|
|
47
|
+
yield self.ok(MyResult(processed=len(state.items)))
|
|
48
|
+
|
|
49
|
+
async def fetch_data(self, state: MyState) -> MyState:
|
|
50
|
+
# Your logic here
|
|
51
|
+
return state.copy(items=fetched_items)
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
checkpoint_provider: Optional[CheckpointProvider[T_State]] = None,
|
|
57
|
+
resume_from: Optional[str] = None,
|
|
58
|
+
) -> None:
|
|
59
|
+
"""
|
|
60
|
+
Args:
|
|
61
|
+
checkpoint_provider: Optional provider for state persistence.
|
|
62
|
+
If None, no checkpointing is performed.
|
|
63
|
+
resume_from: Step ID to resume from. Steps up to and including
|
|
64
|
+
this ID will be skipped, using checkpointed state.
|
|
65
|
+
"""
|
|
66
|
+
self.checkpoint_provider = checkpoint_provider
|
|
67
|
+
self.resume_from = resume_from
|
|
68
|
+
self._completed_steps: set[str] = set()
|
|
69
|
+
self._passed_resume_point = False
|
|
70
|
+
|
|
71
|
+
if resume_from is not None:
|
|
72
|
+
logger.info(f"Will resume after step '{resume_from}'")
|
|
73
|
+
|
|
74
|
+
@abstractmethod
|
|
75
|
+
def run(self) -> AsyncIterator[OrchestratorEvent[T_Result]]:
|
|
76
|
+
"""
|
|
77
|
+
Implement your orchestration logic here.
|
|
78
|
+
|
|
79
|
+
Yield OrchestratorEvent instances to emit status updates, progress,
|
|
80
|
+
messages, and the final result.
|
|
81
|
+
"""
|
|
82
|
+
...
|
|
83
|
+
|
|
84
|
+
def run_sync(self) -> OrchestratorEvent[T_Result]:
|
|
85
|
+
"""
|
|
86
|
+
Run the orchestration synchronously and return the final event.
|
|
87
|
+
|
|
88
|
+
Returns the last event emitted (typically result or error).
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
async def _runner() -> OrchestratorEvent[T_Result]:
|
|
92
|
+
last_event: Optional[OrchestratorEvent[T_Result]] = None
|
|
93
|
+
async for event in self.run():
|
|
94
|
+
last_event = event
|
|
95
|
+
if event.is_terminal():
|
|
96
|
+
return event
|
|
97
|
+
if last_event is None:
|
|
98
|
+
return self.fail("No events emitted during orchestration")
|
|
99
|
+
return last_event
|
|
100
|
+
|
|
101
|
+
return asyncio.run(_runner())
|
|
102
|
+
|
|
103
|
+
async def step(
|
|
104
|
+
self,
|
|
105
|
+
step_id: str,
|
|
106
|
+
name: str,
|
|
107
|
+
fn: Union[
|
|
108
|
+
Callable[[T_State], T_State],
|
|
109
|
+
Callable[[T_State], Awaitable[T_State]],
|
|
110
|
+
],
|
|
111
|
+
state: T_State,
|
|
112
|
+
) -> AsyncIterator[OrchestratorEvent[T_State]]:
|
|
113
|
+
"""
|
|
114
|
+
Execute a step with automatic status emission and checkpointing.
|
|
115
|
+
|
|
116
|
+
Yields:
|
|
117
|
+
- step_started event
|
|
118
|
+
- step_completed/step_failed event
|
|
119
|
+
- result event with new state (or error event on failure)
|
|
120
|
+
|
|
121
|
+
If resuming and this step was already completed, yields step_skipped
|
|
122
|
+
and the checkpointed state instead.
|
|
123
|
+
"""
|
|
124
|
+
if self._should_skip(step_id):
|
|
125
|
+
yield OrchestratorEvent.step_skipped(step_id, name, "Resumed from checkpoint")
|
|
126
|
+
loaded_state = self._load_checkpoint(step_id)
|
|
127
|
+
if loaded_state is not None:
|
|
128
|
+
yield OrchestratorEvent.state_update(loaded_state)
|
|
129
|
+
else:
|
|
130
|
+
yield OrchestratorEvent.fail(f"Checkpoint not found for step '{step_id}'")
|
|
131
|
+
return
|
|
132
|
+
|
|
133
|
+
yield OrchestratorEvent.step_started(step_id, name)
|
|
134
|
+
|
|
135
|
+
try:
|
|
136
|
+
result = fn(state)
|
|
137
|
+
if asyncio.iscoroutine(result):
|
|
138
|
+
new_state = await result
|
|
139
|
+
else:
|
|
140
|
+
new_state = result
|
|
141
|
+
|
|
142
|
+
self._save_checkpoint(step_id, new_state)
|
|
143
|
+
self._completed_steps.add(step_id)
|
|
144
|
+
|
|
145
|
+
yield OrchestratorEvent.step_completed(step_id, name)
|
|
146
|
+
yield OrchestratorEvent.state_update(new_state)
|
|
147
|
+
|
|
148
|
+
except Exception as e:
|
|
149
|
+
logger.exception(f"Step '{step_id}' failed")
|
|
150
|
+
yield OrchestratorEvent.step_failed(step_id, name, str(e))
|
|
151
|
+
yield OrchestratorEvent.fail(str(e))
|
|
152
|
+
|
|
153
|
+
# --- Convenience factory methods ---
|
|
154
|
+
|
|
155
|
+
def ok(self, result: T_Result) -> OrchestratorEvent[T_Result]:
|
|
156
|
+
"""Create a success result event"""
|
|
157
|
+
return OrchestratorEvent.ok(result)
|
|
158
|
+
|
|
159
|
+
def fail(self, error: str) -> OrchestratorEvent[T_Result]:
|
|
160
|
+
"""Create an error event"""
|
|
161
|
+
return OrchestratorEvent.fail(error)
|
|
162
|
+
|
|
163
|
+
def msg(self, message: str) -> OrchestratorEvent[T_Result]:
|
|
164
|
+
"""Create a message event"""
|
|
165
|
+
return OrchestratorEvent.msg(message)
|
|
166
|
+
|
|
167
|
+
def progress(self, step_id: str, current: int, total: int, detail: Optional[str] = None) -> OrchestratorEvent[T_Result]:
|
|
168
|
+
"""Create a progress update event"""
|
|
169
|
+
return OrchestratorEvent.progress_update(step_id, current, total, detail)
|
|
170
|
+
|
|
171
|
+
def step_started(self, step_id: str, name: str, detail: Optional[str] = None) -> OrchestratorEvent[T_Result]:
|
|
172
|
+
"""Create a step started event (for manual step management)"""
|
|
173
|
+
return OrchestratorEvent.step_started(step_id, name, detail)
|
|
174
|
+
|
|
175
|
+
def step_completed(self, step_id: str, name: str, detail: Optional[str] = None) -> OrchestratorEvent[T_Result]:
|
|
176
|
+
"""Create a step completed event (for manual step management)"""
|
|
177
|
+
return OrchestratorEvent.step_completed(step_id, name, detail)
|
|
178
|
+
|
|
179
|
+
def step_failed(self, step_id: str, name: str, error: str) -> OrchestratorEvent[T_Result]:
|
|
180
|
+
"""Create a step failed event (for manual step management)"""
|
|
181
|
+
return OrchestratorEvent.step_failed(step_id, name, error)
|
|
182
|
+
|
|
183
|
+
# --- Internal helpers ---
|
|
184
|
+
|
|
185
|
+
def _should_skip(self, step_id: str) -> bool:
|
|
186
|
+
"""Check if a step should be skipped due to checkpoint resume"""
|
|
187
|
+
if self.resume_from is None:
|
|
188
|
+
return False
|
|
189
|
+
if self._passed_resume_point:
|
|
190
|
+
return False
|
|
191
|
+
# We skip until we hit the resume_from step, then skip that one too
|
|
192
|
+
if step_id == self.resume_from:
|
|
193
|
+
self._passed_resume_point = True
|
|
194
|
+
return True
|
|
195
|
+
|
|
196
|
+
def _load_checkpoint(self, step_id: str) -> Optional[T_State]:
|
|
197
|
+
"""Load checkpointed state for a step"""
|
|
198
|
+
if self.checkpoint_provider is None:
|
|
199
|
+
return None
|
|
200
|
+
return self.checkpoint_provider.load(step_id)
|
|
201
|
+
|
|
202
|
+
def _save_checkpoint(self, step_id: str, state: T_State) -> None:
|
|
203
|
+
"""Save state to checkpoint after step completion"""
|
|
204
|
+
if self.checkpoint_provider is None:
|
|
205
|
+
return
|
|
206
|
+
self.checkpoint_provider.save(step_id, state)
|
|
207
|
+
logger.debug(f"Saved checkpoint for step '{step_id}'")
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class SimpleOrchestrator(BaseOrchestrator[None, T_Result]):
|
|
211
|
+
"""
|
|
212
|
+
Orchestrator without state management.
|
|
213
|
+
|
|
214
|
+
Use this when you don't need to pass state between steps,
|
|
215
|
+
or when steps manage their own state internally.
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
async def run_step(
|
|
219
|
+
self,
|
|
220
|
+
step_id: str,
|
|
221
|
+
name: str,
|
|
222
|
+
fn: Union[Callable[[], Any], Callable[[], Awaitable[Any]]],
|
|
223
|
+
) -> AsyncIterator[OrchestratorEvent[Any]]:
|
|
224
|
+
"""Execute a step without state management"""
|
|
225
|
+
yield OrchestratorEvent.step_started(step_id, name)
|
|
226
|
+
|
|
227
|
+
try:
|
|
228
|
+
result = fn()
|
|
229
|
+
if asyncio.iscoroutine(result):
|
|
230
|
+
await result
|
|
231
|
+
|
|
232
|
+
self._completed_steps.add(step_id)
|
|
233
|
+
yield OrchestratorEvent.step_completed(step_id, name)
|
|
234
|
+
|
|
235
|
+
except Exception as e:
|
|
236
|
+
logger.exception(f"Step '{step_id}' failed")
|
|
237
|
+
yield OrchestratorEvent.step_failed(step_id, name, str(e))
|
|
238
|
+
yield OrchestratorEvent.fail(str(e))
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, Dict, Generic, Optional, TypeVar
|
|
3
|
+
|
|
4
|
+
T_State = TypeVar("T_State")
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class CheckpointProvider(ABC, Generic[T_State]):
|
|
8
|
+
"""
|
|
9
|
+
Protocol for checkpoint storage/retrieval.
|
|
10
|
+
|
|
11
|
+
Implement this to add persistence for orchestration state.
|
|
12
|
+
The orchestrator will call these methods automatically during step execution.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def save(self, step_id: str, state: T_State) -> None:
|
|
17
|
+
"""Save state after a step completes"""
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def load(self, step_id: str) -> Optional[T_State]:
|
|
22
|
+
"""Load state for a specific step. Returns None if not found."""
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def get_completed_steps(self) -> set[str]:
|
|
27
|
+
"""Get all step IDs that have been completed"""
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def clear(self) -> None:
|
|
32
|
+
"""Clear all checkpoints"""
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class InMemoryCheckpointProvider(CheckpointProvider[T_State]):
|
|
37
|
+
"""Simple in-memory checkpoint provider for testing/development"""
|
|
38
|
+
|
|
39
|
+
def __init__(self) -> None:
|
|
40
|
+
self._checkpoints: Dict[str, T_State] = {}
|
|
41
|
+
|
|
42
|
+
def save(self, step_id: str, state: T_State) -> None:
|
|
43
|
+
self._checkpoints[step_id] = state
|
|
44
|
+
|
|
45
|
+
def load(self, step_id: str) -> Optional[T_State]:
|
|
46
|
+
return self._checkpoints.get(step_id)
|
|
47
|
+
|
|
48
|
+
def get_completed_steps(self) -> set[str]:
|
|
49
|
+
return set(self._checkpoints.keys())
|
|
50
|
+
|
|
51
|
+
def clear(self) -> None:
|
|
52
|
+
self._checkpoints.clear()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class DictCheckpointProvider(CheckpointProvider[Dict[str, Any]]):
|
|
56
|
+
"""
|
|
57
|
+
Checkpoint provider that serializes state as dicts.
|
|
58
|
+
|
|
59
|
+
Useful when you want to persist to JSON/database but your state
|
|
60
|
+
objects have to_dict()/from_dict() methods.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, storage: Optional[Dict[str, Dict[str, Any]]] = None) -> None:
|
|
64
|
+
self._storage = storage if storage is not None else {}
|
|
65
|
+
|
|
66
|
+
def save(self, step_id: str, state: Dict[str, Any]) -> None:
|
|
67
|
+
self._storage[step_id] = state
|
|
68
|
+
|
|
69
|
+
def load(self, step_id: str) -> Optional[Dict[str, Any]]:
|
|
70
|
+
return self._storage.get(step_id)
|
|
71
|
+
|
|
72
|
+
def get_completed_steps(self) -> set[str]:
|
|
73
|
+
return set(self._storage.keys())
|
|
74
|
+
|
|
75
|
+
def clear(self) -> None:
|
|
76
|
+
self._storage.clear()
|
|
77
|
+
|
|
78
|
+
def get_all(self) -> Dict[str, Dict[str, Any]]:
|
|
79
|
+
"""Get all stored checkpoints (for serialization)"""
|
|
80
|
+
return dict(self._storage)
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Any, Dict, Generic, Optional, TypeVar
|
|
4
|
+
|
|
5
|
+
T = TypeVar("T")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class StepStatus(Enum):
|
|
9
|
+
"""Status of an orchestration step"""
|
|
10
|
+
|
|
11
|
+
PENDING = "pending"
|
|
12
|
+
RUNNING = "running"
|
|
13
|
+
COMPLETED = "completed"
|
|
14
|
+
FAILED = "failed"
|
|
15
|
+
SKIPPED = "skipped"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class StepInfo:
|
|
20
|
+
"""Information about a step's current state"""
|
|
21
|
+
|
|
22
|
+
step_id: str
|
|
23
|
+
name: str
|
|
24
|
+
status: StepStatus
|
|
25
|
+
detail: Optional[str] = None
|
|
26
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class OrchestratorEvent(Generic[T]):
|
|
31
|
+
"""
|
|
32
|
+
Event emitted during orchestration.
|
|
33
|
+
|
|
34
|
+
This is a union-style dataclass - only one field should be set at a time.
|
|
35
|
+
Use the factory methods for cleaner construction.
|
|
36
|
+
|
|
37
|
+
Note: `state` is for intermediate step outputs, `result` is for final orchestration result.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
step: Optional[StepInfo] = None
|
|
41
|
+
message: Optional[str] = None
|
|
42
|
+
progress: Optional["ProgressUpdate"] = None
|
|
43
|
+
state: Optional[T] = None # Intermediate state from steps
|
|
44
|
+
result: Optional[T] = None # Final result
|
|
45
|
+
error: Optional[str] = None
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
def step_started(step_id: str, name: str, detail: Optional[str] = None) -> "OrchestratorEvent[Any]":
|
|
49
|
+
return OrchestratorEvent(step=StepInfo(step_id, name, StepStatus.RUNNING, detail))
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def step_completed(step_id: str, name: str, detail: Optional[str] = None) -> "OrchestratorEvent[Any]":
|
|
53
|
+
return OrchestratorEvent(step=StepInfo(step_id, name, StepStatus.COMPLETED, detail))
|
|
54
|
+
|
|
55
|
+
@staticmethod
|
|
56
|
+
def step_failed(step_id: str, name: str, error: str) -> "OrchestratorEvent[Any]":
|
|
57
|
+
return OrchestratorEvent(step=StepInfo(step_id, name, StepStatus.FAILED, error))
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def step_skipped(step_id: str, name: str, reason: Optional[str] = None) -> "OrchestratorEvent[Any]":
|
|
61
|
+
return OrchestratorEvent(step=StepInfo(step_id, name, StepStatus.SKIPPED, reason))
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def msg(message: str) -> "OrchestratorEvent[Any]":
|
|
65
|
+
return OrchestratorEvent(message=message)
|
|
66
|
+
|
|
67
|
+
@staticmethod
|
|
68
|
+
def state_update(state: T) -> "OrchestratorEvent[Any]":
|
|
69
|
+
"""Emit intermediate state from a step"""
|
|
70
|
+
return OrchestratorEvent(state=state)
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def ok(result: T) -> "OrchestratorEvent[Any]":
|
|
74
|
+
"""Emit final result"""
|
|
75
|
+
return OrchestratorEvent(result=result)
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def fail(error: str) -> "OrchestratorEvent[Any]":
|
|
79
|
+
return OrchestratorEvent(error=error)
|
|
80
|
+
|
|
81
|
+
@staticmethod
|
|
82
|
+
def progress_update(step_id: str, current: int, total: int, detail: Optional[str] = None) -> "OrchestratorEvent[Any]":
|
|
83
|
+
return OrchestratorEvent(progress=ProgressUpdate(step_id, current, total, detail))
|
|
84
|
+
|
|
85
|
+
def is_terminal(self) -> bool:
|
|
86
|
+
"""Returns True if this event represents a final state (result or error)"""
|
|
87
|
+
return self.result is not None or self.error is not None
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class ProgressUpdate:
|
|
92
|
+
"""Progress within a step (e.g., processing item 3/10)"""
|
|
93
|
+
|
|
94
|
+
step_id: str
|
|
95
|
+
current: int
|
|
96
|
+
total: int
|
|
97
|
+
detail: Optional[str] = None
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def percent(self) -> float:
|
|
101
|
+
if self.total == 0:
|
|
102
|
+
return 0.0
|
|
103
|
+
return (self.current / self.total) * 100
|
|
File without changes
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from spaik_sdk.config.env import env_config
|
|
4
|
+
from spaik_sdk.prompt.local_prompt_loader import LocalPromptLoader
|
|
5
|
+
from spaik_sdk.prompt.prompt_loader import PromptLoader
|
|
6
|
+
from spaik_sdk.prompt.prompt_loader_mode import PromptLoaderMode
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_prompt_loader(mode: Optional[PromptLoaderMode] = None) -> PromptLoader:
|
|
10
|
+
mode = mode or env_config.get_prompt_loader_mode()
|
|
11
|
+
if mode == PromptLoaderMode.LOCAL:
|
|
12
|
+
return LocalPromptLoader()
|
|
13
|
+
raise ValueError(f"Unknown PromptLoaderMode: {mode}")
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from spaik_sdk.config.env import env_config
|
|
5
|
+
from spaik_sdk.prompt.prompt_loader import PromptLoader
|
|
6
|
+
from spaik_sdk.utils.init_logger import init_logger
|
|
7
|
+
|
|
8
|
+
logger = init_logger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LocalPromptLoader(PromptLoader):
|
|
12
|
+
def __init__(self, prompts_dir: Optional[str] = None):
|
|
13
|
+
super().__init__()
|
|
14
|
+
self.prompts_dir = prompts_dir if prompts_dir else env_config.get_prompts_dir()
|
|
15
|
+
|
|
16
|
+
def _load_raw_prompt(self, prompt_path: str) -> str:
|
|
17
|
+
full_path = os.path.join(self.prompts_dir, f"{prompt_path}.md")
|
|
18
|
+
if not os.path.exists(full_path):
|
|
19
|
+
raise FileNotFoundError(f"Prompt file {full_path} not found")
|
|
20
|
+
with open(full_path, "r") as f:
|
|
21
|
+
return f.read().strip()
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
|
|
3
|
+
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from spaik_sdk.agent.base_agent import BaseAgent
|
|
6
|
+
|
|
7
|
+
from spaik_sdk.utils.init_logger import init_logger
|
|
8
|
+
|
|
9
|
+
logger = init_logger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PromptLoader(ABC):
|
|
13
|
+
def __init__(self):
|
|
14
|
+
self._prompts_cache: Dict[str, str] = {}
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def _load_raw_prompt(self, prompt_path: str) -> str:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
def _get_raw_prompt(self, prompt_path: str) -> str:
|
|
21
|
+
if prompt_path in self._prompts_cache:
|
|
22
|
+
return self._prompts_cache[prompt_path]
|
|
23
|
+
else:
|
|
24
|
+
raw_prompt = self._load_raw_prompt(prompt_path)
|
|
25
|
+
self._prompts_cache[prompt_path] = raw_prompt
|
|
26
|
+
return raw_prompt
|
|
27
|
+
|
|
28
|
+
def _format_prompt(self, prompt: str, params: Dict[str, Any]) -> str:
|
|
29
|
+
try:
|
|
30
|
+
return prompt.format(**params)
|
|
31
|
+
except KeyError as e:
|
|
32
|
+
missing_key = str(e).strip("'")
|
|
33
|
+
raise KeyError(f"Missing required variable '{missing_key}' in params: {params} for prompt: '{prompt}'")
|
|
34
|
+
|
|
35
|
+
def get_prompt(self, prompt_path: str, params: Dict[str, Any]) -> str:
|
|
36
|
+
return self._format_prompt(self._get_raw_prompt(prompt_path), params)
|
|
37
|
+
|
|
38
|
+
def _get_raw_agent_prompt(self, agent_class: Type["BaseAgent"], prompt_name: str, version: Optional[str] = None) -> str:
|
|
39
|
+
return self._get_raw_prompt(f"agent/{agent_class.__name__}/{prompt_name}{f'-{version}' if version else ''}")
|
|
40
|
+
|
|
41
|
+
def get_system_prompt(self, agent_class: Type["BaseAgent"], params: Dict[str, Any], version: Optional[str] = None) -> str:
|
|
42
|
+
return self.get_agent_prompt(agent_class, "system", params, version)
|
|
43
|
+
|
|
44
|
+
def get_agent_prompt(
|
|
45
|
+
self, agent_class: Type["BaseAgent"], prompt_name: str, params: Dict[str, Any], version: Optional[str] = None
|
|
46
|
+
) -> str:
|
|
47
|
+
raw_prompt = self._get_raw_agent_prompt(agent_class, prompt_name, version)
|
|
48
|
+
return self._format_prompt(raw_prompt, params)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class PromptLoaderMode(Enum):
|
|
5
|
+
LOCAL = "local"
|
|
6
|
+
|
|
7
|
+
@classmethod
|
|
8
|
+
def from_name(cls, name: str) -> "PromptLoaderMode":
|
|
9
|
+
for mode in cls:
|
|
10
|
+
if mode.value == name:
|
|
11
|
+
return mode
|
|
12
|
+
|
|
13
|
+
available_modes = [mode.value for mode in cls]
|
|
14
|
+
raise ValueError(f"Unknown PromptLoaderMode '{name}'. Available: {', '.join(available_modes)}")
|
spaik_sdk/py.typed
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Any, Dict, Iterator, Optional
|
|
4
|
+
|
|
5
|
+
from spaik_sdk.recording.langchain_serializer import deserialize_token_data
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BasePlayback(ABC):
|
|
9
|
+
"""Abstract base class for playing back recorded LLM interactions."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, recording_name: str = "default", delay: float = 0.001):
|
|
12
|
+
self.recording_name = recording_name
|
|
13
|
+
self.current_session = 1
|
|
14
|
+
self._iterator: Optional[Iterator[Dict[str, Any]]] = None
|
|
15
|
+
self.delay = delay
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def _load_session_data_impl(self, session_num: int) -> Iterator[Dict[str, Any]]:
|
|
19
|
+
"""Load raw data for a specific session number. Returns plain dicts."""
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
def _load_session_data(self, session_num: int) -> Iterator[Dict[str, Any]]:
|
|
23
|
+
"""Load data for a specific session number and deserialize LangChain objects."""
|
|
24
|
+
for raw_data in self._load_session_data_impl(session_num):
|
|
25
|
+
# Deserialize LangChain objects after loading from implementation
|
|
26
|
+
yield deserialize_token_data(raw_data)
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def _session_exists(self, session_num: int) -> bool:
|
|
30
|
+
"""Check if a session file exists."""
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def is_available(self) -> bool:
|
|
35
|
+
"""Check if playback data is available."""
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
def __iter__(self) -> Iterator[Dict[str, Any]]:
|
|
39
|
+
"""Make the playback object iterable."""
|
|
40
|
+
return self
|
|
41
|
+
|
|
42
|
+
def __aiter__(self):
|
|
43
|
+
"""Make the playback object async iterable."""
|
|
44
|
+
return self
|
|
45
|
+
|
|
46
|
+
async def __anext__(self) -> Dict[str, Any]:
|
|
47
|
+
"""Async version of __next__."""
|
|
48
|
+
try:
|
|
49
|
+
return self.__next__()
|
|
50
|
+
except StopIteration:
|
|
51
|
+
raise StopAsyncIteration
|
|
52
|
+
|
|
53
|
+
def __next__(self) -> Dict[str, Any]:
|
|
54
|
+
"""Get the next token/response from current session only."""
|
|
55
|
+
# Initialize iterator if not set
|
|
56
|
+
if self._iterator is None:
|
|
57
|
+
if not self._session_exists(self.current_session):
|
|
58
|
+
raise StopIteration("No recorded sessions found")
|
|
59
|
+
self._iterator = self._load_session_data(self.current_session)
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
time.sleep(self.delay)
|
|
63
|
+
return next(self._iterator)
|
|
64
|
+
except StopIteration:
|
|
65
|
+
# Current session exhausted, auto-bump to next session
|
|
66
|
+
self.current_session += 1
|
|
67
|
+
self._iterator = None
|
|
68
|
+
raise StopIteration("Current session exhausted")
|
|
69
|
+
|
|
70
|
+
def reset(self) -> None:
|
|
71
|
+
"""Reset playback to first session."""
|
|
72
|
+
self.current_session = 1
|
|
73
|
+
self._iterator = None
|
|
74
|
+
|
|
75
|
+
def get_current_session(self) -> int:
|
|
76
|
+
"""Get the current session number."""
|
|
77
|
+
return self.current_session
|
|
78
|
+
|
|
79
|
+
def get_recording_name(self) -> str:
|
|
80
|
+
"""Get the recording name."""
|
|
81
|
+
return self.recording_name
|
|
82
|
+
|
|
83
|
+
def next_session(self) -> None:
|
|
84
|
+
"""Manually advance to next session."""
|
|
85
|
+
self.current_session += 1
|
|
86
|
+
self._iterator = None
|
|
87
|
+
|
|
88
|
+
def has_next_session(self) -> bool:
|
|
89
|
+
"""Check if there's a next session available."""
|
|
90
|
+
return self._session_exists(self.current_session + 1)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, Dict
|
|
3
|
+
|
|
4
|
+
from spaik_sdk.recording.langchain_serializer import ensure_json_serializable, serialize_token_data
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BaseRecorder(ABC):
|
|
8
|
+
"""Abstract base class for recording LLM interactions."""
|
|
9
|
+
|
|
10
|
+
def __init__(self, recording_name: str = "default"):
|
|
11
|
+
self.recording_name = recording_name
|
|
12
|
+
self.current_session = 1
|
|
13
|
+
|
|
14
|
+
def record_token(self, token_data: Dict[str, Any]) -> None:
|
|
15
|
+
"""Record a streaming token from LLM response."""
|
|
16
|
+
# Serialize LangChain objects before passing to implementation
|
|
17
|
+
serialized_data = serialize_token_data(token_data)
|
|
18
|
+
safe_data = ensure_json_serializable(serialized_data)
|
|
19
|
+
self._record_token_impl(safe_data)
|
|
20
|
+
|
|
21
|
+
def record_structured(self, data: Dict[str, Any]) -> None:
|
|
22
|
+
"""Record structured response and immediately bump session counter."""
|
|
23
|
+
# Serialize LangChain objects before passing to implementation
|
|
24
|
+
serialized_data = serialize_token_data(data)
|
|
25
|
+
safe_data = ensure_json_serializable(serialized_data)
|
|
26
|
+
self._record_structured_impl(safe_data)
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def _record_token_impl(self, token_data: Dict[str, Any]) -> None:
|
|
30
|
+
"""Implementation-specific token recording. Data is already serialized."""
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def _record_structured_impl(self, data: Dict[str, Any]) -> None:
|
|
35
|
+
"""Implementation-specific structured recording. Data is already serialized."""
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def request_completed(self) -> None:
|
|
40
|
+
"""Mark current request as completed and bump to next session."""
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def get_current_session(self) -> int:
|
|
45
|
+
"""Get the current session number."""
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
def get_recording_name(self) -> str:
|
|
49
|
+
"""Get the recording name."""
|
|
50
|
+
return self.recording_name
|