spaik-sdk 0.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spaik_sdk/__init__.py +21 -0
- spaik_sdk/agent/__init__.py +0 -0
- spaik_sdk/agent/base_agent.py +249 -0
- spaik_sdk/attachments/__init__.py +22 -0
- spaik_sdk/attachments/builder.py +61 -0
- spaik_sdk/attachments/file_storage_provider.py +27 -0
- spaik_sdk/attachments/mime_types.py +118 -0
- spaik_sdk/attachments/models.py +63 -0
- spaik_sdk/attachments/provider_support.py +53 -0
- spaik_sdk/attachments/storage/__init__.py +0 -0
- spaik_sdk/attachments/storage/base_file_storage.py +32 -0
- spaik_sdk/attachments/storage/impl/__init__.py +0 -0
- spaik_sdk/attachments/storage/impl/local_file_storage.py +101 -0
- spaik_sdk/audio/__init__.py +12 -0
- spaik_sdk/audio/options.py +53 -0
- spaik_sdk/audio/providers/__init__.py +1 -0
- spaik_sdk/audio/providers/google_tts.py +77 -0
- spaik_sdk/audio/providers/openai_stt.py +71 -0
- spaik_sdk/audio/providers/openai_tts.py +111 -0
- spaik_sdk/audio/stt.py +61 -0
- spaik_sdk/audio/tts.py +124 -0
- spaik_sdk/config/credentials_provider.py +10 -0
- spaik_sdk/config/env.py +59 -0
- spaik_sdk/config/env_credentials_provider.py +7 -0
- spaik_sdk/config/get_credentials_provider.py +14 -0
- spaik_sdk/image_gen/__init__.py +9 -0
- spaik_sdk/image_gen/image_generator.py +83 -0
- spaik_sdk/image_gen/options.py +24 -0
- spaik_sdk/image_gen/providers/__init__.py +0 -0
- spaik_sdk/image_gen/providers/google.py +75 -0
- spaik_sdk/image_gen/providers/openai.py +60 -0
- spaik_sdk/llm/__init__.py +0 -0
- spaik_sdk/llm/cancellation_handle.py +10 -0
- spaik_sdk/llm/consumption/__init__.py +0 -0
- spaik_sdk/llm/consumption/consumption_estimate.py +26 -0
- spaik_sdk/llm/consumption/consumption_estimate_builder.py +113 -0
- spaik_sdk/llm/consumption/consumption_extractor.py +59 -0
- spaik_sdk/llm/consumption/token_usage.py +31 -0
- spaik_sdk/llm/converters.py +146 -0
- spaik_sdk/llm/cost/__init__.py +1 -0
- spaik_sdk/llm/cost/builtin_cost_provider.py +83 -0
- spaik_sdk/llm/cost/cost_estimate.py +8 -0
- spaik_sdk/llm/cost/cost_provider.py +28 -0
- spaik_sdk/llm/extract_error_message.py +37 -0
- spaik_sdk/llm/langchain_loop_manager.py +270 -0
- spaik_sdk/llm/langchain_service.py +196 -0
- spaik_sdk/llm/message_handler.py +188 -0
- spaik_sdk/llm/streaming/__init__.py +1 -0
- spaik_sdk/llm/streaming/block_manager.py +152 -0
- spaik_sdk/llm/streaming/models.py +42 -0
- spaik_sdk/llm/streaming/streaming_content_handler.py +157 -0
- spaik_sdk/llm/streaming/streaming_event_handler.py +215 -0
- spaik_sdk/llm/streaming/streaming_state_manager.py +58 -0
- spaik_sdk/models/__init__.py +0 -0
- spaik_sdk/models/factories/__init__.py +0 -0
- spaik_sdk/models/factories/anthropic_factory.py +33 -0
- spaik_sdk/models/factories/base_model_factory.py +71 -0
- spaik_sdk/models/factories/google_factory.py +30 -0
- spaik_sdk/models/factories/ollama_factory.py +41 -0
- spaik_sdk/models/factories/openai_factory.py +50 -0
- spaik_sdk/models/llm_config.py +46 -0
- spaik_sdk/models/llm_families.py +7 -0
- spaik_sdk/models/llm_model.py +17 -0
- spaik_sdk/models/llm_wrapper.py +25 -0
- spaik_sdk/models/model_registry.py +156 -0
- spaik_sdk/models/providers/__init__.py +0 -0
- spaik_sdk/models/providers/anthropic_provider.py +29 -0
- spaik_sdk/models/providers/azure_provider.py +31 -0
- spaik_sdk/models/providers/base_provider.py +62 -0
- spaik_sdk/models/providers/google_provider.py +26 -0
- spaik_sdk/models/providers/ollama_provider.py +26 -0
- spaik_sdk/models/providers/openai_provider.py +26 -0
- spaik_sdk/models/providers/provider_type.py +90 -0
- spaik_sdk/orchestration/__init__.py +24 -0
- spaik_sdk/orchestration/base_orchestrator.py +238 -0
- spaik_sdk/orchestration/checkpoint.py +80 -0
- spaik_sdk/orchestration/models.py +103 -0
- spaik_sdk/prompt/__init__.py +0 -0
- spaik_sdk/prompt/get_prompt_loader.py +13 -0
- spaik_sdk/prompt/local_prompt_loader.py +21 -0
- spaik_sdk/prompt/prompt_loader.py +48 -0
- spaik_sdk/prompt/prompt_loader_mode.py +14 -0
- spaik_sdk/py.typed +1 -0
- spaik_sdk/recording/__init__.py +1 -0
- spaik_sdk/recording/base_playback.py +90 -0
- spaik_sdk/recording/base_recorder.py +50 -0
- spaik_sdk/recording/conditional_recorder.py +38 -0
- spaik_sdk/recording/impl/__init__.py +1 -0
- spaik_sdk/recording/impl/local_playback.py +76 -0
- spaik_sdk/recording/impl/local_recorder.py +85 -0
- spaik_sdk/recording/langchain_serializer.py +88 -0
- spaik_sdk/server/__init__.py +1 -0
- spaik_sdk/server/api/routers/__init__.py +0 -0
- spaik_sdk/server/api/routers/api_builder.py +149 -0
- spaik_sdk/server/api/routers/audio_router_factory.py +201 -0
- spaik_sdk/server/api/routers/file_router_factory.py +111 -0
- spaik_sdk/server/api/routers/thread_router_factory.py +284 -0
- spaik_sdk/server/api/streaming/__init__.py +0 -0
- spaik_sdk/server/api/streaming/format_sse_event.py +41 -0
- spaik_sdk/server/api/streaming/negotiate_streaming_response.py +8 -0
- spaik_sdk/server/api/streaming/streaming_negotiator.py +10 -0
- spaik_sdk/server/authorization/__init__.py +0 -0
- spaik_sdk/server/authorization/base_authorizer.py +64 -0
- spaik_sdk/server/authorization/base_user.py +13 -0
- spaik_sdk/server/authorization/dummy_authorizer.py +17 -0
- spaik_sdk/server/job_processor/__init__.py +0 -0
- spaik_sdk/server/job_processor/base_job_processor.py +8 -0
- spaik_sdk/server/job_processor/thread_job_processor.py +32 -0
- spaik_sdk/server/pubsub/__init__.py +1 -0
- spaik_sdk/server/pubsub/cancellation_publisher.py +7 -0
- spaik_sdk/server/pubsub/cancellation_subscriber.py +38 -0
- spaik_sdk/server/pubsub/event_publisher.py +13 -0
- spaik_sdk/server/pubsub/impl/__init__.py +1 -0
- spaik_sdk/server/pubsub/impl/local_cancellation_pubsub.py +48 -0
- spaik_sdk/server/pubsub/impl/signalr_publisher.py +36 -0
- spaik_sdk/server/queue/__init__.py +1 -0
- spaik_sdk/server/queue/agent_job_queue.py +27 -0
- spaik_sdk/server/queue/impl/__init__.py +1 -0
- spaik_sdk/server/queue/impl/azure_queue.py +24 -0
- spaik_sdk/server/response/__init__.py +0 -0
- spaik_sdk/server/response/agent_response_generator.py +39 -0
- spaik_sdk/server/response/response_generator.py +13 -0
- spaik_sdk/server/response/simple_agent_response_generator.py +14 -0
- spaik_sdk/server/services/__init__.py +0 -0
- spaik_sdk/server/services/thread_converters.py +113 -0
- spaik_sdk/server/services/thread_models.py +90 -0
- spaik_sdk/server/services/thread_service.py +91 -0
- spaik_sdk/server/storage/__init__.py +1 -0
- spaik_sdk/server/storage/base_thread_repository.py +51 -0
- spaik_sdk/server/storage/impl/__init__.py +0 -0
- spaik_sdk/server/storage/impl/in_memory_thread_repository.py +100 -0
- spaik_sdk/server/storage/impl/local_file_thread_repository.py +217 -0
- spaik_sdk/server/storage/thread_filter.py +166 -0
- spaik_sdk/server/storage/thread_metadata.py +53 -0
- spaik_sdk/thread/__init__.py +0 -0
- spaik_sdk/thread/adapters/__init__.py +0 -0
- spaik_sdk/thread/adapters/cli/__init__.py +0 -0
- spaik_sdk/thread/adapters/cli/block_display.py +92 -0
- spaik_sdk/thread/adapters/cli/display_manager.py +84 -0
- spaik_sdk/thread/adapters/cli/live_cli.py +235 -0
- spaik_sdk/thread/adapters/event_adapter.py +28 -0
- spaik_sdk/thread/adapters/streaming_block_adapter.py +57 -0
- spaik_sdk/thread/adapters/sync_adapter.py +76 -0
- spaik_sdk/thread/models.py +224 -0
- spaik_sdk/thread/thread_container.py +468 -0
- spaik_sdk/tools/__init__.py +0 -0
- spaik_sdk/tools/impl/__init__.py +0 -0
- spaik_sdk/tools/impl/mcp_tool_provider.py +93 -0
- spaik_sdk/tools/impl/search_tool_provider.py +18 -0
- spaik_sdk/tools/tool_provider.py +131 -0
- spaik_sdk/tracing/__init__.py +13 -0
- spaik_sdk/tracing/agent_trace.py +72 -0
- spaik_sdk/tracing/get_trace_sink.py +15 -0
- spaik_sdk/tracing/local_trace_sink.py +23 -0
- spaik_sdk/tracing/trace_sink.py +19 -0
- spaik_sdk/tracing/trace_sink_mode.py +14 -0
- spaik_sdk/utils/__init__.py +0 -0
- spaik_sdk/utils/init_logger.py +24 -0
- spaik_sdk-0.6.2.dist-info/METADATA +379 -0
- spaik_sdk-0.6.2.dist-info/RECORD +161 -0
- spaik_sdk-0.6.2.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,468 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import time
|
|
3
|
+
import uuid
|
|
4
|
+
from typing import Callable, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
from langchain_core.messages import BaseMessage, SystemMessage
|
|
7
|
+
|
|
8
|
+
from spaik_sdk.attachments.storage.base_file_storage import BaseFileStorage
|
|
9
|
+
from spaik_sdk.llm.consumption.token_usage import TokenUsage
|
|
10
|
+
from spaik_sdk.llm.converters import convert_thread_message_to_langchain, convert_thread_message_to_langchain_multimodal
|
|
11
|
+
from spaik_sdk.thread.models import (
|
|
12
|
+
BlockAddedEvent,
|
|
13
|
+
BlockFullyAddedEvent,
|
|
14
|
+
MessageAddedEvent,
|
|
15
|
+
MessageBlock,
|
|
16
|
+
MessageBlockType,
|
|
17
|
+
MessageFullyAddedEvent,
|
|
18
|
+
StreamingEndedEvent,
|
|
19
|
+
StreamingUpdatedEvent,
|
|
20
|
+
ThreadEvent,
|
|
21
|
+
ThreadMessage,
|
|
22
|
+
ToolCallResponse,
|
|
23
|
+
ToolCallStartedEvent,
|
|
24
|
+
ToolResponseReceivedEvent,
|
|
25
|
+
)
|
|
26
|
+
from spaik_sdk.utils.init_logger import init_logger
|
|
27
|
+
|
|
28
|
+
logger = init_logger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ThreadContainer:
|
|
32
|
+
def __init__(self, system_prompt: Optional[str] = None):
|
|
33
|
+
self.messages: List[ThreadMessage] = []
|
|
34
|
+
self.streaming_content: Dict[str, str] = {}
|
|
35
|
+
self.tool_call_responses: Dict[str, ToolCallResponse] = {}
|
|
36
|
+
self.system_prompt = system_prompt
|
|
37
|
+
# Single event stream with multiple subscribers
|
|
38
|
+
self._subscribers: List[Callable[[ThreadEvent], None]] = []
|
|
39
|
+
|
|
40
|
+
# Version tracking
|
|
41
|
+
self._version = 0
|
|
42
|
+
self._last_activity_time = int(time.time() * 1000)
|
|
43
|
+
self.thread_id = str(uuid.uuid4())
|
|
44
|
+
self.job_id = "unknown"
|
|
45
|
+
|
|
46
|
+
def subscribe(self, callback: Callable[[ThreadEvent], None]) -> None:
|
|
47
|
+
"""Subscribe to the event stream"""
|
|
48
|
+
if callback not in self._subscribers:
|
|
49
|
+
self._subscribers.append(callback)
|
|
50
|
+
|
|
51
|
+
def unsubscribe(self, callback: Callable[[ThreadEvent], None]) -> None:
|
|
52
|
+
"""Unsubscribe from the event stream"""
|
|
53
|
+
if callback in self._subscribers:
|
|
54
|
+
self._subscribers.remove(callback)
|
|
55
|
+
|
|
56
|
+
def _emit_event(self, event: ThreadEvent) -> None:
|
|
57
|
+
"""Emit a typed event to all subscribers"""
|
|
58
|
+
for callback in self._subscribers:
|
|
59
|
+
try:
|
|
60
|
+
callback(event)
|
|
61
|
+
except Exception as e:
|
|
62
|
+
logger.error(f"Event callback error: {e}")
|
|
63
|
+
|
|
64
|
+
def _increment_version(self) -> None:
|
|
65
|
+
"""Increment version and update activity time"""
|
|
66
|
+
self._version += 1
|
|
67
|
+
self._last_activity_time = int(time.time() * 1000)
|
|
68
|
+
|
|
69
|
+
def get_version(self) -> int:
|
|
70
|
+
"""Get current version for incremental updates"""
|
|
71
|
+
return self._version
|
|
72
|
+
|
|
73
|
+
def get_last_activity_time(self) -> int:
|
|
74
|
+
"""Get timestamp of last activity"""
|
|
75
|
+
return self._last_activity_time
|
|
76
|
+
|
|
77
|
+
def add_streaming_message_chunk(self, block_id: str, content: str) -> None:
|
|
78
|
+
"""Update streaming content by appending new content to existing content for the block_id"""
|
|
79
|
+
if block_id in self.streaming_content:
|
|
80
|
+
self.streaming_content[block_id] += content
|
|
81
|
+
else:
|
|
82
|
+
self.streaming_content[block_id] = content
|
|
83
|
+
|
|
84
|
+
# Emit streaming update event
|
|
85
|
+
self._emit_event(StreamingUpdatedEvent(block_id=block_id, content=content, total_content=self.streaming_content[block_id]))
|
|
86
|
+
|
|
87
|
+
self._increment_version()
|
|
88
|
+
|
|
89
|
+
def add_message(self, msg: ThreadMessage) -> None:
|
|
90
|
+
"""Add a new message to the thread"""
|
|
91
|
+
self.messages.append(msg)
|
|
92
|
+
self._emit_event(MessageAddedEvent(message=copy.deepcopy(msg)))
|
|
93
|
+
self._increment_version()
|
|
94
|
+
|
|
95
|
+
def add_message_block(self, message_id: str, block: MessageBlock) -> None:
|
|
96
|
+
"""Add a message block to an existing message by message_id"""
|
|
97
|
+
for message in self.messages:
|
|
98
|
+
if message.id == message_id:
|
|
99
|
+
message.blocks.append(block)
|
|
100
|
+
|
|
101
|
+
# Emit block added event
|
|
102
|
+
self._emit_event(BlockAddedEvent(message_id=message_id, block_id=block.id, block=copy.deepcopy(block)))
|
|
103
|
+
|
|
104
|
+
# If it's a tool block, emit tool call started
|
|
105
|
+
if block.type == MessageBlockType.TOOL_USE and block.tool_call_id:
|
|
106
|
+
tool_name = block.tool_name or "unknown"
|
|
107
|
+
|
|
108
|
+
self._emit_event(
|
|
109
|
+
ToolCallStartedEvent(tool_call_id=block.tool_call_id, tool_name=tool_name, message_id=message_id, block_id=block.id)
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
self._increment_version()
|
|
113
|
+
break
|
|
114
|
+
|
|
115
|
+
def add_tool_call_response(self, response: ToolCallResponse) -> None:
|
|
116
|
+
"""Add a tool call response by its ID"""
|
|
117
|
+
self.tool_call_responses[response.id] = response
|
|
118
|
+
|
|
119
|
+
# Find the corresponding block ID
|
|
120
|
+
block_id = None
|
|
121
|
+
for message in self.messages:
|
|
122
|
+
for block in message.blocks:
|
|
123
|
+
if block.tool_call_id == response.id:
|
|
124
|
+
block.streaming = False
|
|
125
|
+
block.tool_call_response = response.response
|
|
126
|
+
block.tool_call_error = response.error
|
|
127
|
+
block_id = block.id
|
|
128
|
+
break
|
|
129
|
+
if block_id:
|
|
130
|
+
break
|
|
131
|
+
|
|
132
|
+
# Emit tool response event
|
|
133
|
+
self._emit_event(
|
|
134
|
+
ToolResponseReceivedEvent(tool_call_id=response.id, response=response.response, error=response.error, block_id=block_id)
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
self._increment_version()
|
|
138
|
+
|
|
139
|
+
def update_tool_use_block_with_response(self, tool_call_id: str, response: str, error: Optional[str] = None) -> None:
|
|
140
|
+
"""Update a tool use block with the tool response and mark it as completed."""
|
|
141
|
+
logger.debug(f"🔧 DEBUG: Updating tool response for {tool_call_id}")
|
|
142
|
+
logger.debug(f"🔧 DEBUG: Response: {response[:100]}...")
|
|
143
|
+
logger.debug(f"🔧 DEBUG: Error: {error}")
|
|
144
|
+
|
|
145
|
+
# Add the tool response to our responses dict
|
|
146
|
+
tool_response = ToolCallResponse(id=tool_call_id, response=response, error=error)
|
|
147
|
+
self.add_tool_call_response(tool_response)
|
|
148
|
+
|
|
149
|
+
# Find the message and block with this tool_call_id and mark it as non-streaming
|
|
150
|
+
for message in self.messages:
|
|
151
|
+
for block in message.blocks:
|
|
152
|
+
if block.tool_call_id == tool_call_id:
|
|
153
|
+
block.streaming = False
|
|
154
|
+
logger.debug(f"🔧 DEBUG: Found and updated block {block.id} for tool {tool_call_id}")
|
|
155
|
+
self._increment_version()
|
|
156
|
+
return
|
|
157
|
+
|
|
158
|
+
logger.debug(f"🔧 DEBUG: Could not find block for tool_call_id: {tool_call_id}")
|
|
159
|
+
|
|
160
|
+
def add_error_message(self, error_text: str, author_id: str = "system", author_name: str = "system") -> str:
|
|
161
|
+
"""Add an error message and return the message ID"""
|
|
162
|
+
message_id = str(uuid.uuid4())
|
|
163
|
+
error_message = ThreadMessage(
|
|
164
|
+
id=message_id,
|
|
165
|
+
ai=True,
|
|
166
|
+
author_id=author_id,
|
|
167
|
+
author_name=author_name,
|
|
168
|
+
timestamp=int(time.time() * 1000),
|
|
169
|
+
blocks=[
|
|
170
|
+
MessageBlock(
|
|
171
|
+
id=str(uuid.uuid4()),
|
|
172
|
+
streaming=False,
|
|
173
|
+
type=MessageBlockType.ERROR,
|
|
174
|
+
content=error_text,
|
|
175
|
+
)
|
|
176
|
+
],
|
|
177
|
+
)
|
|
178
|
+
self.add_message(error_message)
|
|
179
|
+
return message_id
|
|
180
|
+
|
|
181
|
+
def finalize_streaming_blocks(self, message_id: str, block_ids: List[str]) -> None:
|
|
182
|
+
"""Mark specified blocks as non-streaming (completed)."""
|
|
183
|
+
completed_blocks = []
|
|
184
|
+
|
|
185
|
+
for message in self.messages:
|
|
186
|
+
if message.id == message_id:
|
|
187
|
+
for block in message.blocks:
|
|
188
|
+
if block.id in block_ids and block.streaming:
|
|
189
|
+
block.streaming = False
|
|
190
|
+
# Move content from streaming_content to block.content when streaming finishes
|
|
191
|
+
if block.id in self.streaming_content:
|
|
192
|
+
block.content = self.streaming_content[block.id]
|
|
193
|
+
# lets not do this, access is needed at least for now
|
|
194
|
+
# # Optionally remove from streaming_content to save memory
|
|
195
|
+
# del self.streaming_content[block.id]
|
|
196
|
+
completed_blocks.append(block.id)
|
|
197
|
+
|
|
198
|
+
# Emit block fully added event for each completed block
|
|
199
|
+
self._emit_event(BlockFullyAddedEvent(block_id=block.id, message_id=message_id, block=copy.deepcopy(block)))
|
|
200
|
+
break
|
|
201
|
+
|
|
202
|
+
if completed_blocks:
|
|
203
|
+
self._increment_version()
|
|
204
|
+
|
|
205
|
+
# Check if streaming has ended
|
|
206
|
+
if not self.is_streaming_active():
|
|
207
|
+
self._emit_event(StreamingEndedEvent(message_id=message_id, completed_blocks=completed_blocks))
|
|
208
|
+
|
|
209
|
+
def cancel_generation(self) -> None:
|
|
210
|
+
"""Cancel the generation"""
|
|
211
|
+
logger.info(f"Cancelling generation. Current streaming content: {self.streaming_content}")
|
|
212
|
+
logger.info(f"Messages: {self.messages}")
|
|
213
|
+
for message in self.messages:
|
|
214
|
+
completed_blocks = []
|
|
215
|
+
for block in message.blocks:
|
|
216
|
+
block.streaming = False
|
|
217
|
+
# Move content from streaming_content to block.content when streaming finishes
|
|
218
|
+
if block.id in self.streaming_content:
|
|
219
|
+
block.content = self.streaming_content[block.id]
|
|
220
|
+
logger.info(f"🔧 Block {block.id} content: {block.content}")
|
|
221
|
+
completed_blocks.append(block.id)
|
|
222
|
+
self._emit_event(BlockFullyAddedEvent(block_id=block.id, message_id=message.id, block=copy.deepcopy(block)))
|
|
223
|
+
self.finalize_streaming_blocks(message.id, completed_blocks)
|
|
224
|
+
if completed_blocks:
|
|
225
|
+
self._increment_version()
|
|
226
|
+
|
|
227
|
+
# Check if streaming has ended
|
|
228
|
+
if not self.is_streaming_active():
|
|
229
|
+
self._emit_event(StreamingEndedEvent(message_id=message.id, completed_blocks=completed_blocks))
|
|
230
|
+
|
|
231
|
+
def complete_generation(self) -> None:
|
|
232
|
+
"""Mark the message as fully added and emit the event"""
|
|
233
|
+
latest_message = self.get_latest_ai_message()
|
|
234
|
+
if latest_message:
|
|
235
|
+
self._emit_event(MessageFullyAddedEvent(message=copy.deepcopy(latest_message)))
|
|
236
|
+
|
|
237
|
+
def is_streaming_active(self) -> bool:
|
|
238
|
+
"""Check if any blocks are currently streaming"""
|
|
239
|
+
for message in self.messages:
|
|
240
|
+
for block in message.blocks:
|
|
241
|
+
if block.streaming:
|
|
242
|
+
return True
|
|
243
|
+
return False
|
|
244
|
+
|
|
245
|
+
def get_latest_ai_message(self) -> Optional[ThreadMessage]:
|
|
246
|
+
"""Get the most recent AI message"""
|
|
247
|
+
for message in reversed(self.messages):
|
|
248
|
+
if message.ai:
|
|
249
|
+
return message
|
|
250
|
+
return None
|
|
251
|
+
|
|
252
|
+
def get_latest_message(self) -> ThreadMessage:
|
|
253
|
+
"""Get the most recent message"""
|
|
254
|
+
return self.messages[-1]
|
|
255
|
+
|
|
256
|
+
def get_message_by_id(self, message_id: str) -> Optional[ThreadMessage]:
|
|
257
|
+
"""Get message by ID"""
|
|
258
|
+
for message in self.messages:
|
|
259
|
+
if message.id == message_id:
|
|
260
|
+
return message
|
|
261
|
+
return None
|
|
262
|
+
|
|
263
|
+
def get_block_content(self, block: MessageBlock) -> str:
|
|
264
|
+
"""Get content for a specific block"""
|
|
265
|
+
# First check if block has content directly
|
|
266
|
+
if block.content is not None:
|
|
267
|
+
return block.content
|
|
268
|
+
|
|
269
|
+
# For streaming blocks, check streaming_content
|
|
270
|
+
if block.id in self.streaming_content:
|
|
271
|
+
return self.streaming_content[block.id]
|
|
272
|
+
|
|
273
|
+
if block.type == MessageBlockType.TOOL_USE:
|
|
274
|
+
if block.tool_call_id and block.tool_call_id in self.tool_call_responses:
|
|
275
|
+
response = self.tool_call_responses[block.tool_call_id]
|
|
276
|
+
return response.response
|
|
277
|
+
|
|
278
|
+
return ""
|
|
279
|
+
|
|
280
|
+
def get_system_prompt(self) -> str:
|
|
281
|
+
if self.system_prompt is None:
|
|
282
|
+
raise ValueError("System prompt is not set")
|
|
283
|
+
return self.system_prompt
|
|
284
|
+
|
|
285
|
+
def get_streaming_blocks(self) -> List[str]:
|
|
286
|
+
"""Get list of currently streaming block IDs"""
|
|
287
|
+
streaming_blocks = []
|
|
288
|
+
for message in self.messages:
|
|
289
|
+
for block in message.blocks:
|
|
290
|
+
if block.streaming:
|
|
291
|
+
streaming_blocks.append(block.id)
|
|
292
|
+
return streaming_blocks
|
|
293
|
+
|
|
294
|
+
def has_errors(self) -> bool:
|
|
295
|
+
"""Check if there are any error blocks or tool errors"""
|
|
296
|
+
for message in self.messages:
|
|
297
|
+
for block in message.blocks:
|
|
298
|
+
if block.type == MessageBlockType.ERROR:
|
|
299
|
+
return True
|
|
300
|
+
|
|
301
|
+
for response in self.tool_call_responses.values():
|
|
302
|
+
if response.error:
|
|
303
|
+
return True
|
|
304
|
+
|
|
305
|
+
return False
|
|
306
|
+
|
|
307
|
+
def get_final_text_content(self) -> str:
|
|
308
|
+
"""Get clean final text from the latest AI message"""
|
|
309
|
+
latest_message = self.get_latest_ai_message()
|
|
310
|
+
if not latest_message:
|
|
311
|
+
return ""
|
|
312
|
+
|
|
313
|
+
text_parts = []
|
|
314
|
+
for block in latest_message.blocks:
|
|
315
|
+
if block.type == MessageBlockType.PLAIN and not block.streaming:
|
|
316
|
+
content = self.get_block_content(block)
|
|
317
|
+
if content:
|
|
318
|
+
text_parts.append(content)
|
|
319
|
+
|
|
320
|
+
return " ".join(text_parts).strip()
|
|
321
|
+
|
|
322
|
+
def _find_message_id_by_block(self, block_id: str) -> Optional[str]:
|
|
323
|
+
"""Find message ID that contains the given block ID"""
|
|
324
|
+
for message in self.messages:
|
|
325
|
+
for block in message.blocks:
|
|
326
|
+
if block.id == block_id:
|
|
327
|
+
return message.id
|
|
328
|
+
return None
|
|
329
|
+
|
|
330
|
+
def get_langchain_messages(self) -> List[BaseMessage]:
|
|
331
|
+
"""Get all messages as LangChain BaseMessages"""
|
|
332
|
+
messages: List[BaseMessage] = [SystemMessage(content=self.get_system_prompt())]
|
|
333
|
+
messages.extend([convert_thread_message_to_langchain(msg) for msg in self.messages])
|
|
334
|
+
return messages
|
|
335
|
+
|
|
336
|
+
async def get_langchain_messages_multimodal(self, file_storage: BaseFileStorage, provider_family: str = "openai") -> List[BaseMessage]:
|
|
337
|
+
"""Get all messages as LangChain BaseMessages with multimodal content support"""
|
|
338
|
+
messages: List[BaseMessage] = [SystemMessage(content=self.get_system_prompt())]
|
|
339
|
+
for msg in self.messages:
|
|
340
|
+
converted = await convert_thread_message_to_langchain_multimodal(msg, file_storage, provider_family)
|
|
341
|
+
messages.append(converted)
|
|
342
|
+
return messages
|
|
343
|
+
|
|
344
|
+
def get_nof_messages_including_system(self) -> int:
|
|
345
|
+
"""Get number of messages including system message"""
|
|
346
|
+
return len(self.messages) + 1
|
|
347
|
+
|
|
348
|
+
def add_consumption_metadata(self, message_id: str, consumption_metadata: TokenUsage) -> None:
|
|
349
|
+
"""Add consumption metadata to a specific message"""
|
|
350
|
+
for message in self.messages:
|
|
351
|
+
if message.id == message_id:
|
|
352
|
+
message.consumption_metadata = consumption_metadata
|
|
353
|
+
self._increment_version()
|
|
354
|
+
break
|
|
355
|
+
|
|
356
|
+
def get_total_consumption(self) -> TokenUsage:
|
|
357
|
+
"""Calculate total consumption across all messages with consumption metadata"""
|
|
358
|
+
total_tokens = TokenUsage()
|
|
359
|
+
|
|
360
|
+
for message in self.messages:
|
|
361
|
+
if message.consumption_metadata and isinstance(message.consumption_metadata, TokenUsage):
|
|
362
|
+
token_usage = message.consumption_metadata
|
|
363
|
+
total_tokens.input_tokens += token_usage.input_tokens
|
|
364
|
+
total_tokens.output_tokens += token_usage.output_tokens
|
|
365
|
+
total_tokens.total_tokens += token_usage.total_tokens
|
|
366
|
+
total_tokens.reasoning_tokens += token_usage.reasoning_tokens
|
|
367
|
+
total_tokens.cache_creation_tokens += token_usage.cache_creation_tokens
|
|
368
|
+
total_tokens.cache_read_tokens += token_usage.cache_read_tokens
|
|
369
|
+
|
|
370
|
+
return total_tokens
|
|
371
|
+
|
|
372
|
+
def get_consumption_by_message(self, message_id: str) -> Optional[TokenUsage]:
|
|
373
|
+
"""Get consumption metadata for a specific message"""
|
|
374
|
+
for message in self.messages:
|
|
375
|
+
if message.id == message_id and message.consumption_metadata:
|
|
376
|
+
return message.consumption_metadata
|
|
377
|
+
return None
|
|
378
|
+
|
|
379
|
+
def get_latest_token_usage(self) -> Optional[TokenUsage]:
|
|
380
|
+
"""Get consumption metadata for the latest message"""
|
|
381
|
+
latest_message = self.get_latest_ai_message()
|
|
382
|
+
if latest_message and latest_message.consumption_metadata:
|
|
383
|
+
return latest_message.consumption_metadata
|
|
384
|
+
return None
|
|
385
|
+
|
|
386
|
+
def __str__(self) -> str:
|
|
387
|
+
"""String representation of the entire thread container"""
|
|
388
|
+
lines = ["=== THREAD CONTAINER ==="]
|
|
389
|
+
lines.append(f"Version: {self._version} | Active streaming: {self.is_streaming_active()}")
|
|
390
|
+
|
|
391
|
+
lines.append(f"\n📨 MESSAGES ({len(self.messages)}):")
|
|
392
|
+
for i, msg in enumerate(self.messages):
|
|
393
|
+
author = "🤖 AI" if msg.ai else f"👤 {msg.author_id}"
|
|
394
|
+
lines.append(f" [{i}] {author} | {msg.id} | {msg.timestamp}")
|
|
395
|
+
for j, block in enumerate(msg.blocks):
|
|
396
|
+
if block.type == MessageBlockType.ERROR:
|
|
397
|
+
stream_indicator = "❌"
|
|
398
|
+
elif block.type == MessageBlockType.REASONING:
|
|
399
|
+
stream_indicator = "🧠" if not block.streaming else "🤔"
|
|
400
|
+
elif block.streaming:
|
|
401
|
+
stream_indicator = "🌊"
|
|
402
|
+
else:
|
|
403
|
+
stream_indicator = "✅"
|
|
404
|
+
tool_info = f" | tool_id: {block.tool_call_id}" if block.tool_call_id else ""
|
|
405
|
+
|
|
406
|
+
# Get content preview for the block
|
|
407
|
+
content = self.get_block_content(block)
|
|
408
|
+
content_preview = content[:50] + "..." if len(content) > 50 else content
|
|
409
|
+
content_info = f" | {repr(content_preview)}" if content else " (no content)"
|
|
410
|
+
|
|
411
|
+
lines.append(f" [{j}] {stream_indicator} {block.type.value} | {block.id}{tool_info}{content_info}")
|
|
412
|
+
|
|
413
|
+
lines.append(f"\n🌊 STREAMING CONTENT ({len(self.streaming_content)}):")
|
|
414
|
+
for block_id, content in self.streaming_content.items():
|
|
415
|
+
preview = content[:50] + "..." if len(content) > 50 else content
|
|
416
|
+
lines.append(f" {block_id}: {repr(preview)}")
|
|
417
|
+
|
|
418
|
+
lines.append(f"\n🔧 TOOL CALL RESPONSES ({len(self.tool_call_responses)}):")
|
|
419
|
+
for tool_id, response in self.tool_call_responses.items():
|
|
420
|
+
error_indicator = "❌" if response.error else "✅"
|
|
421
|
+
preview = response.response[:50] + "..." if len(response.response) > 50 else response.response
|
|
422
|
+
lines.append(f" {error_indicator} {tool_id}: {repr(preview)}")
|
|
423
|
+
if response.error:
|
|
424
|
+
lines.append(f" Error: {response.error}")
|
|
425
|
+
|
|
426
|
+
# Add consumption summary
|
|
427
|
+
total_consumption = self.get_total_consumption()
|
|
428
|
+
consumption_messages = sum(1 for msg in self.messages if msg.consumption_metadata)
|
|
429
|
+
total_messages = len(self.messages)
|
|
430
|
+
lines.append(f"\n📊 CONSUMPTION SUMMARY ({consumption_messages}/{total_messages} messages):")
|
|
431
|
+
|
|
432
|
+
if consumption_messages > 0:
|
|
433
|
+
lines.append(f" Total tokens: {total_consumption.total_tokens:,}")
|
|
434
|
+
lines.append(f" Input: {total_consumption.input_tokens:,} | Output: {total_consumption.output_tokens:,}")
|
|
435
|
+
if total_consumption.reasoning_tokens > 0:
|
|
436
|
+
lines.append(f" Reasoning: {total_consumption.reasoning_tokens:,}")
|
|
437
|
+
if total_consumption.cache_creation_tokens > 0 or total_consumption.cache_read_tokens > 0:
|
|
438
|
+
lines.append(
|
|
439
|
+
f" Cache create: {total_consumption.cache_creation_tokens:,} | Cache read: {total_consumption.cache_read_tokens:,}"
|
|
440
|
+
)
|
|
441
|
+
else:
|
|
442
|
+
lines.append(" No consumption data available")
|
|
443
|
+
|
|
444
|
+
return "\n".join(lines)
|
|
445
|
+
|
|
446
|
+
def print_all(self) -> None:
|
|
447
|
+
"""Print everything to console for debugging"""
|
|
448
|
+
print(str(self))
|
|
449
|
+
|
|
450
|
+
def create_serializable_copy(self) -> "ThreadContainer":
|
|
451
|
+
"""Create a copy of this ThreadContainer that can be safely pickled"""
|
|
452
|
+
# Create new instance without calling __init__ to avoid subscriber initialization
|
|
453
|
+
copy = ThreadContainer.__new__(ThreadContainer)
|
|
454
|
+
|
|
455
|
+
# Copy all serializable attributes
|
|
456
|
+
copy.messages = self.messages.copy()
|
|
457
|
+
copy.streaming_content = self.streaming_content.copy()
|
|
458
|
+
copy.tool_call_responses = self.tool_call_responses.copy()
|
|
459
|
+
copy.system_prompt = self.system_prompt
|
|
460
|
+
copy._version = self._version
|
|
461
|
+
copy._last_activity_time = self._last_activity_time
|
|
462
|
+
copy.thread_id = self.thread_id
|
|
463
|
+
copy.job_id = self.job_id
|
|
464
|
+
|
|
465
|
+
# Initialize empty subscribers list (will be empty when loaded)
|
|
466
|
+
copy._subscribers = []
|
|
467
|
+
|
|
468
|
+
return copy
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from typing import Dict, List, Literal, TypedDict, Union
|
|
2
|
+
|
|
3
|
+
from langchain_core.tools import BaseTool
|
|
4
|
+
from langchain_mcp_adapters.client import MultiServerMCPClient
|
|
5
|
+
from langchain_mcp_adapters.tools import load_mcp_tools
|
|
6
|
+
from mcp import ClientSession, StdioServerParameters
|
|
7
|
+
from mcp.client.stdio import stdio_client
|
|
8
|
+
|
|
9
|
+
from spaik_sdk.tools.tool_provider import ToolProvider
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class StdioServerConfig(TypedDict, total=False):
|
|
13
|
+
transport: Literal["stdio"]
|
|
14
|
+
command: str
|
|
15
|
+
args: List[str]
|
|
16
|
+
env: Dict[str, str]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class HttpServerConfig(TypedDict, total=False):
|
|
20
|
+
transport: Literal["http"]
|
|
21
|
+
url: str
|
|
22
|
+
headers: Dict[str, str]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
McpServerConfig = Union[StdioServerConfig, HttpServerConfig]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class McpToolProvider(ToolProvider):
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
servers: Dict[str, McpServerConfig],
|
|
32
|
+
tools: List[BaseTool] | None = None,
|
|
33
|
+
):
|
|
34
|
+
self._servers = servers
|
|
35
|
+
self._tools = tools
|
|
36
|
+
|
|
37
|
+
async def load_tools(self) -> List[BaseTool]:
|
|
38
|
+
if self._tools is not None:
|
|
39
|
+
return self._tools
|
|
40
|
+
|
|
41
|
+
client = MultiServerMCPClient(self._servers) # type: ignore[arg-type]
|
|
42
|
+
self._tools = await client.get_tools()
|
|
43
|
+
return self._tools
|
|
44
|
+
|
|
45
|
+
def get_tools(self) -> List[BaseTool]:
|
|
46
|
+
if self._tools is None:
|
|
47
|
+
raise RuntimeError(
|
|
48
|
+
"MCP tools not loaded. Call `await provider.load_tools()` first, or use `McpToolProvider.create()` async factory method."
|
|
49
|
+
)
|
|
50
|
+
return self._tools
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
async def create(cls, servers: Dict[str, McpServerConfig]) -> "McpToolProvider":
|
|
54
|
+
provider = cls(servers)
|
|
55
|
+
await provider.load_tools()
|
|
56
|
+
return provider
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
async def load_from_stdio(
|
|
60
|
+
command: str,
|
|
61
|
+
args: List[str] | None = None,
|
|
62
|
+
env: Dict[str, str] | None = None,
|
|
63
|
+
) -> List[BaseTool]:
|
|
64
|
+
server_params = StdioServerParameters(
|
|
65
|
+
command=command,
|
|
66
|
+
args=args or [],
|
|
67
|
+
env=env,
|
|
68
|
+
)
|
|
69
|
+
async with stdio_client(server_params) as (read, write):
|
|
70
|
+
async with ClientSession(read, write) as session:
|
|
71
|
+
await session.initialize()
|
|
72
|
+
return await load_mcp_tools(session)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def mcp_server_stdio(
|
|
76
|
+
command: str,
|
|
77
|
+
args: List[str] | None = None,
|
|
78
|
+
env: Dict[str, str] | None = None,
|
|
79
|
+
) -> StdioServerConfig:
|
|
80
|
+
config: StdioServerConfig = {"transport": "stdio", "command": command, "args": args or []}
|
|
81
|
+
if env:
|
|
82
|
+
config["env"] = env
|
|
83
|
+
return config
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def mcp_server_http(
|
|
87
|
+
url: str,
|
|
88
|
+
headers: Dict[str, str] | None = None,
|
|
89
|
+
) -> HttpServerConfig:
|
|
90
|
+
config: HttpServerConfig = {"transport": "http", "url": url}
|
|
91
|
+
if headers:
|
|
92
|
+
config["headers"] = headers
|
|
93
|
+
return config
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from langchain_tavily import TavilySearch
|
|
4
|
+
|
|
5
|
+
from spaik_sdk.config.get_credentials_provider import credentials_provider
|
|
6
|
+
from spaik_sdk.tools.tool_provider import ToolProvider
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SearchToolProvider(ToolProvider):
|
|
10
|
+
def __init__(self, max_results=10):
|
|
11
|
+
self._init_env()
|
|
12
|
+
self.tool = TavilySearch(max_results=max_results)
|
|
13
|
+
|
|
14
|
+
def _init_env(self):
|
|
15
|
+
os.environ["TAVILY_API_KEY"] = credentials_provider.get_provider_key("tavily")
|
|
16
|
+
|
|
17
|
+
def get_tools(self):
|
|
18
|
+
return [self.tool]
|