ds-agent-cli 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/ds-agent.js +451 -0
- package/ds_agent/__init__.py +8 -0
- package/package.json +28 -0
- package/requirements.txt +126 -0
- package/setup.py +35 -0
- package/src/__init__.py +7 -0
- package/src/_compress_tool_result.py +118 -0
- package/src/api/__init__.py +4 -0
- package/src/api/app.py +1626 -0
- package/src/cache/__init__.py +5 -0
- package/src/cache/cache_manager.py +561 -0
- package/src/cli.py +2886 -0
- package/src/dynamic_prompts.py +281 -0
- package/src/orchestrator.py +4799 -0
- package/src/progress_manager.py +139 -0
- package/src/reasoning/__init__.py +332 -0
- package/src/reasoning/business_summary.py +431 -0
- package/src/reasoning/data_understanding.py +356 -0
- package/src/reasoning/model_explanation.py +383 -0
- package/src/reasoning/reasoning_trace.py +239 -0
- package/src/registry/__init__.py +3 -0
- package/src/registry/tools_registry.py +3 -0
- package/src/session_memory.py +448 -0
- package/src/session_store.py +370 -0
- package/src/storage/__init__.py +19 -0
- package/src/storage/artifact_store.py +620 -0
- package/src/storage/helpers.py +116 -0
- package/src/storage/huggingface_storage.py +694 -0
- package/src/storage/r2_storage.py +0 -0
- package/src/storage/user_files_service.py +288 -0
- package/src/tools/__init__.py +335 -0
- package/src/tools/advanced_analysis.py +823 -0
- package/src/tools/advanced_feature_engineering.py +708 -0
- package/src/tools/advanced_insights.py +578 -0
- package/src/tools/advanced_preprocessing.py +549 -0
- package/src/tools/advanced_training.py +906 -0
- package/src/tools/agent_tool_mapping.py +326 -0
- package/src/tools/auto_pipeline.py +420 -0
- package/src/tools/autogluon_training.py +1480 -0
- package/src/tools/business_intelligence.py +860 -0
- package/src/tools/cloud_data_sources.py +581 -0
- package/src/tools/code_interpreter.py +390 -0
- package/src/tools/computer_vision.py +614 -0
- package/src/tools/data_cleaning.py +614 -0
- package/src/tools/data_profiling.py +593 -0
- package/src/tools/data_type_conversion.py +268 -0
- package/src/tools/data_wrangling.py +433 -0
- package/src/tools/eda_reports.py +284 -0
- package/src/tools/enhanced_feature_engineering.py +241 -0
- package/src/tools/feature_engineering.py +302 -0
- package/src/tools/matplotlib_visualizations.py +1327 -0
- package/src/tools/model_training.py +520 -0
- package/src/tools/nlp_text_analytics.py +761 -0
- package/src/tools/plotly_visualizations.py +497 -0
- package/src/tools/production_mlops.py +852 -0
- package/src/tools/time_series.py +507 -0
- package/src/tools/tools_registry.py +2133 -0
- package/src/tools/visualization_engine.py +559 -0
- package/src/utils/__init__.py +42 -0
- package/src/utils/error_recovery.py +313 -0
- package/src/utils/parallel_executor.py +402 -0
- package/src/utils/polars_helpers.py +248 -0
- package/src/utils/schema_extraction.py +132 -0
- package/src/utils/semantic_layer.py +392 -0
- package/src/utils/token_budget.py +411 -0
- package/src/utils/validation.py +377 -0
- package/src/workflow_state.py +154 -0
|
@@ -0,0 +1,411 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Strict Token Budget Management
|
|
3
|
+
|
|
4
|
+
Implements sliding window conversation history, aggressive compression,
|
|
5
|
+
and emergency context truncation to prevent context window overflow.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import List, Dict, Any, Optional, Tuple
|
|
9
|
+
import json
|
|
10
|
+
import tiktoken
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ConversationMessage:
|
|
15
|
+
"""Represents a message with priority for history management."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, role: str, content: str, message_type: str = "normal",
|
|
18
|
+
priority: int = 5, tokens: Optional[int] = None):
|
|
19
|
+
self.role = role
|
|
20
|
+
self.content = content
|
|
21
|
+
self.message_type = message_type # system, tool_result, assistant, user, normal
|
|
22
|
+
self.priority = priority # 1 (drop first) to 10 (keep last)
|
|
23
|
+
self.tokens = tokens
|
|
24
|
+
self.timestamp = None
|
|
25
|
+
|
|
26
|
+
def to_dict(self) -> Dict[str, str]:
|
|
27
|
+
"""Convert to OpenAI message format."""
|
|
28
|
+
return {"role": self.role, "content": self.content}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class TokenBudgetManager:
|
|
32
|
+
"""
|
|
33
|
+
Manages conversation history with strict token budget enforcement.
|
|
34
|
+
|
|
35
|
+
Features:
|
|
36
|
+
- Accurate token counting using tiktoken
|
|
37
|
+
- Priority-based message dropping
|
|
38
|
+
- Sliding window with smart compression
|
|
39
|
+
- Emergency context truncation
|
|
40
|
+
- Keeps recent tool results, drops old assistant messages
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self, model: str = "gpt-4", max_tokens: int = 128000,
|
|
44
|
+
reserve_tokens: int = 8000):
|
|
45
|
+
"""
|
|
46
|
+
Initialize token budget manager.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
model: Model name for token counting
|
|
50
|
+
max_tokens: Maximum context window size
|
|
51
|
+
reserve_tokens: Tokens to reserve for response
|
|
52
|
+
"""
|
|
53
|
+
self.model = model
|
|
54
|
+
self.max_tokens = max_tokens
|
|
55
|
+
self.reserve_tokens = reserve_tokens
|
|
56
|
+
self.available_tokens = max_tokens - reserve_tokens
|
|
57
|
+
|
|
58
|
+
# Initialize tokenizer
|
|
59
|
+
try:
|
|
60
|
+
self.encoding = tiktoken.encoding_for_model(model)
|
|
61
|
+
except:
|
|
62
|
+
# Fallback to cl100k_base (GPT-4/GPT-3.5)
|
|
63
|
+
self.encoding = tiktoken.get_encoding("cl100k_base")
|
|
64
|
+
|
|
65
|
+
print(f"📊 Token Budget: {self.available_tokens:,} tokens available ({self.max_tokens:,} - {self.reserve_tokens:,} reserve)")
|
|
66
|
+
|
|
67
|
+
def count_tokens(self, text: str) -> int:
|
|
68
|
+
"""Count tokens in text using tiktoken."""
|
|
69
|
+
try:
|
|
70
|
+
return len(self.encoding.encode(text))
|
|
71
|
+
except:
|
|
72
|
+
# Fallback estimation: ~4 chars per token
|
|
73
|
+
return len(text) // 4
|
|
74
|
+
|
|
75
|
+
def count_message_tokens(self, message) -> int:
|
|
76
|
+
"""
|
|
77
|
+
Count tokens in a message (includes role overhead).
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
message: Either a dict or a Pydantic ChatMessage object
|
|
81
|
+
"""
|
|
82
|
+
# Format: <|role|>content<|endofmessage|>
|
|
83
|
+
# Approximately 4 tokens overhead per message
|
|
84
|
+
|
|
85
|
+
# Handle both dict and Pydantic object formats
|
|
86
|
+
if isinstance(message, dict):
|
|
87
|
+
content = message.get("content", "")
|
|
88
|
+
role = message.get("role", "")
|
|
89
|
+
else:
|
|
90
|
+
# Pydantic object (like ChatMessage from Mistral SDK)
|
|
91
|
+
content = getattr(message, "content", "")
|
|
92
|
+
role = getattr(message, "role", "")
|
|
93
|
+
|
|
94
|
+
content_tokens = self.count_tokens(str(content))
|
|
95
|
+
role_tokens = self.count_tokens(str(role))
|
|
96
|
+
return content_tokens + role_tokens + 4
|
|
97
|
+
|
|
98
|
+
def count_messages_tokens(self, messages: List) -> int:
|
|
99
|
+
"""Count total tokens in message list."""
|
|
100
|
+
return sum(self.count_message_tokens(msg) for msg in messages)
|
|
101
|
+
|
|
102
|
+
def compress_tool_result(self, tool_result: str, max_tokens: int = 500) -> str:
|
|
103
|
+
"""
|
|
104
|
+
Aggressively compress tool result while keeping key information.
|
|
105
|
+
|
|
106
|
+
Keeps:
|
|
107
|
+
- Success/failure status
|
|
108
|
+
- Key metrics and numbers
|
|
109
|
+
- Error messages
|
|
110
|
+
|
|
111
|
+
Drops:
|
|
112
|
+
- Verbose logs
|
|
113
|
+
- Duplicate information
|
|
114
|
+
- Large data structures
|
|
115
|
+
"""
|
|
116
|
+
if self.count_tokens(tool_result) <= max_tokens:
|
|
117
|
+
return tool_result
|
|
118
|
+
|
|
119
|
+
try:
|
|
120
|
+
# Try to parse as JSON
|
|
121
|
+
result_dict = json.loads(tool_result)
|
|
122
|
+
|
|
123
|
+
# Extract essential fields
|
|
124
|
+
compressed = {
|
|
125
|
+
"success": result_dict.get("success", True),
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
# Add error if present
|
|
129
|
+
if "error" in result_dict:
|
|
130
|
+
compressed["error"] = str(result_dict["error"])[:200]
|
|
131
|
+
|
|
132
|
+
# Add key metrics (numbers, scores, paths)
|
|
133
|
+
for key in ["score", "accuracy", "best_score", "n_rows", "n_cols",
|
|
134
|
+
"output_path", "best_model", "result_summary"]:
|
|
135
|
+
if key in result_dict:
|
|
136
|
+
compressed[key] = result_dict[key]
|
|
137
|
+
|
|
138
|
+
# Add result if it's small
|
|
139
|
+
if "result" in result_dict:
|
|
140
|
+
result_str = str(result_dict["result"])
|
|
141
|
+
if len(result_str) < 300:
|
|
142
|
+
compressed["result"] = result_str[:300]
|
|
143
|
+
|
|
144
|
+
return json.dumps(compressed, indent=None)
|
|
145
|
+
|
|
146
|
+
except json.JSONDecodeError:
|
|
147
|
+
# Not JSON - truncate intelligently
|
|
148
|
+
lines = tool_result.split('\n')
|
|
149
|
+
|
|
150
|
+
# Keep first 5 and last 5 lines
|
|
151
|
+
if len(lines) > 15:
|
|
152
|
+
compressed_lines = lines[:5] + ["... (truncated) ..."] + lines[-5:]
|
|
153
|
+
result = '\n'.join(compressed_lines)
|
|
154
|
+
else:
|
|
155
|
+
result = tool_result
|
|
156
|
+
|
|
157
|
+
# Hard truncate if still too long
|
|
158
|
+
token_count = self.count_tokens(result)
|
|
159
|
+
if token_count > max_tokens:
|
|
160
|
+
# Truncate to character limit (rough)
|
|
161
|
+
char_limit = max_tokens * 4
|
|
162
|
+
result = result[:char_limit] + "... (truncated)"
|
|
163
|
+
|
|
164
|
+
return result
|
|
165
|
+
|
|
166
|
+
def prioritize_messages(self, messages: List[ConversationMessage]) -> List[ConversationMessage]:
|
|
167
|
+
"""
|
|
168
|
+
Assign priorities to messages based on type and importance.
|
|
169
|
+
|
|
170
|
+
Priority levels:
|
|
171
|
+
- 10: System prompt, recent user messages
|
|
172
|
+
- 9: Recent tool results (last 3)
|
|
173
|
+
- 8: Recent assistant responses (last 2)
|
|
174
|
+
- 5: Normal messages
|
|
175
|
+
- 3: Old tool results
|
|
176
|
+
- 2: Old assistant responses
|
|
177
|
+
- 1: Very old messages
|
|
178
|
+
"""
|
|
179
|
+
# Find recent messages (last 5)
|
|
180
|
+
recent_threshold = max(0, len(messages) - 5)
|
|
181
|
+
|
|
182
|
+
for i, msg in enumerate(messages):
|
|
183
|
+
if msg.message_type == "system":
|
|
184
|
+
msg.priority = 10
|
|
185
|
+
elif msg.role == "user":
|
|
186
|
+
msg.priority = 10 if i >= recent_threshold else 7
|
|
187
|
+
elif msg.message_type == "tool_result":
|
|
188
|
+
msg.priority = 9 if i >= recent_threshold else 3
|
|
189
|
+
elif msg.role == "assistant":
|
|
190
|
+
msg.priority = 8 if i >= recent_threshold else 2
|
|
191
|
+
else:
|
|
192
|
+
msg.priority = 5 if i >= recent_threshold else 1
|
|
193
|
+
|
|
194
|
+
return messages
|
|
195
|
+
|
|
196
|
+
def apply_sliding_window(self, messages: List[ConversationMessage],
|
|
197
|
+
target_tokens: int) -> List[ConversationMessage]:
|
|
198
|
+
"""
|
|
199
|
+
Apply sliding window to fit within token budget.
|
|
200
|
+
|
|
201
|
+
Strategy:
|
|
202
|
+
1. Always keep system prompt (first message)
|
|
203
|
+
2. Keep recent messages (last N)
|
|
204
|
+
3. Drop low-priority messages from middle
|
|
205
|
+
4. Compress tool results if needed
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
messages: List of ConversationMessage objects
|
|
209
|
+
target_tokens: Target token count
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
Filtered message list within budget
|
|
213
|
+
"""
|
|
214
|
+
if not messages:
|
|
215
|
+
return []
|
|
216
|
+
|
|
217
|
+
# Always keep system prompt
|
|
218
|
+
system_msg = messages[0] if messages[0].message_type == "system" else None
|
|
219
|
+
other_messages = messages[1:] if system_msg else messages
|
|
220
|
+
|
|
221
|
+
# Prioritize messages
|
|
222
|
+
other_messages = self.prioritize_messages(other_messages)
|
|
223
|
+
|
|
224
|
+
# Sort by priority (high to low)
|
|
225
|
+
sorted_messages = sorted(other_messages, key=lambda m: m.priority, reverse=True)
|
|
226
|
+
|
|
227
|
+
# Calculate tokens for each message
|
|
228
|
+
for msg in sorted_messages:
|
|
229
|
+
if msg.tokens is None:
|
|
230
|
+
msg.tokens = self.count_message_tokens(msg.to_dict())
|
|
231
|
+
|
|
232
|
+
# Greedily add messages until budget exhausted
|
|
233
|
+
kept_messages = []
|
|
234
|
+
current_tokens = 0
|
|
235
|
+
|
|
236
|
+
# Add system prompt first
|
|
237
|
+
if system_msg:
|
|
238
|
+
system_msg.tokens = self.count_message_tokens(system_msg.to_dict())
|
|
239
|
+
kept_messages.append(system_msg)
|
|
240
|
+
current_tokens += system_msg.tokens
|
|
241
|
+
|
|
242
|
+
# Add other messages by priority
|
|
243
|
+
for msg in sorted_messages:
|
|
244
|
+
if current_tokens + msg.tokens <= target_tokens:
|
|
245
|
+
kept_messages.append(msg)
|
|
246
|
+
current_tokens += msg.tokens
|
|
247
|
+
elif msg.message_type == "tool_result" and msg.priority >= 8:
|
|
248
|
+
# Try compressing critical tool results
|
|
249
|
+
compressed_content = self.compress_tool_result(msg.content, max_tokens=300)
|
|
250
|
+
compressed_tokens = self.count_tokens(compressed_content)
|
|
251
|
+
|
|
252
|
+
if current_tokens + compressed_tokens <= target_tokens:
|
|
253
|
+
msg.content = compressed_content
|
|
254
|
+
msg.tokens = compressed_tokens
|
|
255
|
+
kept_messages.append(msg)
|
|
256
|
+
current_tokens += compressed_tokens
|
|
257
|
+
|
|
258
|
+
# Sort kept messages back to chronological order
|
|
259
|
+
# System message stays first, rest in order they appeared
|
|
260
|
+
if system_msg:
|
|
261
|
+
non_system = [m for m in kept_messages if m != system_msg]
|
|
262
|
+
# Sort by original index (approximate by content comparison)
|
|
263
|
+
original_order = []
|
|
264
|
+
for orig_msg in messages:
|
|
265
|
+
for kept in non_system:
|
|
266
|
+
if kept.content == orig_msg.content:
|
|
267
|
+
original_order.append(kept)
|
|
268
|
+
break
|
|
269
|
+
|
|
270
|
+
kept_messages = [system_msg] + original_order
|
|
271
|
+
|
|
272
|
+
print(f"📊 Sliding window: {len(messages)} → {len(kept_messages)} messages ({current_tokens:,} tokens)")
|
|
273
|
+
|
|
274
|
+
return kept_messages
|
|
275
|
+
|
|
276
|
+
def emergency_truncate(self, messages: List[Dict[str, str]],
|
|
277
|
+
max_tokens: int) -> List[Dict[str, str]]:
|
|
278
|
+
"""
|
|
279
|
+
Emergency truncation when context is about to overflow.
|
|
280
|
+
|
|
281
|
+
Aggressive strategy:
|
|
282
|
+
- Keep system prompt
|
|
283
|
+
- Keep last user message
|
|
284
|
+
- Keep last 2 messages
|
|
285
|
+
- Truncate everything else
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
messages: Message list
|
|
289
|
+
max_tokens: Hard token limit
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
Truncated message list
|
|
293
|
+
"""
|
|
294
|
+
if not messages:
|
|
295
|
+
return []
|
|
296
|
+
|
|
297
|
+
print("⚠️ EMERGENCY TRUNCATION: Context overflow imminent")
|
|
298
|
+
|
|
299
|
+
# Always keep system, last user, and last 2 messages
|
|
300
|
+
essential_messages = []
|
|
301
|
+
|
|
302
|
+
# System prompt (first message)
|
|
303
|
+
if messages:
|
|
304
|
+
essential_messages.append(messages[0])
|
|
305
|
+
|
|
306
|
+
# Last 2 messages
|
|
307
|
+
if len(messages) > 2:
|
|
308
|
+
essential_messages.extend(messages[-2:])
|
|
309
|
+
else:
|
|
310
|
+
essential_messages.extend(messages[1:])
|
|
311
|
+
|
|
312
|
+
# Count tokens
|
|
313
|
+
total_tokens = self.count_messages_tokens(essential_messages)
|
|
314
|
+
|
|
315
|
+
if total_tokens <= max_tokens:
|
|
316
|
+
return essential_messages
|
|
317
|
+
|
|
318
|
+
# Still too large - truncate system prompt
|
|
319
|
+
print("⚠️ Truncating system prompt to fit budget")
|
|
320
|
+
system_msg = essential_messages[0]
|
|
321
|
+
|
|
322
|
+
# Handle both dict and Pydantic object formats
|
|
323
|
+
if isinstance(system_msg, dict):
|
|
324
|
+
system_content = system_msg["content"]
|
|
325
|
+
else:
|
|
326
|
+
system_content = getattr(system_msg, "content", "")
|
|
327
|
+
|
|
328
|
+
# Keep first 1000 chars of system prompt
|
|
329
|
+
truncated_system = {
|
|
330
|
+
"role": "system",
|
|
331
|
+
"content": str(system_content)[:1000] + "\n\n... (truncated due to context limit) ..."
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
return [truncated_system] + essential_messages[1:]
|
|
335
|
+
|
|
336
|
+
def enforce_budget(self, messages: List[Dict[str, str]],
|
|
337
|
+
system_prompt: Optional[str] = None) -> Tuple[List[Dict[str, str]], int]:
|
|
338
|
+
"""
|
|
339
|
+
Main entry point: Enforce token budget on message list.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
messages: List of messages
|
|
343
|
+
system_prompt: Optional new system prompt to prepend
|
|
344
|
+
|
|
345
|
+
Returns:
|
|
346
|
+
(filtered_messages, total_tokens)
|
|
347
|
+
"""
|
|
348
|
+
# Add system prompt if provided
|
|
349
|
+
if system_prompt:
|
|
350
|
+
messages = [{"role": "system", "content": system_prompt}] + messages
|
|
351
|
+
|
|
352
|
+
# Count current tokens
|
|
353
|
+
current_tokens = self.count_messages_tokens(messages)
|
|
354
|
+
|
|
355
|
+
print(f"📊 Token Budget Check: {current_tokens:,} / {self.available_tokens:,} tokens")
|
|
356
|
+
|
|
357
|
+
# If within budget, return as-is
|
|
358
|
+
if current_tokens <= self.available_tokens:
|
|
359
|
+
print("✅ Within budget")
|
|
360
|
+
return messages, current_tokens
|
|
361
|
+
|
|
362
|
+
print(f"⚠️ Over budget by {current_tokens - self.available_tokens:,} tokens")
|
|
363
|
+
|
|
364
|
+
# Convert to ConversationMessage objects
|
|
365
|
+
conv_messages = []
|
|
366
|
+
for i, msg in enumerate(messages):
|
|
367
|
+
# Handle both dict and Pydantic object formats
|
|
368
|
+
if isinstance(msg, dict):
|
|
369
|
+
role = msg.get("role", "")
|
|
370
|
+
content = msg.get("content", "")
|
|
371
|
+
else:
|
|
372
|
+
role = getattr(msg, "role", "")
|
|
373
|
+
content = getattr(msg, "content", "")
|
|
374
|
+
|
|
375
|
+
msg_type = "system" if i == 0 and role == "system" else "normal"
|
|
376
|
+
if "tool" in str(content).lower() or "function" in str(content).lower():
|
|
377
|
+
msg_type = "tool_result"
|
|
378
|
+
|
|
379
|
+
conv_msg = ConversationMessage(
|
|
380
|
+
role=role,
|
|
381
|
+
content=str(content),
|
|
382
|
+
message_type=msg_type
|
|
383
|
+
)
|
|
384
|
+
conv_messages.append(conv_msg)
|
|
385
|
+
|
|
386
|
+
# Apply sliding window
|
|
387
|
+
filtered = self.apply_sliding_window(conv_messages, self.available_tokens)
|
|
388
|
+
|
|
389
|
+
# Convert back to dict format
|
|
390
|
+
result_messages = [msg.to_dict() for msg in filtered]
|
|
391
|
+
final_tokens = self.count_messages_tokens(result_messages)
|
|
392
|
+
|
|
393
|
+
# Emergency truncation if still over
|
|
394
|
+
if final_tokens > self.available_tokens:
|
|
395
|
+
result_messages = self.emergency_truncate(result_messages, self.available_tokens)
|
|
396
|
+
final_tokens = self.count_messages_tokens(result_messages)
|
|
397
|
+
|
|
398
|
+
print(f"✅ Budget enforced: {final_tokens:,} tokens ({len(result_messages)} messages)")
|
|
399
|
+
|
|
400
|
+
return result_messages, final_tokens
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
# Global token budget manager instance
|
|
404
|
+
_token_manager = None
|
|
405
|
+
|
|
406
|
+
def get_token_manager(model: str = "gpt-4", max_tokens: int = 128000) -> TokenBudgetManager:
|
|
407
|
+
"""Get or create global token budget manager."""
|
|
408
|
+
global _token_manager
|
|
409
|
+
if _token_manager is None:
|
|
410
|
+
_token_manager = TokenBudgetManager(model=model, max_tokens=max_tokens)
|
|
411
|
+
return _token_manager
|