ds-agent-cli 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. package/bin/ds-agent.js +451 -0
  2. package/ds_agent/__init__.py +8 -0
  3. package/package.json +28 -0
  4. package/requirements.txt +126 -0
  5. package/setup.py +35 -0
  6. package/src/__init__.py +7 -0
  7. package/src/_compress_tool_result.py +118 -0
  8. package/src/api/__init__.py +4 -0
  9. package/src/api/app.py +1626 -0
  10. package/src/cache/__init__.py +5 -0
  11. package/src/cache/cache_manager.py +561 -0
  12. package/src/cli.py +2886 -0
  13. package/src/dynamic_prompts.py +281 -0
  14. package/src/orchestrator.py +4799 -0
  15. package/src/progress_manager.py +139 -0
  16. package/src/reasoning/__init__.py +332 -0
  17. package/src/reasoning/business_summary.py +431 -0
  18. package/src/reasoning/data_understanding.py +356 -0
  19. package/src/reasoning/model_explanation.py +383 -0
  20. package/src/reasoning/reasoning_trace.py +239 -0
  21. package/src/registry/__init__.py +3 -0
  22. package/src/registry/tools_registry.py +3 -0
  23. package/src/session_memory.py +448 -0
  24. package/src/session_store.py +370 -0
  25. package/src/storage/__init__.py +19 -0
  26. package/src/storage/artifact_store.py +620 -0
  27. package/src/storage/helpers.py +116 -0
  28. package/src/storage/huggingface_storage.py +694 -0
  29. package/src/storage/r2_storage.py +0 -0
  30. package/src/storage/user_files_service.py +288 -0
  31. package/src/tools/__init__.py +335 -0
  32. package/src/tools/advanced_analysis.py +823 -0
  33. package/src/tools/advanced_feature_engineering.py +708 -0
  34. package/src/tools/advanced_insights.py +578 -0
  35. package/src/tools/advanced_preprocessing.py +549 -0
  36. package/src/tools/advanced_training.py +906 -0
  37. package/src/tools/agent_tool_mapping.py +326 -0
  38. package/src/tools/auto_pipeline.py +420 -0
  39. package/src/tools/autogluon_training.py +1480 -0
  40. package/src/tools/business_intelligence.py +860 -0
  41. package/src/tools/cloud_data_sources.py +581 -0
  42. package/src/tools/code_interpreter.py +390 -0
  43. package/src/tools/computer_vision.py +614 -0
  44. package/src/tools/data_cleaning.py +614 -0
  45. package/src/tools/data_profiling.py +593 -0
  46. package/src/tools/data_type_conversion.py +268 -0
  47. package/src/tools/data_wrangling.py +433 -0
  48. package/src/tools/eda_reports.py +284 -0
  49. package/src/tools/enhanced_feature_engineering.py +241 -0
  50. package/src/tools/feature_engineering.py +302 -0
  51. package/src/tools/matplotlib_visualizations.py +1327 -0
  52. package/src/tools/model_training.py +520 -0
  53. package/src/tools/nlp_text_analytics.py +761 -0
  54. package/src/tools/plotly_visualizations.py +497 -0
  55. package/src/tools/production_mlops.py +852 -0
  56. package/src/tools/time_series.py +507 -0
  57. package/src/tools/tools_registry.py +2133 -0
  58. package/src/tools/visualization_engine.py +559 -0
  59. package/src/utils/__init__.py +42 -0
  60. package/src/utils/error_recovery.py +313 -0
  61. package/src/utils/parallel_executor.py +402 -0
  62. package/src/utils/polars_helpers.py +248 -0
  63. package/src/utils/schema_extraction.py +132 -0
  64. package/src/utils/semantic_layer.py +392 -0
  65. package/src/utils/token_budget.py +411 -0
  66. package/src/utils/validation.py +377 -0
  67. package/src/workflow_state.py +154 -0
@@ -0,0 +1,411 @@
1
+ """
2
+ Strict Token Budget Management
3
+
4
+ Implements sliding window conversation history, aggressive compression,
5
+ and emergency context truncation to prevent context window overflow.
6
+ """
7
+
8
+ from typing import List, Dict, Any, Optional, Tuple
9
+ import json
10
+ import tiktoken
11
+ from pathlib import Path
12
+
13
+
14
+ class ConversationMessage:
15
+ """Represents a message with priority for history management."""
16
+
17
+ def __init__(self, role: str, content: str, message_type: str = "normal",
18
+ priority: int = 5, tokens: Optional[int] = None):
19
+ self.role = role
20
+ self.content = content
21
+ self.message_type = message_type # system, tool_result, assistant, user, normal
22
+ self.priority = priority # 1 (drop first) to 10 (keep last)
23
+ self.tokens = tokens
24
+ self.timestamp = None
25
+
26
+ def to_dict(self) -> Dict[str, str]:
27
+ """Convert to OpenAI message format."""
28
+ return {"role": self.role, "content": self.content}
29
+
30
+
31
+ class TokenBudgetManager:
32
+ """
33
+ Manages conversation history with strict token budget enforcement.
34
+
35
+ Features:
36
+ - Accurate token counting using tiktoken
37
+ - Priority-based message dropping
38
+ - Sliding window with smart compression
39
+ - Emergency context truncation
40
+ - Keeps recent tool results, drops old assistant messages
41
+ """
42
+
43
+ def __init__(self, model: str = "gpt-4", max_tokens: int = 128000,
44
+ reserve_tokens: int = 8000):
45
+ """
46
+ Initialize token budget manager.
47
+
48
+ Args:
49
+ model: Model name for token counting
50
+ max_tokens: Maximum context window size
51
+ reserve_tokens: Tokens to reserve for response
52
+ """
53
+ self.model = model
54
+ self.max_tokens = max_tokens
55
+ self.reserve_tokens = reserve_tokens
56
+ self.available_tokens = max_tokens - reserve_tokens
57
+
58
+ # Initialize tokenizer
59
+ try:
60
+ self.encoding = tiktoken.encoding_for_model(model)
61
+ except:
62
+ # Fallback to cl100k_base (GPT-4/GPT-3.5)
63
+ self.encoding = tiktoken.get_encoding("cl100k_base")
64
+
65
+ print(f"📊 Token Budget: {self.available_tokens:,} tokens available ({self.max_tokens:,} - {self.reserve_tokens:,} reserve)")
66
+
67
+ def count_tokens(self, text: str) -> int:
68
+ """Count tokens in text using tiktoken."""
69
+ try:
70
+ return len(self.encoding.encode(text))
71
+ except:
72
+ # Fallback estimation: ~4 chars per token
73
+ return len(text) // 4
74
+
75
+ def count_message_tokens(self, message) -> int:
76
+ """
77
+ Count tokens in a message (includes role overhead).
78
+
79
+ Args:
80
+ message: Either a dict or a Pydantic ChatMessage object
81
+ """
82
+ # Format: <|role|>content<|endofmessage|>
83
+ # Approximately 4 tokens overhead per message
84
+
85
+ # Handle both dict and Pydantic object formats
86
+ if isinstance(message, dict):
87
+ content = message.get("content", "")
88
+ role = message.get("role", "")
89
+ else:
90
+ # Pydantic object (like ChatMessage from Mistral SDK)
91
+ content = getattr(message, "content", "")
92
+ role = getattr(message, "role", "")
93
+
94
+ content_tokens = self.count_tokens(str(content))
95
+ role_tokens = self.count_tokens(str(role))
96
+ return content_tokens + role_tokens + 4
97
+
98
+ def count_messages_tokens(self, messages: List) -> int:
99
+ """Count total tokens in message list."""
100
+ return sum(self.count_message_tokens(msg) for msg in messages)
101
+
102
+ def compress_tool_result(self, tool_result: str, max_tokens: int = 500) -> str:
103
+ """
104
+ Aggressively compress tool result while keeping key information.
105
+
106
+ Keeps:
107
+ - Success/failure status
108
+ - Key metrics and numbers
109
+ - Error messages
110
+
111
+ Drops:
112
+ - Verbose logs
113
+ - Duplicate information
114
+ - Large data structures
115
+ """
116
+ if self.count_tokens(tool_result) <= max_tokens:
117
+ return tool_result
118
+
119
+ try:
120
+ # Try to parse as JSON
121
+ result_dict = json.loads(tool_result)
122
+
123
+ # Extract essential fields
124
+ compressed = {
125
+ "success": result_dict.get("success", True),
126
+ }
127
+
128
+ # Add error if present
129
+ if "error" in result_dict:
130
+ compressed["error"] = str(result_dict["error"])[:200]
131
+
132
+ # Add key metrics (numbers, scores, paths)
133
+ for key in ["score", "accuracy", "best_score", "n_rows", "n_cols",
134
+ "output_path", "best_model", "result_summary"]:
135
+ if key in result_dict:
136
+ compressed[key] = result_dict[key]
137
+
138
+ # Add result if it's small
139
+ if "result" in result_dict:
140
+ result_str = str(result_dict["result"])
141
+ if len(result_str) < 300:
142
+ compressed["result"] = result_str[:300]
143
+
144
+ return json.dumps(compressed, indent=None)
145
+
146
+ except json.JSONDecodeError:
147
+ # Not JSON - truncate intelligently
148
+ lines = tool_result.split('\n')
149
+
150
+ # Keep first 5 and last 5 lines
151
+ if len(lines) > 15:
152
+ compressed_lines = lines[:5] + ["... (truncated) ..."] + lines[-5:]
153
+ result = '\n'.join(compressed_lines)
154
+ else:
155
+ result = tool_result
156
+
157
+ # Hard truncate if still too long
158
+ token_count = self.count_tokens(result)
159
+ if token_count > max_tokens:
160
+ # Truncate to character limit (rough)
161
+ char_limit = max_tokens * 4
162
+ result = result[:char_limit] + "... (truncated)"
163
+
164
+ return result
165
+
166
+ def prioritize_messages(self, messages: List[ConversationMessage]) -> List[ConversationMessage]:
167
+ """
168
+ Assign priorities to messages based on type and importance.
169
+
170
+ Priority levels:
171
+ - 10: System prompt, recent user messages
172
+ - 9: Recent tool results (last 3)
173
+ - 8: Recent assistant responses (last 2)
174
+ - 5: Normal messages
175
+ - 3: Old tool results
176
+ - 2: Old assistant responses
177
+ - 1: Very old messages
178
+ """
179
+ # Find recent messages (last 5)
180
+ recent_threshold = max(0, len(messages) - 5)
181
+
182
+ for i, msg in enumerate(messages):
183
+ if msg.message_type == "system":
184
+ msg.priority = 10
185
+ elif msg.role == "user":
186
+ msg.priority = 10 if i >= recent_threshold else 7
187
+ elif msg.message_type == "tool_result":
188
+ msg.priority = 9 if i >= recent_threshold else 3
189
+ elif msg.role == "assistant":
190
+ msg.priority = 8 if i >= recent_threshold else 2
191
+ else:
192
+ msg.priority = 5 if i >= recent_threshold else 1
193
+
194
+ return messages
195
+
196
+ def apply_sliding_window(self, messages: List[ConversationMessage],
197
+ target_tokens: int) -> List[ConversationMessage]:
198
+ """
199
+ Apply sliding window to fit within token budget.
200
+
201
+ Strategy:
202
+ 1. Always keep system prompt (first message)
203
+ 2. Keep recent messages (last N)
204
+ 3. Drop low-priority messages from middle
205
+ 4. Compress tool results if needed
206
+
207
+ Args:
208
+ messages: List of ConversationMessage objects
209
+ target_tokens: Target token count
210
+
211
+ Returns:
212
+ Filtered message list within budget
213
+ """
214
+ if not messages:
215
+ return []
216
+
217
+ # Always keep system prompt
218
+ system_msg = messages[0] if messages[0].message_type == "system" else None
219
+ other_messages = messages[1:] if system_msg else messages
220
+
221
+ # Prioritize messages
222
+ other_messages = self.prioritize_messages(other_messages)
223
+
224
+ # Sort by priority (high to low)
225
+ sorted_messages = sorted(other_messages, key=lambda m: m.priority, reverse=True)
226
+
227
+ # Calculate tokens for each message
228
+ for msg in sorted_messages:
229
+ if msg.tokens is None:
230
+ msg.tokens = self.count_message_tokens(msg.to_dict())
231
+
232
+ # Greedily add messages until budget exhausted
233
+ kept_messages = []
234
+ current_tokens = 0
235
+
236
+ # Add system prompt first
237
+ if system_msg:
238
+ system_msg.tokens = self.count_message_tokens(system_msg.to_dict())
239
+ kept_messages.append(system_msg)
240
+ current_tokens += system_msg.tokens
241
+
242
+ # Add other messages by priority
243
+ for msg in sorted_messages:
244
+ if current_tokens + msg.tokens <= target_tokens:
245
+ kept_messages.append(msg)
246
+ current_tokens += msg.tokens
247
+ elif msg.message_type == "tool_result" and msg.priority >= 8:
248
+ # Try compressing critical tool results
249
+ compressed_content = self.compress_tool_result(msg.content, max_tokens=300)
250
+ compressed_tokens = self.count_tokens(compressed_content)
251
+
252
+ if current_tokens + compressed_tokens <= target_tokens:
253
+ msg.content = compressed_content
254
+ msg.tokens = compressed_tokens
255
+ kept_messages.append(msg)
256
+ current_tokens += compressed_tokens
257
+
258
+ # Sort kept messages back to chronological order
259
+ # System message stays first, rest in order they appeared
260
+ if system_msg:
261
+ non_system = [m for m in kept_messages if m != system_msg]
262
+ # Sort by original index (approximate by content comparison)
263
+ original_order = []
264
+ for orig_msg in messages:
265
+ for kept in non_system:
266
+ if kept.content == orig_msg.content:
267
+ original_order.append(kept)
268
+ break
269
+
270
+ kept_messages = [system_msg] + original_order
271
+
272
+ print(f"📊 Sliding window: {len(messages)} → {len(kept_messages)} messages ({current_tokens:,} tokens)")
273
+
274
+ return kept_messages
275
+
276
+ def emergency_truncate(self, messages: List[Dict[str, str]],
277
+ max_tokens: int) -> List[Dict[str, str]]:
278
+ """
279
+ Emergency truncation when context is about to overflow.
280
+
281
+ Aggressive strategy:
282
+ - Keep system prompt
283
+ - Keep last user message
284
+ - Keep last 2 messages
285
+ - Truncate everything else
286
+
287
+ Args:
288
+ messages: Message list
289
+ max_tokens: Hard token limit
290
+
291
+ Returns:
292
+ Truncated message list
293
+ """
294
+ if not messages:
295
+ return []
296
+
297
+ print("⚠️ EMERGENCY TRUNCATION: Context overflow imminent")
298
+
299
+ # Always keep system, last user, and last 2 messages
300
+ essential_messages = []
301
+
302
+ # System prompt (first message)
303
+ if messages:
304
+ essential_messages.append(messages[0])
305
+
306
+ # Last 2 messages
307
+ if len(messages) > 2:
308
+ essential_messages.extend(messages[-2:])
309
+ else:
310
+ essential_messages.extend(messages[1:])
311
+
312
+ # Count tokens
313
+ total_tokens = self.count_messages_tokens(essential_messages)
314
+
315
+ if total_tokens <= max_tokens:
316
+ return essential_messages
317
+
318
+ # Still too large - truncate system prompt
319
+ print("⚠️ Truncating system prompt to fit budget")
320
+ system_msg = essential_messages[0]
321
+
322
+ # Handle both dict and Pydantic object formats
323
+ if isinstance(system_msg, dict):
324
+ system_content = system_msg["content"]
325
+ else:
326
+ system_content = getattr(system_msg, "content", "")
327
+
328
+ # Keep first 1000 chars of system prompt
329
+ truncated_system = {
330
+ "role": "system",
331
+ "content": str(system_content)[:1000] + "\n\n... (truncated due to context limit) ..."
332
+ }
333
+
334
+ return [truncated_system] + essential_messages[1:]
335
+
336
+ def enforce_budget(self, messages: List[Dict[str, str]],
337
+ system_prompt: Optional[str] = None) -> Tuple[List[Dict[str, str]], int]:
338
+ """
339
+ Main entry point: Enforce token budget on message list.
340
+
341
+ Args:
342
+ messages: List of messages
343
+ system_prompt: Optional new system prompt to prepend
344
+
345
+ Returns:
346
+ (filtered_messages, total_tokens)
347
+ """
348
+ # Add system prompt if provided
349
+ if system_prompt:
350
+ messages = [{"role": "system", "content": system_prompt}] + messages
351
+
352
+ # Count current tokens
353
+ current_tokens = self.count_messages_tokens(messages)
354
+
355
+ print(f"📊 Token Budget Check: {current_tokens:,} / {self.available_tokens:,} tokens")
356
+
357
+ # If within budget, return as-is
358
+ if current_tokens <= self.available_tokens:
359
+ print("✅ Within budget")
360
+ return messages, current_tokens
361
+
362
+ print(f"⚠️ Over budget by {current_tokens - self.available_tokens:,} tokens")
363
+
364
+ # Convert to ConversationMessage objects
365
+ conv_messages = []
366
+ for i, msg in enumerate(messages):
367
+ # Handle both dict and Pydantic object formats
368
+ if isinstance(msg, dict):
369
+ role = msg.get("role", "")
370
+ content = msg.get("content", "")
371
+ else:
372
+ role = getattr(msg, "role", "")
373
+ content = getattr(msg, "content", "")
374
+
375
+ msg_type = "system" if i == 0 and role == "system" else "normal"
376
+ if "tool" in str(content).lower() or "function" in str(content).lower():
377
+ msg_type = "tool_result"
378
+
379
+ conv_msg = ConversationMessage(
380
+ role=role,
381
+ content=str(content),
382
+ message_type=msg_type
383
+ )
384
+ conv_messages.append(conv_msg)
385
+
386
+ # Apply sliding window
387
+ filtered = self.apply_sliding_window(conv_messages, self.available_tokens)
388
+
389
+ # Convert back to dict format
390
+ result_messages = [msg.to_dict() for msg in filtered]
391
+ final_tokens = self.count_messages_tokens(result_messages)
392
+
393
+ # Emergency truncation if still over
394
+ if final_tokens > self.available_tokens:
395
+ result_messages = self.emergency_truncate(result_messages, self.available_tokens)
396
+ final_tokens = self.count_messages_tokens(result_messages)
397
+
398
+ print(f"✅ Budget enforced: {final_tokens:,} tokens ({len(result_messages)} messages)")
399
+
400
+ return result_messages, final_tokens
401
+
402
+
403
+ # Global token budget manager instance
404
+ _token_manager = None
405
+
406
+ def get_token_manager(model: str = "gpt-4", max_tokens: int = 128000) -> TokenBudgetManager:
407
+ """Get or create global token budget manager."""
408
+ global _token_manager
409
+ if _token_manager is None:
410
+ _token_manager = TokenBudgetManager(model=model, max_tokens=max_tokens)
411
+ return _token_manager