amd-gaia 0.15.0__py3-none-any.whl → 0.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.1.dist-info}/METADATA +223 -223
  2. amd_gaia-0.15.1.dist-info/RECORD +178 -0
  3. {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.1.dist-info}/entry_points.txt +1 -0
  4. {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.1.dist-info}/licenses/LICENSE.md +20 -20
  5. gaia/__init__.py +29 -29
  6. gaia/agents/__init__.py +19 -19
  7. gaia/agents/base/__init__.py +9 -9
  8. gaia/agents/base/agent.py +2177 -2177
  9. gaia/agents/base/api_agent.py +120 -120
  10. gaia/agents/base/console.py +1841 -1841
  11. gaia/agents/base/errors.py +237 -237
  12. gaia/agents/base/mcp_agent.py +86 -86
  13. gaia/agents/base/tools.py +83 -83
  14. gaia/agents/blender/agent.py +556 -556
  15. gaia/agents/blender/agent_simple.py +133 -135
  16. gaia/agents/blender/app.py +211 -211
  17. gaia/agents/blender/app_simple.py +41 -41
  18. gaia/agents/blender/core/__init__.py +16 -16
  19. gaia/agents/blender/core/materials.py +506 -506
  20. gaia/agents/blender/core/objects.py +316 -316
  21. gaia/agents/blender/core/rendering.py +225 -225
  22. gaia/agents/blender/core/scene.py +220 -220
  23. gaia/agents/blender/core/view.py +146 -146
  24. gaia/agents/chat/__init__.py +9 -9
  25. gaia/agents/chat/agent.py +835 -835
  26. gaia/agents/chat/app.py +1058 -1058
  27. gaia/agents/chat/session.py +508 -508
  28. gaia/agents/chat/tools/__init__.py +15 -15
  29. gaia/agents/chat/tools/file_tools.py +96 -96
  30. gaia/agents/chat/tools/rag_tools.py +1729 -1729
  31. gaia/agents/chat/tools/shell_tools.py +436 -436
  32. gaia/agents/code/__init__.py +7 -7
  33. gaia/agents/code/agent.py +549 -549
  34. gaia/agents/code/cli.py +377 -0
  35. gaia/agents/code/models.py +135 -135
  36. gaia/agents/code/orchestration/__init__.py +24 -24
  37. gaia/agents/code/orchestration/checklist_executor.py +1763 -1763
  38. gaia/agents/code/orchestration/checklist_generator.py +713 -713
  39. gaia/agents/code/orchestration/factories/__init__.py +9 -9
  40. gaia/agents/code/orchestration/factories/base.py +63 -63
  41. gaia/agents/code/orchestration/factories/nextjs_factory.py +118 -118
  42. gaia/agents/code/orchestration/factories/python_factory.py +106 -106
  43. gaia/agents/code/orchestration/orchestrator.py +841 -841
  44. gaia/agents/code/orchestration/project_analyzer.py +391 -391
  45. gaia/agents/code/orchestration/steps/__init__.py +67 -67
  46. gaia/agents/code/orchestration/steps/base.py +188 -188
  47. gaia/agents/code/orchestration/steps/error_handler.py +314 -314
  48. gaia/agents/code/orchestration/steps/nextjs.py +828 -828
  49. gaia/agents/code/orchestration/steps/python.py +307 -307
  50. gaia/agents/code/orchestration/template_catalog.py +469 -469
  51. gaia/agents/code/orchestration/workflows/__init__.py +14 -14
  52. gaia/agents/code/orchestration/workflows/base.py +80 -80
  53. gaia/agents/code/orchestration/workflows/nextjs.py +186 -186
  54. gaia/agents/code/orchestration/workflows/python.py +94 -94
  55. gaia/agents/code/prompts/__init__.py +11 -11
  56. gaia/agents/code/prompts/base_prompt.py +77 -77
  57. gaia/agents/code/prompts/code_patterns.py +2036 -2036
  58. gaia/agents/code/prompts/nextjs_prompt.py +40 -40
  59. gaia/agents/code/prompts/python_prompt.py +109 -109
  60. gaia/agents/code/schema_inference.py +365 -365
  61. gaia/agents/code/system_prompt.py +41 -41
  62. gaia/agents/code/tools/__init__.py +42 -42
  63. gaia/agents/code/tools/cli_tools.py +1138 -1138
  64. gaia/agents/code/tools/code_formatting.py +319 -319
  65. gaia/agents/code/tools/code_tools.py +769 -769
  66. gaia/agents/code/tools/error_fixing.py +1347 -1347
  67. gaia/agents/code/tools/external_tools.py +180 -180
  68. gaia/agents/code/tools/file_io.py +845 -845
  69. gaia/agents/code/tools/prisma_tools.py +190 -190
  70. gaia/agents/code/tools/project_management.py +1016 -1016
  71. gaia/agents/code/tools/testing.py +321 -321
  72. gaia/agents/code/tools/typescript_tools.py +122 -122
  73. gaia/agents/code/tools/validation_parsing.py +461 -461
  74. gaia/agents/code/tools/validation_tools.py +806 -806
  75. gaia/agents/code/tools/web_dev_tools.py +1758 -1758
  76. gaia/agents/code/validators/__init__.py +16 -16
  77. gaia/agents/code/validators/antipattern_checker.py +241 -241
  78. gaia/agents/code/validators/ast_analyzer.py +197 -197
  79. gaia/agents/code/validators/requirements_validator.py +145 -145
  80. gaia/agents/code/validators/syntax_validator.py +171 -171
  81. gaia/agents/docker/__init__.py +7 -7
  82. gaia/agents/docker/agent.py +642 -642
  83. gaia/agents/emr/__init__.py +8 -8
  84. gaia/agents/emr/agent.py +1506 -1506
  85. gaia/agents/emr/cli.py +1322 -1322
  86. gaia/agents/emr/constants.py +475 -475
  87. gaia/agents/emr/dashboard/__init__.py +4 -4
  88. gaia/agents/emr/dashboard/server.py +1974 -1974
  89. gaia/agents/jira/__init__.py +11 -11
  90. gaia/agents/jira/agent.py +894 -894
  91. gaia/agents/jira/jql_templates.py +299 -299
  92. gaia/agents/routing/__init__.py +7 -7
  93. gaia/agents/routing/agent.py +567 -570
  94. gaia/agents/routing/system_prompt.py +75 -75
  95. gaia/agents/summarize/__init__.py +11 -0
  96. gaia/agents/summarize/agent.py +885 -0
  97. gaia/agents/summarize/prompts.py +129 -0
  98. gaia/api/__init__.py +23 -23
  99. gaia/api/agent_registry.py +238 -238
  100. gaia/api/app.py +305 -305
  101. gaia/api/openai_server.py +575 -575
  102. gaia/api/schemas.py +186 -186
  103. gaia/api/sse_handler.py +373 -373
  104. gaia/apps/__init__.py +4 -4
  105. gaia/apps/llm/__init__.py +6 -6
  106. gaia/apps/llm/app.py +173 -169
  107. gaia/apps/summarize/app.py +116 -633
  108. gaia/apps/summarize/html_viewer.py +133 -133
  109. gaia/apps/summarize/pdf_formatter.py +284 -284
  110. gaia/audio/__init__.py +2 -2
  111. gaia/audio/audio_client.py +439 -439
  112. gaia/audio/audio_recorder.py +269 -269
  113. gaia/audio/kokoro_tts.py +599 -599
  114. gaia/audio/whisper_asr.py +432 -432
  115. gaia/chat/__init__.py +16 -16
  116. gaia/chat/app.py +430 -430
  117. gaia/chat/prompts.py +522 -522
  118. gaia/chat/sdk.py +1228 -1225
  119. gaia/cli.py +5481 -5632
  120. gaia/database/__init__.py +10 -10
  121. gaia/database/agent.py +176 -176
  122. gaia/database/mixin.py +290 -290
  123. gaia/database/testing.py +64 -64
  124. gaia/eval/batch_experiment.py +2332 -2332
  125. gaia/eval/claude.py +542 -542
  126. gaia/eval/config.py +37 -37
  127. gaia/eval/email_generator.py +512 -512
  128. gaia/eval/eval.py +3179 -3179
  129. gaia/eval/groundtruth.py +1130 -1130
  130. gaia/eval/transcript_generator.py +582 -582
  131. gaia/eval/webapp/README.md +167 -167
  132. gaia/eval/webapp/package-lock.json +875 -875
  133. gaia/eval/webapp/package.json +20 -20
  134. gaia/eval/webapp/public/app.js +3402 -3402
  135. gaia/eval/webapp/public/index.html +87 -87
  136. gaia/eval/webapp/public/styles.css +3661 -3661
  137. gaia/eval/webapp/server.js +415 -415
  138. gaia/eval/webapp/test-setup.js +72 -72
  139. gaia/llm/__init__.py +9 -2
  140. gaia/llm/base_client.py +60 -0
  141. gaia/llm/exceptions.py +12 -0
  142. gaia/llm/factory.py +70 -0
  143. gaia/llm/lemonade_client.py +3236 -3221
  144. gaia/llm/lemonade_manager.py +294 -294
  145. gaia/llm/providers/__init__.py +9 -0
  146. gaia/llm/providers/claude.py +108 -0
  147. gaia/llm/providers/lemonade.py +120 -0
  148. gaia/llm/providers/openai_provider.py +79 -0
  149. gaia/llm/vlm_client.py +382 -382
  150. gaia/logger.py +189 -189
  151. gaia/mcp/agent_mcp_server.py +245 -245
  152. gaia/mcp/blender_mcp_client.py +138 -138
  153. gaia/mcp/blender_mcp_server.py +648 -648
  154. gaia/mcp/context7_cache.py +332 -332
  155. gaia/mcp/external_services.py +518 -518
  156. gaia/mcp/mcp_bridge.py +811 -550
  157. gaia/mcp/servers/__init__.py +6 -6
  158. gaia/mcp/servers/docker_mcp.py +83 -83
  159. gaia/perf_analysis.py +361 -0
  160. gaia/rag/__init__.py +10 -10
  161. gaia/rag/app.py +293 -293
  162. gaia/rag/demo.py +304 -304
  163. gaia/rag/pdf_utils.py +235 -235
  164. gaia/rag/sdk.py +2194 -2194
  165. gaia/security.py +163 -163
  166. gaia/talk/app.py +289 -289
  167. gaia/talk/sdk.py +538 -538
  168. gaia/testing/__init__.py +87 -87
  169. gaia/testing/assertions.py +330 -330
  170. gaia/testing/fixtures.py +333 -333
  171. gaia/testing/mocks.py +493 -493
  172. gaia/util.py +46 -46
  173. gaia/utils/__init__.py +33 -33
  174. gaia/utils/file_watcher.py +675 -675
  175. gaia/utils/parsing.py +223 -223
  176. gaia/version.py +100 -100
  177. amd_gaia-0.15.0.dist-info/RECORD +0 -168
  178. gaia/agents/code/app.py +0 -266
  179. gaia/llm/llm_client.py +0 -723
  180. {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.1.dist-info}/WHEEL +0 -0
  181. {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.1.dist-info}/top_level.txt +0 -0
gaia/chat/sdk.py CHANGED
@@ -1,1225 +1,1228 @@
1
- #!/usr/bin/env python3
2
- # Copyright(C) 2024-2025 Advanced Micro Devices, Inc. All rights reserved.
3
- # SPDX-License-Identifier: MIT
4
-
5
- """
6
- Gaia Chat SDK - Unified text chat integration with conversation history
7
- """
8
-
9
- import json
10
- import logging
11
- from collections import deque
12
- from dataclasses import dataclass
13
- from typing import Any, Dict, List, Optional
14
-
15
- from gaia.chat.prompts import Prompts
16
- from gaia.llm.lemonade_client import DEFAULT_MODEL_NAME
17
- from gaia.llm.llm_client import LLMClient
18
- from gaia.logger import get_logger
19
-
20
-
21
- @dataclass
22
- class ChatConfig:
23
- """Configuration for ChatSDK."""
24
-
25
- model: str = DEFAULT_MODEL_NAME
26
- max_tokens: int = 512
27
- system_prompt: Optional[str] = None
28
- max_history_length: int = 4 # Number of conversation pairs to keep
29
- show_stats: bool = False
30
- logging_level: str = "INFO"
31
- use_claude: bool = False # Use Claude API
32
- use_chatgpt: bool = False # Use ChatGPT/OpenAI API
33
- use_local_llm: bool = (
34
- True # Use local LLM (computed as not use_claude and not use_chatgpt if not explicitly set)
35
- )
36
- claude_model: str = "claude-sonnet-4-20250514" # Claude model when use_claude=True
37
- base_url: str = "http://localhost:8000/api/v1" # Lemonade server base URL
38
- assistant_name: str = "gaia" # Name to use for the assistant in conversations
39
-
40
-
41
- @dataclass
42
- class ChatResponse:
43
- """Response from chat operations."""
44
-
45
- text: str
46
- history: Optional[List[str]] = None
47
- stats: Optional[Dict[str, Any]] = None
48
- is_complete: bool = True
49
-
50
-
51
- class ChatSDK:
52
- """
53
- Gaia Chat SDK - Unified text chat integration with conversation history.
54
-
55
- This SDK provides a simple interface for integrating Gaia's text chat
56
- capabilities with conversation memory into applications.
57
-
58
- Example usage:
59
- ```python
60
- from gaia.chat.sdk import ChatSDK, ChatConfig
61
-
62
- # Create SDK instance
63
- config = ChatConfig(model=DEFAULT_MODEL_NAME, show_stats=True)
64
- chat = ChatSDK(config)
65
-
66
- # Single message
67
- response = await chat.send("Hello, how are you?")
68
- print(response.text)
69
-
70
- # Streaming chat
71
- async for chunk in chat.send_stream("Tell me a story"):
72
- print(chunk.text, end="", flush=True)
73
-
74
- # Get conversation history
75
- history = chat.get_history()
76
- ```
77
- """
78
-
79
- def __init__(self, config: Optional[ChatConfig] = None):
80
- """
81
- Initialize the ChatSDK.
82
-
83
- Args:
84
- config: Configuration options. If None, uses defaults.
85
- """
86
- self.config = config or ChatConfig()
87
- self.log = get_logger(__name__)
88
- self.log.setLevel(getattr(logging, self.config.logging_level))
89
-
90
- # Validate that both providers aren't specified
91
- if self.config.use_claude and self.config.use_chatgpt:
92
- raise ValueError(
93
- "Cannot specify both use_claude and use_chatgpt. Please choose one."
94
- )
95
-
96
- # Initialize LLM client - it will compute use_local automatically
97
- self.llm_client = LLMClient(
98
- use_claude=self.config.use_claude,
99
- use_openai=self.config.use_chatgpt,
100
- claude_model=self.config.claude_model,
101
- base_url=self.config.base_url,
102
- system_prompt=None, # We handle system prompts through Prompts class
103
- )
104
-
105
- # Store conversation history
106
- self.chat_history = deque(maxlen=self.config.max_history_length * 2)
107
-
108
- # RAG support
109
- self.rag = None
110
- self.rag_enabled = False
111
-
112
- self.log.debug("ChatSDK initialized")
113
-
114
- def _format_history_for_context(self) -> str:
115
- """Format chat history for inclusion in LLM context using model-specific formatting."""
116
- history_list = list(self.chat_history)
117
- return Prompts.format_chat_history(
118
- self.config.model,
119
- history_list,
120
- self.config.assistant_name,
121
- self.config.system_prompt,
122
- )
123
-
124
- def _normalize_message_content(self, content: Any) -> str:
125
- """
126
- Convert message content into a string for prompt construction, handling structured payloads.
127
- """
128
- if isinstance(content, str):
129
- return content
130
- if isinstance(content, list):
131
- parts = []
132
- for entry in content:
133
- if isinstance(entry, dict):
134
- if entry.get("type") == "text":
135
- parts.append(entry.get("text", ""))
136
- else:
137
- parts.append(json.dumps(entry))
138
- else:
139
- parts.append(str(entry))
140
- return "\n".join(part for part in parts if part)
141
- if isinstance(content, dict):
142
- return json.dumps(content)
143
- return str(content)
144
-
145
- def _prepare_messages_for_llm(
146
- self, messages: List[Dict[str, Any]]
147
- ) -> List[Dict[str, Any]]:
148
- """
149
- Ensure messages are safe to send to the LLM by appending a continuation
150
- prompt when the last entry is a tool_result, which some models ignore.
151
- """
152
- if not messages:
153
- return []
154
-
155
- prepared = list(messages)
156
- try:
157
- last_role = prepared[-1].get("role")
158
- except Exception:
159
- return prepared
160
-
161
- if last_role == "tool":
162
- prepared.append({"role": "user", "content": "continue"})
163
-
164
- return prepared
165
-
166
- def send_messages(
167
- self,
168
- messages: List[Dict[str, Any]],
169
- system_prompt: Optional[str] = None,
170
- **kwargs,
171
- ) -> ChatResponse:
172
- """
173
- Send a full conversation history and get a response.
174
-
175
- Args:
176
- messages: List of message dicts with 'role' and 'content' keys
177
- system_prompt: Optional system prompt to use (overrides config)
178
- **kwargs: Additional arguments for LLM generation
179
-
180
- Returns:
181
- ChatResponse with the complete response
182
- """
183
- try:
184
- messages = self._prepare_messages_for_llm(messages)
185
-
186
- # Convert messages to chat history format
187
- chat_history = []
188
-
189
- for msg in messages:
190
- role = msg.get("role", "")
191
- content = self._normalize_message_content(msg.get("content", ""))
192
-
193
- if role == "user":
194
- chat_history.append(f"user: {content}")
195
- elif role == "assistant":
196
- chat_history.append(f"assistant: {content}")
197
- elif role == "tool":
198
- tool_name = msg.get("name", "tool")
199
- chat_history.append(f"assistant: [tool:{tool_name}] {content}")
200
- # Skip system messages since they're passed separately
201
-
202
- # Use provided system prompt or fall back to config default
203
- effective_system_prompt = system_prompt or self.config.system_prompt
204
-
205
- # Format according to model type
206
- formatted_prompt = Prompts.format_chat_history(
207
- model=self.config.model,
208
- chat_history=chat_history,
209
- assistant_name="assistant",
210
- system_prompt=effective_system_prompt,
211
- )
212
-
213
- # Debug logging
214
- self.log.debug(f"Formatted prompt length: {len(formatted_prompt)} chars")
215
- self.log.debug(
216
- f"System prompt used: {effective_system_prompt[:100] if effective_system_prompt else 'None'}..."
217
- )
218
-
219
- # Set appropriate stop tokens based on model
220
- model_lower = self.config.model.lower() if self.config.model else ""
221
- if "qwen" in model_lower:
222
- kwargs.setdefault("stop", ["<|im_end|>", "<|im_start|>"])
223
- elif "llama" in model_lower:
224
- kwargs.setdefault("stop", ["<|eot_id|>", "<|start_header_id|>"])
225
-
226
- # Use generate with formatted prompt
227
- response = self.llm_client.generate(
228
- prompt=formatted_prompt,
229
- model=self.config.model,
230
- stream=False,
231
- **kwargs,
232
- )
233
-
234
- # Prepare response data
235
- stats = None
236
- if self.config.show_stats:
237
- stats = self.get_stats()
238
-
239
- return ChatResponse(text=response, stats=stats, is_complete=True)
240
-
241
- except ConnectionError as e:
242
- # Re-raise connection errors with additional context
243
- self.log.error(f"LLM connection error in send_messages: {e}")
244
- raise ConnectionError(f"Failed to connect to LLM server: {e}") from e
245
- except Exception as e:
246
- self.log.error(f"Error in send_messages: {e}")
247
- raise
248
-
249
- def send_messages_stream(
250
- self,
251
- messages: List[Dict[str, Any]],
252
- system_prompt: Optional[str] = None,
253
- **kwargs,
254
- ):
255
- """
256
- Send a full conversation history and get a streaming response.
257
-
258
- Args:
259
- messages: List of message dicts with 'role' and 'content' keys
260
- system_prompt: Optional system prompt to use (overrides config)
261
- **kwargs: Additional arguments for LLM generation
262
-
263
- Yields:
264
- ChatResponse chunks as they arrive
265
- """
266
- try:
267
- messages = self._prepare_messages_for_llm(messages)
268
-
269
- # Convert messages to chat history format
270
- chat_history = []
271
-
272
- for msg in messages:
273
- role = msg.get("role", "")
274
- content = self._normalize_message_content(msg.get("content", ""))
275
-
276
- if role == "user":
277
- chat_history.append(f"user: {content}")
278
- elif role == "assistant":
279
- chat_history.append(f"assistant: {content}")
280
- elif role == "tool":
281
- tool_name = msg.get("name", "tool")
282
- chat_history.append(f"assistant: [tool:{tool_name}] {content}")
283
- # Skip system messages since they're passed separately
284
-
285
- # Use provided system prompt or fall back to config default
286
- effective_system_prompt = system_prompt or self.config.system_prompt
287
-
288
- # Format according to model type
289
- formatted_prompt = Prompts.format_chat_history(
290
- model=self.config.model,
291
- chat_history=chat_history,
292
- assistant_name="assistant",
293
- system_prompt=effective_system_prompt,
294
- )
295
-
296
- # Debug logging
297
- self.log.debug(f"Formatted prompt length: {len(formatted_prompt)} chars")
298
- self.log.debug(
299
- f"System prompt used: {effective_system_prompt[:100] if effective_system_prompt else 'None'}..."
300
- )
301
-
302
- # Set appropriate stop tokens based on model
303
- model_lower = self.config.model.lower() if self.config.model else ""
304
- if "qwen" in model_lower:
305
- kwargs.setdefault("stop", ["<|im_end|>", "<|im_start|>"])
306
- elif "llama" in model_lower:
307
- kwargs.setdefault("stop", ["<|eot_id|>", "<|start_header_id|>"])
308
-
309
- # Use generate with formatted prompt for streaming
310
- full_response = ""
311
- for chunk in self.llm_client.generate(
312
- prompt=formatted_prompt, model=self.config.model, stream=True, **kwargs
313
- ):
314
- full_response += chunk
315
- yield ChatResponse(text=chunk, is_complete=False)
316
-
317
- # Send final response with stats
318
- # Always get stats for token tracking (show_stats controls display, not collection)
319
- stats = self.get_stats()
320
-
321
- yield ChatResponse(text="", stats=stats, is_complete=True)
322
-
323
- except ConnectionError as e:
324
- # Re-raise connection errors with additional context
325
- self.log.error(f"LLM connection error in send_messages_stream: {e}")
326
- raise ConnectionError(
327
- f"Failed to connect to LLM server (streaming): {e}"
328
- ) from e
329
- except Exception as e:
330
- self.log.error(f"Error in send_messages_stream: {e}")
331
- raise
332
-
333
- def send(self, message: str, *, no_history: bool = False, **kwargs) -> ChatResponse:
334
- """
335
- Send a message and get a complete response with conversation history.
336
-
337
- Args:
338
- message: The message to send
339
- no_history: When True, bypass stored chat history and send only this prompt
340
- **kwargs: Additional arguments for LLM generation
341
-
342
- Returns:
343
- ChatResponse with the complete response and updated history
344
- """
345
- try:
346
- if not message.strip():
347
- raise ValueError("Message cannot be empty")
348
-
349
- # Enhance message with RAG context if enabled
350
- enhanced_message, _rag_metadata = self._enhance_with_rag(message.strip())
351
-
352
- if no_history:
353
- # Build a prompt using only the current enhanced message
354
- full_prompt = Prompts.format_chat_history(
355
- model=self.config.model,
356
- chat_history=[f"user: {enhanced_message}"],
357
- assistant_name=self.config.assistant_name,
358
- system_prompt=self.config.system_prompt,
359
- )
360
- else:
361
- # Add user message to history (use original message for history)
362
- self.chat_history.append(f"user: {message.strip()}")
363
-
364
- # Prepare prompt with conversation context (use enhanced message for LLM)
365
- # Temporarily replace the last message with enhanced version for formatting
366
- if self.rag_enabled and enhanced_message != message.strip():
367
- # Save original and replace with enhanced version
368
- original_last = self.chat_history.pop()
369
- self.chat_history.append(f"user: {enhanced_message}")
370
- full_prompt = self._format_history_for_context()
371
- # Restore original for history
372
- self.chat_history.pop()
373
- self.chat_history.append(original_last)
374
- else:
375
- full_prompt = self._format_history_for_context()
376
-
377
- # Generate response
378
- generate_kwargs = dict(kwargs)
379
- if "max_tokens" not in generate_kwargs:
380
- generate_kwargs["max_tokens"] = self.config.max_tokens
381
-
382
- # Note: Retry logic is now handled at the LLM client level
383
- response = self.llm_client.generate(
384
- full_prompt,
385
- model=self.config.model,
386
- **generate_kwargs,
387
- )
388
-
389
- # Add assistant message to history when tracking conversation
390
- if not no_history:
391
- self.chat_history.append(f"{self.config.assistant_name}: {response}")
392
-
393
- # Prepare response data
394
- stats = None
395
- if self.config.show_stats:
396
- stats = self.get_stats()
397
-
398
- history = (
399
- list(self.chat_history)
400
- if kwargs.get("include_history", False)
401
- else None
402
- )
403
-
404
- return ChatResponse(
405
- text=response, history=history, stats=stats, is_complete=True
406
- )
407
-
408
- except Exception as e:
409
- self.log.error(f"Error in send: {e}")
410
- raise
411
-
412
- def send_stream(self, message: str, **kwargs):
413
- """
414
- Send a message and get a streaming response with conversation history.
415
-
416
- Args:
417
- message: The message to send
418
- **kwargs: Additional arguments for LLM generation
419
-
420
- Yields:
421
- ChatResponse chunks as they arrive
422
- """
423
- try:
424
- if not message.strip():
425
- raise ValueError("Message cannot be empty")
426
-
427
- # Enhance message with RAG context if enabled
428
- enhanced_message, _rag_metadata = self._enhance_with_rag(message.strip())
429
-
430
- # Add user message to history (use original message for history)
431
- self.chat_history.append(f"user: {message.strip()}")
432
-
433
- # Prepare prompt with conversation context (use enhanced message for LLM)
434
- # Temporarily replace the last message with enhanced version for formatting
435
- if self.rag_enabled and enhanced_message != message.strip():
436
- # Save original and replace with enhanced version
437
- original_last = self.chat_history.pop()
438
- self.chat_history.append(f"user: {enhanced_message}")
439
- full_prompt = self._format_history_for_context()
440
- # Restore original for history
441
- self.chat_history.pop()
442
- self.chat_history.append(original_last)
443
- else:
444
- full_prompt = self._format_history_for_context()
445
-
446
- # Generate streaming response
447
- generate_kwargs = dict(kwargs)
448
- if "max_tokens" not in generate_kwargs:
449
- generate_kwargs["max_tokens"] = self.config.max_tokens
450
-
451
- full_response = ""
452
- for chunk in self.llm_client.generate(
453
- full_prompt, model=self.config.model, stream=True, **generate_kwargs
454
- ):
455
- full_response += chunk
456
- yield ChatResponse(text=chunk, is_complete=False)
457
-
458
- # Add complete assistant message to history
459
- self.chat_history.append(f"{self.config.assistant_name}: {full_response}")
460
-
461
- # Send final response with stats and history if requested
462
- stats = None
463
- if self.config.show_stats:
464
- stats = self.get_stats()
465
-
466
- history = (
467
- list(self.chat_history)
468
- if kwargs.get("include_history", False)
469
- else None
470
- )
471
-
472
- yield ChatResponse(text="", history=history, stats=stats, is_complete=True)
473
-
474
- except Exception as e:
475
- self.log.error(f"Error in send_stream: {e}")
476
- raise
477
-
478
- def get_history(self) -> List[str]:
479
- """
480
- Get the current conversation history.
481
-
482
- Returns:
483
- List of conversation entries in "role: message" format
484
- """
485
- return list(self.chat_history)
486
-
487
- def clear_history(self) -> None:
488
- """Clear the conversation history."""
489
- self.chat_history.clear()
490
- self.log.debug("Chat history cleared")
491
-
492
- def get_formatted_history(self) -> List[Dict[str, str]]:
493
- """
494
- Get conversation history in structured format.
495
-
496
- Returns:
497
- List of dictionaries with 'role' and 'message' keys
498
- """
499
- formatted = []
500
- assistant_prefix = f"{self.config.assistant_name}: "
501
-
502
- for entry in self.chat_history:
503
- if entry.startswith("user: "):
504
- role, message = "user", entry[6:]
505
- formatted.append({"role": role, "message": message})
506
- elif entry.startswith(assistant_prefix):
507
- role, message = (
508
- self.config.assistant_name,
509
- entry[len(assistant_prefix) :],
510
- )
511
- formatted.append({"role": role, "message": message})
512
- elif ": " in entry:
513
- # Fallback for any other format
514
- role, message = entry.split(": ", 1)
515
- formatted.append({"role": role, "message": message})
516
- else:
517
- formatted.append({"role": "unknown", "message": entry})
518
- return formatted
519
-
520
- def get_stats(self) -> Dict[str, Any]:
521
- """
522
- Get performance statistics.
523
-
524
- Returns:
525
- Dictionary of performance stats
526
- """
527
- try:
528
- return self.llm_client.get_performance_stats() or {}
529
- except Exception as e:
530
- self.log.warning(f"Failed to get stats: {e}")
531
- return {}
532
-
533
- def get_system_prompt(self) -> Optional[str]:
534
- """
535
- Get the current system prompt.
536
-
537
- Returns:
538
- Current system prompt or None if not set
539
- """
540
- return self.config.system_prompt
541
-
542
- def set_system_prompt(self, system_prompt: Optional[str]) -> None:
543
- """
544
- Set the system prompt for future conversations.
545
-
546
- Args:
547
- system_prompt: New system prompt to use, or None to clear it
548
- """
549
- self.config.system_prompt = system_prompt
550
- self.log.debug(
551
- f"System prompt updated: {system_prompt[:100] if system_prompt else 'None'}..."
552
- )
553
-
554
- def display_stats(self, stats: Optional[Dict[str, Any]] = None) -> None:
555
- """
556
- Display performance statistics in a formatted way.
557
-
558
- Args:
559
- stats: Optional stats dictionary. If None, gets current stats.
560
- """
561
- if stats is None:
562
- stats = self.get_stats()
563
-
564
- if stats:
565
- print("\n" + "=" * 30)
566
- print("Performance Statistics:")
567
- print("=" * 30)
568
- for key, value in stats.items():
569
- if isinstance(value, float):
570
- if "time" in key.lower():
571
- print(f" {key}: {value:.3f}s")
572
- elif "tokens_per_second" in key.lower():
573
- print(f" {key}: {value:.2f} tokens/s")
574
- else:
575
- print(f" {key}: {value:.4f}")
576
- elif isinstance(value, int):
577
- if "tokens" in key.lower():
578
- print(f" {key}: {value:,} tokens")
579
- else:
580
- print(f" {key}: {value}")
581
- else:
582
- print(f" {key}: {value}")
583
- print("=" * 30)
584
- else:
585
- print("No statistics available.")
586
-
587
- async def start_interactive_session(self) -> None:
588
- """
589
- Start an interactive chat session with conversation history.
590
-
591
- This provides a full CLI-style interactive experience with commands
592
- for managing conversation history and viewing statistics.
593
- """
594
- print("=" * 50)
595
- print("Interactive Chat Session Started")
596
- print(f"Using model: {self.config.model}")
597
- print("Type 'quit', 'exit', or 'bye' to end the conversation")
598
- print("Commands:")
599
- print(" /clear - clear conversation history")
600
- print(" /history - show conversation history")
601
- print(" /stats - show performance statistics")
602
- print(" /help - show this help message")
603
- print("=" * 50)
604
-
605
- while True:
606
- try:
607
- user_input = input("\nYou: ").strip()
608
-
609
- if user_input.lower() in ["quit", "exit", "bye"]:
610
- print("\nGoodbye!")
611
- break
612
- elif user_input.lower() == "/clear":
613
- self.clear_history()
614
- print("Conversation history cleared.")
615
- continue
616
- elif user_input.lower() == "/history":
617
- history = self.get_formatted_history()
618
- if not history:
619
- print("No conversation history.")
620
- else:
621
- print("\n" + "=" * 30)
622
- print("Conversation History:")
623
- print("=" * 30)
624
- for entry in history:
625
- print(f"{entry['role'].title()}: {entry['message']}")
626
- print("=" * 30)
627
- continue
628
- elif user_input.lower() == "/stats":
629
- self.display_stats()
630
- continue
631
- elif user_input.lower() == "/help":
632
- print("\n" + "=" * 40)
633
- print("Available Commands:")
634
- print("=" * 40)
635
- print(" /clear - clear conversation history")
636
- print(" /history - show conversation history")
637
- print(" /stats - show performance statistics")
638
- print(" /help - show this help message")
639
- print("\nTo exit: type 'quit', 'exit', or 'bye'")
640
- print("=" * 40)
641
- continue
642
- elif not user_input:
643
- print("Please enter a message.")
644
- continue
645
-
646
- print(f"\n{self.config.assistant_name.title()}: ", end="", flush=True)
647
-
648
- # Generate and stream response
649
- for chunk in self.send_stream(user_input):
650
- if not chunk.is_complete:
651
- print(chunk.text, end="", flush=True)
652
- else:
653
- # Show stats if configured and available
654
- if self.config.show_stats and chunk.stats:
655
- self.display_stats(chunk.stats)
656
- print() # Add newline after response
657
-
658
- except KeyboardInterrupt:
659
- print("\n\nChat interrupted. Goodbye!")
660
- break
661
- except Exception as e:
662
- print(f"\nError: {e}")
663
- raise
664
-
665
- def update_config(self, **kwargs) -> None:
666
- """
667
- Update configuration dynamically.
668
-
669
- Args:
670
- **kwargs: Configuration parameters to update
671
- """
672
- # Update our config
673
- for key, value in kwargs.items():
674
- if hasattr(self.config, key):
675
- setattr(self.config, key, value)
676
-
677
- # Handle special cases
678
- if "max_history_length" in kwargs:
679
- # Create new deque with updated maxlen
680
- old_history = list(self.chat_history)
681
- new_maxlen = kwargs["max_history_length"] * 2
682
- self.chat_history = deque(old_history, maxlen=new_maxlen)
683
-
684
- if "system_prompt" in kwargs:
685
- # System prompt is handled through Prompts class, not directly
686
- pass
687
-
688
- if "assistant_name" in kwargs:
689
- # Assistant name change affects history display but not underlying storage
690
- # since we dynamically parse the history based on current assistant_name
691
- pass
692
-
693
- @property
694
- def history_length(self) -> int:
695
- """Get the current number of conversation entries."""
696
- return len(self.chat_history)
697
-
698
- @property
699
- def conversation_pairs(self) -> int:
700
- """Get the number of conversation pairs (user + assistant)."""
701
- return len(self.chat_history) // 2
702
-
703
- def enable_rag(self, documents: Optional[List[str]] = None, **rag_kwargs):
704
- """
705
- Enable RAG (Retrieval-Augmented Generation) for document-based chat.
706
-
707
- Args:
708
- documents: List of PDF file paths to index
709
- **rag_kwargs: Additional RAG configuration options
710
- """
711
- try:
712
- from gaia.rag.sdk import RAGSDK, RAGConfig
713
- except ImportError:
714
- raise ImportError(
715
- 'RAG dependencies not installed. Install with: uv pip install -e ".[rag]"'
716
- )
717
-
718
- # Create RAG config matching chat config
719
- rag_config = RAGConfig(
720
- model=self.config.model,
721
- show_stats=self.config.show_stats,
722
- use_local_llm=self.config.use_local_llm,
723
- **rag_kwargs,
724
- )
725
-
726
- self.rag = RAGSDK(rag_config)
727
- self.rag_enabled = True
728
-
729
- # Index documents if provided
730
- if documents:
731
- for doc_path in documents:
732
- self.log.info(f"Indexing document: {doc_path}")
733
- result = self.rag.index_document(doc_path)
734
-
735
- if result:
736
- self.log.info(f"Successfully indexed: {doc_path}")
737
- else:
738
- self.log.warning(f"Failed to index document: {doc_path}")
739
-
740
- self.log.info(
741
- f"RAG enabled with {len(documents) if documents else 0} documents"
742
- )
743
-
744
- def disable_rag(self):
745
- """Disable RAG functionality."""
746
- self.rag = None
747
- self.rag_enabled = False
748
- self.log.info("RAG disabled")
749
-
750
- def add_document(self, document_path: str) -> bool:
751
- """
752
- Add a document to the RAG index.
753
-
754
- Args:
755
- document_path: Path to PDF file to index
756
-
757
- Returns:
758
- True if indexing succeeded
759
- """
760
- if not self.rag_enabled or not self.rag:
761
- raise ValueError("RAG not enabled. Call enable_rag() first.")
762
-
763
- return self.rag.index_document(document_path)
764
-
765
- def _estimate_tokens(self, text: str) -> int:
766
- """
767
- Estimate the number of tokens in text.
768
- Uses rough approximation of 4 characters per token.
769
-
770
- Args:
771
- text: The text to estimate tokens for
772
-
773
- Returns:
774
- Estimated token count
775
- """
776
- # Rough approximation: ~4 characters per token for English text
777
- # This is conservative to avoid overrunning context
778
- return len(text) // 4
779
-
780
- def summarize_conversation_history(self, max_history_tokens: int) -> Optional[str]:
781
- """
782
- Summarize conversation history when it exceeds the token budget.
783
-
784
- Args:
785
- max_history_tokens: Maximum allowed tokens for stored history
786
-
787
- Returns:
788
- The generated summary (when summarization occurred) or None
789
- """
790
- if max_history_tokens <= 0:
791
- raise ValueError("max_history_tokens must be positive")
792
-
793
- history_entries = list(self.chat_history)
794
- if not history_entries:
795
- return None
796
-
797
- history_text = "\n".join(history_entries)
798
- history_tokens = self._estimate_tokens(history_text)
799
-
800
- if history_tokens <= max_history_tokens:
801
- print(
802
- "History tokens are less than max history tokens, so no summarization is needed"
803
- )
804
- return None
805
-
806
- print(
807
- "History tokens are greater than max history tokens, so summarization is needed"
808
- )
809
-
810
- self.log.info(
811
- "Conversation history (~%d tokens) exceeds budget (%d). Summarizing...",
812
- history_tokens,
813
- max_history_tokens,
814
- )
815
-
816
- summary_prompt = (
817
- "Summarize the following conversation between a user and the GAIA web "
818
- "development agent. Preserve:\n"
819
- "- The app requirements and inferred schema/data models\n"
820
- "- Key implementation details already completed\n"
821
- "- Outstanding issues, validation failures, or TODOs (quote error/warning text verbatim)\n"
822
- "- Any constraints or preferences the user emphasized\n\n"
823
- "Write the summary in under 400 tokens, using concise paragraphs, and include the exact text of any warnings/errors so future fixes have full context."
824
- )
825
- full_prompt = (
826
- f"{summary_prompt}\n\nConversation History:\n{history_text}\n\nSummary:"
827
- )
828
-
829
- try:
830
- summary = self.llm_client.generate(
831
- full_prompt,
832
- model=self.config.model,
833
- max_tokens=min(self.config.max_tokens, 2048),
834
- timeout=1200,
835
- )
836
- except Exception as exc: # pylint: disable=broad-exception-caught
837
- self.log.error("Failed to summarize conversation history: %s", exc)
838
- return None
839
-
840
- summary = summary.strip()
841
- if not summary:
842
- self.log.warning("Summarization returned empty content; keeping history.")
843
- return None
844
-
845
- self.chat_history.clear()
846
- self.chat_history.append(
847
- f"{self.config.assistant_name}: Conversation summary so far:\n{summary}"
848
- )
849
- return summary
850
-
851
- def _truncate_rag_context(self, context: str, max_tokens: int) -> str:
852
- """
853
- Truncate RAG context to fit within token budget.
854
-
855
- Args:
856
- context: The RAG context to truncate
857
- max_tokens: Maximum tokens allowed
858
-
859
- Returns:
860
- Truncated context with ellipsis if needed
861
- """
862
- estimated_tokens = self._estimate_tokens(context)
863
-
864
- if estimated_tokens <= max_tokens:
865
- return context
866
-
867
- # Calculate how many characters we can keep
868
- target_chars = max_tokens * 4 # Using same 4:1 ratio
869
-
870
- # Truncate and add ellipsis
871
- truncated = context[: target_chars - 20] # Leave room for ellipsis
872
- truncated += "\n... [context truncated for length]"
873
-
874
- self.log.warning(
875
- f"RAG context truncated from ~{estimated_tokens} to ~{max_tokens} tokens"
876
- )
877
- return truncated
878
-
879
- def _enhance_with_rag(self, message: str) -> tuple:
880
- """
881
- Enhance user message with relevant document context using RAG.
882
-
883
- Args:
884
- message: Original user message
885
-
886
- Returns:
887
- Tuple of (enhanced_message, metadata_dict)
888
- """
889
- if not self.rag_enabled or not self.rag:
890
- return message, None
891
-
892
- try:
893
- # Query RAG for relevant context with metadata
894
- rag_response = self.rag.query(message, include_metadata=True)
895
-
896
- if rag_response.chunks:
897
- # Build context with source information
898
- context_parts = []
899
- if rag_response.chunk_metadata:
900
- for i, (chunk, metadata) in enumerate(
901
- zip(rag_response.chunks, rag_response.chunk_metadata)
902
- ):
903
- context_parts.append(
904
- f"Context {i+1} (from {metadata['source_file']}, relevance: {metadata['relevance_score']:.2f}):\n{chunk}"
905
- )
906
- else:
907
- context_parts = [
908
- f"Context {i+1}:\n{chunk}"
909
- for i, chunk in enumerate(rag_response.chunks)
910
- ]
911
-
912
- context = "\n\n".join(context_parts)
913
-
914
- # Check token limits
915
- message_tokens = self._estimate_tokens(message)
916
- template_tokens = 150 # Template text overhead
917
- response_tokens = self.config.max_tokens
918
- history_tokens = self._estimate_tokens(str(self.chat_history))
919
-
920
- # Conservative context size for models
921
- model_context_size = 32768
922
- available_for_rag = (
923
- model_context_size
924
- - message_tokens
925
- - template_tokens
926
- - response_tokens
927
- - history_tokens
928
- )
929
-
930
- # Ensure minimum RAG context
931
- if available_for_rag < 500:
932
- self.log.warning(
933
- f"Limited space for RAG context: {available_for_rag} tokens"
934
- )
935
- available_for_rag = 500
936
-
937
- # Truncate context if needed
938
- context = self._truncate_rag_context(context, available_for_rag)
939
-
940
- # Build enhanced message
941
- enhanced_message = f"""Based on the provided documents, please answer the following question. Use the context below to inform your response.
942
-
943
- Context from documents:
944
- {context}
945
-
946
- User question: {message}
947
-
948
- Note: When citing information, please mention which context number it came from."""
949
-
950
- # Prepare metadata for return
951
- metadata = {
952
- "rag_used": True,
953
- "chunks_retrieved": len(rag_response.chunks),
954
- "estimated_context_tokens": self._estimate_tokens(context),
955
- "available_tokens": available_for_rag,
956
- "context_truncated": (
957
- len(context) < sum(len(c) for c in rag_response.chunks)
958
- if rag_response.chunks
959
- else False
960
- ),
961
- }
962
-
963
- # Add query metadata if available
964
- if rag_response.query_metadata:
965
- metadata["query_metadata"] = rag_response.query_metadata
966
-
967
- self.log.debug(
968
- f"Enhanced message with {len(rag_response.chunks)} chunks from "
969
- f"{len(set(rag_response.source_files)) if rag_response.source_files else 0} documents, "
970
- f"~{metadata['estimated_context_tokens']} context tokens"
971
- )
972
- return enhanced_message, metadata
973
- else:
974
- self.log.debug("No relevant document context found")
975
- return message, {"rag_used": True, "chunks_retrieved": 0}
976
-
977
- except Exception as e:
978
- self.log.warning(
979
- f"RAG enhancement failed: {e}, falling back to direct query"
980
- )
981
- return message, {"rag_used": False, "error": str(e)}
982
-
983
-
984
- class SimpleChat:
985
- """
986
- Ultra-simple interface for quick chat integration.
987
-
988
- Example usage:
989
- ```python
990
- from gaia.chat.sdk import SimpleChat
991
-
992
- chat = SimpleChat()
993
-
994
- # Simple question-answer
995
- response = await chat.ask("What's the weather like?")
996
- print(response)
997
-
998
- # Chat with memory
999
- response1 = await chat.ask("My name is John")
1000
- response2 = await chat.ask("What's my name?") # Remembers previous context
1001
- ```
1002
- """
1003
-
1004
- def __init__(
1005
- self,
1006
- system_prompt: Optional[str] = None,
1007
- model: Optional[str] = None,
1008
- assistant_name: Optional[str] = None,
1009
- ):
1010
- """
1011
- Initialize SimpleChat with minimal configuration.
1012
-
1013
- Args:
1014
- system_prompt: Optional system prompt for the AI
1015
- model: Model to use (defaults to DEFAULT_MODEL_NAME)
1016
- assistant_name: Name to use for the assistant (defaults to "assistant")
1017
- """
1018
- config = ChatConfig(
1019
- model=model or DEFAULT_MODEL_NAME,
1020
- system_prompt=system_prompt,
1021
- assistant_name=assistant_name or "gaia",
1022
- show_stats=False,
1023
- logging_level="WARNING", # Minimal logging
1024
- )
1025
- self._sdk = ChatSDK(config)
1026
-
1027
- def ask(self, question: str) -> str:
1028
- """
1029
- Ask a question and get a text response with conversation memory.
1030
-
1031
- Args:
1032
- question: The question to ask
1033
-
1034
- Returns:
1035
- The AI's response as a string
1036
- """
1037
- response = self._sdk.send(question)
1038
- return response.text
1039
-
1040
- def ask_stream(self, question: str):
1041
- """
1042
- Ask a question and get a streaming response with conversation memory.
1043
-
1044
- Args:
1045
- question: The question to ask
1046
-
1047
- Yields:
1048
- Response chunks as they arrive
1049
- """
1050
- for chunk in self._sdk.send_stream(question):
1051
- if not chunk.is_complete:
1052
- yield chunk.text
1053
-
1054
- def clear_memory(self) -> None:
1055
- """Clear the conversation memory."""
1056
- self._sdk.clear_history()
1057
-
1058
- def get_conversation(self) -> List[Dict[str, str]]:
1059
- """Get the conversation history in a readable format."""
1060
- return self._sdk.get_formatted_history()
1061
-
1062
-
1063
- class ChatSession:
1064
- """
1065
- Session-based chat interface for managing multiple separate conversations.
1066
-
1067
- Example usage:
1068
- ```python
1069
- from gaia.chat.sdk import ChatSession
1070
-
1071
- # Create session manager
1072
- sessions = ChatSession()
1073
-
1074
- # Create different chat sessions
1075
- work_chat = sessions.create_session("work", system_prompt="You are a professional assistant")
1076
- personal_chat = sessions.create_session("personal", system_prompt="You are a friendly companion")
1077
-
1078
- # Chat in different contexts
1079
- work_response = await work_chat.ask("Draft an email to my team")
1080
- personal_response = await personal_chat.ask("What's a good recipe for dinner?")
1081
- ```
1082
- """
1083
-
1084
- def __init__(self, default_config: Optional[ChatConfig] = None):
1085
- """Initialize the session manager."""
1086
- self.default_config = default_config or ChatConfig()
1087
- self.sessions: Dict[str, ChatSDK] = {}
1088
- self.log = get_logger(__name__)
1089
-
1090
- def create_session(
1091
- self, session_id: str, config: Optional[ChatConfig] = None, **config_kwargs
1092
- ) -> ChatSDK:
1093
- """
1094
- Create a new chat session.
1095
-
1096
- Args:
1097
- session_id: Unique identifier for the session
1098
- config: Optional configuration (uses default if not provided)
1099
- **config_kwargs: Configuration parameters to override
1100
-
1101
- Returns:
1102
- ChatSDK instance for the session
1103
- """
1104
- if config is None:
1105
- # Create config from defaults with overrides
1106
- config = ChatConfig(
1107
- model=config_kwargs.get("model", self.default_config.model),
1108
- max_tokens=config_kwargs.get(
1109
- "max_tokens", self.default_config.max_tokens
1110
- ),
1111
- system_prompt=config_kwargs.get(
1112
- "system_prompt", self.default_config.system_prompt
1113
- ),
1114
- max_history_length=config_kwargs.get(
1115
- "max_history_length", self.default_config.max_history_length
1116
- ),
1117
- show_stats=config_kwargs.get(
1118
- "show_stats", self.default_config.show_stats
1119
- ),
1120
- logging_level=config_kwargs.get(
1121
- "logging_level", self.default_config.logging_level
1122
- ),
1123
- use_claude=config_kwargs.get(
1124
- "use_claude", self.default_config.use_claude
1125
- ),
1126
- use_chatgpt=config_kwargs.get(
1127
- "use_chatgpt", self.default_config.use_chatgpt
1128
- ),
1129
- assistant_name=config_kwargs.get(
1130
- "assistant_name", self.default_config.assistant_name
1131
- ),
1132
- )
1133
-
1134
- session = ChatSDK(config)
1135
- self.sessions[session_id] = session
1136
- self.log.debug(f"Created chat session: {session_id}")
1137
- return session
1138
-
1139
- def get_session(self, session_id: str) -> Optional[ChatSDK]:
1140
- """Get an existing session by ID."""
1141
- return self.sessions.get(session_id)
1142
-
1143
- def delete_session(self, session_id: str) -> bool:
1144
- """Delete a session."""
1145
- if session_id in self.sessions:
1146
- del self.sessions[session_id]
1147
- self.log.debug(f"Deleted chat session: {session_id}")
1148
- return True
1149
- return False
1150
-
1151
- def list_sessions(self) -> List[str]:
1152
- """List all active session IDs."""
1153
- return list(self.sessions.keys())
1154
-
1155
- def clear_all_sessions(self) -> None:
1156
- """Clear all sessions."""
1157
- self.sessions.clear()
1158
- self.log.debug("Cleared all chat sessions")
1159
-
1160
-
1161
- # Convenience functions for one-off usage
1162
- def quick_chat(
1163
- message: str,
1164
- system_prompt: Optional[str] = None,
1165
- model: Optional[str] = None,
1166
- assistant_name: Optional[str] = None,
1167
- ) -> str:
1168
- """
1169
- Quick one-off text chat without conversation memory.
1170
-
1171
- Args:
1172
- message: Message to send
1173
- system_prompt: Optional system prompt
1174
- model: Optional model to use
1175
- assistant_name: Name to use for the assistant
1176
-
1177
- Returns:
1178
- AI response
1179
- """
1180
- config = ChatConfig(
1181
- model=model or DEFAULT_MODEL_NAME,
1182
- system_prompt=system_prompt,
1183
- assistant_name=assistant_name or "gaia",
1184
- show_stats=False,
1185
- logging_level="WARNING",
1186
- max_history_length=2, # Small history for quick chat
1187
- )
1188
- sdk = ChatSDK(config)
1189
- response = sdk.send(message)
1190
- return response.text
1191
-
1192
-
1193
- def quick_chat_with_memory(
1194
- messages: List[str],
1195
- system_prompt: Optional[str] = None,
1196
- model: Optional[str] = None,
1197
- assistant_name: Optional[str] = None,
1198
- ) -> List[str]:
1199
- """
1200
- Quick multi-turn chat with conversation memory.
1201
-
1202
- Args:
1203
- messages: List of messages to send sequentially
1204
- system_prompt: Optional system prompt
1205
- model: Optional model to use
1206
- assistant_name: Name to use for the assistant
1207
-
1208
- Returns:
1209
- List of AI responses
1210
- """
1211
- config = ChatConfig(
1212
- model=model or DEFAULT_MODEL_NAME,
1213
- system_prompt=system_prompt,
1214
- assistant_name=assistant_name or "gaia",
1215
- show_stats=False,
1216
- logging_level="WARNING",
1217
- )
1218
- sdk = ChatSDK(config)
1219
-
1220
- responses = []
1221
- for message in messages:
1222
- response = sdk.send(message)
1223
- responses.append(response.text)
1224
-
1225
- return responses
1
+ #!/usr/bin/env python3
2
+ # Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ """
6
+ Gaia Chat SDK - Unified text chat integration with conversation history
7
+ """
8
+
9
+ import json
10
+ import logging
11
+ from collections import deque
12
+ from dataclasses import dataclass
13
+ from typing import Any, Dict, List, Optional
14
+
15
+ from gaia.chat.prompts import Prompts
16
+ from gaia.llm import create_client
17
+ from gaia.llm.lemonade_client import DEFAULT_MODEL_NAME
18
+ from gaia.logger import get_logger
19
+
20
+
21
+ @dataclass
22
+ class ChatConfig:
23
+ """Configuration for ChatSDK."""
24
+
25
+ model: str = DEFAULT_MODEL_NAME
26
+ max_tokens: int = 512
27
+ system_prompt: Optional[str] = None
28
+ max_history_length: int = 4 # Number of conversation pairs to keep
29
+ show_stats: bool = False
30
+ logging_level: str = "INFO"
31
+ use_claude: bool = False # Use Claude API
32
+ use_chatgpt: bool = False # Use ChatGPT/OpenAI API
33
+ use_local_llm: bool = (
34
+ True # Use local LLM (computed as not use_claude and not use_chatgpt if not explicitly set)
35
+ )
36
+ claude_model: str = "claude-sonnet-4-20250514" # Claude model when use_claude=True
37
+ base_url: str = "http://localhost:8000/api/v1" # Lemonade server base URL
38
+ assistant_name: str = "gaia" # Name to use for the assistant in conversations
39
+
40
+
41
+ @dataclass
42
+ class ChatResponse:
43
+ """Response from chat operations."""
44
+
45
+ text: str
46
+ history: Optional[List[str]] = None
47
+ stats: Optional[Dict[str, Any]] = None
48
+ is_complete: bool = True
49
+
50
+
51
+ class ChatSDK:
52
+ """
53
+ Gaia Chat SDK - Unified text chat integration with conversation history.
54
+
55
+ This SDK provides a simple interface for integrating Gaia's text chat
56
+ capabilities with conversation memory into applications.
57
+
58
+ Example usage:
59
+ ```python
60
+ from gaia.chat.sdk import ChatSDK, ChatConfig
61
+
62
+ # Create SDK instance
63
+ config = ChatConfig(model=DEFAULT_MODEL_NAME, show_stats=True)
64
+ chat = ChatSDK(config)
65
+
66
+ # Single message
67
+ response = await chat.send("Hello, how are you?")
68
+ print(response.text)
69
+
70
+ # Streaming chat
71
+ async for chunk in chat.send_stream("Tell me a story"):
72
+ print(chunk.text, end="", flush=True)
73
+
74
+ # Get conversation history
75
+ history = chat.get_history()
76
+ ```
77
+ """
78
+
79
+ def __init__(self, config: Optional[ChatConfig] = None):
80
+ """
81
+ Initialize the ChatSDK.
82
+
83
+ Args:
84
+ config: Configuration options. If None, uses defaults.
85
+ """
86
+ self.config = config or ChatConfig()
87
+ self.log = get_logger(__name__)
88
+ self.log.setLevel(getattr(logging, self.config.logging_level))
89
+
90
+ # Initialize LLM client - factory auto-detects provider and validates
91
+ self.llm_client = create_client(
92
+ use_claude=self.config.use_claude,
93
+ use_openai=self.config.use_chatgpt,
94
+ model=(
95
+ self.config.claude_model
96
+ if self.config.use_claude
97
+ else self.config.model
98
+ ),
99
+ base_url=self.config.base_url,
100
+ system_prompt=self.config.system_prompt,
101
+ )
102
+
103
+ # Store conversation history
104
+ self.chat_history = deque(maxlen=self.config.max_history_length * 2)
105
+
106
+ # RAG support
107
+ self.rag = None
108
+ self.rag_enabled = False
109
+
110
+ self.log.debug("ChatSDK initialized")
111
+
112
+ def _format_history_for_context(self) -> str:
113
+ """Format chat history for inclusion in LLM context using model-specific formatting."""
114
+ history_list = list(self.chat_history)
115
+ return Prompts.format_chat_history(
116
+ self.config.model,
117
+ history_list,
118
+ self.config.assistant_name,
119
+ self.config.system_prompt,
120
+ )
121
+
122
+ def _normalize_message_content(self, content: Any) -> str:
123
+ """
124
+ Convert message content into a string for prompt construction, handling structured payloads.
125
+ """
126
+ if isinstance(content, str):
127
+ return content
128
+ if isinstance(content, list):
129
+ parts = []
130
+ for entry in content:
131
+ if isinstance(entry, dict):
132
+ if entry.get("type") == "text":
133
+ parts.append(entry.get("text", ""))
134
+ else:
135
+ parts.append(json.dumps(entry))
136
+ else:
137
+ parts.append(str(entry))
138
+ return "\n".join(part for part in parts if part)
139
+ if isinstance(content, dict):
140
+ return json.dumps(content)
141
+ return str(content)
142
+
143
+ def _prepare_messages_for_llm(
144
+ self, messages: List[Dict[str, Any]]
145
+ ) -> List[Dict[str, Any]]:
146
+ """
147
+ Ensure messages are safe to send to the LLM by appending a continuation
148
+ prompt when the last entry is a tool_result, which some models ignore.
149
+ """
150
+ if not messages:
151
+ return []
152
+
153
+ prepared = list(messages)
154
+ try:
155
+ last_role = prepared[-1].get("role")
156
+ except Exception:
157
+ return prepared
158
+
159
+ if last_role == "tool":
160
+ prepared.append({"role": "user", "content": "continue"})
161
+
162
+ return prepared
163
+
164
+ def send_messages(
165
+ self,
166
+ messages: List[Dict[str, Any]],
167
+ system_prompt: Optional[str] = None,
168
+ **kwargs,
169
+ ) -> ChatResponse:
170
+ """
171
+ Send a full conversation history and get a response.
172
+
173
+ Args:
174
+ messages: List of message dicts with 'role' and 'content' keys
175
+ system_prompt: Optional system prompt to use (overrides config)
176
+ **kwargs: Additional arguments for LLM generation
177
+
178
+ Returns:
179
+ ChatResponse with the complete response
180
+ """
181
+ try:
182
+ messages = self._prepare_messages_for_llm(messages)
183
+
184
+ # Convert messages to chat history format
185
+ chat_history = []
186
+
187
+ for msg in messages:
188
+ role = msg.get("role", "")
189
+ content = self._normalize_message_content(msg.get("content", ""))
190
+
191
+ if role == "user":
192
+ chat_history.append(f"user: {content}")
193
+ elif role == "assistant":
194
+ chat_history.append(f"assistant: {content}")
195
+ elif role == "tool":
196
+ tool_name = msg.get("name", "tool")
197
+ chat_history.append(f"assistant: [tool:{tool_name}] {content}")
198
+ # Skip system messages since they're passed separately
199
+
200
+ # Use provided system prompt or fall back to config default
201
+ effective_system_prompt = system_prompt or self.config.system_prompt
202
+
203
+ # Format according to model type
204
+ formatted_prompt = Prompts.format_chat_history(
205
+ model=self.config.model,
206
+ chat_history=chat_history,
207
+ assistant_name="assistant",
208
+ system_prompt=effective_system_prompt,
209
+ )
210
+
211
+ # Debug logging
212
+ self.log.debug(f"Formatted prompt length: {len(formatted_prompt)} chars")
213
+ self.log.debug(
214
+ f"System prompt used: {effective_system_prompt[:100] if effective_system_prompt else 'None'}..."
215
+ )
216
+
217
+ # Set appropriate stop tokens based on model
218
+ model_lower = self.config.model.lower() if self.config.model else ""
219
+ if "qwen" in model_lower:
220
+ kwargs.setdefault("stop", ["<|im_end|>", "<|im_start|>"])
221
+ elif "llama" in model_lower:
222
+ kwargs.setdefault("stop", ["<|eot_id|>", "<|start_header_id|>"])
223
+
224
+ # Use generate with formatted prompt
225
+ response = self.llm_client.generate(
226
+ prompt=formatted_prompt,
227
+ model=self.config.model,
228
+ stream=False,
229
+ **kwargs,
230
+ )
231
+
232
+ # Prepare response data
233
+ stats = None
234
+ if self.config.show_stats:
235
+ stats = self.get_stats()
236
+
237
+ return ChatResponse(text=response, stats=stats, is_complete=True)
238
+
239
+ except ConnectionError as e:
240
+ # Re-raise connection errors with additional context
241
+ self.log.error(f"LLM connection error in send_messages: {e}")
242
+ raise ConnectionError(f"Failed to connect to LLM server: {e}") from e
243
+ except Exception as e:
244
+ self.log.error(f"Error in send_messages: {e}")
245
+ raise
246
+
247
+ def send_messages_stream(
248
+ self,
249
+ messages: List[Dict[str, Any]],
250
+ system_prompt: Optional[str] = None,
251
+ **kwargs,
252
+ ):
253
+ """
254
+ Send a full conversation history and get a streaming response.
255
+
256
+ Args:
257
+ messages: List of message dicts with 'role' and 'content' keys
258
+ system_prompt: Optional system prompt to use (overrides config)
259
+ **kwargs: Additional arguments for LLM generation
260
+
261
+ Yields:
262
+ ChatResponse chunks as they arrive
263
+ """
264
+ try:
265
+ messages = self._prepare_messages_for_llm(messages)
266
+
267
+ # Convert messages to chat history format
268
+ chat_history = []
269
+
270
+ for msg in messages:
271
+ role = msg.get("role", "")
272
+ content = self._normalize_message_content(msg.get("content", ""))
273
+
274
+ if role == "user":
275
+ chat_history.append(f"user: {content}")
276
+ elif role == "assistant":
277
+ chat_history.append(f"assistant: {content}")
278
+ elif role == "tool":
279
+ tool_name = msg.get("name", "tool")
280
+ chat_history.append(f"assistant: [tool:{tool_name}] {content}")
281
+ # Skip system messages since they're passed separately
282
+
283
+ # Use provided system prompt or fall back to config default
284
+ effective_system_prompt = system_prompt or self.config.system_prompt
285
+
286
+ # Format according to model type
287
+ formatted_prompt = Prompts.format_chat_history(
288
+ model=self.config.model,
289
+ chat_history=chat_history,
290
+ assistant_name="assistant",
291
+ system_prompt=effective_system_prompt,
292
+ )
293
+
294
+ # Debug logging
295
+ self.log.debug(f"Formatted prompt length: {len(formatted_prompt)} chars")
296
+ self.log.debug(
297
+ f"System prompt used: {effective_system_prompt[:100] if effective_system_prompt else 'None'}..."
298
+ )
299
+
300
+ # Set appropriate stop tokens based on model
301
+ model_lower = self.config.model.lower() if self.config.model else ""
302
+ if "qwen" in model_lower:
303
+ kwargs.setdefault("stop", ["<|im_end|>", "<|im_start|>"])
304
+ elif "llama" in model_lower:
305
+ kwargs.setdefault("stop", ["<|eot_id|>", "<|start_header_id|>"])
306
+
307
+ # Use generate with formatted prompt for streaming
308
+ full_response = ""
309
+ for chunk in self.llm_client.generate(
310
+ prompt=formatted_prompt, model=self.config.model, stream=True, **kwargs
311
+ ):
312
+ full_response += chunk
313
+ yield ChatResponse(text=chunk, is_complete=False)
314
+
315
+ # Send final response with stats
316
+ # Always get stats for token tracking (show_stats controls display, not collection)
317
+ stats = self.get_stats()
318
+
319
+ yield ChatResponse(text="", stats=stats, is_complete=True)
320
+
321
+ except ConnectionError as e:
322
+ # Re-raise connection errors with additional context
323
+ self.log.error(f"LLM connection error in send_messages_stream: {e}")
324
+ raise ConnectionError(
325
+ f"Failed to connect to LLM server (streaming): {e}"
326
+ ) from e
327
+ except Exception as e:
328
+ self.log.error(f"Error in send_messages_stream: {e}")
329
+ raise
330
+
331
+ def send(self, message: str, *, no_history: bool = False, **kwargs) -> ChatResponse:
332
+ """
333
+ Send a message and get a complete response with conversation history.
334
+
335
+ Args:
336
+ message: The message to send
337
+ no_history: When True, bypass stored chat history and send only this prompt
338
+ **kwargs: Additional arguments for LLM generation
339
+
340
+ Returns:
341
+ ChatResponse with the complete response and updated history
342
+ """
343
+ try:
344
+ if not message.strip():
345
+ raise ValueError("Message cannot be empty")
346
+
347
+ # Enhance message with RAG context if enabled
348
+ enhanced_message, _rag_metadata = self._enhance_with_rag(message.strip())
349
+
350
+ if no_history:
351
+ # Build a prompt using only the current enhanced message
352
+ full_prompt = Prompts.format_chat_history(
353
+ model=self.config.model,
354
+ chat_history=[f"user: {enhanced_message}"],
355
+ assistant_name=self.config.assistant_name,
356
+ system_prompt=self.config.system_prompt,
357
+ )
358
+ else:
359
+ # Add user message to history (use original message for history)
360
+ self.chat_history.append(f"user: {message.strip()}")
361
+
362
+ # Prepare prompt with conversation context (use enhanced message for LLM)
363
+ # Temporarily replace the last message with enhanced version for formatting
364
+ if self.rag_enabled and enhanced_message != message.strip():
365
+ # Save original and replace with enhanced version
366
+ original_last = self.chat_history.pop()
367
+ self.chat_history.append(f"user: {enhanced_message}")
368
+ full_prompt = self._format_history_for_context()
369
+ # Restore original for history
370
+ self.chat_history.pop()
371
+ self.chat_history.append(original_last)
372
+ else:
373
+ full_prompt = self._format_history_for_context()
374
+
375
+ # Generate response
376
+ generate_kwargs = dict(kwargs)
377
+ if "max_tokens" not in generate_kwargs:
378
+ generate_kwargs["max_tokens"] = self.config.max_tokens
379
+
380
+ # Note: Retry logic is now handled at the LLM client level
381
+ response = self.llm_client.generate(
382
+ full_prompt,
383
+ model=self.config.model,
384
+ **generate_kwargs,
385
+ )
386
+
387
+ # Add assistant message to history when tracking conversation
388
+ if not no_history:
389
+ self.chat_history.append(f"{self.config.assistant_name}: {response}")
390
+
391
+ # Prepare response data
392
+ stats = None
393
+ if self.config.show_stats:
394
+ stats = self.get_stats()
395
+
396
+ history = (
397
+ list(self.chat_history)
398
+ if kwargs.get("include_history", False)
399
+ else None
400
+ )
401
+
402
+ return ChatResponse(
403
+ text=response, history=history, stats=stats, is_complete=True
404
+ )
405
+
406
+ except Exception as e:
407
+ self.log.error(f"Error in send: {e}")
408
+ raise
409
+
410
+ def send_stream(self, message: str, **kwargs):
411
+ """
412
+ Send a message and get a streaming response with conversation history.
413
+
414
+ Args:
415
+ message: The message to send
416
+ **kwargs: Additional arguments for LLM generation
417
+
418
+ Yields:
419
+ ChatResponse chunks as they arrive
420
+ """
421
+ try:
422
+ if not message.strip():
423
+ raise ValueError("Message cannot be empty")
424
+
425
+ # Enhance message with RAG context if enabled
426
+ enhanced_message, _rag_metadata = self._enhance_with_rag(message.strip())
427
+
428
+ # Add user message to history (use original message for history)
429
+ self.chat_history.append(f"user: {message.strip()}")
430
+
431
+ # Prepare prompt with conversation context (use enhanced message for LLM)
432
+ # Temporarily replace the last message with enhanced version for formatting
433
+ if self.rag_enabled and enhanced_message != message.strip():
434
+ # Save original and replace with enhanced version
435
+ original_last = self.chat_history.pop()
436
+ self.chat_history.append(f"user: {enhanced_message}")
437
+ full_prompt = self._format_history_for_context()
438
+ # Restore original for history
439
+ self.chat_history.pop()
440
+ self.chat_history.append(original_last)
441
+ else:
442
+ full_prompt = self._format_history_for_context()
443
+
444
+ # Generate streaming response
445
+ generate_kwargs = dict(kwargs)
446
+ if "max_tokens" not in generate_kwargs:
447
+ generate_kwargs["max_tokens"] = self.config.max_tokens
448
+
449
+ full_response = ""
450
+ for chunk in self.llm_client.generate(
451
+ full_prompt, model=self.config.model, stream=True, **generate_kwargs
452
+ ):
453
+ full_response += chunk
454
+ yield ChatResponse(text=chunk, is_complete=False)
455
+
456
+ # Add complete assistant message to history
457
+ self.chat_history.append(f"{self.config.assistant_name}: {full_response}")
458
+
459
+ # Send final response with stats and history if requested
460
+ stats = None
461
+ if self.config.show_stats:
462
+ stats = self.get_stats()
463
+
464
+ history = (
465
+ list(self.chat_history)
466
+ if kwargs.get("include_history", False)
467
+ else None
468
+ )
469
+
470
+ yield ChatResponse(text="", history=history, stats=stats, is_complete=True)
471
+
472
+ except Exception as e:
473
+ self.log.error(f"Error in send_stream: {e}")
474
+ raise
475
+
476
+ def get_history(self) -> List[str]:
477
+ """
478
+ Get the current conversation history.
479
+
480
+ Returns:
481
+ List of conversation entries in "role: message" format
482
+ """
483
+ return list(self.chat_history)
484
+
485
+ def clear_history(self) -> None:
486
+ """Clear the conversation history."""
487
+ self.chat_history.clear()
488
+ self.log.debug("Chat history cleared")
489
+
490
+ def get_formatted_history(self) -> List[Dict[str, str]]:
491
+ """
492
+ Get conversation history in structured format.
493
+
494
+ Returns:
495
+ List of dictionaries with 'role' and 'message' keys
496
+ """
497
+ formatted = []
498
+ assistant_prefix = f"{self.config.assistant_name}: "
499
+
500
+ for entry in self.chat_history:
501
+ if entry.startswith("user: "):
502
+ role, message = "user", entry[6:]
503
+ formatted.append({"role": role, "message": message})
504
+ elif entry.startswith(assistant_prefix):
505
+ role, message = (
506
+ self.config.assistant_name,
507
+ entry[len(assistant_prefix) :],
508
+ )
509
+ formatted.append({"role": role, "message": message})
510
+ elif ": " in entry:
511
+ # Fallback for any other format
512
+ role, message = entry.split(": ", 1)
513
+ formatted.append({"role": role, "message": message})
514
+ else:
515
+ formatted.append({"role": "unknown", "message": entry})
516
+ return formatted
517
+
518
+ def get_stats(self) -> Dict[str, Any]:
519
+ """
520
+ Get performance statistics.
521
+
522
+ Returns:
523
+ Dictionary of performance stats
524
+ """
525
+ try:
526
+ return self.llm_client.get_performance_stats() or {}
527
+ except Exception as e:
528
+ self.log.warning(f"Failed to get stats: {e}")
529
+ return {}
530
+
531
+ def get_system_prompt(self) -> Optional[str]:
532
+ """
533
+ Get the current system prompt.
534
+
535
+ Returns:
536
+ Current system prompt or None if not set
537
+ """
538
+ return self.config.system_prompt
539
+
540
+ def set_system_prompt(self, system_prompt: Optional[str]) -> None:
541
+ """
542
+ Set the system prompt for future conversations.
543
+
544
+ Args:
545
+ system_prompt: New system prompt to use, or None to clear it
546
+ """
547
+ self.config.system_prompt = system_prompt
548
+ self.log.debug(
549
+ f"System prompt updated: {system_prompt[:100] if system_prompt else 'None'}..."
550
+ )
551
+
552
+ def display_stats(self, stats: Optional[Dict[str, Any]] = None) -> None:
553
+ """
554
+ Display performance statistics in a formatted way.
555
+
556
+ Args:
557
+ stats: Optional stats dictionary. If None, gets current stats.
558
+ """
559
+ if stats is None:
560
+ stats = self.get_stats()
561
+
562
+ if stats:
563
+ print("\n" + "=" * 30)
564
+ print("Performance Statistics:")
565
+ print("=" * 30)
566
+ for key, value in stats.items():
567
+ if isinstance(value, float):
568
+ if "time" in key.lower():
569
+ print(f" {key}: {value:.3f}s")
570
+ elif "tokens_per_second" in key.lower():
571
+ print(f" {key}: {value:.2f} tokens/s")
572
+ else:
573
+ print(f" {key}: {value:.4f}")
574
+ elif isinstance(value, int):
575
+ if "tokens" in key.lower():
576
+ print(f" {key}: {value:,} tokens")
577
+ else:
578
+ print(f" {key}: {value}")
579
+ else:
580
+ print(f" {key}: {value}")
581
+ print("=" * 30)
582
+ else:
583
+ print("No statistics available.")
584
+
585
+ async def start_interactive_session(self) -> None:
586
+ """
587
+ Start an interactive chat session with conversation history.
588
+
589
+ This provides a full CLI-style interactive experience with commands
590
+ for managing conversation history and viewing statistics.
591
+ """
592
+ print("=" * 50)
593
+ print("Interactive Chat Session Started")
594
+ print(f"Using model: {self.config.model}")
595
+ print("Type 'quit', 'exit', or 'bye' to end the conversation")
596
+ print("Commands:")
597
+ print(" /clear - clear conversation history")
598
+ print(" /history - show conversation history")
599
+ print(" /stats - show performance statistics")
600
+ print(" /help - show this help message")
601
+ print("=" * 50)
602
+
603
+ while True:
604
+ try:
605
+ user_input = input("\nYou: ").strip()
606
+
607
+ if user_input.lower() in ["quit", "exit", "bye"]:
608
+ print("\nGoodbye!")
609
+ break
610
+ elif user_input.lower() == "/clear":
611
+ self.clear_history()
612
+ print("Conversation history cleared.")
613
+ continue
614
+ elif user_input.lower() == "/history":
615
+ history = self.get_formatted_history()
616
+ if not history:
617
+ print("No conversation history.")
618
+ else:
619
+ print("\n" + "=" * 30)
620
+ print("Conversation History:")
621
+ print("=" * 30)
622
+ for entry in history:
623
+ print(f"{entry['role'].title()}: {entry['message']}")
624
+ print("=" * 30)
625
+ continue
626
+ elif user_input.lower() == "/stats":
627
+ self.display_stats()
628
+ continue
629
+ elif user_input.lower() == "/help":
630
+ print("\n" + "=" * 40)
631
+ print("Available Commands:")
632
+ print("=" * 40)
633
+ print(" /clear - clear conversation history")
634
+ print(" /history - show conversation history")
635
+ print(" /stats - show performance statistics")
636
+ print(" /help - show this help message")
637
+ print("\nTo exit: type 'quit', 'exit', or 'bye'")
638
+ print("=" * 40)
639
+ continue
640
+ elif not user_input:
641
+ print("Please enter a message.")
642
+ continue
643
+
644
+ print(f"\n{self.config.assistant_name.title()}: ", end="", flush=True)
645
+
646
+ # Generate and stream response
647
+ for chunk in self.send_stream(user_input):
648
+ if not chunk.is_complete:
649
+ print(chunk.text, end="", flush=True)
650
+ else:
651
+ # Show stats if configured and available
652
+ if self.config.show_stats and chunk.stats:
653
+ self.display_stats(chunk.stats)
654
+ print() # Add newline after response
655
+
656
+ except KeyboardInterrupt:
657
+ print("\n\nChat interrupted. Goodbye!")
658
+ break
659
+ except Exception as e:
660
+ print(f"\nError: {e}")
661
+ raise
662
+
663
+ def update_config(self, **kwargs) -> None:
664
+ """
665
+ Update configuration dynamically.
666
+
667
+ Args:
668
+ **kwargs: Configuration parameters to update
669
+ """
670
+ # Update our config
671
+ for key, value in kwargs.items():
672
+ if hasattr(self.config, key):
673
+ setattr(self.config, key, value)
674
+
675
+ # Handle special cases
676
+ if "max_history_length" in kwargs:
677
+ # Create new deque with updated maxlen
678
+ old_history = list(self.chat_history)
679
+ new_maxlen = kwargs["max_history_length"] * 2
680
+ self.chat_history = deque(old_history, maxlen=new_maxlen)
681
+
682
+ if "system_prompt" in kwargs:
683
+ # System prompt is handled through Prompts class, not directly
684
+ pass
685
+
686
+ if "assistant_name" in kwargs:
687
+ # Assistant name change affects history display but not underlying storage
688
+ # since we dynamically parse the history based on current assistant_name
689
+ pass
690
+
691
+ @property
692
+ def history_length(self) -> int:
693
+ """Get the current number of conversation entries."""
694
+ return len(self.chat_history)
695
+
696
+ @property
697
+ def conversation_pairs(self) -> int:
698
+ """Get the number of conversation pairs (user + assistant)."""
699
+ return len(self.chat_history) // 2
700
+
701
+ def enable_rag(self, documents: Optional[List[str]] = None, **rag_kwargs):
702
+ """
703
+ Enable RAG (Retrieval-Augmented Generation) for document-based chat.
704
+
705
+ Args:
706
+ documents: List of PDF file paths to index
707
+ **rag_kwargs: Additional RAG configuration options
708
+ """
709
+ try:
710
+ from gaia.rag.sdk import RAGSDK, RAGConfig
711
+ except ImportError:
712
+ raise ImportError(
713
+ 'RAG dependencies not installed. Install with: uv pip install -e ".[rag]"'
714
+ )
715
+
716
+ # Create RAG config matching chat config
717
+ rag_config = RAGConfig(
718
+ model=self.config.model,
719
+ show_stats=self.config.show_stats,
720
+ use_local_llm=self.config.use_local_llm,
721
+ **rag_kwargs,
722
+ )
723
+
724
+ self.rag = RAGSDK(rag_config)
725
+ self.rag_enabled = True
726
+
727
+ # Index documents if provided
728
+ if documents:
729
+ for doc_path in documents:
730
+ self.log.info(f"Indexing document: {doc_path}")
731
+ result = self.rag.index_document(doc_path)
732
+
733
+ if result:
734
+ self.log.info(f"Successfully indexed: {doc_path}")
735
+ else:
736
+ self.log.warning(f"Failed to index document: {doc_path}")
737
+
738
+ self.log.info(
739
+ f"RAG enabled with {len(documents) if documents else 0} documents"
740
+ )
741
+
742
+ def disable_rag(self):
743
+ """Disable RAG functionality."""
744
+ self.rag = None
745
+ self.rag_enabled = False
746
+ self.log.info("RAG disabled")
747
+
748
+ def add_document(self, document_path: str) -> bool:
749
+ """
750
+ Add a document to the RAG index.
751
+
752
+ Args:
753
+ document_path: Path to PDF file to index
754
+
755
+ Returns:
756
+ True if indexing succeeded
757
+ """
758
+ if not self.rag_enabled or not self.rag:
759
+ raise ValueError("RAG not enabled. Call enable_rag() first.")
760
+
761
+ return self.rag.index_document(document_path)
762
+
763
+ def _estimate_tokens(self, text: str) -> int:
764
+ """
765
+ Estimate the number of tokens in text.
766
+ Uses rough approximation of 4 characters per token.
767
+
768
+ Args:
769
+ text: The text to estimate tokens for
770
+
771
+ Returns:
772
+ Estimated token count
773
+ """
774
+ # Rough approximation: ~4 characters per token for English text
775
+ # This is conservative to avoid overrunning context
776
+ return len(text) // 4
777
+
778
+ def summarize_conversation_history(self, max_history_tokens: int) -> Optional[str]:
779
+ """
780
+ Summarize conversation history when it exceeds the token budget.
781
+
782
+ Args:
783
+ max_history_tokens: Maximum allowed tokens for stored history
784
+
785
+ Returns:
786
+ The generated summary (when summarization occurred) or None
787
+ """
788
+ if max_history_tokens <= 0:
789
+ raise ValueError("max_history_tokens must be positive")
790
+
791
+ history_entries = list(self.chat_history)
792
+ if not history_entries:
793
+ return None
794
+
795
+ history_text = "\n".join(history_entries)
796
+ history_tokens = self._estimate_tokens(history_text)
797
+
798
+ if history_tokens <= max_history_tokens:
799
+ print(
800
+ "History tokens are less than max history tokens, so no summarization is needed"
801
+ )
802
+ return None
803
+
804
+ print(
805
+ "History tokens are greater than max history tokens, so summarization is needed"
806
+ )
807
+
808
+ self.log.info(
809
+ "Conversation history (~%d tokens) exceeds budget (%d). Summarizing...",
810
+ history_tokens,
811
+ max_history_tokens,
812
+ )
813
+
814
+ summary_prompt = (
815
+ "Summarize the following conversation between a user and the GAIA web "
816
+ "development agent. Preserve:\n"
817
+ "- The app requirements and inferred schema/data models\n"
818
+ "- Key implementation details already completed\n"
819
+ "- Outstanding issues, validation failures, or TODOs (quote error/warning text verbatim)\n"
820
+ "- Any constraints or preferences the user emphasized\n\n"
821
+ "Write the summary in under 400 tokens, using concise paragraphs, and include the exact text of any warnings/errors so future fixes have full context.\n\n"
822
+ "You have full access to the prior conversation history above; summarize it directly without restating the entire transcript."
823
+ )
824
+
825
+ # Use ChatSDK's send() so history formatting/ordering is handled consistently
826
+ # by the same path used for normal chat turns.
827
+ original_history = list(self.chat_history)
828
+ try:
829
+ chat_response = self.send(
830
+ summary_prompt,
831
+ max_tokens=min(self.config.max_tokens, 2048),
832
+ timeout=1200,
833
+ )
834
+ except Exception as exc: # pylint: disable=broad-exception-caught
835
+ self.log.error("Failed to summarize conversation history: %s", exc)
836
+ # Restore history to avoid dropping context on failure
837
+ self.chat_history.clear()
838
+ self.chat_history.extend(original_history)
839
+ return None
840
+
841
+ summary = chat_response.text.strip() if chat_response else ""
842
+ if not summary:
843
+ self.log.warning("Summarization returned empty content; keeping history.")
844
+ self.chat_history.clear()
845
+ self.chat_history.extend(original_history)
846
+ return None
847
+
848
+ self.chat_history.clear()
849
+ self.chat_history.append(
850
+ f"{self.config.assistant_name}: Conversation summary so far:\n{summary}"
851
+ )
852
+ return summary
853
+
854
+ def _truncate_rag_context(self, context: str, max_tokens: int) -> str:
855
+ """
856
+ Truncate RAG context to fit within token budget.
857
+
858
+ Args:
859
+ context: The RAG context to truncate
860
+ max_tokens: Maximum tokens allowed
861
+
862
+ Returns:
863
+ Truncated context with ellipsis if needed
864
+ """
865
+ estimated_tokens = self._estimate_tokens(context)
866
+
867
+ if estimated_tokens <= max_tokens:
868
+ return context
869
+
870
+ # Calculate how many characters we can keep
871
+ target_chars = max_tokens * 4 # Using same 4:1 ratio
872
+
873
+ # Truncate and add ellipsis
874
+ truncated = context[: target_chars - 20] # Leave room for ellipsis
875
+ truncated += "\n... [context truncated for length]"
876
+
877
+ self.log.warning(
878
+ f"RAG context truncated from ~{estimated_tokens} to ~{max_tokens} tokens"
879
+ )
880
+ return truncated
881
+
882
+ def _enhance_with_rag(self, message: str) -> tuple:
883
+ """
884
+ Enhance user message with relevant document context using RAG.
885
+
886
+ Args:
887
+ message: Original user message
888
+
889
+ Returns:
890
+ Tuple of (enhanced_message, metadata_dict)
891
+ """
892
+ if not self.rag_enabled or not self.rag:
893
+ return message, None
894
+
895
+ try:
896
+ # Query RAG for relevant context with metadata
897
+ rag_response = self.rag.query(message, include_metadata=True)
898
+
899
+ if rag_response.chunks:
900
+ # Build context with source information
901
+ context_parts = []
902
+ if rag_response.chunk_metadata:
903
+ for i, (chunk, metadata) in enumerate(
904
+ zip(rag_response.chunks, rag_response.chunk_metadata)
905
+ ):
906
+ context_parts.append(
907
+ f"Context {i+1} (from {metadata['source_file']}, relevance: {metadata['relevance_score']:.2f}):\n{chunk}"
908
+ )
909
+ else:
910
+ context_parts = [
911
+ f"Context {i+1}:\n{chunk}"
912
+ for i, chunk in enumerate(rag_response.chunks)
913
+ ]
914
+
915
+ context = "\n\n".join(context_parts)
916
+
917
+ # Check token limits
918
+ message_tokens = self._estimate_tokens(message)
919
+ template_tokens = 150 # Template text overhead
920
+ response_tokens = self.config.max_tokens
921
+ history_tokens = self._estimate_tokens(str(self.chat_history))
922
+
923
+ # Conservative context size for models
924
+ model_context_size = 32768
925
+ available_for_rag = (
926
+ model_context_size
927
+ - message_tokens
928
+ - template_tokens
929
+ - response_tokens
930
+ - history_tokens
931
+ )
932
+
933
+ # Ensure minimum RAG context
934
+ if available_for_rag < 500:
935
+ self.log.warning(
936
+ f"Limited space for RAG context: {available_for_rag} tokens"
937
+ )
938
+ available_for_rag = 500
939
+
940
+ # Truncate context if needed
941
+ context = self._truncate_rag_context(context, available_for_rag)
942
+
943
+ # Build enhanced message
944
+ enhanced_message = f"""Based on the provided documents, please answer the following question. Use the context below to inform your response.
945
+
946
+ Context from documents:
947
+ {context}
948
+
949
+ User question: {message}
950
+
951
+ Note: When citing information, please mention which context number it came from."""
952
+
953
+ # Prepare metadata for return
954
+ metadata = {
955
+ "rag_used": True,
956
+ "chunks_retrieved": len(rag_response.chunks),
957
+ "estimated_context_tokens": self._estimate_tokens(context),
958
+ "available_tokens": available_for_rag,
959
+ "context_truncated": (
960
+ len(context) < sum(len(c) for c in rag_response.chunks)
961
+ if rag_response.chunks
962
+ else False
963
+ ),
964
+ }
965
+
966
+ # Add query metadata if available
967
+ if rag_response.query_metadata:
968
+ metadata["query_metadata"] = rag_response.query_metadata
969
+
970
+ self.log.debug(
971
+ f"Enhanced message with {len(rag_response.chunks)} chunks from "
972
+ f"{len(set(rag_response.source_files)) if rag_response.source_files else 0} documents, "
973
+ f"~{metadata['estimated_context_tokens']} context tokens"
974
+ )
975
+ return enhanced_message, metadata
976
+ else:
977
+ self.log.debug("No relevant document context found")
978
+ return message, {"rag_used": True, "chunks_retrieved": 0}
979
+
980
+ except Exception as e:
981
+ self.log.warning(
982
+ f"RAG enhancement failed: {e}, falling back to direct query"
983
+ )
984
+ return message, {"rag_used": False, "error": str(e)}
985
+
986
+
987
+ class SimpleChat:
988
+ """
989
+ Ultra-simple interface for quick chat integration.
990
+
991
+ Example usage:
992
+ ```python
993
+ from gaia.chat.sdk import SimpleChat
994
+
995
+ chat = SimpleChat()
996
+
997
+ # Simple question-answer
998
+ response = await chat.ask("What's the weather like?")
999
+ print(response)
1000
+
1001
+ # Chat with memory
1002
+ response1 = await chat.ask("My name is John")
1003
+ response2 = await chat.ask("What's my name?") # Remembers previous context
1004
+ ```
1005
+ """
1006
+
1007
+ def __init__(
1008
+ self,
1009
+ system_prompt: Optional[str] = None,
1010
+ model: Optional[str] = None,
1011
+ assistant_name: Optional[str] = None,
1012
+ ):
1013
+ """
1014
+ Initialize SimpleChat with minimal configuration.
1015
+
1016
+ Args:
1017
+ system_prompt: Optional system prompt for the AI
1018
+ model: Model to use (defaults to DEFAULT_MODEL_NAME)
1019
+ assistant_name: Name to use for the assistant (defaults to "assistant")
1020
+ """
1021
+ config = ChatConfig(
1022
+ model=model or DEFAULT_MODEL_NAME,
1023
+ system_prompt=system_prompt,
1024
+ assistant_name=assistant_name or "gaia",
1025
+ show_stats=False,
1026
+ logging_level="WARNING", # Minimal logging
1027
+ )
1028
+ self._sdk = ChatSDK(config)
1029
+
1030
+ def ask(self, question: str) -> str:
1031
+ """
1032
+ Ask a question and get a text response with conversation memory.
1033
+
1034
+ Args:
1035
+ question: The question to ask
1036
+
1037
+ Returns:
1038
+ The AI's response as a string
1039
+ """
1040
+ response = self._sdk.send(question)
1041
+ return response.text
1042
+
1043
+ def ask_stream(self, question: str):
1044
+ """
1045
+ Ask a question and get a streaming response with conversation memory.
1046
+
1047
+ Args:
1048
+ question: The question to ask
1049
+
1050
+ Yields:
1051
+ Response chunks as they arrive
1052
+ """
1053
+ for chunk in self._sdk.send_stream(question):
1054
+ if not chunk.is_complete:
1055
+ yield chunk.text
1056
+
1057
+ def clear_memory(self) -> None:
1058
+ """Clear the conversation memory."""
1059
+ self._sdk.clear_history()
1060
+
1061
+ def get_conversation(self) -> List[Dict[str, str]]:
1062
+ """Get the conversation history in a readable format."""
1063
+ return self._sdk.get_formatted_history()
1064
+
1065
+
1066
+ class ChatSession:
1067
+ """
1068
+ Session-based chat interface for managing multiple separate conversations.
1069
+
1070
+ Example usage:
1071
+ ```python
1072
+ from gaia.chat.sdk import ChatSession
1073
+
1074
+ # Create session manager
1075
+ sessions = ChatSession()
1076
+
1077
+ # Create different chat sessions
1078
+ work_chat = sessions.create_session("work", system_prompt="You are a professional assistant")
1079
+ personal_chat = sessions.create_session("personal", system_prompt="You are a friendly companion")
1080
+
1081
+ # Chat in different contexts
1082
+ work_response = await work_chat.ask("Draft an email to my team")
1083
+ personal_response = await personal_chat.ask("What's a good recipe for dinner?")
1084
+ ```
1085
+ """
1086
+
1087
+ def __init__(self, default_config: Optional[ChatConfig] = None):
1088
+ """Initialize the session manager."""
1089
+ self.default_config = default_config or ChatConfig()
1090
+ self.sessions: Dict[str, ChatSDK] = {}
1091
+ self.log = get_logger(__name__)
1092
+
1093
+ def create_session(
1094
+ self, session_id: str, config: Optional[ChatConfig] = None, **config_kwargs
1095
+ ) -> ChatSDK:
1096
+ """
1097
+ Create a new chat session.
1098
+
1099
+ Args:
1100
+ session_id: Unique identifier for the session
1101
+ config: Optional configuration (uses default if not provided)
1102
+ **config_kwargs: Configuration parameters to override
1103
+
1104
+ Returns:
1105
+ ChatSDK instance for the session
1106
+ """
1107
+ if config is None:
1108
+ # Create config from defaults with overrides
1109
+ config = ChatConfig(
1110
+ model=config_kwargs.get("model", self.default_config.model),
1111
+ max_tokens=config_kwargs.get(
1112
+ "max_tokens", self.default_config.max_tokens
1113
+ ),
1114
+ system_prompt=config_kwargs.get(
1115
+ "system_prompt", self.default_config.system_prompt
1116
+ ),
1117
+ max_history_length=config_kwargs.get(
1118
+ "max_history_length", self.default_config.max_history_length
1119
+ ),
1120
+ show_stats=config_kwargs.get(
1121
+ "show_stats", self.default_config.show_stats
1122
+ ),
1123
+ logging_level=config_kwargs.get(
1124
+ "logging_level", self.default_config.logging_level
1125
+ ),
1126
+ use_claude=config_kwargs.get(
1127
+ "use_claude", self.default_config.use_claude
1128
+ ),
1129
+ use_chatgpt=config_kwargs.get(
1130
+ "use_chatgpt", self.default_config.use_chatgpt
1131
+ ),
1132
+ assistant_name=config_kwargs.get(
1133
+ "assistant_name", self.default_config.assistant_name
1134
+ ),
1135
+ )
1136
+
1137
+ session = ChatSDK(config)
1138
+ self.sessions[session_id] = session
1139
+ self.log.debug(f"Created chat session: {session_id}")
1140
+ return session
1141
+
1142
+ def get_session(self, session_id: str) -> Optional[ChatSDK]:
1143
+ """Get an existing session by ID."""
1144
+ return self.sessions.get(session_id)
1145
+
1146
+ def delete_session(self, session_id: str) -> bool:
1147
+ """Delete a session."""
1148
+ if session_id in self.sessions:
1149
+ del self.sessions[session_id]
1150
+ self.log.debug(f"Deleted chat session: {session_id}")
1151
+ return True
1152
+ return False
1153
+
1154
+ def list_sessions(self) -> List[str]:
1155
+ """List all active session IDs."""
1156
+ return list(self.sessions.keys())
1157
+
1158
+ def clear_all_sessions(self) -> None:
1159
+ """Clear all sessions."""
1160
+ self.sessions.clear()
1161
+ self.log.debug("Cleared all chat sessions")
1162
+
1163
+
1164
+ # Convenience functions for one-off usage
1165
+ def quick_chat(
1166
+ message: str,
1167
+ system_prompt: Optional[str] = None,
1168
+ model: Optional[str] = None,
1169
+ assistant_name: Optional[str] = None,
1170
+ ) -> str:
1171
+ """
1172
+ Quick one-off text chat without conversation memory.
1173
+
1174
+ Args:
1175
+ message: Message to send
1176
+ system_prompt: Optional system prompt
1177
+ model: Optional model to use
1178
+ assistant_name: Name to use for the assistant
1179
+
1180
+ Returns:
1181
+ AI response
1182
+ """
1183
+ config = ChatConfig(
1184
+ model=model or DEFAULT_MODEL_NAME,
1185
+ system_prompt=system_prompt,
1186
+ assistant_name=assistant_name or "gaia",
1187
+ show_stats=False,
1188
+ logging_level="WARNING",
1189
+ max_history_length=2, # Small history for quick chat
1190
+ )
1191
+ sdk = ChatSDK(config)
1192
+ response = sdk.send(message)
1193
+ return response.text
1194
+
1195
+
1196
+ def quick_chat_with_memory(
1197
+ messages: List[str],
1198
+ system_prompt: Optional[str] = None,
1199
+ model: Optional[str] = None,
1200
+ assistant_name: Optional[str] = None,
1201
+ ) -> List[str]:
1202
+ """
1203
+ Quick multi-turn chat with conversation memory.
1204
+
1205
+ Args:
1206
+ messages: List of messages to send sequentially
1207
+ system_prompt: Optional system prompt
1208
+ model: Optional model to use
1209
+ assistant_name: Name to use for the assistant
1210
+
1211
+ Returns:
1212
+ List of AI responses
1213
+ """
1214
+ config = ChatConfig(
1215
+ model=model or DEFAULT_MODEL_NAME,
1216
+ system_prompt=system_prompt,
1217
+ assistant_name=assistant_name or "gaia",
1218
+ show_stats=False,
1219
+ logging_level="WARNING",
1220
+ )
1221
+ sdk = ChatSDK(config)
1222
+
1223
+ responses = []
1224
+ for message in messages:
1225
+ response = sdk.send(message)
1226
+ responses.append(response.text)
1227
+
1228
+ return responses