ds-agent-cli 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/ds-agent.js +451 -0
- package/ds_agent/__init__.py +8 -0
- package/package.json +28 -0
- package/requirements.txt +126 -0
- package/setup.py +35 -0
- package/src/__init__.py +7 -0
- package/src/_compress_tool_result.py +118 -0
- package/src/api/__init__.py +4 -0
- package/src/api/app.py +1626 -0
- package/src/cache/__init__.py +5 -0
- package/src/cache/cache_manager.py +561 -0
- package/src/cli.py +2886 -0
- package/src/dynamic_prompts.py +281 -0
- package/src/orchestrator.py +4799 -0
- package/src/progress_manager.py +139 -0
- package/src/reasoning/__init__.py +332 -0
- package/src/reasoning/business_summary.py +431 -0
- package/src/reasoning/data_understanding.py +356 -0
- package/src/reasoning/model_explanation.py +383 -0
- package/src/reasoning/reasoning_trace.py +239 -0
- package/src/registry/__init__.py +3 -0
- package/src/registry/tools_registry.py +3 -0
- package/src/session_memory.py +448 -0
- package/src/session_store.py +370 -0
- package/src/storage/__init__.py +19 -0
- package/src/storage/artifact_store.py +620 -0
- package/src/storage/helpers.py +116 -0
- package/src/storage/huggingface_storage.py +694 -0
- package/src/storage/r2_storage.py +0 -0
- package/src/storage/user_files_service.py +288 -0
- package/src/tools/__init__.py +335 -0
- package/src/tools/advanced_analysis.py +823 -0
- package/src/tools/advanced_feature_engineering.py +708 -0
- package/src/tools/advanced_insights.py +578 -0
- package/src/tools/advanced_preprocessing.py +549 -0
- package/src/tools/advanced_training.py +906 -0
- package/src/tools/agent_tool_mapping.py +326 -0
- package/src/tools/auto_pipeline.py +420 -0
- package/src/tools/autogluon_training.py +1480 -0
- package/src/tools/business_intelligence.py +860 -0
- package/src/tools/cloud_data_sources.py +581 -0
- package/src/tools/code_interpreter.py +390 -0
- package/src/tools/computer_vision.py +614 -0
- package/src/tools/data_cleaning.py +614 -0
- package/src/tools/data_profiling.py +593 -0
- package/src/tools/data_type_conversion.py +268 -0
- package/src/tools/data_wrangling.py +433 -0
- package/src/tools/eda_reports.py +284 -0
- package/src/tools/enhanced_feature_engineering.py +241 -0
- package/src/tools/feature_engineering.py +302 -0
- package/src/tools/matplotlib_visualizations.py +1327 -0
- package/src/tools/model_training.py +520 -0
- package/src/tools/nlp_text_analytics.py +761 -0
- package/src/tools/plotly_visualizations.py +497 -0
- package/src/tools/production_mlops.py +852 -0
- package/src/tools/time_series.py +507 -0
- package/src/tools/tools_registry.py +2133 -0
- package/src/tools/visualization_engine.py +559 -0
- package/src/utils/__init__.py +42 -0
- package/src/utils/error_recovery.py +313 -0
- package/src/utils/parallel_executor.py +402 -0
- package/src/utils/polars_helpers.py +248 -0
- package/src/utils/schema_extraction.py +132 -0
- package/src/utils/semantic_layer.py +392 -0
- package/src/utils/token_budget.py +411 -0
- package/src/utils/validation.py +377 -0
- package/src/workflow_state.py +154 -0
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Global Progress Event Manager for Real-Time SSE Streaming
|
|
3
|
+
|
|
4
|
+
This module provides a singleton ProgressManager that captures all workflow progress
|
|
5
|
+
events and broadcasts them to connected SSE clients in real-time.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import json
|
|
10
|
+
from typing import Dict, List, Any, Optional
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
from collections import defaultdict
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ProgressManager:
|
|
16
|
+
"""
|
|
17
|
+
Manages progress events for active analysis sessions.
|
|
18
|
+
|
|
19
|
+
Features:
|
|
20
|
+
- Emit events to multiple subscribers simultaneously
|
|
21
|
+
- Store event history for late-joining clients
|
|
22
|
+
- Automatic cleanup of dead connections
|
|
23
|
+
- Thread-safe event broadcasting
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self):
|
|
27
|
+
self._queues: Dict[str, List[asyncio.Queue]] = defaultdict(list)
|
|
28
|
+
self._history: Dict[str, List[Dict]] = defaultdict(list)
|
|
29
|
+
self._lock = asyncio.Lock()
|
|
30
|
+
|
|
31
|
+
def emit(self, session_id: str, event: Dict[str, Any]):
|
|
32
|
+
"""
|
|
33
|
+
Emit a progress event to all subscribers.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
session_id: Session identifier
|
|
37
|
+
event: Event data (must include 'type' and 'message')
|
|
38
|
+
"""
|
|
39
|
+
print(f"[SSE] PROGRESS_MANAGER EMIT: session={session_id}, event_type={event.get('type')}, msg={event.get('message', '')[:50]}")
|
|
40
|
+
|
|
41
|
+
# Add timestamp
|
|
42
|
+
event['timestamp'] = datetime.now().isoformat()
|
|
43
|
+
|
|
44
|
+
# Store in history
|
|
45
|
+
self._history[session_id].append(event)
|
|
46
|
+
|
|
47
|
+
# Limit history size to prevent memory leaks
|
|
48
|
+
if len(self._history[session_id]) > 500:
|
|
49
|
+
self._history[session_id] = self._history[session_id][-500:]
|
|
50
|
+
|
|
51
|
+
print(f"[SSE] History stored, total events for {session_id}: {len(self._history[session_id])}")
|
|
52
|
+
|
|
53
|
+
# Send to all active subscribers
|
|
54
|
+
if session_id in self._queues:
|
|
55
|
+
print(f"[SSE] Found {len(self._queues[session_id])} subscribers for {session_id}")
|
|
56
|
+
dead_queues = []
|
|
57
|
+
for i, queue in enumerate(self._queues[session_id]):
|
|
58
|
+
try:
|
|
59
|
+
queue.put_nowait(event)
|
|
60
|
+
print(f"[SSE] Successfully queued event to subscriber {i+1}")
|
|
61
|
+
except asyncio.QueueFull:
|
|
62
|
+
print(f"[SSE] ERROR: Queue full for subscriber {i+1}")
|
|
63
|
+
dead_queues.append(queue)
|
|
64
|
+
except Exception as e:
|
|
65
|
+
print(f"[SSE] ERROR: Exception queuing event to subscriber {i+1}: {type(e).__name__}: {e}")
|
|
66
|
+
dead_queues.append(queue)
|
|
67
|
+
|
|
68
|
+
# Remove dead queues
|
|
69
|
+
for dead_queue in dead_queues:
|
|
70
|
+
if dead_queue in self._queues[session_id]:
|
|
71
|
+
self._queues[session_id].remove(dead_queue)
|
|
72
|
+
|
|
73
|
+
async def subscribe(self, session_id: str):
|
|
74
|
+
"""
|
|
75
|
+
Subscribe to progress events for a session.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
session_id: Session identifier
|
|
79
|
+
|
|
80
|
+
Yields:
|
|
81
|
+
Progress events as they occur
|
|
82
|
+
"""
|
|
83
|
+
queue = asyncio.Queue(maxsize=100)
|
|
84
|
+
self._queues[session_id].append(queue)
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
while True:
|
|
88
|
+
event = await queue.get()
|
|
89
|
+
print(f"[SSE] YIELDING event to client: type={event.get('type')}, msg={event.get('message', '')[:50]}")
|
|
90
|
+
yield event
|
|
91
|
+
except asyncio.CancelledError:
|
|
92
|
+
# Client disconnected
|
|
93
|
+
pass
|
|
94
|
+
finally:
|
|
95
|
+
# Cleanup
|
|
96
|
+
if session_id in self._queues and queue in self._queues[session_id]:
|
|
97
|
+
self._queues[session_id].remove(queue)
|
|
98
|
+
|
|
99
|
+
def get_history(self, session_id: str) -> List[Dict]:
|
|
100
|
+
"""
|
|
101
|
+
Get all past events for a session.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
session_id: Session identifier
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
List of past events
|
|
108
|
+
"""
|
|
109
|
+
return self._history.get(session_id, [])
|
|
110
|
+
|
|
111
|
+
def clear(self, session_id: str):
|
|
112
|
+
"""
|
|
113
|
+
Clear history and disconnect all subscribers for a session.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
session_id: Session identifier
|
|
117
|
+
"""
|
|
118
|
+
if session_id in self._history:
|
|
119
|
+
del self._history[session_id]
|
|
120
|
+
if session_id in self._queues:
|
|
121
|
+
# Close all queues
|
|
122
|
+
for queue in self._queues[session_id]:
|
|
123
|
+
try:
|
|
124
|
+
queue.put_nowait({'type': 'session_cleared', 'message': 'Session ended'})
|
|
125
|
+
except:
|
|
126
|
+
pass
|
|
127
|
+
del self._queues[session_id]
|
|
128
|
+
|
|
129
|
+
def get_active_sessions(self) -> List[str]:
|
|
130
|
+
"""Get list of sessions with active subscribers."""
|
|
131
|
+
return [sid for sid, queues in self._queues.items() if len(queues) > 0]
|
|
132
|
+
|
|
133
|
+
def get_subscriber_count(self, session_id: str) -> int:
|
|
134
|
+
"""Get number of active subscribers for a session."""
|
|
135
|
+
return len(self._queues.get(session_id, []))
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
# Global singleton instance
|
|
139
|
+
progress_manager = ProgressManager()
|
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Reasoning Module - Core Abstraction
|
|
3
|
+
|
|
4
|
+
Provides clean separation between:
|
|
5
|
+
- Deterministic data processing (tools)
|
|
6
|
+
- Non-deterministic reasoning (LLM)
|
|
7
|
+
|
|
8
|
+
Design Principles:
|
|
9
|
+
- NO RAW DATA ACCESS - Only summaries/metadata
|
|
10
|
+
- NO TRAINING DECISIONS - Only explanations
|
|
11
|
+
- STRUCTURED I/O - JSON in, JSON + text out
|
|
12
|
+
- CACHEABLE - Deterministic enough to cache
|
|
13
|
+
- REASONING ONLY - No execution, no side effects
|
|
14
|
+
|
|
15
|
+
Architecture:
|
|
16
|
+
Tool → Generates Summary → Reasoning Module → Returns Explanation
|
|
17
|
+
|
|
18
|
+
Tool: "Here's what I found: {stats}"
|
|
19
|
+
Reasoning: "Based on these stats, this means..."
|
|
20
|
+
|
|
21
|
+
Usage:
|
|
22
|
+
from reasoning import get_reasoner
|
|
23
|
+
|
|
24
|
+
reasoner = get_reasoner()
|
|
25
|
+
result = reasoner.explain_data(
|
|
26
|
+
summary={"rows": 1000, "columns": 20, "missing": 50}
|
|
27
|
+
)
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
import os
|
|
31
|
+
from typing import Dict, Any, Optional, Union
|
|
32
|
+
from abc import ABC, abstractmethod
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ReasoningBackend(ABC):
|
|
36
|
+
"""Abstract base class for reasoning backends."""
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def generate(
|
|
40
|
+
self,
|
|
41
|
+
prompt: str,
|
|
42
|
+
system_prompt: Optional[str] = None,
|
|
43
|
+
temperature: float = 0.1,
|
|
44
|
+
max_tokens: int = 2048
|
|
45
|
+
) -> str:
|
|
46
|
+
"""Generate reasoning response."""
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def generate_structured(
|
|
51
|
+
self,
|
|
52
|
+
prompt: str,
|
|
53
|
+
schema: Dict[str, Any],
|
|
54
|
+
system_prompt: Optional[str] = None
|
|
55
|
+
) -> Dict[str, Any]:
|
|
56
|
+
"""Generate structured JSON response."""
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class GeminiBackend(ReasoningBackend):
|
|
61
|
+
"""Gemini reasoning backend."""
|
|
62
|
+
|
|
63
|
+
def __init__(self, api_key: Optional[str] = None, model: str = "gemini-2.0-flash-exp"):
|
|
64
|
+
try:
|
|
65
|
+
import google.generativeai as genai
|
|
66
|
+
except ImportError:
|
|
67
|
+
raise ImportError(
|
|
68
|
+
"google-generativeai not installed. "
|
|
69
|
+
"Install with: pip install google-generativeai"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
api_key = api_key or os.getenv("GOOGLE_API_KEY")
|
|
73
|
+
if not api_key:
|
|
74
|
+
raise ValueError(
|
|
75
|
+
"Google API key required. Set GOOGLE_API_KEY env var or pass api_key"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
genai.configure(api_key=api_key)
|
|
79
|
+
self.model = genai.GenerativeModel(
|
|
80
|
+
model,
|
|
81
|
+
generation_config={"temperature": 0.1}
|
|
82
|
+
)
|
|
83
|
+
self.model_name = model
|
|
84
|
+
|
|
85
|
+
def generate(
|
|
86
|
+
self,
|
|
87
|
+
prompt: str,
|
|
88
|
+
system_prompt: Optional[str] = None,
|
|
89
|
+
temperature: float = 0.1,
|
|
90
|
+
max_tokens: int = 2048
|
|
91
|
+
) -> str:
|
|
92
|
+
"""Generate reasoning response."""
|
|
93
|
+
# Combine system and user prompts
|
|
94
|
+
full_prompt = prompt
|
|
95
|
+
if system_prompt:
|
|
96
|
+
full_prompt = f"{system_prompt}\n\n{prompt}"
|
|
97
|
+
|
|
98
|
+
response = self.model.generate_content(
|
|
99
|
+
full_prompt,
|
|
100
|
+
generation_config={
|
|
101
|
+
"temperature": temperature,
|
|
102
|
+
"max_output_tokens": max_tokens
|
|
103
|
+
}
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
return response.text
|
|
107
|
+
|
|
108
|
+
def generate_structured(
|
|
109
|
+
self,
|
|
110
|
+
prompt: str,
|
|
111
|
+
schema: Dict[str, Any],
|
|
112
|
+
system_prompt: Optional[str] = None
|
|
113
|
+
) -> Dict[str, Any]:
|
|
114
|
+
"""Generate structured JSON response."""
|
|
115
|
+
import json
|
|
116
|
+
|
|
117
|
+
# Add schema instruction
|
|
118
|
+
schema_str = json.dumps(schema, indent=2)
|
|
119
|
+
structured_prompt = f"""{prompt}
|
|
120
|
+
|
|
121
|
+
Respond with valid JSON matching this schema:
|
|
122
|
+
{schema_str}
|
|
123
|
+
|
|
124
|
+
Your response must be valid JSON only, no other text."""
|
|
125
|
+
|
|
126
|
+
response_text = self.generate(structured_prompt, system_prompt)
|
|
127
|
+
|
|
128
|
+
# Extract JSON from response
|
|
129
|
+
try:
|
|
130
|
+
# Try direct parse
|
|
131
|
+
return json.loads(response_text)
|
|
132
|
+
except json.JSONDecodeError:
|
|
133
|
+
# Try to extract JSON from markdown code blocks
|
|
134
|
+
import re
|
|
135
|
+
json_match = re.search(r'```json\s*\n(.*?)\n```', response_text, re.DOTALL)
|
|
136
|
+
if json_match:
|
|
137
|
+
return json.loads(json_match.group(1))
|
|
138
|
+
|
|
139
|
+
# Try to extract any JSON object
|
|
140
|
+
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
|
141
|
+
if json_match:
|
|
142
|
+
return json.loads(json_match.group(0))
|
|
143
|
+
|
|
144
|
+
raise ValueError(f"Failed to extract JSON from response: {response_text[:200]}...")
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class GroqBackend(ReasoningBackend):
|
|
148
|
+
"""Groq reasoning backend."""
|
|
149
|
+
|
|
150
|
+
def __init__(self, api_key: Optional[str] = None, model: str = "llama-3.3-70b-versatile"):
|
|
151
|
+
try:
|
|
152
|
+
from groq import Groq
|
|
153
|
+
except ImportError:
|
|
154
|
+
raise ImportError(
|
|
155
|
+
"groq not installed. "
|
|
156
|
+
"Install with: pip install groq"
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
api_key = api_key or os.getenv("GROQ_API_KEY")
|
|
160
|
+
if not api_key:
|
|
161
|
+
raise ValueError(
|
|
162
|
+
"Groq API key required. Set GROQ_API_KEY env var or pass api_key"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
self.client = Groq(api_key=api_key)
|
|
166
|
+
self.model_name = model
|
|
167
|
+
|
|
168
|
+
def generate(
|
|
169
|
+
self,
|
|
170
|
+
prompt: str,
|
|
171
|
+
system_prompt: Optional[str] = None,
|
|
172
|
+
temperature: float = 0.1,
|
|
173
|
+
max_tokens: int = 2048
|
|
174
|
+
) -> str:
|
|
175
|
+
"""Generate reasoning response."""
|
|
176
|
+
messages = []
|
|
177
|
+
|
|
178
|
+
if system_prompt:
|
|
179
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
180
|
+
|
|
181
|
+
messages.append({"role": "user", "content": prompt})
|
|
182
|
+
|
|
183
|
+
response = self.client.chat.completions.create(
|
|
184
|
+
model=self.model_name,
|
|
185
|
+
messages=messages,
|
|
186
|
+
temperature=temperature,
|
|
187
|
+
max_tokens=max_tokens
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
return response.choices[0].message.content
|
|
191
|
+
|
|
192
|
+
def generate_structured(
|
|
193
|
+
self,
|
|
194
|
+
prompt: str,
|
|
195
|
+
schema: Dict[str, Any],
|
|
196
|
+
system_prompt: Optional[str] = None
|
|
197
|
+
) -> Dict[str, Any]:
|
|
198
|
+
"""Generate structured JSON response."""
|
|
199
|
+
import json
|
|
200
|
+
|
|
201
|
+
# Add schema instruction
|
|
202
|
+
schema_str = json.dumps(schema, indent=2)
|
|
203
|
+
structured_prompt = f"""{prompt}
|
|
204
|
+
|
|
205
|
+
Respond with valid JSON matching this schema:
|
|
206
|
+
{schema_str}
|
|
207
|
+
|
|
208
|
+
Your response must be valid JSON only, no other text."""
|
|
209
|
+
|
|
210
|
+
response_text = self.generate(structured_prompt, system_prompt)
|
|
211
|
+
|
|
212
|
+
# Extract JSON from response
|
|
213
|
+
try:
|
|
214
|
+
return json.loads(response_text)
|
|
215
|
+
except json.JSONDecodeError:
|
|
216
|
+
# Try to extract JSON from markdown code blocks
|
|
217
|
+
import re
|
|
218
|
+
json_match = re.search(r'```json\s*\n(.*?)\n```', response_text, re.DOTALL)
|
|
219
|
+
if json_match:
|
|
220
|
+
return json.loads(json_match.group(1))
|
|
221
|
+
|
|
222
|
+
# Try to extract any JSON object
|
|
223
|
+
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
|
224
|
+
if json_match:
|
|
225
|
+
return json.loads(json_match.group(0))
|
|
226
|
+
|
|
227
|
+
raise ValueError(f"Failed to extract JSON from response: {response_text[:200]}...")
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class ReasoningEngine:
|
|
231
|
+
"""
|
|
232
|
+
Main reasoning engine.
|
|
233
|
+
|
|
234
|
+
Delegates to appropriate backend (Gemini, Groq, etc).
|
|
235
|
+
Provides high-level reasoning capabilities.
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
def __init__(
|
|
239
|
+
self,
|
|
240
|
+
backend: Optional[ReasoningBackend] = None,
|
|
241
|
+
provider: str = "gemini"
|
|
242
|
+
):
|
|
243
|
+
"""
|
|
244
|
+
Initialize reasoning engine.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
backend: Custom backend instance
|
|
248
|
+
provider: 'gemini' or 'groq' (if backend not provided)
|
|
249
|
+
"""
|
|
250
|
+
if backend:
|
|
251
|
+
self.backend = backend
|
|
252
|
+
else:
|
|
253
|
+
provider = provider or os.getenv("LLM_PROVIDER", "gemini")
|
|
254
|
+
|
|
255
|
+
if provider == "gemini":
|
|
256
|
+
self.backend = GeminiBackend()
|
|
257
|
+
elif provider == "groq":
|
|
258
|
+
self.backend = GroqBackend()
|
|
259
|
+
else:
|
|
260
|
+
raise ValueError(f"Unsupported provider: {provider}")
|
|
261
|
+
|
|
262
|
+
self.provider = provider
|
|
263
|
+
|
|
264
|
+
def reason(
|
|
265
|
+
self,
|
|
266
|
+
prompt: str,
|
|
267
|
+
system_prompt: Optional[str] = None,
|
|
268
|
+
temperature: float = 0.1
|
|
269
|
+
) -> str:
|
|
270
|
+
"""
|
|
271
|
+
General-purpose reasoning.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
prompt: User prompt
|
|
275
|
+
system_prompt: Optional system context
|
|
276
|
+
temperature: Creativity (0.0 = deterministic, 1.0 = creative)
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
Natural language response
|
|
280
|
+
"""
|
|
281
|
+
return self.backend.generate(prompt, system_prompt, temperature)
|
|
282
|
+
|
|
283
|
+
def reason_structured(
|
|
284
|
+
self,
|
|
285
|
+
prompt: str,
|
|
286
|
+
schema: Dict[str, Any],
|
|
287
|
+
system_prompt: Optional[str] = None
|
|
288
|
+
) -> Dict[str, Any]:
|
|
289
|
+
"""
|
|
290
|
+
Structured reasoning with JSON output.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
prompt: User prompt
|
|
294
|
+
schema: Expected JSON schema
|
|
295
|
+
system_prompt: Optional system context
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
Parsed JSON response
|
|
299
|
+
"""
|
|
300
|
+
return self.backend.generate_structured(prompt, schema, system_prompt)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
# Singleton instance
|
|
304
|
+
_reasoning_engine: Optional[ReasoningEngine] = None
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def get_reasoner(
|
|
308
|
+
backend: Optional[ReasoningBackend] = None,
|
|
309
|
+
provider: Optional[str] = None
|
|
310
|
+
) -> ReasoningEngine:
|
|
311
|
+
"""
|
|
312
|
+
Get singleton reasoning engine.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
backend: Custom backend instance
|
|
316
|
+
provider: 'gemini' or 'groq'
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
ReasoningEngine instance
|
|
320
|
+
"""
|
|
321
|
+
global _reasoning_engine
|
|
322
|
+
|
|
323
|
+
if _reasoning_engine is None or backend is not None:
|
|
324
|
+
_reasoning_engine = ReasoningEngine(backend=backend, provider=provider)
|
|
325
|
+
|
|
326
|
+
return _reasoning_engine
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def reset_reasoner():
|
|
330
|
+
"""Reset singleton (for testing)."""
|
|
331
|
+
global _reasoning_engine
|
|
332
|
+
_reasoning_engine = None
|