spaik-sdk 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. spaik_sdk/__init__.py +21 -0
  2. spaik_sdk/agent/__init__.py +0 -0
  3. spaik_sdk/agent/base_agent.py +249 -0
  4. spaik_sdk/attachments/__init__.py +22 -0
  5. spaik_sdk/attachments/builder.py +61 -0
  6. spaik_sdk/attachments/file_storage_provider.py +27 -0
  7. spaik_sdk/attachments/mime_types.py +118 -0
  8. spaik_sdk/attachments/models.py +63 -0
  9. spaik_sdk/attachments/provider_support.py +53 -0
  10. spaik_sdk/attachments/storage/__init__.py +0 -0
  11. spaik_sdk/attachments/storage/base_file_storage.py +32 -0
  12. spaik_sdk/attachments/storage/impl/__init__.py +0 -0
  13. spaik_sdk/attachments/storage/impl/local_file_storage.py +101 -0
  14. spaik_sdk/audio/__init__.py +12 -0
  15. spaik_sdk/audio/options.py +53 -0
  16. spaik_sdk/audio/providers/__init__.py +1 -0
  17. spaik_sdk/audio/providers/google_tts.py +77 -0
  18. spaik_sdk/audio/providers/openai_stt.py +71 -0
  19. spaik_sdk/audio/providers/openai_tts.py +111 -0
  20. spaik_sdk/audio/stt.py +61 -0
  21. spaik_sdk/audio/tts.py +124 -0
  22. spaik_sdk/config/credentials_provider.py +10 -0
  23. spaik_sdk/config/env.py +59 -0
  24. spaik_sdk/config/env_credentials_provider.py +7 -0
  25. spaik_sdk/config/get_credentials_provider.py +14 -0
  26. spaik_sdk/image_gen/__init__.py +9 -0
  27. spaik_sdk/image_gen/image_generator.py +83 -0
  28. spaik_sdk/image_gen/options.py +24 -0
  29. spaik_sdk/image_gen/providers/__init__.py +0 -0
  30. spaik_sdk/image_gen/providers/google.py +75 -0
  31. spaik_sdk/image_gen/providers/openai.py +60 -0
  32. spaik_sdk/llm/__init__.py +0 -0
  33. spaik_sdk/llm/cancellation_handle.py +10 -0
  34. spaik_sdk/llm/consumption/__init__.py +0 -0
  35. spaik_sdk/llm/consumption/consumption_estimate.py +26 -0
  36. spaik_sdk/llm/consumption/consumption_estimate_builder.py +113 -0
  37. spaik_sdk/llm/consumption/consumption_extractor.py +59 -0
  38. spaik_sdk/llm/consumption/token_usage.py +31 -0
  39. spaik_sdk/llm/converters.py +146 -0
  40. spaik_sdk/llm/cost/__init__.py +1 -0
  41. spaik_sdk/llm/cost/builtin_cost_provider.py +83 -0
  42. spaik_sdk/llm/cost/cost_estimate.py +8 -0
  43. spaik_sdk/llm/cost/cost_provider.py +28 -0
  44. spaik_sdk/llm/extract_error_message.py +37 -0
  45. spaik_sdk/llm/langchain_loop_manager.py +270 -0
  46. spaik_sdk/llm/langchain_service.py +196 -0
  47. spaik_sdk/llm/message_handler.py +188 -0
  48. spaik_sdk/llm/streaming/__init__.py +1 -0
  49. spaik_sdk/llm/streaming/block_manager.py +152 -0
  50. spaik_sdk/llm/streaming/models.py +42 -0
  51. spaik_sdk/llm/streaming/streaming_content_handler.py +157 -0
  52. spaik_sdk/llm/streaming/streaming_event_handler.py +215 -0
  53. spaik_sdk/llm/streaming/streaming_state_manager.py +58 -0
  54. spaik_sdk/models/__init__.py +0 -0
  55. spaik_sdk/models/factories/__init__.py +0 -0
  56. spaik_sdk/models/factories/anthropic_factory.py +33 -0
  57. spaik_sdk/models/factories/base_model_factory.py +71 -0
  58. spaik_sdk/models/factories/google_factory.py +30 -0
  59. spaik_sdk/models/factories/ollama_factory.py +41 -0
  60. spaik_sdk/models/factories/openai_factory.py +50 -0
  61. spaik_sdk/models/llm_config.py +46 -0
  62. spaik_sdk/models/llm_families.py +7 -0
  63. spaik_sdk/models/llm_model.py +17 -0
  64. spaik_sdk/models/llm_wrapper.py +25 -0
  65. spaik_sdk/models/model_registry.py +156 -0
  66. spaik_sdk/models/providers/__init__.py +0 -0
  67. spaik_sdk/models/providers/anthropic_provider.py +29 -0
  68. spaik_sdk/models/providers/azure_provider.py +31 -0
  69. spaik_sdk/models/providers/base_provider.py +62 -0
  70. spaik_sdk/models/providers/google_provider.py +26 -0
  71. spaik_sdk/models/providers/ollama_provider.py +26 -0
  72. spaik_sdk/models/providers/openai_provider.py +26 -0
  73. spaik_sdk/models/providers/provider_type.py +90 -0
  74. spaik_sdk/orchestration/__init__.py +24 -0
  75. spaik_sdk/orchestration/base_orchestrator.py +238 -0
  76. spaik_sdk/orchestration/checkpoint.py +80 -0
  77. spaik_sdk/orchestration/models.py +103 -0
  78. spaik_sdk/prompt/__init__.py +0 -0
  79. spaik_sdk/prompt/get_prompt_loader.py +13 -0
  80. spaik_sdk/prompt/local_prompt_loader.py +21 -0
  81. spaik_sdk/prompt/prompt_loader.py +48 -0
  82. spaik_sdk/prompt/prompt_loader_mode.py +14 -0
  83. spaik_sdk/py.typed +1 -0
  84. spaik_sdk/recording/__init__.py +1 -0
  85. spaik_sdk/recording/base_playback.py +90 -0
  86. spaik_sdk/recording/base_recorder.py +50 -0
  87. spaik_sdk/recording/conditional_recorder.py +38 -0
  88. spaik_sdk/recording/impl/__init__.py +1 -0
  89. spaik_sdk/recording/impl/local_playback.py +76 -0
  90. spaik_sdk/recording/impl/local_recorder.py +85 -0
  91. spaik_sdk/recording/langchain_serializer.py +88 -0
  92. spaik_sdk/server/__init__.py +1 -0
  93. spaik_sdk/server/api/routers/__init__.py +0 -0
  94. spaik_sdk/server/api/routers/api_builder.py +149 -0
  95. spaik_sdk/server/api/routers/audio_router_factory.py +201 -0
  96. spaik_sdk/server/api/routers/file_router_factory.py +111 -0
  97. spaik_sdk/server/api/routers/thread_router_factory.py +284 -0
  98. spaik_sdk/server/api/streaming/__init__.py +0 -0
  99. spaik_sdk/server/api/streaming/format_sse_event.py +41 -0
  100. spaik_sdk/server/api/streaming/negotiate_streaming_response.py +8 -0
  101. spaik_sdk/server/api/streaming/streaming_negotiator.py +10 -0
  102. spaik_sdk/server/authorization/__init__.py +0 -0
  103. spaik_sdk/server/authorization/base_authorizer.py +64 -0
  104. spaik_sdk/server/authorization/base_user.py +13 -0
  105. spaik_sdk/server/authorization/dummy_authorizer.py +17 -0
  106. spaik_sdk/server/job_processor/__init__.py +0 -0
  107. spaik_sdk/server/job_processor/base_job_processor.py +8 -0
  108. spaik_sdk/server/job_processor/thread_job_processor.py +32 -0
  109. spaik_sdk/server/pubsub/__init__.py +1 -0
  110. spaik_sdk/server/pubsub/cancellation_publisher.py +7 -0
  111. spaik_sdk/server/pubsub/cancellation_subscriber.py +38 -0
  112. spaik_sdk/server/pubsub/event_publisher.py +13 -0
  113. spaik_sdk/server/pubsub/impl/__init__.py +1 -0
  114. spaik_sdk/server/pubsub/impl/local_cancellation_pubsub.py +48 -0
  115. spaik_sdk/server/pubsub/impl/signalr_publisher.py +36 -0
  116. spaik_sdk/server/queue/__init__.py +1 -0
  117. spaik_sdk/server/queue/agent_job_queue.py +27 -0
  118. spaik_sdk/server/queue/impl/__init__.py +1 -0
  119. spaik_sdk/server/queue/impl/azure_queue.py +24 -0
  120. spaik_sdk/server/response/__init__.py +0 -0
  121. spaik_sdk/server/response/agent_response_generator.py +39 -0
  122. spaik_sdk/server/response/response_generator.py +13 -0
  123. spaik_sdk/server/response/simple_agent_response_generator.py +14 -0
  124. spaik_sdk/server/services/__init__.py +0 -0
  125. spaik_sdk/server/services/thread_converters.py +113 -0
  126. spaik_sdk/server/services/thread_models.py +90 -0
  127. spaik_sdk/server/services/thread_service.py +91 -0
  128. spaik_sdk/server/storage/__init__.py +1 -0
  129. spaik_sdk/server/storage/base_thread_repository.py +51 -0
  130. spaik_sdk/server/storage/impl/__init__.py +0 -0
  131. spaik_sdk/server/storage/impl/in_memory_thread_repository.py +100 -0
  132. spaik_sdk/server/storage/impl/local_file_thread_repository.py +217 -0
  133. spaik_sdk/server/storage/thread_filter.py +166 -0
  134. spaik_sdk/server/storage/thread_metadata.py +53 -0
  135. spaik_sdk/thread/__init__.py +0 -0
  136. spaik_sdk/thread/adapters/__init__.py +0 -0
  137. spaik_sdk/thread/adapters/cli/__init__.py +0 -0
  138. spaik_sdk/thread/adapters/cli/block_display.py +92 -0
  139. spaik_sdk/thread/adapters/cli/display_manager.py +84 -0
  140. spaik_sdk/thread/adapters/cli/live_cli.py +235 -0
  141. spaik_sdk/thread/adapters/event_adapter.py +28 -0
  142. spaik_sdk/thread/adapters/streaming_block_adapter.py +57 -0
  143. spaik_sdk/thread/adapters/sync_adapter.py +76 -0
  144. spaik_sdk/thread/models.py +224 -0
  145. spaik_sdk/thread/thread_container.py +468 -0
  146. spaik_sdk/tools/__init__.py +0 -0
  147. spaik_sdk/tools/impl/__init__.py +0 -0
  148. spaik_sdk/tools/impl/mcp_tool_provider.py +93 -0
  149. spaik_sdk/tools/impl/search_tool_provider.py +18 -0
  150. spaik_sdk/tools/tool_provider.py +131 -0
  151. spaik_sdk/tracing/__init__.py +13 -0
  152. spaik_sdk/tracing/agent_trace.py +72 -0
  153. spaik_sdk/tracing/get_trace_sink.py +15 -0
  154. spaik_sdk/tracing/local_trace_sink.py +23 -0
  155. spaik_sdk/tracing/trace_sink.py +19 -0
  156. spaik_sdk/tracing/trace_sink_mode.py +14 -0
  157. spaik_sdk/utils/__init__.py +0 -0
  158. spaik_sdk/utils/init_logger.py +24 -0
  159. spaik_sdk-0.6.2.dist-info/METADATA +379 -0
  160. spaik_sdk-0.6.2.dist-info/RECORD +161 -0
  161. spaik_sdk-0.6.2.dist-info/WHEEL +4 -0
@@ -0,0 +1,468 @@
1
+ import copy
2
+ import time
3
+ import uuid
4
+ from typing import Callable, Dict, List, Optional
5
+
6
+ from langchain_core.messages import BaseMessage, SystemMessage
7
+
8
+ from spaik_sdk.attachments.storage.base_file_storage import BaseFileStorage
9
+ from spaik_sdk.llm.consumption.token_usage import TokenUsage
10
+ from spaik_sdk.llm.converters import convert_thread_message_to_langchain, convert_thread_message_to_langchain_multimodal
11
+ from spaik_sdk.thread.models import (
12
+ BlockAddedEvent,
13
+ BlockFullyAddedEvent,
14
+ MessageAddedEvent,
15
+ MessageBlock,
16
+ MessageBlockType,
17
+ MessageFullyAddedEvent,
18
+ StreamingEndedEvent,
19
+ StreamingUpdatedEvent,
20
+ ThreadEvent,
21
+ ThreadMessage,
22
+ ToolCallResponse,
23
+ ToolCallStartedEvent,
24
+ ToolResponseReceivedEvent,
25
+ )
26
+ from spaik_sdk.utils.init_logger import init_logger
27
+
28
+ logger = init_logger(__name__)
29
+
30
+
31
+ class ThreadContainer:
32
+ def __init__(self, system_prompt: Optional[str] = None):
33
+ self.messages: List[ThreadMessage] = []
34
+ self.streaming_content: Dict[str, str] = {}
35
+ self.tool_call_responses: Dict[str, ToolCallResponse] = {}
36
+ self.system_prompt = system_prompt
37
+ # Single event stream with multiple subscribers
38
+ self._subscribers: List[Callable[[ThreadEvent], None]] = []
39
+
40
+ # Version tracking
41
+ self._version = 0
42
+ self._last_activity_time = int(time.time() * 1000)
43
+ self.thread_id = str(uuid.uuid4())
44
+ self.job_id = "unknown"
45
+
46
+ def subscribe(self, callback: Callable[[ThreadEvent], None]) -> None:
47
+ """Subscribe to the event stream"""
48
+ if callback not in self._subscribers:
49
+ self._subscribers.append(callback)
50
+
51
+ def unsubscribe(self, callback: Callable[[ThreadEvent], None]) -> None:
52
+ """Unsubscribe from the event stream"""
53
+ if callback in self._subscribers:
54
+ self._subscribers.remove(callback)
55
+
56
+ def _emit_event(self, event: ThreadEvent) -> None:
57
+ """Emit a typed event to all subscribers"""
58
+ for callback in self._subscribers:
59
+ try:
60
+ callback(event)
61
+ except Exception as e:
62
+ logger.error(f"Event callback error: {e}")
63
+
64
+ def _increment_version(self) -> None:
65
+ """Increment version and update activity time"""
66
+ self._version += 1
67
+ self._last_activity_time = int(time.time() * 1000)
68
+
69
+ def get_version(self) -> int:
70
+ """Get current version for incremental updates"""
71
+ return self._version
72
+
73
+ def get_last_activity_time(self) -> int:
74
+ """Get timestamp of last activity"""
75
+ return self._last_activity_time
76
+
77
+ def add_streaming_message_chunk(self, block_id: str, content: str) -> None:
78
+ """Update streaming content by appending new content to existing content for the block_id"""
79
+ if block_id in self.streaming_content:
80
+ self.streaming_content[block_id] += content
81
+ else:
82
+ self.streaming_content[block_id] = content
83
+
84
+ # Emit streaming update event
85
+ self._emit_event(StreamingUpdatedEvent(block_id=block_id, content=content, total_content=self.streaming_content[block_id]))
86
+
87
+ self._increment_version()
88
+
89
+ def add_message(self, msg: ThreadMessage) -> None:
90
+ """Add a new message to the thread"""
91
+ self.messages.append(msg)
92
+ self._emit_event(MessageAddedEvent(message=copy.deepcopy(msg)))
93
+ self._increment_version()
94
+
95
+ def add_message_block(self, message_id: str, block: MessageBlock) -> None:
96
+ """Add a message block to an existing message by message_id"""
97
+ for message in self.messages:
98
+ if message.id == message_id:
99
+ message.blocks.append(block)
100
+
101
+ # Emit block added event
102
+ self._emit_event(BlockAddedEvent(message_id=message_id, block_id=block.id, block=copy.deepcopy(block)))
103
+
104
+ # If it's a tool block, emit tool call started
105
+ if block.type == MessageBlockType.TOOL_USE and block.tool_call_id:
106
+ tool_name = block.tool_name or "unknown"
107
+
108
+ self._emit_event(
109
+ ToolCallStartedEvent(tool_call_id=block.tool_call_id, tool_name=tool_name, message_id=message_id, block_id=block.id)
110
+ )
111
+
112
+ self._increment_version()
113
+ break
114
+
115
+ def add_tool_call_response(self, response: ToolCallResponse) -> None:
116
+ """Add a tool call response by its ID"""
117
+ self.tool_call_responses[response.id] = response
118
+
119
+ # Find the corresponding block ID
120
+ block_id = None
121
+ for message in self.messages:
122
+ for block in message.blocks:
123
+ if block.tool_call_id == response.id:
124
+ block.streaming = False
125
+ block.tool_call_response = response.response
126
+ block.tool_call_error = response.error
127
+ block_id = block.id
128
+ break
129
+ if block_id:
130
+ break
131
+
132
+ # Emit tool response event
133
+ self._emit_event(
134
+ ToolResponseReceivedEvent(tool_call_id=response.id, response=response.response, error=response.error, block_id=block_id)
135
+ )
136
+
137
+ self._increment_version()
138
+
139
+ def update_tool_use_block_with_response(self, tool_call_id: str, response: str, error: Optional[str] = None) -> None:
140
+ """Update a tool use block with the tool response and mark it as completed."""
141
+ logger.debug(f"🔧 DEBUG: Updating tool response for {tool_call_id}")
142
+ logger.debug(f"🔧 DEBUG: Response: {response[:100]}...")
143
+ logger.debug(f"🔧 DEBUG: Error: {error}")
144
+
145
+ # Add the tool response to our responses dict
146
+ tool_response = ToolCallResponse(id=tool_call_id, response=response, error=error)
147
+ self.add_tool_call_response(tool_response)
148
+
149
+ # Find the message and block with this tool_call_id and mark it as non-streaming
150
+ for message in self.messages:
151
+ for block in message.blocks:
152
+ if block.tool_call_id == tool_call_id:
153
+ block.streaming = False
154
+ logger.debug(f"🔧 DEBUG: Found and updated block {block.id} for tool {tool_call_id}")
155
+ self._increment_version()
156
+ return
157
+
158
+ logger.debug(f"🔧 DEBUG: Could not find block for tool_call_id: {tool_call_id}")
159
+
160
+ def add_error_message(self, error_text: str, author_id: str = "system", author_name: str = "system") -> str:
161
+ """Add an error message and return the message ID"""
162
+ message_id = str(uuid.uuid4())
163
+ error_message = ThreadMessage(
164
+ id=message_id,
165
+ ai=True,
166
+ author_id=author_id,
167
+ author_name=author_name,
168
+ timestamp=int(time.time() * 1000),
169
+ blocks=[
170
+ MessageBlock(
171
+ id=str(uuid.uuid4()),
172
+ streaming=False,
173
+ type=MessageBlockType.ERROR,
174
+ content=error_text,
175
+ )
176
+ ],
177
+ )
178
+ self.add_message(error_message)
179
+ return message_id
180
+
181
+ def finalize_streaming_blocks(self, message_id: str, block_ids: List[str]) -> None:
182
+ """Mark specified blocks as non-streaming (completed)."""
183
+ completed_blocks = []
184
+
185
+ for message in self.messages:
186
+ if message.id == message_id:
187
+ for block in message.blocks:
188
+ if block.id in block_ids and block.streaming:
189
+ block.streaming = False
190
+ # Move content from streaming_content to block.content when streaming finishes
191
+ if block.id in self.streaming_content:
192
+ block.content = self.streaming_content[block.id]
193
+ # lets not do this, access is needed at least for now
194
+ # # Optionally remove from streaming_content to save memory
195
+ # del self.streaming_content[block.id]
196
+ completed_blocks.append(block.id)
197
+
198
+ # Emit block fully added event for each completed block
199
+ self._emit_event(BlockFullyAddedEvent(block_id=block.id, message_id=message_id, block=copy.deepcopy(block)))
200
+ break
201
+
202
+ if completed_blocks:
203
+ self._increment_version()
204
+
205
+ # Check if streaming has ended
206
+ if not self.is_streaming_active():
207
+ self._emit_event(StreamingEndedEvent(message_id=message_id, completed_blocks=completed_blocks))
208
+
209
+ def cancel_generation(self) -> None:
210
+ """Cancel the generation"""
211
+ logger.info(f"Cancelling generation. Current streaming content: {self.streaming_content}")
212
+ logger.info(f"Messages: {self.messages}")
213
+ for message in self.messages:
214
+ completed_blocks = []
215
+ for block in message.blocks:
216
+ block.streaming = False
217
+ # Move content from streaming_content to block.content when streaming finishes
218
+ if block.id in self.streaming_content:
219
+ block.content = self.streaming_content[block.id]
220
+ logger.info(f"🔧 Block {block.id} content: {block.content}")
221
+ completed_blocks.append(block.id)
222
+ self._emit_event(BlockFullyAddedEvent(block_id=block.id, message_id=message.id, block=copy.deepcopy(block)))
223
+ self.finalize_streaming_blocks(message.id, completed_blocks)
224
+ if completed_blocks:
225
+ self._increment_version()
226
+
227
+ # Check if streaming has ended
228
+ if not self.is_streaming_active():
229
+ self._emit_event(StreamingEndedEvent(message_id=message.id, completed_blocks=completed_blocks))
230
+
231
+ def complete_generation(self) -> None:
232
+ """Mark the message as fully added and emit the event"""
233
+ latest_message = self.get_latest_ai_message()
234
+ if latest_message:
235
+ self._emit_event(MessageFullyAddedEvent(message=copy.deepcopy(latest_message)))
236
+
237
+ def is_streaming_active(self) -> bool:
238
+ """Check if any blocks are currently streaming"""
239
+ for message in self.messages:
240
+ for block in message.blocks:
241
+ if block.streaming:
242
+ return True
243
+ return False
244
+
245
+ def get_latest_ai_message(self) -> Optional[ThreadMessage]:
246
+ """Get the most recent AI message"""
247
+ for message in reversed(self.messages):
248
+ if message.ai:
249
+ return message
250
+ return None
251
+
252
+ def get_latest_message(self) -> ThreadMessage:
253
+ """Get the most recent message"""
254
+ return self.messages[-1]
255
+
256
+ def get_message_by_id(self, message_id: str) -> Optional[ThreadMessage]:
257
+ """Get message by ID"""
258
+ for message in self.messages:
259
+ if message.id == message_id:
260
+ return message
261
+ return None
262
+
263
+ def get_block_content(self, block: MessageBlock) -> str:
264
+ """Get content for a specific block"""
265
+ # First check if block has content directly
266
+ if block.content is not None:
267
+ return block.content
268
+
269
+ # For streaming blocks, check streaming_content
270
+ if block.id in self.streaming_content:
271
+ return self.streaming_content[block.id]
272
+
273
+ if block.type == MessageBlockType.TOOL_USE:
274
+ if block.tool_call_id and block.tool_call_id in self.tool_call_responses:
275
+ response = self.tool_call_responses[block.tool_call_id]
276
+ return response.response
277
+
278
+ return ""
279
+
280
+ def get_system_prompt(self) -> str:
281
+ if self.system_prompt is None:
282
+ raise ValueError("System prompt is not set")
283
+ return self.system_prompt
284
+
285
+ def get_streaming_blocks(self) -> List[str]:
286
+ """Get list of currently streaming block IDs"""
287
+ streaming_blocks = []
288
+ for message in self.messages:
289
+ for block in message.blocks:
290
+ if block.streaming:
291
+ streaming_blocks.append(block.id)
292
+ return streaming_blocks
293
+
294
+ def has_errors(self) -> bool:
295
+ """Check if there are any error blocks or tool errors"""
296
+ for message in self.messages:
297
+ for block in message.blocks:
298
+ if block.type == MessageBlockType.ERROR:
299
+ return True
300
+
301
+ for response in self.tool_call_responses.values():
302
+ if response.error:
303
+ return True
304
+
305
+ return False
306
+
307
+ def get_final_text_content(self) -> str:
308
+ """Get clean final text from the latest AI message"""
309
+ latest_message = self.get_latest_ai_message()
310
+ if not latest_message:
311
+ return ""
312
+
313
+ text_parts = []
314
+ for block in latest_message.blocks:
315
+ if block.type == MessageBlockType.PLAIN and not block.streaming:
316
+ content = self.get_block_content(block)
317
+ if content:
318
+ text_parts.append(content)
319
+
320
+ return " ".join(text_parts).strip()
321
+
322
+ def _find_message_id_by_block(self, block_id: str) -> Optional[str]:
323
+ """Find message ID that contains the given block ID"""
324
+ for message in self.messages:
325
+ for block in message.blocks:
326
+ if block.id == block_id:
327
+ return message.id
328
+ return None
329
+
330
+ def get_langchain_messages(self) -> List[BaseMessage]:
331
+ """Get all messages as LangChain BaseMessages"""
332
+ messages: List[BaseMessage] = [SystemMessage(content=self.get_system_prompt())]
333
+ messages.extend([convert_thread_message_to_langchain(msg) for msg in self.messages])
334
+ return messages
335
+
336
+ async def get_langchain_messages_multimodal(self, file_storage: BaseFileStorage, provider_family: str = "openai") -> List[BaseMessage]:
337
+ """Get all messages as LangChain BaseMessages with multimodal content support"""
338
+ messages: List[BaseMessage] = [SystemMessage(content=self.get_system_prompt())]
339
+ for msg in self.messages:
340
+ converted = await convert_thread_message_to_langchain_multimodal(msg, file_storage, provider_family)
341
+ messages.append(converted)
342
+ return messages
343
+
344
+ def get_nof_messages_including_system(self) -> int:
345
+ """Get number of messages including system message"""
346
+ return len(self.messages) + 1
347
+
348
+ def add_consumption_metadata(self, message_id: str, consumption_metadata: TokenUsage) -> None:
349
+ """Add consumption metadata to a specific message"""
350
+ for message in self.messages:
351
+ if message.id == message_id:
352
+ message.consumption_metadata = consumption_metadata
353
+ self._increment_version()
354
+ break
355
+
356
+ def get_total_consumption(self) -> TokenUsage:
357
+ """Calculate total consumption across all messages with consumption metadata"""
358
+ total_tokens = TokenUsage()
359
+
360
+ for message in self.messages:
361
+ if message.consumption_metadata and isinstance(message.consumption_metadata, TokenUsage):
362
+ token_usage = message.consumption_metadata
363
+ total_tokens.input_tokens += token_usage.input_tokens
364
+ total_tokens.output_tokens += token_usage.output_tokens
365
+ total_tokens.total_tokens += token_usage.total_tokens
366
+ total_tokens.reasoning_tokens += token_usage.reasoning_tokens
367
+ total_tokens.cache_creation_tokens += token_usage.cache_creation_tokens
368
+ total_tokens.cache_read_tokens += token_usage.cache_read_tokens
369
+
370
+ return total_tokens
371
+
372
+ def get_consumption_by_message(self, message_id: str) -> Optional[TokenUsage]:
373
+ """Get consumption metadata for a specific message"""
374
+ for message in self.messages:
375
+ if message.id == message_id and message.consumption_metadata:
376
+ return message.consumption_metadata
377
+ return None
378
+
379
+ def get_latest_token_usage(self) -> Optional[TokenUsage]:
380
+ """Get consumption metadata for the latest message"""
381
+ latest_message = self.get_latest_ai_message()
382
+ if latest_message and latest_message.consumption_metadata:
383
+ return latest_message.consumption_metadata
384
+ return None
385
+
386
+ def __str__(self) -> str:
387
+ """String representation of the entire thread container"""
388
+ lines = ["=== THREAD CONTAINER ==="]
389
+ lines.append(f"Version: {self._version} | Active streaming: {self.is_streaming_active()}")
390
+
391
+ lines.append(f"\n📨 MESSAGES ({len(self.messages)}):")
392
+ for i, msg in enumerate(self.messages):
393
+ author = "🤖 AI" if msg.ai else f"👤 {msg.author_id}"
394
+ lines.append(f" [{i}] {author} | {msg.id} | {msg.timestamp}")
395
+ for j, block in enumerate(msg.blocks):
396
+ if block.type == MessageBlockType.ERROR:
397
+ stream_indicator = "❌"
398
+ elif block.type == MessageBlockType.REASONING:
399
+ stream_indicator = "🧠" if not block.streaming else "🤔"
400
+ elif block.streaming:
401
+ stream_indicator = "🌊"
402
+ else:
403
+ stream_indicator = "✅"
404
+ tool_info = f" | tool_id: {block.tool_call_id}" if block.tool_call_id else ""
405
+
406
+ # Get content preview for the block
407
+ content = self.get_block_content(block)
408
+ content_preview = content[:50] + "..." if len(content) > 50 else content
409
+ content_info = f" | {repr(content_preview)}" if content else " (no content)"
410
+
411
+ lines.append(f" [{j}] {stream_indicator} {block.type.value} | {block.id}{tool_info}{content_info}")
412
+
413
+ lines.append(f"\n🌊 STREAMING CONTENT ({len(self.streaming_content)}):")
414
+ for block_id, content in self.streaming_content.items():
415
+ preview = content[:50] + "..." if len(content) > 50 else content
416
+ lines.append(f" {block_id}: {repr(preview)}")
417
+
418
+ lines.append(f"\n🔧 TOOL CALL RESPONSES ({len(self.tool_call_responses)}):")
419
+ for tool_id, response in self.tool_call_responses.items():
420
+ error_indicator = "❌" if response.error else "✅"
421
+ preview = response.response[:50] + "..." if len(response.response) > 50 else response.response
422
+ lines.append(f" {error_indicator} {tool_id}: {repr(preview)}")
423
+ if response.error:
424
+ lines.append(f" Error: {response.error}")
425
+
426
+ # Add consumption summary
427
+ total_consumption = self.get_total_consumption()
428
+ consumption_messages = sum(1 for msg in self.messages if msg.consumption_metadata)
429
+ total_messages = len(self.messages)
430
+ lines.append(f"\n📊 CONSUMPTION SUMMARY ({consumption_messages}/{total_messages} messages):")
431
+
432
+ if consumption_messages > 0:
433
+ lines.append(f" Total tokens: {total_consumption.total_tokens:,}")
434
+ lines.append(f" Input: {total_consumption.input_tokens:,} | Output: {total_consumption.output_tokens:,}")
435
+ if total_consumption.reasoning_tokens > 0:
436
+ lines.append(f" Reasoning: {total_consumption.reasoning_tokens:,}")
437
+ if total_consumption.cache_creation_tokens > 0 or total_consumption.cache_read_tokens > 0:
438
+ lines.append(
439
+ f" Cache create: {total_consumption.cache_creation_tokens:,} | Cache read: {total_consumption.cache_read_tokens:,}"
440
+ )
441
+ else:
442
+ lines.append(" No consumption data available")
443
+
444
+ return "\n".join(lines)
445
+
446
+ def print_all(self) -> None:
447
+ """Print everything to console for debugging"""
448
+ print(str(self))
449
+
450
+ def create_serializable_copy(self) -> "ThreadContainer":
451
+ """Create a copy of this ThreadContainer that can be safely pickled"""
452
+ # Create new instance without calling __init__ to avoid subscriber initialization
453
+ copy = ThreadContainer.__new__(ThreadContainer)
454
+
455
+ # Copy all serializable attributes
456
+ copy.messages = self.messages.copy()
457
+ copy.streaming_content = self.streaming_content.copy()
458
+ copy.tool_call_responses = self.tool_call_responses.copy()
459
+ copy.system_prompt = self.system_prompt
460
+ copy._version = self._version
461
+ copy._last_activity_time = self._last_activity_time
462
+ copy.thread_id = self.thread_id
463
+ copy.job_id = self.job_id
464
+
465
+ # Initialize empty subscribers list (will be empty when loaded)
466
+ copy._subscribers = []
467
+
468
+ return copy
File without changes
File without changes
@@ -0,0 +1,93 @@
1
+ from typing import Dict, List, Literal, TypedDict, Union
2
+
3
+ from langchain_core.tools import BaseTool
4
+ from langchain_mcp_adapters.client import MultiServerMCPClient
5
+ from langchain_mcp_adapters.tools import load_mcp_tools
6
+ from mcp import ClientSession, StdioServerParameters
7
+ from mcp.client.stdio import stdio_client
8
+
9
+ from spaik_sdk.tools.tool_provider import ToolProvider
10
+
11
+
12
+ class StdioServerConfig(TypedDict, total=False):
13
+ transport: Literal["stdio"]
14
+ command: str
15
+ args: List[str]
16
+ env: Dict[str, str]
17
+
18
+
19
+ class HttpServerConfig(TypedDict, total=False):
20
+ transport: Literal["http"]
21
+ url: str
22
+ headers: Dict[str, str]
23
+
24
+
25
+ McpServerConfig = Union[StdioServerConfig, HttpServerConfig]
26
+
27
+
28
+ class McpToolProvider(ToolProvider):
29
+ def __init__(
30
+ self,
31
+ servers: Dict[str, McpServerConfig],
32
+ tools: List[BaseTool] | None = None,
33
+ ):
34
+ self._servers = servers
35
+ self._tools = tools
36
+
37
+ async def load_tools(self) -> List[BaseTool]:
38
+ if self._tools is not None:
39
+ return self._tools
40
+
41
+ client = MultiServerMCPClient(self._servers) # type: ignore[arg-type]
42
+ self._tools = await client.get_tools()
43
+ return self._tools
44
+
45
+ def get_tools(self) -> List[BaseTool]:
46
+ if self._tools is None:
47
+ raise RuntimeError(
48
+ "MCP tools not loaded. Call `await provider.load_tools()` first, or use `McpToolProvider.create()` async factory method."
49
+ )
50
+ return self._tools
51
+
52
+ @classmethod
53
+ async def create(cls, servers: Dict[str, McpServerConfig]) -> "McpToolProvider":
54
+ provider = cls(servers)
55
+ await provider.load_tools()
56
+ return provider
57
+
58
+ @staticmethod
59
+ async def load_from_stdio(
60
+ command: str,
61
+ args: List[str] | None = None,
62
+ env: Dict[str, str] | None = None,
63
+ ) -> List[BaseTool]:
64
+ server_params = StdioServerParameters(
65
+ command=command,
66
+ args=args or [],
67
+ env=env,
68
+ )
69
+ async with stdio_client(server_params) as (read, write):
70
+ async with ClientSession(read, write) as session:
71
+ await session.initialize()
72
+ return await load_mcp_tools(session)
73
+
74
+
75
+ def mcp_server_stdio(
76
+ command: str,
77
+ args: List[str] | None = None,
78
+ env: Dict[str, str] | None = None,
79
+ ) -> StdioServerConfig:
80
+ config: StdioServerConfig = {"transport": "stdio", "command": command, "args": args or []}
81
+ if env:
82
+ config["env"] = env
83
+ return config
84
+
85
+
86
+ def mcp_server_http(
87
+ url: str,
88
+ headers: Dict[str, str] | None = None,
89
+ ) -> HttpServerConfig:
90
+ config: HttpServerConfig = {"transport": "http", "url": url}
91
+ if headers:
92
+ config["headers"] = headers
93
+ return config
@@ -0,0 +1,18 @@
1
+ import os
2
+
3
+ from langchain_tavily import TavilySearch
4
+
5
+ from spaik_sdk.config.get_credentials_provider import credentials_provider
6
+ from spaik_sdk.tools.tool_provider import ToolProvider
7
+
8
+
9
+ class SearchToolProvider(ToolProvider):
10
+ def __init__(self, max_results=10):
11
+ self._init_env()
12
+ self.tool = TavilySearch(max_results=max_results)
13
+
14
+ def _init_env(self):
15
+ os.environ["TAVILY_API_KEY"] = credentials_provider.get_provider_key("tavily")
16
+
17
+ def get_tools(self):
18
+ return [self.tool]