spaik-sdk 0.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spaik_sdk/__init__.py +21 -0
- spaik_sdk/agent/__init__.py +0 -0
- spaik_sdk/agent/base_agent.py +249 -0
- spaik_sdk/attachments/__init__.py +22 -0
- spaik_sdk/attachments/builder.py +61 -0
- spaik_sdk/attachments/file_storage_provider.py +27 -0
- spaik_sdk/attachments/mime_types.py +118 -0
- spaik_sdk/attachments/models.py +63 -0
- spaik_sdk/attachments/provider_support.py +53 -0
- spaik_sdk/attachments/storage/__init__.py +0 -0
- spaik_sdk/attachments/storage/base_file_storage.py +32 -0
- spaik_sdk/attachments/storage/impl/__init__.py +0 -0
- spaik_sdk/attachments/storage/impl/local_file_storage.py +101 -0
- spaik_sdk/audio/__init__.py +12 -0
- spaik_sdk/audio/options.py +53 -0
- spaik_sdk/audio/providers/__init__.py +1 -0
- spaik_sdk/audio/providers/google_tts.py +77 -0
- spaik_sdk/audio/providers/openai_stt.py +71 -0
- spaik_sdk/audio/providers/openai_tts.py +111 -0
- spaik_sdk/audio/stt.py +61 -0
- spaik_sdk/audio/tts.py +124 -0
- spaik_sdk/config/credentials_provider.py +10 -0
- spaik_sdk/config/env.py +59 -0
- spaik_sdk/config/env_credentials_provider.py +7 -0
- spaik_sdk/config/get_credentials_provider.py +14 -0
- spaik_sdk/image_gen/__init__.py +9 -0
- spaik_sdk/image_gen/image_generator.py +83 -0
- spaik_sdk/image_gen/options.py +24 -0
- spaik_sdk/image_gen/providers/__init__.py +0 -0
- spaik_sdk/image_gen/providers/google.py +75 -0
- spaik_sdk/image_gen/providers/openai.py +60 -0
- spaik_sdk/llm/__init__.py +0 -0
- spaik_sdk/llm/cancellation_handle.py +10 -0
- spaik_sdk/llm/consumption/__init__.py +0 -0
- spaik_sdk/llm/consumption/consumption_estimate.py +26 -0
- spaik_sdk/llm/consumption/consumption_estimate_builder.py +113 -0
- spaik_sdk/llm/consumption/consumption_extractor.py +59 -0
- spaik_sdk/llm/consumption/token_usage.py +31 -0
- spaik_sdk/llm/converters.py +146 -0
- spaik_sdk/llm/cost/__init__.py +1 -0
- spaik_sdk/llm/cost/builtin_cost_provider.py +83 -0
- spaik_sdk/llm/cost/cost_estimate.py +8 -0
- spaik_sdk/llm/cost/cost_provider.py +28 -0
- spaik_sdk/llm/extract_error_message.py +37 -0
- spaik_sdk/llm/langchain_loop_manager.py +270 -0
- spaik_sdk/llm/langchain_service.py +196 -0
- spaik_sdk/llm/message_handler.py +188 -0
- spaik_sdk/llm/streaming/__init__.py +1 -0
- spaik_sdk/llm/streaming/block_manager.py +152 -0
- spaik_sdk/llm/streaming/models.py +42 -0
- spaik_sdk/llm/streaming/streaming_content_handler.py +157 -0
- spaik_sdk/llm/streaming/streaming_event_handler.py +215 -0
- spaik_sdk/llm/streaming/streaming_state_manager.py +58 -0
- spaik_sdk/models/__init__.py +0 -0
- spaik_sdk/models/factories/__init__.py +0 -0
- spaik_sdk/models/factories/anthropic_factory.py +33 -0
- spaik_sdk/models/factories/base_model_factory.py +71 -0
- spaik_sdk/models/factories/google_factory.py +30 -0
- spaik_sdk/models/factories/ollama_factory.py +41 -0
- spaik_sdk/models/factories/openai_factory.py +50 -0
- spaik_sdk/models/llm_config.py +46 -0
- spaik_sdk/models/llm_families.py +7 -0
- spaik_sdk/models/llm_model.py +17 -0
- spaik_sdk/models/llm_wrapper.py +25 -0
- spaik_sdk/models/model_registry.py +156 -0
- spaik_sdk/models/providers/__init__.py +0 -0
- spaik_sdk/models/providers/anthropic_provider.py +29 -0
- spaik_sdk/models/providers/azure_provider.py +31 -0
- spaik_sdk/models/providers/base_provider.py +62 -0
- spaik_sdk/models/providers/google_provider.py +26 -0
- spaik_sdk/models/providers/ollama_provider.py +26 -0
- spaik_sdk/models/providers/openai_provider.py +26 -0
- spaik_sdk/models/providers/provider_type.py +90 -0
- spaik_sdk/orchestration/__init__.py +24 -0
- spaik_sdk/orchestration/base_orchestrator.py +238 -0
- spaik_sdk/orchestration/checkpoint.py +80 -0
- spaik_sdk/orchestration/models.py +103 -0
- spaik_sdk/prompt/__init__.py +0 -0
- spaik_sdk/prompt/get_prompt_loader.py +13 -0
- spaik_sdk/prompt/local_prompt_loader.py +21 -0
- spaik_sdk/prompt/prompt_loader.py +48 -0
- spaik_sdk/prompt/prompt_loader_mode.py +14 -0
- spaik_sdk/py.typed +1 -0
- spaik_sdk/recording/__init__.py +1 -0
- spaik_sdk/recording/base_playback.py +90 -0
- spaik_sdk/recording/base_recorder.py +50 -0
- spaik_sdk/recording/conditional_recorder.py +38 -0
- spaik_sdk/recording/impl/__init__.py +1 -0
- spaik_sdk/recording/impl/local_playback.py +76 -0
- spaik_sdk/recording/impl/local_recorder.py +85 -0
- spaik_sdk/recording/langchain_serializer.py +88 -0
- spaik_sdk/server/__init__.py +1 -0
- spaik_sdk/server/api/routers/__init__.py +0 -0
- spaik_sdk/server/api/routers/api_builder.py +149 -0
- spaik_sdk/server/api/routers/audio_router_factory.py +201 -0
- spaik_sdk/server/api/routers/file_router_factory.py +111 -0
- spaik_sdk/server/api/routers/thread_router_factory.py +284 -0
- spaik_sdk/server/api/streaming/__init__.py +0 -0
- spaik_sdk/server/api/streaming/format_sse_event.py +41 -0
- spaik_sdk/server/api/streaming/negotiate_streaming_response.py +8 -0
- spaik_sdk/server/api/streaming/streaming_negotiator.py +10 -0
- spaik_sdk/server/authorization/__init__.py +0 -0
- spaik_sdk/server/authorization/base_authorizer.py +64 -0
- spaik_sdk/server/authorization/base_user.py +13 -0
- spaik_sdk/server/authorization/dummy_authorizer.py +17 -0
- spaik_sdk/server/job_processor/__init__.py +0 -0
- spaik_sdk/server/job_processor/base_job_processor.py +8 -0
- spaik_sdk/server/job_processor/thread_job_processor.py +32 -0
- spaik_sdk/server/pubsub/__init__.py +1 -0
- spaik_sdk/server/pubsub/cancellation_publisher.py +7 -0
- spaik_sdk/server/pubsub/cancellation_subscriber.py +38 -0
- spaik_sdk/server/pubsub/event_publisher.py +13 -0
- spaik_sdk/server/pubsub/impl/__init__.py +1 -0
- spaik_sdk/server/pubsub/impl/local_cancellation_pubsub.py +48 -0
- spaik_sdk/server/pubsub/impl/signalr_publisher.py +36 -0
- spaik_sdk/server/queue/__init__.py +1 -0
- spaik_sdk/server/queue/agent_job_queue.py +27 -0
- spaik_sdk/server/queue/impl/__init__.py +1 -0
- spaik_sdk/server/queue/impl/azure_queue.py +24 -0
- spaik_sdk/server/response/__init__.py +0 -0
- spaik_sdk/server/response/agent_response_generator.py +39 -0
- spaik_sdk/server/response/response_generator.py +13 -0
- spaik_sdk/server/response/simple_agent_response_generator.py +14 -0
- spaik_sdk/server/services/__init__.py +0 -0
- spaik_sdk/server/services/thread_converters.py +113 -0
- spaik_sdk/server/services/thread_models.py +90 -0
- spaik_sdk/server/services/thread_service.py +91 -0
- spaik_sdk/server/storage/__init__.py +1 -0
- spaik_sdk/server/storage/base_thread_repository.py +51 -0
- spaik_sdk/server/storage/impl/__init__.py +0 -0
- spaik_sdk/server/storage/impl/in_memory_thread_repository.py +100 -0
- spaik_sdk/server/storage/impl/local_file_thread_repository.py +217 -0
- spaik_sdk/server/storage/thread_filter.py +166 -0
- spaik_sdk/server/storage/thread_metadata.py +53 -0
- spaik_sdk/thread/__init__.py +0 -0
- spaik_sdk/thread/adapters/__init__.py +0 -0
- spaik_sdk/thread/adapters/cli/__init__.py +0 -0
- spaik_sdk/thread/adapters/cli/block_display.py +92 -0
- spaik_sdk/thread/adapters/cli/display_manager.py +84 -0
- spaik_sdk/thread/adapters/cli/live_cli.py +235 -0
- spaik_sdk/thread/adapters/event_adapter.py +28 -0
- spaik_sdk/thread/adapters/streaming_block_adapter.py +57 -0
- spaik_sdk/thread/adapters/sync_adapter.py +76 -0
- spaik_sdk/thread/models.py +224 -0
- spaik_sdk/thread/thread_container.py +468 -0
- spaik_sdk/tools/__init__.py +0 -0
- spaik_sdk/tools/impl/__init__.py +0 -0
- spaik_sdk/tools/impl/mcp_tool_provider.py +93 -0
- spaik_sdk/tools/impl/search_tool_provider.py +18 -0
- spaik_sdk/tools/tool_provider.py +131 -0
- spaik_sdk/tracing/__init__.py +13 -0
- spaik_sdk/tracing/agent_trace.py +72 -0
- spaik_sdk/tracing/get_trace_sink.py +15 -0
- spaik_sdk/tracing/local_trace_sink.py +23 -0
- spaik_sdk/tracing/trace_sink.py +19 -0
- spaik_sdk/tracing/trace_sink_mode.py +14 -0
- spaik_sdk/utils/__init__.py +0 -0
- spaik_sdk/utils/init_logger.py +24 -0
- spaik_sdk-0.6.2.dist-info/METADATA +379 -0
- spaik_sdk-0.6.2.dist-info/RECORD +161 -0
- spaik_sdk-0.6.2.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
from typing import AsyncGenerator, Optional, Union
|
|
2
|
+
|
|
3
|
+
from langchain_core.messages import AIMessage, AIMessageChunk
|
|
4
|
+
|
|
5
|
+
from spaik_sdk.llm.consumption.token_usage import TokenUsage
|
|
6
|
+
from spaik_sdk.llm.streaming.block_manager import BlockManager
|
|
7
|
+
from spaik_sdk.llm.streaming.models import EventType, StreamingEvent
|
|
8
|
+
from spaik_sdk.llm.streaming.streaming_content_handler import StreamingContentHandler
|
|
9
|
+
from spaik_sdk.llm.streaming.streaming_state_manager import StreamingStateManager
|
|
10
|
+
from spaik_sdk.recording.base_recorder import BaseRecorder
|
|
11
|
+
from spaik_sdk.utils.init_logger import init_logger
|
|
12
|
+
|
|
13
|
+
logger = init_logger(__name__)
|
|
14
|
+
|
|
15
|
+
AIMessageType = Union[AIMessage, AIMessageChunk]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class StreamingEventHandler:
|
|
19
|
+
"""Handles LangChain 1.x streaming events."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, recorder: Optional[BaseRecorder] = None):
|
|
22
|
+
self.recorder = recorder
|
|
23
|
+
self.block_manager = BlockManager()
|
|
24
|
+
self.state_manager = StreamingStateManager()
|
|
25
|
+
self.content_handler = StreamingContentHandler(self.block_manager, self.state_manager)
|
|
26
|
+
self._processed_message_ids: set[str] = set()
|
|
27
|
+
self._final_message: Optional[AIMessageType] = None
|
|
28
|
+
self._got_chat_model_stream: bool = False
|
|
29
|
+
|
|
30
|
+
def reset(self) -> None:
|
|
31
|
+
self.block_manager.reset()
|
|
32
|
+
self.state_manager.reset()
|
|
33
|
+
self._processed_message_ids.clear()
|
|
34
|
+
self._final_message = None
|
|
35
|
+
self._got_chat_model_stream = False
|
|
36
|
+
|
|
37
|
+
async def process_stream(self, agent_stream) -> AsyncGenerator[StreamingEvent, None]:
|
|
38
|
+
"""Process LangChain 1.x agent stream events."""
|
|
39
|
+
self.reset()
|
|
40
|
+
|
|
41
|
+
async for event in agent_stream:
|
|
42
|
+
if self.recorder is not None:
|
|
43
|
+
self.recorder.record_token(event)
|
|
44
|
+
|
|
45
|
+
event_type = event.get("event", "")
|
|
46
|
+
data = event.get("data", {})
|
|
47
|
+
logger.trace(f"Stream event: {event_type}")
|
|
48
|
+
|
|
49
|
+
# on_chat_model_stream - real-time token streaming (preferred)
|
|
50
|
+
if event_type == "on_chat_model_stream":
|
|
51
|
+
self._got_chat_model_stream = True
|
|
52
|
+
chunk = data.get("chunk")
|
|
53
|
+
if isinstance(chunk, AIMessageChunk):
|
|
54
|
+
async for streaming_event in self._handle_chunk(chunk):
|
|
55
|
+
yield streaming_event
|
|
56
|
+
|
|
57
|
+
# on_chain_stream - complete messages (fallback if no chat_model_stream)
|
|
58
|
+
elif event_type == "on_chain_stream":
|
|
59
|
+
if not self._got_chat_model_stream:
|
|
60
|
+
ai_message = self._extract_ai_message(data.get("chunk", {}))
|
|
61
|
+
if ai_message and not self._is_duplicate(ai_message):
|
|
62
|
+
async for streaming_event in self._handle_message(ai_message):
|
|
63
|
+
yield streaming_event
|
|
64
|
+
self._final_message = ai_message
|
|
65
|
+
|
|
66
|
+
# on_chat_model_end - usage metadata from the model
|
|
67
|
+
elif event_type == "on_chat_model_end":
|
|
68
|
+
output = data.get("output")
|
|
69
|
+
if isinstance(output, (AIMessage, AIMessageChunk)):
|
|
70
|
+
self._final_message = output
|
|
71
|
+
async for streaming_event in self._emit_usage_if_available(output):
|
|
72
|
+
yield streaming_event
|
|
73
|
+
|
|
74
|
+
# on_chain_end - final state
|
|
75
|
+
elif event_type == "on_chain_end":
|
|
76
|
+
output = data.get("output", {})
|
|
77
|
+
if isinstance(output, dict) and "messages" in output:
|
|
78
|
+
for msg in output["messages"]:
|
|
79
|
+
if isinstance(msg, (AIMessage, AIMessageChunk)):
|
|
80
|
+
if self._final_message is None:
|
|
81
|
+
self._final_message = msg
|
|
82
|
+
async for streaming_event in self._emit_usage_if_available(msg):
|
|
83
|
+
yield streaming_event
|
|
84
|
+
break
|
|
85
|
+
|
|
86
|
+
# on_tool_end - tool execution completed
|
|
87
|
+
elif event_type == "on_tool_end":
|
|
88
|
+
output = data.get("output")
|
|
89
|
+
if output is not None:
|
|
90
|
+
tool_call_id = getattr(output, "tool_call_id", None)
|
|
91
|
+
content = getattr(output, "content", str(output))
|
|
92
|
+
if tool_call_id:
|
|
93
|
+
async for streaming_event in self.content_handler.handle_tool_response(
|
|
94
|
+
tool_call_id, content if isinstance(content, str) else str(content)
|
|
95
|
+
):
|
|
96
|
+
yield streaming_event
|
|
97
|
+
|
|
98
|
+
# End any active thinking session
|
|
99
|
+
async for event in self.content_handler.end_final_thinking_session_if_needed():
|
|
100
|
+
yield event
|
|
101
|
+
|
|
102
|
+
# Emit final COMPLETE event
|
|
103
|
+
if self._final_message or self.state_manager.current_message_id:
|
|
104
|
+
yield StreamingEvent(
|
|
105
|
+
event_type=EventType.COMPLETE,
|
|
106
|
+
message=self._final_message,
|
|
107
|
+
blocks=self.block_manager.get_block_ids(),
|
|
108
|
+
message_id=self.state_manager.current_message_id,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def _is_duplicate(self, message: AIMessageType) -> bool:
|
|
112
|
+
msg_id = getattr(message, "id", None)
|
|
113
|
+
if not msg_id:
|
|
114
|
+
return False
|
|
115
|
+
if msg_id in self._processed_message_ids:
|
|
116
|
+
return True
|
|
117
|
+
self._processed_message_ids.add(msg_id)
|
|
118
|
+
return False
|
|
119
|
+
|
|
120
|
+
def _extract_ai_message(self, chunk: dict) -> Optional[AIMessageType]:
|
|
121
|
+
if "messages" in chunk:
|
|
122
|
+
for msg in chunk["messages"]:
|
|
123
|
+
if isinstance(msg, (AIMessage, AIMessageChunk)):
|
|
124
|
+
return msg
|
|
125
|
+
if "model" in chunk and isinstance(chunk["model"], dict):
|
|
126
|
+
if "messages" in chunk["model"]:
|
|
127
|
+
for msg in chunk["model"]["messages"]:
|
|
128
|
+
if isinstance(msg, (AIMessage, AIMessageChunk)):
|
|
129
|
+
return msg
|
|
130
|
+
return None
|
|
131
|
+
|
|
132
|
+
async def _handle_chunk(self, chunk: AIMessageChunk) -> AsyncGenerator[StreamingEvent, None]:
|
|
133
|
+
"""Handle streaming chunk - real-time content."""
|
|
134
|
+
content = chunk.content
|
|
135
|
+
|
|
136
|
+
if isinstance(content, str) and content:
|
|
137
|
+
async for event in self.content_handler.handle_regular_content(content):
|
|
138
|
+
yield event
|
|
139
|
+
self.state_manager.mark_text_content_received()
|
|
140
|
+
|
|
141
|
+
elif isinstance(content, list):
|
|
142
|
+
for block in content:
|
|
143
|
+
if isinstance(block, dict):
|
|
144
|
+
block_type = block.get("type")
|
|
145
|
+
if block_type == "text":
|
|
146
|
+
text = block.get("text", "")
|
|
147
|
+
if text:
|
|
148
|
+
async for event in self.content_handler.handle_regular_content(text):
|
|
149
|
+
yield event
|
|
150
|
+
self.state_manager.mark_text_content_received()
|
|
151
|
+
elif block_type in ("reasoning", "thinking"):
|
|
152
|
+
reasoning = block.get("reasoning", "") or block.get("thinking", "")
|
|
153
|
+
if reasoning:
|
|
154
|
+
async for event in self.content_handler.handle_reasoning_content(reasoning):
|
|
155
|
+
yield event
|
|
156
|
+
elif isinstance(block, str) and block:
|
|
157
|
+
async for event in self.content_handler.handle_regular_content(block):
|
|
158
|
+
yield event
|
|
159
|
+
self.state_manager.mark_text_content_received()
|
|
160
|
+
|
|
161
|
+
if hasattr(chunk, "tool_calls") and chunk.tool_calls:
|
|
162
|
+
for tool_call in chunk.tool_calls:
|
|
163
|
+
tool_id = tool_call.get("id") if isinstance(tool_call, dict) else getattr(tool_call, "id", None)
|
|
164
|
+
tool_name = tool_call.get("name") if isinstance(tool_call, dict) else getattr(tool_call, "name", None)
|
|
165
|
+
tool_args = tool_call.get("args", {}) if isinstance(tool_call, dict) else getattr(tool_call, "args", {})
|
|
166
|
+
if tool_id and tool_name:
|
|
167
|
+
async for event in self.content_handler.handle_tool_use(tool_id, tool_name, tool_args):
|
|
168
|
+
yield event
|
|
169
|
+
|
|
170
|
+
async def _handle_message(self, message: AIMessageType) -> AsyncGenerator[StreamingEvent, None]:
|
|
171
|
+
"""Handle complete message (from on_chain_stream fallback)."""
|
|
172
|
+
content = message.content
|
|
173
|
+
|
|
174
|
+
if isinstance(content, str) and content:
|
|
175
|
+
async for event in self.content_handler.handle_regular_content(content):
|
|
176
|
+
yield event
|
|
177
|
+
self.state_manager.mark_text_content_received()
|
|
178
|
+
|
|
179
|
+
elif isinstance(content, list):
|
|
180
|
+
for block in content:
|
|
181
|
+
if isinstance(block, dict):
|
|
182
|
+
block_type = block.get("type")
|
|
183
|
+
if block_type == "text":
|
|
184
|
+
async for event in self.content_handler.handle_regular_content(block.get("text", "")):
|
|
185
|
+
yield event
|
|
186
|
+
self.state_manager.mark_text_content_received()
|
|
187
|
+
elif block_type in ("reasoning", "thinking"):
|
|
188
|
+
reasoning = block.get("reasoning", "") or block.get("thinking", "")
|
|
189
|
+
async for event in self.content_handler.handle_reasoning_content(reasoning):
|
|
190
|
+
yield event
|
|
191
|
+
elif isinstance(block, str) and block:
|
|
192
|
+
async for event in self.content_handler.handle_regular_content(block):
|
|
193
|
+
yield event
|
|
194
|
+
self.state_manager.mark_text_content_received()
|
|
195
|
+
|
|
196
|
+
if hasattr(message, "tool_calls") and message.tool_calls:
|
|
197
|
+
for tool_call in message.tool_calls:
|
|
198
|
+
tool_id = tool_call.get("id") if isinstance(tool_call, dict) else getattr(tool_call, "id", None)
|
|
199
|
+
tool_name = tool_call.get("name") if isinstance(tool_call, dict) else getattr(tool_call, "name", None)
|
|
200
|
+
tool_args = tool_call.get("args", {}) if isinstance(tool_call, dict) else getattr(tool_call, "args", {})
|
|
201
|
+
if tool_id and tool_name:
|
|
202
|
+
async for event in self.content_handler.handle_tool_use(tool_id, tool_name, tool_args):
|
|
203
|
+
yield event
|
|
204
|
+
|
|
205
|
+
async def _emit_usage_if_available(self, message: AIMessageType) -> AsyncGenerator[StreamingEvent, None]:
|
|
206
|
+
"""Emit usage metadata if available on message."""
|
|
207
|
+
if hasattr(message, "usage_metadata") and message.usage_metadata:
|
|
208
|
+
yield StreamingEvent(
|
|
209
|
+
event_type=EventType.USAGE_METADATA,
|
|
210
|
+
message_id=self.state_manager.current_message_id,
|
|
211
|
+
usage_metadata=TokenUsage.from_langchain(message.usage_metadata),
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
__all__ = ["StreamingEventHandler"]
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from spaik_sdk.utils.init_logger import init_logger
|
|
4
|
+
|
|
5
|
+
logger = init_logger(__name__)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class StreamingStateManager:
|
|
9
|
+
"""Manages state for streaming operations."""
|
|
10
|
+
|
|
11
|
+
def __init__(self):
|
|
12
|
+
self.current_message_id: Optional[str] = None
|
|
13
|
+
self.streaming_started = False
|
|
14
|
+
|
|
15
|
+
# Track mid-response thinking state
|
|
16
|
+
self.last_block_type = None
|
|
17
|
+
self.has_text_content = False
|
|
18
|
+
self.reasoning_blocks_created = 0 # Track how many reasoning blocks we've created
|
|
19
|
+
self.in_thinking_session = False # Track if we're currently in a thinking session
|
|
20
|
+
|
|
21
|
+
def reset(self):
|
|
22
|
+
"""Reset all state for new stream."""
|
|
23
|
+
self.current_message_id = None
|
|
24
|
+
self.streaming_started = False
|
|
25
|
+
self.last_block_type = None
|
|
26
|
+
self.has_text_content = False
|
|
27
|
+
self.reasoning_blocks_created = 0
|
|
28
|
+
self.in_thinking_session = False
|
|
29
|
+
|
|
30
|
+
def start_thinking_session(self):
|
|
31
|
+
"""Mark the start of a thinking session."""
|
|
32
|
+
if not self.in_thinking_session:
|
|
33
|
+
logger.debug("🧠 Starting thinking session")
|
|
34
|
+
self.in_thinking_session = True
|
|
35
|
+
|
|
36
|
+
def end_thinking_session(self):
|
|
37
|
+
"""Mark the end of a thinking session."""
|
|
38
|
+
if self.in_thinking_session:
|
|
39
|
+
logger.debug("Ending thinking session - got text content")
|
|
40
|
+
self.in_thinking_session = False
|
|
41
|
+
|
|
42
|
+
def increment_reasoning_blocks(self):
|
|
43
|
+
"""Increment the count of reasoning blocks created."""
|
|
44
|
+
self.reasoning_blocks_created += 1
|
|
45
|
+
logger.debug(f"Created reasoning block #{self.reasoning_blocks_created}")
|
|
46
|
+
|
|
47
|
+
def should_create_new_thinking_session(self, reasoning_content: bool, current_block_type: str) -> bool:
|
|
48
|
+
"""Determine if we should create a new thinking session (mid-response thinking)."""
|
|
49
|
+
return reasoning_content and self.has_text_content and self.last_block_type == "text" and current_block_type == "thinking"
|
|
50
|
+
|
|
51
|
+
def update_block_type(self, block_type: str):
|
|
52
|
+
"""Update the last block type."""
|
|
53
|
+
if block_type:
|
|
54
|
+
self.last_block_type = block_type
|
|
55
|
+
|
|
56
|
+
def mark_text_content_received(self):
|
|
57
|
+
"""Mark that we've received text content."""
|
|
58
|
+
self.has_text_content = True
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
|
|
3
|
+
from spaik_sdk.models.factories.base_model_factory import BaseModelFactory
|
|
4
|
+
from spaik_sdk.models.llm_config import LLMConfig
|
|
5
|
+
from spaik_sdk.models.llm_families import LLMFamilies
|
|
6
|
+
from spaik_sdk.models.llm_model import LLMModel
|
|
7
|
+
from spaik_sdk.models.model_registry import ModelRegistry
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AnthropicModelFactory(BaseModelFactory):
|
|
11
|
+
MODELS = ModelRegistry.get_by_family(LLMFamilies.ANTHROPIC)
|
|
12
|
+
|
|
13
|
+
def supports_model(self, model: LLMModel) -> bool:
|
|
14
|
+
return model in AnthropicModelFactory.MODELS
|
|
15
|
+
|
|
16
|
+
def get_cache_control(self, config: LLMConfig) -> Optional[Dict[str, Any]]:
|
|
17
|
+
return {"type": "ephemeral"}
|
|
18
|
+
|
|
19
|
+
def get_model_specific_config(self, config: LLMConfig) -> Dict[str, Any]:
|
|
20
|
+
allow_reasoning = config.reasoning and not config.structured_response
|
|
21
|
+
model_config: Dict[str, Any] = {
|
|
22
|
+
"model_name": config.model.name,
|
|
23
|
+
"streaming": config.streaming,
|
|
24
|
+
"max_tokens": config.max_output_tokens,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
# Handle thinking mode via model_kwargs for LangChain compatibility
|
|
28
|
+
if allow_reasoning:
|
|
29
|
+
model_config["thinking"] = {"type": "enabled", "budget_tokens": config.reasoning_budget_tokens}
|
|
30
|
+
else:
|
|
31
|
+
model_config["temperature"] = config.temperature
|
|
32
|
+
|
|
33
|
+
return model_config
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
3
|
+
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from spaik_sdk.models.llm_config import LLMConfig
|
|
6
|
+
|
|
7
|
+
from spaik_sdk.models.llm_model import LLMModel
|
|
8
|
+
from spaik_sdk.models.llm_wrapper import LLMWrapper
|
|
9
|
+
from spaik_sdk.models.providers.base_provider import BaseProvider
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BaseModelFactory(ABC):
|
|
13
|
+
def create_model(self, config: "LLMConfig", provider: BaseProvider) -> LLMWrapper:
|
|
14
|
+
"""Create a model wrapper for the given config and provider instance."""
|
|
15
|
+
# Check if this factory supports the model with this config
|
|
16
|
+
if not self.supports_model_config(config):
|
|
17
|
+
raise ValueError(f"Factory doesn't support model config: {config}")
|
|
18
|
+
|
|
19
|
+
# Get provider config and cache control
|
|
20
|
+
provider_config = provider.get_model_config(config)
|
|
21
|
+
cache_control = self.get_cache_control(config)
|
|
22
|
+
|
|
23
|
+
# Get model-specific configuration from subclass
|
|
24
|
+
model_specific_config = self.get_model_specific_config(config)
|
|
25
|
+
|
|
26
|
+
# Build complete model config
|
|
27
|
+
model_config = {**model_specific_config, **provider_config}
|
|
28
|
+
|
|
29
|
+
# Let provider create the langchain model
|
|
30
|
+
langchain_model = provider.create_langchain_model(config, model_config)
|
|
31
|
+
|
|
32
|
+
return LLMWrapper(langchain_model, cache_control, config.model)
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def supports_model(self, model: LLMModel) -> bool:
|
|
36
|
+
"""Check if this factory supports the given model (basic check)."""
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
def supports_model_config(self, config: "LLMConfig") -> bool:
|
|
40
|
+
return self.supports_model(config.model)
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def get_cache_control(self, config: "LLMConfig") -> Optional[Dict[str, Any]]:
|
|
44
|
+
"""Get cache control settings for this factory's models."""
|
|
45
|
+
pass
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def get_model_specific_config(self, config: "LLMConfig") -> Dict[str, Any]:
|
|
49
|
+
"""Get model-specific configuration for the given config."""
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def create_factory(cls, config: "LLMConfig") -> "BaseModelFactory":
|
|
54
|
+
"""Factory method to create appropriate factory instance."""
|
|
55
|
+
|
|
56
|
+
from spaik_sdk.models.factories.anthropic_factory import AnthropicModelFactory
|
|
57
|
+
from spaik_sdk.models.factories.google_factory import GoogleModelFactory
|
|
58
|
+
from spaik_sdk.models.factories.ollama_factory import OllamaModelFactory
|
|
59
|
+
from spaik_sdk.models.factories.openai_factory import OpenAIModelFactory
|
|
60
|
+
|
|
61
|
+
factories = [
|
|
62
|
+
AnthropicModelFactory(),
|
|
63
|
+
OpenAIModelFactory(),
|
|
64
|
+
GoogleModelFactory(),
|
|
65
|
+
OllamaModelFactory(),
|
|
66
|
+
]
|
|
67
|
+
for factory in factories:
|
|
68
|
+
if factory.supports_model_config(config):
|
|
69
|
+
return factory
|
|
70
|
+
|
|
71
|
+
raise ValueError(f"No factory found that supports model config: {config}")
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
|
|
3
|
+
from spaik_sdk.models.factories.base_model_factory import BaseModelFactory
|
|
4
|
+
from spaik_sdk.models.llm_config import LLMConfig
|
|
5
|
+
from spaik_sdk.models.llm_families import LLMFamilies
|
|
6
|
+
from spaik_sdk.models.llm_model import LLMModel
|
|
7
|
+
from spaik_sdk.models.model_registry import ModelRegistry
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class GoogleModelFactory(BaseModelFactory):
|
|
11
|
+
MODELS = ModelRegistry.get_by_family(LLMFamilies.GOOGLE)
|
|
12
|
+
|
|
13
|
+
def supports_model(self, model: LLMModel) -> bool:
|
|
14
|
+
return model in GoogleModelFactory.MODELS
|
|
15
|
+
|
|
16
|
+
def get_cache_control(self, config: LLMConfig) -> Optional[Dict[str, Any]]:
|
|
17
|
+
return {"type": "permanent"}
|
|
18
|
+
|
|
19
|
+
def get_model_specific_config(self, config: LLMConfig) -> Dict[str, Any]:
|
|
20
|
+
model_config: Dict[str, Any] = {"model": config.model.name, "temperature": config.temperature}
|
|
21
|
+
|
|
22
|
+
if config.reasoning:
|
|
23
|
+
model_config["thinking_budget"] = config.reasoning_budget_tokens
|
|
24
|
+
model_config["include_thoughts"] = True
|
|
25
|
+
|
|
26
|
+
# Handle streaming - Google models use disable_streaming instead of streaming
|
|
27
|
+
if not config.streaming:
|
|
28
|
+
model_config["disable_streaming"] = True
|
|
29
|
+
|
|
30
|
+
return model_config
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
|
|
3
|
+
from spaik_sdk.models.factories.base_model_factory import BaseModelFactory
|
|
4
|
+
from spaik_sdk.models.llm_config import LLMConfig
|
|
5
|
+
from spaik_sdk.models.llm_families import LLMFamilies
|
|
6
|
+
from spaik_sdk.models.llm_model import LLMModel
|
|
7
|
+
from spaik_sdk.models.model_registry import ModelRegistry
|
|
8
|
+
from spaik_sdk.utils.init_logger import init_logger
|
|
9
|
+
|
|
10
|
+
logger = init_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class OllamaModelFactory(BaseModelFactory):
|
|
14
|
+
# Models are dynamically created by users with LLMModel(family="ollama", name="...")
|
|
15
|
+
# So we start with an empty registry and let users add models as needed
|
|
16
|
+
MODELS = ModelRegistry.get_by_family(LLMFamilies.OLLAMA)
|
|
17
|
+
|
|
18
|
+
def supports_model(self, model: LLMModel) -> bool:
|
|
19
|
+
return model.family == "ollama"
|
|
20
|
+
|
|
21
|
+
def get_cache_control(self, config: LLMConfig) -> Optional[Dict[str, Any]]:
|
|
22
|
+
# Ollama doesn't support prompt caching in the same way as cloud providers
|
|
23
|
+
return None
|
|
24
|
+
|
|
25
|
+
def get_model_specific_config(self, config: LLMConfig) -> Dict[str, Any]:
|
|
26
|
+
model_config: Dict[str, Any] = {
|
|
27
|
+
"model": config.model.name,
|
|
28
|
+
"temperature": config.temperature,
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
# Enable streaming if requested
|
|
32
|
+
if config.streaming:
|
|
33
|
+
model_config["streaming"] = True
|
|
34
|
+
|
|
35
|
+
# Handle reasoning configuration for models that support it (like deepseek-r1)
|
|
36
|
+
# reasoning=True should separate <think> content to additional_kwargs['reasoning_content']
|
|
37
|
+
# reasoning=None/False leaves <think> tags in main content
|
|
38
|
+
if config.reasoning is not None:
|
|
39
|
+
model_config["reasoning"] = config.reasoning
|
|
40
|
+
|
|
41
|
+
return model_config
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
|
|
3
|
+
from spaik_sdk.models.factories.base_model_factory import BaseModelFactory
|
|
4
|
+
from spaik_sdk.models.llm_config import LLMConfig
|
|
5
|
+
from spaik_sdk.models.llm_families import LLMFamilies
|
|
6
|
+
from spaik_sdk.models.llm_model import LLMModel
|
|
7
|
+
from spaik_sdk.models.model_registry import ModelRegistry
|
|
8
|
+
from spaik_sdk.utils.init_logger import init_logger
|
|
9
|
+
|
|
10
|
+
logger = init_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class OpenAIModelFactory(BaseModelFactory):
|
|
14
|
+
MODELS = ModelRegistry.get_by_family(LLMFamilies.OPENAI)
|
|
15
|
+
|
|
16
|
+
def supports_model(self, model: LLMModel) -> bool:
|
|
17
|
+
return model in OpenAIModelFactory.MODELS
|
|
18
|
+
|
|
19
|
+
def supports_model_config(self, config: LLMConfig) -> bool:
|
|
20
|
+
# First check basic model support
|
|
21
|
+
if not self.supports_model(config.model):
|
|
22
|
+
return False
|
|
23
|
+
if config.reasoning and not config.model.reasoning:
|
|
24
|
+
# let's not fail here, but we should log a warning
|
|
25
|
+
logger.warning(f"Model {config.model} does not support reasoning")
|
|
26
|
+
return True
|
|
27
|
+
|
|
28
|
+
def get_cache_control(self, config: LLMConfig) -> Optional[Dict[str, Any]]:
|
|
29
|
+
if config.model.prompt_caching:
|
|
30
|
+
return {"type": "permanent"}
|
|
31
|
+
return None
|
|
32
|
+
|
|
33
|
+
def get_model_specific_config(self, config: LLMConfig) -> Dict[str, Any]:
|
|
34
|
+
model_config: Dict[str, Any] = {"model": config.model.name, "streaming": config.streaming}
|
|
35
|
+
# Add parallel tool calls if tool usage is enabled
|
|
36
|
+
if config.tool_usage:
|
|
37
|
+
model_config["model_kwargs"] = {"parallel_tool_calls": True}
|
|
38
|
+
|
|
39
|
+
# Add model-specific configurations for reasoning models
|
|
40
|
+
if config.model.reasoning:
|
|
41
|
+
# Enable Responses API for reasoning models
|
|
42
|
+
model_config["use_responses_api"] = True
|
|
43
|
+
|
|
44
|
+
# Configure reasoning through model_kwargs as per LangChain docs
|
|
45
|
+
if config.reasoning_summary:
|
|
46
|
+
model_config["model_kwargs"] = {"reasoning": {"effort": config.reasoning_effort, "summary": config.reasoning_summary}}
|
|
47
|
+
else:
|
|
48
|
+
model_config["temperature"] = config.temperature
|
|
49
|
+
|
|
50
|
+
return model_config
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from dataclasses import dataclass, replace
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from spaik_sdk.models.llm_model import LLMModel
|
|
5
|
+
from spaik_sdk.models.llm_wrapper import LLMWrapper
|
|
6
|
+
from spaik_sdk.models.providers.base_provider import BaseProvider
|
|
7
|
+
from spaik_sdk.models.providers.provider_type import ProviderType
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class LLMConfig:
|
|
12
|
+
model: LLMModel
|
|
13
|
+
provider_type: Optional[ProviderType] = None
|
|
14
|
+
reasoning: bool = True
|
|
15
|
+
tool_usage: bool = True
|
|
16
|
+
streaming: bool = True
|
|
17
|
+
reasoning_summary: str = "detailed" # Options: "auto", "concise", "detailed", None
|
|
18
|
+
reasoning_effort: str = "medium" # Options: "low", "medium", "high"
|
|
19
|
+
max_output_tokens: int = 8192
|
|
20
|
+
reasoning_budget_tokens: int = 4096
|
|
21
|
+
temperature: float = 0.1
|
|
22
|
+
structured_response: bool = False
|
|
23
|
+
|
|
24
|
+
_model_wrapper: Optional[LLMWrapper] = None
|
|
25
|
+
|
|
26
|
+
def get_model_wrapper(self) -> LLMWrapper:
|
|
27
|
+
if self._model_wrapper is None:
|
|
28
|
+
self._model_wrapper = self.create_model_wrapper()
|
|
29
|
+
return self._model_wrapper
|
|
30
|
+
|
|
31
|
+
def create_model_wrapper(self) -> LLMWrapper:
|
|
32
|
+
provider = self.get_provider()
|
|
33
|
+
factory = self.get_factory()
|
|
34
|
+
return factory.create_model(self, provider)
|
|
35
|
+
|
|
36
|
+
def get_provider(self) -> BaseProvider:
|
|
37
|
+
return BaseProvider.create_provider(self.provider_type)
|
|
38
|
+
|
|
39
|
+
def get_factory(self):
|
|
40
|
+
# Late import to avoid circular dependency
|
|
41
|
+
from spaik_sdk.models.factories.base_model_factory import BaseModelFactory
|
|
42
|
+
|
|
43
|
+
return BaseModelFactory.create_factory(self)
|
|
44
|
+
|
|
45
|
+
def as_structured_response_config(self) -> "LLMConfig":
|
|
46
|
+
return replace(self, structured_response=True)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@dataclass(frozen=True)
|
|
5
|
+
class LLMModel:
|
|
6
|
+
"""Model identifier with extensibility support."""
|
|
7
|
+
|
|
8
|
+
family: str
|
|
9
|
+
name: str
|
|
10
|
+
reasoning: bool = True
|
|
11
|
+
prompt_caching: bool = False
|
|
12
|
+
|
|
13
|
+
def __str__(self) -> str:
|
|
14
|
+
return self.name
|
|
15
|
+
|
|
16
|
+
def __repr__(self) -> str:
|
|
17
|
+
return f"LLMModel('{self.name}')"
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
|
|
3
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
|
4
|
+
|
|
5
|
+
from spaik_sdk.models.llm_model import LLMModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class LLMWrapper:
|
|
9
|
+
def __init__(self, langchain_model: BaseChatModel, cache_control: Optional[Dict[str, Any]], model_type: LLMModel):
|
|
10
|
+
"""Initialize wrapper with pre-created langchain model and cache control."""
|
|
11
|
+
self._langchain_model = langchain_model
|
|
12
|
+
self._cache_control = cache_control
|
|
13
|
+
self._model_type = model_type
|
|
14
|
+
|
|
15
|
+
def get_langchain_model(self) -> BaseChatModel:
|
|
16
|
+
"""Get the underlying langchain model instance."""
|
|
17
|
+
return self._langchain_model
|
|
18
|
+
|
|
19
|
+
def get_cache_control(self) -> Optional[Dict[str, Any]]:
|
|
20
|
+
"""Get cache control settings for this model."""
|
|
21
|
+
return self._cache_control
|
|
22
|
+
|
|
23
|
+
def get_model_type(self) -> LLMModel:
|
|
24
|
+
"""Get the model type enum."""
|
|
25
|
+
return self._model_type
|