DeepFabric 4.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepfabric/__init__.py +70 -0
- deepfabric/__main__.py +6 -0
- deepfabric/auth.py +382 -0
- deepfabric/builders.py +303 -0
- deepfabric/builders_agent.py +1304 -0
- deepfabric/cli.py +1288 -0
- deepfabric/config.py +899 -0
- deepfabric/config_manager.py +251 -0
- deepfabric/constants.py +94 -0
- deepfabric/dataset_manager.py +534 -0
- deepfabric/error_codes.py +581 -0
- deepfabric/evaluation/__init__.py +47 -0
- deepfabric/evaluation/backends/__init__.py +32 -0
- deepfabric/evaluation/backends/ollama_backend.py +137 -0
- deepfabric/evaluation/backends/tool_call_parsers.py +409 -0
- deepfabric/evaluation/backends/transformers_backend.py +326 -0
- deepfabric/evaluation/evaluator.py +845 -0
- deepfabric/evaluation/evaluators/__init__.py +13 -0
- deepfabric/evaluation/evaluators/base.py +104 -0
- deepfabric/evaluation/evaluators/builtin/__init__.py +5 -0
- deepfabric/evaluation/evaluators/builtin/tool_calling.py +93 -0
- deepfabric/evaluation/evaluators/registry.py +66 -0
- deepfabric/evaluation/inference.py +155 -0
- deepfabric/evaluation/metrics.py +397 -0
- deepfabric/evaluation/parser.py +304 -0
- deepfabric/evaluation/reporters/__init__.py +13 -0
- deepfabric/evaluation/reporters/base.py +56 -0
- deepfabric/evaluation/reporters/cloud_reporter.py +195 -0
- deepfabric/evaluation/reporters/file_reporter.py +61 -0
- deepfabric/evaluation/reporters/multi_reporter.py +56 -0
- deepfabric/exceptions.py +67 -0
- deepfabric/factory.py +26 -0
- deepfabric/generator.py +1084 -0
- deepfabric/graph.py +545 -0
- deepfabric/hf_hub.py +214 -0
- deepfabric/kaggle_hub.py +219 -0
- deepfabric/llm/__init__.py +41 -0
- deepfabric/llm/api_key_verifier.py +534 -0
- deepfabric/llm/client.py +1206 -0
- deepfabric/llm/errors.py +105 -0
- deepfabric/llm/rate_limit_config.py +262 -0
- deepfabric/llm/rate_limit_detector.py +278 -0
- deepfabric/llm/retry_handler.py +270 -0
- deepfabric/metrics.py +212 -0
- deepfabric/progress.py +262 -0
- deepfabric/prompts.py +290 -0
- deepfabric/schemas.py +1000 -0
- deepfabric/spin/__init__.py +6 -0
- deepfabric/spin/client.py +263 -0
- deepfabric/spin/models.py +26 -0
- deepfabric/stream_simulator.py +90 -0
- deepfabric/tools/__init__.py +5 -0
- deepfabric/tools/defaults.py +85 -0
- deepfabric/tools/loader.py +87 -0
- deepfabric/tools/mcp_client.py +677 -0
- deepfabric/topic_manager.py +303 -0
- deepfabric/topic_model.py +20 -0
- deepfabric/training/__init__.py +35 -0
- deepfabric/training/api_key_prompt.py +302 -0
- deepfabric/training/callback.py +363 -0
- deepfabric/training/metrics_sender.py +301 -0
- deepfabric/tree.py +438 -0
- deepfabric/tui.py +1267 -0
- deepfabric/update_checker.py +166 -0
- deepfabric/utils.py +150 -0
- deepfabric/validation.py +143 -0
- deepfabric-4.4.0.dist-info/METADATA +702 -0
- deepfabric-4.4.0.dist-info/RECORD +71 -0
- deepfabric-4.4.0.dist-info/WHEEL +4 -0
- deepfabric-4.4.0.dist-info/entry_points.txt +2 -0
- deepfabric-4.4.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,1304 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import random
|
|
4
|
+
import uuid
|
|
5
|
+
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
|
|
10
|
+
from .builders import ConversationBuilder
|
|
11
|
+
from .constants import DEFAULT_SAMPLE_RETRIES
|
|
12
|
+
from .exceptions import DataSetGeneratorError
|
|
13
|
+
from .progress import ProgressReporter
|
|
14
|
+
from .schemas import (
|
|
15
|
+
AgentContext,
|
|
16
|
+
AgentStep,
|
|
17
|
+
ChatMessage,
|
|
18
|
+
Conversation,
|
|
19
|
+
PendingToolCall,
|
|
20
|
+
ReasoningStep,
|
|
21
|
+
ReasoningTrace,
|
|
22
|
+
ToolCall,
|
|
23
|
+
ToolContext,
|
|
24
|
+
ToolDefinition,
|
|
25
|
+
ToolExecution,
|
|
26
|
+
generate_tool_call_id,
|
|
27
|
+
)
|
|
28
|
+
from .spin import SpinClient, SpinSession
|
|
29
|
+
from .stream_simulator import simulate_stream
|
|
30
|
+
from .utils import is_validation_error
|
|
31
|
+
|
|
32
|
+
if TYPE_CHECKING:
|
|
33
|
+
from .generator import DataSetGeneratorConfig
|
|
34
|
+
from .llm import LLMClient
|
|
35
|
+
from .schemas import ToolRegistry
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _convert_steps_to_reasoning(
|
|
41
|
+
steps: list["AgentStep"],
|
|
42
|
+
final_action_text: str = "Ready to respond",
|
|
43
|
+
) -> list[ReasoningStep]:
|
|
44
|
+
"""Convert AgentStep objects to ReasoningStep objects.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
steps: List of AgentStep objects to convert
|
|
48
|
+
final_action_text: Text to use for the action when step is final
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
List of ReasoningStep objects
|
|
52
|
+
"""
|
|
53
|
+
result = []
|
|
54
|
+
for i, step in enumerate(steps, 1):
|
|
55
|
+
action = None
|
|
56
|
+
if step.tool_calls:
|
|
57
|
+
actions = [f"{tc.function_name}({tc.arguments})" for tc in step.tool_calls]
|
|
58
|
+
action = "; ".join(actions)
|
|
59
|
+
elif step.is_final:
|
|
60
|
+
action = final_action_text
|
|
61
|
+
result.append(
|
|
62
|
+
ReasoningStep(
|
|
63
|
+
step_number=i,
|
|
64
|
+
thought=step.thought,
|
|
65
|
+
action=action,
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
return result
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class UserQuestion(BaseModel):
|
|
72
|
+
"""User's question or request."""
|
|
73
|
+
|
|
74
|
+
content: str = Field(
|
|
75
|
+
description="The user's question or request text - just the question itself, nothing else",
|
|
76
|
+
min_length=10,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class Scenario(BaseModel):
|
|
81
|
+
"""Multi-turn scenario description."""
|
|
82
|
+
|
|
83
|
+
description: str = Field(
|
|
84
|
+
description="Brief scenario description requiring multiple turns",
|
|
85
|
+
min_length=20,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class AgentResponse(BaseModel):
|
|
90
|
+
"""Agent's response to user."""
|
|
91
|
+
|
|
92
|
+
content: str = Field(
|
|
93
|
+
description="The agent's response text - clear and concise",
|
|
94
|
+
min_length=10,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class ToolOutput(BaseModel):
|
|
99
|
+
"""Simulated tool execution output."""
|
|
100
|
+
|
|
101
|
+
result: str = Field(description="The tool's output/result", min_length=1)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class ConclusionDecision(BaseModel):
|
|
105
|
+
"""Decision on whether to conclude conversation."""
|
|
106
|
+
|
|
107
|
+
should_conclude: bool = Field(
|
|
108
|
+
description="True if conversation task is complete, False if more turns needed"
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class StepWithResults(BaseModel):
|
|
113
|
+
"""A ReAct step paired with its execution results.
|
|
114
|
+
|
|
115
|
+
Preserves the step-by-step structure for proper conversation formatting.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
step: AgentStep = Field(description="The original step with thought and pending tool calls")
|
|
119
|
+
results: list[ToolExecution] = Field(
|
|
120
|
+
default_factory=list, description="Tool execution results for this step"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class AgentTurnData(BaseModel):
|
|
125
|
+
"""Typed data for a single turn in an agent conversation.
|
|
126
|
+
|
|
127
|
+
This model ensures type safety when building multi-turn conversations.
|
|
128
|
+
Stores steps with their results to preserve ReAct structure.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
user_message: ChatMessage = Field(description="User's message for this turn")
|
|
132
|
+
steps_with_results: list[StepWithResults] = Field(
|
|
133
|
+
description="ReAct steps with their execution results, preserving step-by-step order"
|
|
134
|
+
)
|
|
135
|
+
agent_response: ChatMessage = Field(description="Agent's final response for this turn")
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def reasoning_steps(self) -> list[ReasoningStep]:
|
|
139
|
+
"""Convert steps to ReasoningSteps for backward compatibility."""
|
|
140
|
+
steps = [swr.step for swr in self.steps_with_results]
|
|
141
|
+
return _convert_steps_to_reasoning(steps)
|
|
142
|
+
|
|
143
|
+
@property
|
|
144
|
+
def tool_calls(self) -> list[ToolExecution]:
|
|
145
|
+
"""Get all tool executions for backward compatibility."""
|
|
146
|
+
result = []
|
|
147
|
+
for swr in self.steps_with_results:
|
|
148
|
+
result.extend(swr.results)
|
|
149
|
+
return result
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class SingleTurnAgentBuilder(ConversationBuilder):
|
|
153
|
+
"""Builder for single-turn agent conversations with tool calling.
|
|
154
|
+
|
|
155
|
+
Generates conversations using a multi-step process:
|
|
156
|
+
1. Generate user question
|
|
157
|
+
2. Generate agent reasoning + tool calls
|
|
158
|
+
3. Execute tools via Spin (or simulate if no Spin endpoint)
|
|
159
|
+
4. Generate agent's final response
|
|
160
|
+
|
|
161
|
+
This produces realistic tool-calling training data.
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
def __init__(
|
|
165
|
+
self,
|
|
166
|
+
llm: "LLMClient",
|
|
167
|
+
config: "DataSetGeneratorConfig",
|
|
168
|
+
tool_registry: "ToolRegistry",
|
|
169
|
+
progress_reporter: ProgressReporter | None = None,
|
|
170
|
+
):
|
|
171
|
+
"""Initialize with required tool registry.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
llm: LLM client for generation
|
|
175
|
+
config: Generator configuration
|
|
176
|
+
tool_registry: Tool registry (required for agent builders)
|
|
177
|
+
progress_reporter: Optional progress reporter for streaming feedback
|
|
178
|
+
"""
|
|
179
|
+
super().__init__(llm, config, tool_registry, progress_reporter)
|
|
180
|
+
# Store as non-optional for type checker
|
|
181
|
+
self.tool_registry: ToolRegistry = tool_registry
|
|
182
|
+
|
|
183
|
+
# Spin integration for real tool execution
|
|
184
|
+
self._spin_client: SpinClient | None = None
|
|
185
|
+
self._spin_session: SpinSession | None = None
|
|
186
|
+
|
|
187
|
+
# Track seen tool signatures to skip duplicates
|
|
188
|
+
self._seen_tool_signatures: set[str] = set()
|
|
189
|
+
|
|
190
|
+
# Initialize Spin client if endpoint is configured
|
|
191
|
+
spin_endpoint = getattr(config, "spin_endpoint", None)
|
|
192
|
+
tool_execute_path = getattr(config, "tool_execute_path", None)
|
|
193
|
+
if spin_endpoint:
|
|
194
|
+
self._spin_client = SpinClient(
|
|
195
|
+
endpoint=spin_endpoint,
|
|
196
|
+
tool_execute_path=tool_execute_path,
|
|
197
|
+
)
|
|
198
|
+
if tool_execute_path:
|
|
199
|
+
logger.info(
|
|
200
|
+
"Spin execution enabled: %s (execute path: %s)",
|
|
201
|
+
spin_endpoint,
|
|
202
|
+
tool_execute_path,
|
|
203
|
+
)
|
|
204
|
+
else:
|
|
205
|
+
logger.info("Spin execution enabled: %s", spin_endpoint)
|
|
206
|
+
|
|
207
|
+
async def generate(self, topic_prompt: str, error_feedback: str | None = None) -> Conversation:
|
|
208
|
+
"""Generate single-turn agent conversation with tools using ReAct loop.
|
|
209
|
+
|
|
210
|
+
Uses a think-act-observe loop where each step's tool calls are based on
|
|
211
|
+
observations from previous steps. This prevents the agent from making
|
|
212
|
+
decisions (like writes) before observing results (like reads).
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
topic_prompt: Topic or scenario to generate conversation about
|
|
216
|
+
error_feedback: Optional error message from a previous failed attempt
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Complete Conversation with tool calling
|
|
220
|
+
|
|
221
|
+
Raises:
|
|
222
|
+
ValueError: If generation fails at any step
|
|
223
|
+
"""
|
|
224
|
+
try:
|
|
225
|
+
# Initialize Spin session if configured
|
|
226
|
+
await self._ensure_spin_session()
|
|
227
|
+
|
|
228
|
+
# Step 1: Generate user question
|
|
229
|
+
user_message = await self._generate_user_question(topic_prompt)
|
|
230
|
+
|
|
231
|
+
# Step 2: ReAct loop - think, act, observe
|
|
232
|
+
all_steps: list[AgentStep] = []
|
|
233
|
+
all_tool_results: list[ToolExecution] = []
|
|
234
|
+
max_steps = getattr(self.config, "max_agent_steps", 5)
|
|
235
|
+
|
|
236
|
+
# Reset duplicate tracking for this conversation
|
|
237
|
+
self._seen_tool_signatures.clear()
|
|
238
|
+
|
|
239
|
+
for step_num in range(max_steps):
|
|
240
|
+
if self.progress_reporter:
|
|
241
|
+
self.progress_reporter.emit_step_start(f"ReAct step {step_num + 1}/{max_steps}")
|
|
242
|
+
|
|
243
|
+
# Generate next step based on observations so far
|
|
244
|
+
step = await self._generate_next_step(
|
|
245
|
+
user_message,
|
|
246
|
+
all_steps,
|
|
247
|
+
all_tool_results,
|
|
248
|
+
error_feedback if step_num == 0 else None,
|
|
249
|
+
)
|
|
250
|
+
all_steps.append(step)
|
|
251
|
+
|
|
252
|
+
# Check if agent is done
|
|
253
|
+
if step.is_final or not step.tool_calls:
|
|
254
|
+
if self.progress_reporter:
|
|
255
|
+
self.progress_reporter.emit_step_complete(
|
|
256
|
+
f"Agent decided to conclude after {step_num + 1} steps"
|
|
257
|
+
)
|
|
258
|
+
break
|
|
259
|
+
|
|
260
|
+
# Execute THIS step's tools
|
|
261
|
+
step_results = await self._execute_step_tools(step.tool_calls)
|
|
262
|
+
all_tool_results.extend(step_results)
|
|
263
|
+
|
|
264
|
+
if self.progress_reporter:
|
|
265
|
+
self.progress_reporter.emit_step_complete(
|
|
266
|
+
f"Executed {len(step.tool_calls)} tool(s) in step {step_num + 1}"
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Step 3: Generate agent's final response based on all observations
|
|
270
|
+
agent_response = await self._generate_agent_conclusion(
|
|
271
|
+
user_message, all_steps, all_tool_results
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
# Assemble into Conversation
|
|
275
|
+
return self._build_conversation(
|
|
276
|
+
user_message, all_steps, all_tool_results, agent_response, topic_prompt
|
|
277
|
+
)
|
|
278
|
+
finally:
|
|
279
|
+
# Always cleanup Spin session
|
|
280
|
+
await self._cleanup_spin_session()
|
|
281
|
+
|
|
282
|
+
async def _ensure_spin_session(self) -> None:
|
|
283
|
+
"""Initialize Spin session if configured."""
|
|
284
|
+
if self._spin_client is None:
|
|
285
|
+
return
|
|
286
|
+
|
|
287
|
+
# Create new session
|
|
288
|
+
session_id = str(uuid.uuid4())
|
|
289
|
+
self._spin_session = SpinSession(self._spin_client, session_id)
|
|
290
|
+
|
|
291
|
+
# Seed initial state if configured
|
|
292
|
+
scenario_seed = getattr(self.config, "scenario_seed", None)
|
|
293
|
+
if scenario_seed and isinstance(scenario_seed, dict):
|
|
294
|
+
files = scenario_seed.get("files", {})
|
|
295
|
+
if files:
|
|
296
|
+
success = await self._spin_session.seed_files(files)
|
|
297
|
+
if success:
|
|
298
|
+
logger.debug("Seeded %d files for session %s", len(files), session_id)
|
|
299
|
+
else:
|
|
300
|
+
logger.warning("Failed to seed some files for session %s", session_id)
|
|
301
|
+
|
|
302
|
+
async def _cleanup_spin_session(self) -> None:
|
|
303
|
+
"""Clean up Spin session after generation."""
|
|
304
|
+
if self._spin_session is not None:
|
|
305
|
+
await self._spin_session.cleanup()
|
|
306
|
+
self._spin_session = None
|
|
307
|
+
|
|
308
|
+
async def _generate_user_question(self, topic_prompt: str) -> ChatMessage:
|
|
309
|
+
"""Generate the user's question for this scenario.
|
|
310
|
+
|
|
311
|
+
Args:
|
|
312
|
+
topic_prompt: The scenario topic
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
User message (typed ChatMessage)
|
|
316
|
+
"""
|
|
317
|
+
prompt = f"""Generate a short, natural user question for this scenario:
|
|
318
|
+
{topic_prompt}
|
|
319
|
+
|
|
320
|
+
Requirements:
|
|
321
|
+
- Just the user's question - no reasoning, no explanations, no examples
|
|
322
|
+
- Should require using tools to answer
|
|
323
|
+
- 1-2 sentences maximum
|
|
324
|
+
- Natural, conversational tone
|
|
325
|
+
|
|
326
|
+
Example format: "Can you tell me the weather in Paris tomorrow and suggest what to wear?"
|
|
327
|
+
|
|
328
|
+
Generate only the user's question:"""
|
|
329
|
+
|
|
330
|
+
# Always use non-streaming for reliable structured output
|
|
331
|
+
response = await self.llm.generate_async(
|
|
332
|
+
prompt=prompt,
|
|
333
|
+
schema=UserQuestion,
|
|
334
|
+
max_tokens=self.config.max_tokens,
|
|
335
|
+
temperature=self.config.temperature,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
# Fire-and-forget: simulate streaming for TUI preview (non-blocking)
|
|
339
|
+
simulate_stream(
|
|
340
|
+
self.progress_reporter,
|
|
341
|
+
response.model_dump_json(),
|
|
342
|
+
source="user_question",
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
return ChatMessage(role="user", content=response.content)
|
|
346
|
+
|
|
347
|
+
async def _generate_next_step(
|
|
348
|
+
self,
|
|
349
|
+
user_message: ChatMessage,
|
|
350
|
+
previous_steps: list[AgentStep],
|
|
351
|
+
previous_results: list[ToolExecution],
|
|
352
|
+
error_feedback: str | None = None,
|
|
353
|
+
) -> AgentStep:
|
|
354
|
+
"""Generate the next ReAct step based on observations so far.
|
|
355
|
+
|
|
356
|
+
This is the core of the ReAct loop - the agent decides its next action
|
|
357
|
+
based on what it has already observed from previous tool executions.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
user_message: The original user question
|
|
361
|
+
previous_steps: Steps taken so far (thoughts + tool calls)
|
|
362
|
+
previous_results: Results from executed tools
|
|
363
|
+
error_feedback: Optional error from previous generation attempt
|
|
364
|
+
|
|
365
|
+
Returns:
|
|
366
|
+
AgentStep with thought, optional tool calls, and is_final flag
|
|
367
|
+
"""
|
|
368
|
+
max_retries = getattr(self.config, "sample_retries", DEFAULT_SAMPLE_RETRIES)
|
|
369
|
+
last_error: Exception | None = None
|
|
370
|
+
current_feedback = error_feedback
|
|
371
|
+
|
|
372
|
+
for attempt in range(max_retries + 1):
|
|
373
|
+
try:
|
|
374
|
+
return await self._generate_next_step_impl(
|
|
375
|
+
user_message, previous_steps, previous_results, current_feedback
|
|
376
|
+
)
|
|
377
|
+
except Exception as e:
|
|
378
|
+
last_error = e
|
|
379
|
+
if is_validation_error(e) and attempt < max_retries:
|
|
380
|
+
current_feedback = str(e)
|
|
381
|
+
if self.progress_reporter:
|
|
382
|
+
self.progress_reporter.emit_step_start(
|
|
383
|
+
f"Retrying step generation (attempt {attempt + 2}/{max_retries + 1})"
|
|
384
|
+
)
|
|
385
|
+
continue
|
|
386
|
+
raise
|
|
387
|
+
|
|
388
|
+
raise last_error # type: ignore[misc]
|
|
389
|
+
|
|
390
|
+
async def _generate_next_step_impl(
|
|
391
|
+
self,
|
|
392
|
+
user_message: ChatMessage,
|
|
393
|
+
previous_steps: list[AgentStep],
|
|
394
|
+
previous_results: list[ToolExecution],
|
|
395
|
+
error_feedback: str | None = None,
|
|
396
|
+
) -> AgentStep:
|
|
397
|
+
"""Implementation of next step generation."""
|
|
398
|
+
tools_info = self._format_tools_for_prompt()
|
|
399
|
+
history = self._format_step_history(previous_steps, previous_results)
|
|
400
|
+
|
|
401
|
+
prompt_parts = [
|
|
402
|
+
"## User Request",
|
|
403
|
+
user_message.content or "",
|
|
404
|
+
"",
|
|
405
|
+
"## Available Tools",
|
|
406
|
+
tools_info,
|
|
407
|
+
"",
|
|
408
|
+
"## Previous Actions & Results",
|
|
409
|
+
history if history else "None yet - this is your first action.",
|
|
410
|
+
"",
|
|
411
|
+
"## Instructions",
|
|
412
|
+
"Based on what you've observed so far, decide your next action:",
|
|
413
|
+
"- If you need more information, specify tool_calls for THIS step only",
|
|
414
|
+
"- If you have enough information to answer, set is_final=true and leave tool_calls empty",
|
|
415
|
+
"- IMPORTANT: Do NOT call write/modify operations until you've confirmed current state via read operations",
|
|
416
|
+
"- Tool arguments must use concrete values (no placeholders like '<user_input>' or null)",
|
|
417
|
+
"",
|
|
418
|
+
"What is your next step?",
|
|
419
|
+
]
|
|
420
|
+
|
|
421
|
+
if error_feedback:
|
|
422
|
+
prompt_parts.insert(
|
|
423
|
+
-1,
|
|
424
|
+
f"\n## Previous Attempt Failed\nError: {error_feedback}\nPlease fix this issue.\n",
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
prompt = "\n".join(prompt_parts)
|
|
428
|
+
|
|
429
|
+
# Always use non-streaming for reliable structured output
|
|
430
|
+
response = await self.llm.generate_async(
|
|
431
|
+
prompt=prompt,
|
|
432
|
+
schema=AgentStep,
|
|
433
|
+
max_tokens=self.config.max_tokens,
|
|
434
|
+
temperature=self.config.temperature,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
# Fire-and-forget: simulate streaming for TUI preview (non-blocking)
|
|
438
|
+
simulate_stream(
|
|
439
|
+
self.progress_reporter,
|
|
440
|
+
response.model_dump_json(),
|
|
441
|
+
source="agent_step",
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
return response
|
|
445
|
+
|
|
446
|
+
def _format_step_history(self, steps: list[AgentStep], results: list[ToolExecution]) -> str:
|
|
447
|
+
"""Format previous steps and results for the prompt.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
steps: Previous AgentSteps taken
|
|
451
|
+
results: Tool execution results (with actual outputs)
|
|
452
|
+
|
|
453
|
+
Returns:
|
|
454
|
+
Formatted string showing the progression of actions and observations
|
|
455
|
+
"""
|
|
456
|
+
if not steps:
|
|
457
|
+
return ""
|
|
458
|
+
|
|
459
|
+
history_parts = []
|
|
460
|
+
result_idx = 0
|
|
461
|
+
|
|
462
|
+
for step_num, step in enumerate(steps, 1):
|
|
463
|
+
history_parts.append(f"### Step {step_num}")
|
|
464
|
+
history_parts.append(f"Thought: {step.thought}")
|
|
465
|
+
|
|
466
|
+
if step.tool_calls:
|
|
467
|
+
history_parts.append("Tool calls:")
|
|
468
|
+
for tool_call in step.tool_calls:
|
|
469
|
+
history_parts.append(f" - {tool_call.function_name}({tool_call.arguments})")
|
|
470
|
+
# Match with result if available
|
|
471
|
+
if result_idx < len(results):
|
|
472
|
+
result = results[result_idx]
|
|
473
|
+
history_parts.append(f" Result: {result.result}")
|
|
474
|
+
result_idx += 1
|
|
475
|
+
|
|
476
|
+
history_parts.append("")
|
|
477
|
+
|
|
478
|
+
return "\n".join(history_parts)
|
|
479
|
+
|
|
480
|
+
def _get_tool_signature(self, pending_call: PendingToolCall) -> str:
|
|
481
|
+
"""Generate a signature for deduplication.
|
|
482
|
+
|
|
483
|
+
Args:
|
|
484
|
+
pending_call: The pending tool call
|
|
485
|
+
|
|
486
|
+
Returns:
|
|
487
|
+
Signature string combining tool name and arguments
|
|
488
|
+
"""
|
|
489
|
+
try:
|
|
490
|
+
# Normalize arguments by parsing and re-dumping JSON to handle
|
|
491
|
+
# differences in whitespace and key order.
|
|
492
|
+
args = json.loads(pending_call.arguments)
|
|
493
|
+
normalized_args = json.dumps(args, sort_keys=True)
|
|
494
|
+
return f"{pending_call.function_name}:{normalized_args}" # noqa: TRY300
|
|
495
|
+
except json.JSONDecodeError:
|
|
496
|
+
# Fallback for any case where arguments are not valid JSON,
|
|
497
|
+
# though this should be caught by Pydantic validation.
|
|
498
|
+
return f"{pending_call.function_name}:{pending_call.arguments}"
|
|
499
|
+
|
|
500
|
+
async def _execute_step_tools(self, tool_calls: list[PendingToolCall]) -> list[ToolExecution]:
|
|
501
|
+
"""Execute tool calls for a single ReAct step.
|
|
502
|
+
|
|
503
|
+
Skips duplicate tool calls (same tool + same arguments) that were
|
|
504
|
+
already executed in a previous step of this conversation.
|
|
505
|
+
|
|
506
|
+
Args:
|
|
507
|
+
tool_calls: Pending tool calls from the current step (without results)
|
|
508
|
+
|
|
509
|
+
Returns:
|
|
510
|
+
List of ToolExecutions with results populated from Spin
|
|
511
|
+
"""
|
|
512
|
+
completed_executions = []
|
|
513
|
+
|
|
514
|
+
for pending_call in tool_calls:
|
|
515
|
+
# Check for duplicate
|
|
516
|
+
signature = self._get_tool_signature(pending_call)
|
|
517
|
+
if signature in self._seen_tool_signatures:
|
|
518
|
+
logger.debug("Skipping duplicate tool call: %s", pending_call.function_name)
|
|
519
|
+
continue
|
|
520
|
+
|
|
521
|
+
# Mark as seen
|
|
522
|
+
self._seen_tool_signatures.add(signature)
|
|
523
|
+
|
|
524
|
+
# Get tool definition from registry
|
|
525
|
+
tool_def = self.tool_registry.get_tool(pending_call.function_name)
|
|
526
|
+
if not tool_def:
|
|
527
|
+
# Return error as result instead of raising
|
|
528
|
+
completed_executions.append(
|
|
529
|
+
pending_call.to_tool_execution(
|
|
530
|
+
f"Error: Tool '{pending_call.function_name}' not found in registry"
|
|
531
|
+
)
|
|
532
|
+
)
|
|
533
|
+
continue
|
|
534
|
+
|
|
535
|
+
# Execute tool via Spin
|
|
536
|
+
result = await self._generate_tool_output(tool_def, pending_call)
|
|
537
|
+
completed_executions.append(pending_call.to_tool_execution(result.result))
|
|
538
|
+
|
|
539
|
+
return completed_executions
|
|
540
|
+
|
|
541
|
+
async def _generate_tool_output(
|
|
542
|
+
self,
|
|
543
|
+
tool_def: ToolDefinition,
|
|
544
|
+
pending_call: PendingToolCall,
|
|
545
|
+
error_feedback: str | None = None, # noqa: ARG002 - kept for interface compatibility
|
|
546
|
+
) -> ToolOutput:
|
|
547
|
+
"""Execute tool via Spin and return real output.
|
|
548
|
+
|
|
549
|
+
Args:
|
|
550
|
+
tool_def: Tool definition from registry
|
|
551
|
+
pending_call: Pending tool call with arguments (no result yet)
|
|
552
|
+
error_feedback: Unused - kept for interface compatibility with retry logic
|
|
553
|
+
|
|
554
|
+
Returns:
|
|
555
|
+
ToolOutput with real execution result
|
|
556
|
+
|
|
557
|
+
Raises:
|
|
558
|
+
DataSetGeneratorError: If Spin is not configured or execution fails
|
|
559
|
+
"""
|
|
560
|
+
# Require Spin for tool execution
|
|
561
|
+
if self._spin_session is None:
|
|
562
|
+
raise DataSetGeneratorError(
|
|
563
|
+
"Spin endpoint not configured. Tool execution requires a Spin service. "
|
|
564
|
+
"Set 'spin_endpoint' in your tools configuration."
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
# Parse arguments from JSON string
|
|
568
|
+
try:
|
|
569
|
+
args: dict[str, Any] = json.loads(pending_call.arguments)
|
|
570
|
+
except json.JSONDecodeError as e:
|
|
571
|
+
return ToolOutput(result=f"Error: Invalid JSON arguments: {e}")
|
|
572
|
+
|
|
573
|
+
# Execute via Spin
|
|
574
|
+
result = await self._spin_session.execute_tool(
|
|
575
|
+
tool_name=tool_def.name,
|
|
576
|
+
arguments=args,
|
|
577
|
+
component=tool_def.component,
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
if result.success:
|
|
581
|
+
if self.progress_reporter:
|
|
582
|
+
self.progress_reporter.emit_tool_execution(
|
|
583
|
+
tool_def.name, success=True, arguments=args
|
|
584
|
+
)
|
|
585
|
+
return ToolOutput(result=result.result)
|
|
586
|
+
|
|
587
|
+
# Return error as tool output (this is valid training data for error handling)
|
|
588
|
+
error_msg = result.result
|
|
589
|
+
if result.error_type:
|
|
590
|
+
error_msg = f"Error ({result.error_type}): {result.result}"
|
|
591
|
+
if self.progress_reporter:
|
|
592
|
+
self.progress_reporter.emit_tool_execution(
|
|
593
|
+
tool_def.name,
|
|
594
|
+
success=False,
|
|
595
|
+
arguments=args,
|
|
596
|
+
error_type=result.error_type or "error",
|
|
597
|
+
)
|
|
598
|
+
return ToolOutput(result=error_msg)
|
|
599
|
+
|
|
600
|
+
async def _generate_agent_conclusion(
|
|
601
|
+
self,
|
|
602
|
+
user_message: ChatMessage,
|
|
603
|
+
steps: list[AgentStep], # noqa: ARG002 - kept for potential future use
|
|
604
|
+
tool_results: list[ToolExecution],
|
|
605
|
+
context: str = "",
|
|
606
|
+
) -> ChatMessage:
|
|
607
|
+
"""Generate agent's final response interpreting tool results.
|
|
608
|
+
|
|
609
|
+
Args:
|
|
610
|
+
user_message: Original user question
|
|
611
|
+
steps: All ReAct steps taken (with thoughts)
|
|
612
|
+
tool_results: All tool execution results
|
|
613
|
+
context: Previous conversation context (for multi-turn)
|
|
614
|
+
|
|
615
|
+
Returns:
|
|
616
|
+
Agent's final response message
|
|
617
|
+
"""
|
|
618
|
+
# Format tool results summary
|
|
619
|
+
results_text = (
|
|
620
|
+
"\n".join([f"Tool: {r.function_name}\nResult: {r.result}" for r in tool_results])
|
|
621
|
+
if tool_results
|
|
622
|
+
else "No tools were executed."
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
# Format available tools for context
|
|
626
|
+
tools_info = self._format_tools_for_prompt()
|
|
627
|
+
|
|
628
|
+
# Build context section if provided
|
|
629
|
+
context_section = ""
|
|
630
|
+
if context:
|
|
631
|
+
context_section = f"Previous conversation:\n{context}\n\n"
|
|
632
|
+
|
|
633
|
+
prompt = f"""{self.config.dataset_system_prompt}
|
|
634
|
+
|
|
635
|
+
Available tools:
|
|
636
|
+
{tools_info}
|
|
637
|
+
|
|
638
|
+
{context_section}User request: {user_message.content}
|
|
639
|
+
|
|
640
|
+
You executed these tools:
|
|
641
|
+
{results_text}
|
|
642
|
+
|
|
643
|
+
Based on these results, provide a clear, helpful response to the user.
|
|
644
|
+
Remember: You have access to the tools listed above and have used them in this conversation."""
|
|
645
|
+
|
|
646
|
+
# Always use non-streaming for reliable structured output
|
|
647
|
+
response = await self.llm.generate_async(
|
|
648
|
+
prompt=prompt,
|
|
649
|
+
schema=AgentResponse,
|
|
650
|
+
max_tokens=self.config.max_tokens,
|
|
651
|
+
temperature=self.config.temperature,
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
# Fire-and-forget: simulate streaming for TUI preview (non-blocking)
|
|
655
|
+
simulate_stream(
|
|
656
|
+
self.progress_reporter,
|
|
657
|
+
response.model_dump_json(),
|
|
658
|
+
source="agent_response",
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
return ChatMessage(role="assistant", content=response.content)
|
|
662
|
+
|
|
663
|
+
def _build_conversation(
|
|
664
|
+
self,
|
|
665
|
+
user_message: ChatMessage,
|
|
666
|
+
steps: list[AgentStep],
|
|
667
|
+
tool_results: list[ToolExecution],
|
|
668
|
+
agent_response: ChatMessage,
|
|
669
|
+
_topic_prompt: str = "",
|
|
670
|
+
) -> Conversation:
|
|
671
|
+
"""Assemble all components into a Conversation.
|
|
672
|
+
|
|
673
|
+
Preserves ReAct step-by-step structure: each step's tool calls become
|
|
674
|
+
a separate assistant message followed by tool responses. This ensures
|
|
675
|
+
training data shows the agent making decisions AFTER observing results.
|
|
676
|
+
|
|
677
|
+
Args:
|
|
678
|
+
user_message: User's question
|
|
679
|
+
steps: All ReAct steps (thoughts + tool calls)
|
|
680
|
+
tool_results: All tool execution results
|
|
681
|
+
agent_response: Agent's final response
|
|
682
|
+
_topic_prompt: Topic used to generate this conversation (unused, for interface)
|
|
683
|
+
|
|
684
|
+
Returns:
|
|
685
|
+
Complete Conversation object
|
|
686
|
+
"""
|
|
687
|
+
messages = []
|
|
688
|
+
|
|
689
|
+
# Add user message
|
|
690
|
+
messages.append(user_message)
|
|
691
|
+
|
|
692
|
+
# Process each ReAct step separately to preserve step-by-step structure
|
|
693
|
+
# This is critical: agent should see results from step N before deciding step N+1
|
|
694
|
+
result_idx = 0
|
|
695
|
+
|
|
696
|
+
for step in steps:
|
|
697
|
+
# Skip steps with no tool calls (e.g., final "is_final=true" step)
|
|
698
|
+
if not step.tool_calls:
|
|
699
|
+
continue
|
|
700
|
+
|
|
701
|
+
# Build tool_calls for THIS step only
|
|
702
|
+
step_tool_calls: list[ToolCall] = []
|
|
703
|
+
step_tool_call_ids: list[str] = []
|
|
704
|
+
|
|
705
|
+
for _pending_call in step.tool_calls:
|
|
706
|
+
tool_call_id = generate_tool_call_id()
|
|
707
|
+
step_tool_call_ids.append(tool_call_id)
|
|
708
|
+
# Get the matching result
|
|
709
|
+
if result_idx < len(tool_results):
|
|
710
|
+
result = tool_results[result_idx]
|
|
711
|
+
step_tool_calls.append(result.to_tool_call(tool_call_id))
|
|
712
|
+
result_idx += 1
|
|
713
|
+
|
|
714
|
+
# Assistant message with tool_calls for this step
|
|
715
|
+
if step_tool_calls:
|
|
716
|
+
messages.append(
|
|
717
|
+
ChatMessage(
|
|
718
|
+
role="assistant",
|
|
719
|
+
content="",
|
|
720
|
+
tool_calls=step_tool_calls,
|
|
721
|
+
)
|
|
722
|
+
)
|
|
723
|
+
|
|
724
|
+
# Tool response messages for this step
|
|
725
|
+
# We need to re-iterate to get matching results
|
|
726
|
+
result_base_idx = result_idx - len(step_tool_calls)
|
|
727
|
+
for idx, _tc in enumerate(step_tool_calls):
|
|
728
|
+
res_idx = result_base_idx + idx
|
|
729
|
+
if res_idx < len(tool_results):
|
|
730
|
+
messages.append(
|
|
731
|
+
ChatMessage(
|
|
732
|
+
role="tool",
|
|
733
|
+
content=tool_results[res_idx].result,
|
|
734
|
+
tool_call_id=step_tool_call_ids[idx],
|
|
735
|
+
)
|
|
736
|
+
)
|
|
737
|
+
|
|
738
|
+
# Add final assistant response with the answer
|
|
739
|
+
messages.append(agent_response)
|
|
740
|
+
|
|
741
|
+
# Build tool context (executions only - tools are in top-level 'tools' field)
|
|
742
|
+
tool_context = ToolContext(
|
|
743
|
+
executions=tool_results,
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
# Build reasoning trace from AgentSteps
|
|
747
|
+
reasoning_steps = _convert_steps_to_reasoning(steps, "Ready to respond to user")
|
|
748
|
+
|
|
749
|
+
reasoning_trace = ReasoningTrace(
|
|
750
|
+
style=self.config.reasoning_style or "agent", # type: ignore
|
|
751
|
+
content=reasoning_steps,
|
|
752
|
+
)
|
|
753
|
+
|
|
754
|
+
# Build agent context
|
|
755
|
+
agent_context = AgentContext(mode="single_turn")
|
|
756
|
+
|
|
757
|
+
# Build metadata
|
|
758
|
+
metadata = {
|
|
759
|
+
"conversation_type": "chain_of_thought",
|
|
760
|
+
"react_steps": len(steps),
|
|
761
|
+
}
|
|
762
|
+
|
|
763
|
+
# Insert system message if configured
|
|
764
|
+
self._insert_system_message_if_configured(messages)
|
|
765
|
+
|
|
766
|
+
# Convert tools to OpenAI format
|
|
767
|
+
tools_openai = [tool.to_openai() for tool in self.tool_registry.tools]
|
|
768
|
+
|
|
769
|
+
return Conversation(
|
|
770
|
+
messages=messages,
|
|
771
|
+
reasoning=reasoning_trace,
|
|
772
|
+
tool_context=tool_context,
|
|
773
|
+
tools=tools_openai,
|
|
774
|
+
agent_context=agent_context,
|
|
775
|
+
question=user_message.content or "",
|
|
776
|
+
final_answer=agent_response.content or "",
|
|
777
|
+
metadata=metadata,
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
def _format_tools_for_prompt(self) -> str:
|
|
781
|
+
"""Format available tools for inclusion in prompts.
|
|
782
|
+
|
|
783
|
+
Provides detailed tool information including parameter descriptions
|
|
784
|
+
and whether parameters are required, helping the LLM generate
|
|
785
|
+
correct tool calls.
|
|
786
|
+
|
|
787
|
+
Returns:
|
|
788
|
+
Formatted string describing available tools
|
|
789
|
+
"""
|
|
790
|
+
tool_descriptions = []
|
|
791
|
+
for tool in self.tool_registry.tools:
|
|
792
|
+
# Build parameter details
|
|
793
|
+
if tool.parameters:
|
|
794
|
+
param_lines = []
|
|
795
|
+
for p in tool.parameters:
|
|
796
|
+
req_marker = "(required)" if p.required else "(optional)"
|
|
797
|
+
param_lines.append(f" - {p.name}: {p.type} {req_marker} - {p.description}")
|
|
798
|
+
params_section = "\n".join(param_lines)
|
|
799
|
+
tool_descriptions.append(
|
|
800
|
+
f"### {tool.name}\n"
|
|
801
|
+
f"{tool.description}\n"
|
|
802
|
+
f"Parameters:\n{params_section}\n"
|
|
803
|
+
f"Returns: {tool.returns}"
|
|
804
|
+
)
|
|
805
|
+
else:
|
|
806
|
+
# No parameters - make this explicit
|
|
807
|
+
tool_descriptions.append(
|
|
808
|
+
f"### {tool.name}\n"
|
|
809
|
+
f"{tool.description}\n"
|
|
810
|
+
f"Parameters: None (use empty object {{}})\n"
|
|
811
|
+
f"Returns: {tool.returns}"
|
|
812
|
+
)
|
|
813
|
+
|
|
814
|
+
return "\n\n".join(tool_descriptions)
|
|
815
|
+
|
|
816
|
+
def _insert_system_message_if_configured(self, messages: list[ChatMessage]) -> None:
|
|
817
|
+
"""Insert system message at the beginning of messages if configured.
|
|
818
|
+
|
|
819
|
+
Args:
|
|
820
|
+
messages: List of messages to potentially prepend system message to
|
|
821
|
+
"""
|
|
822
|
+
if self.config.sys_msg:
|
|
823
|
+
messages.insert(
|
|
824
|
+
0,
|
|
825
|
+
ChatMessage(role="system", content=self.config.dataset_system_prompt or ""),
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
|
|
829
|
+
class MultiTurnAgentBuilder(SingleTurnAgentBuilder):
|
|
830
|
+
"""Builder for multi-turn agent conversations.
|
|
831
|
+
|
|
832
|
+
Extends SingleTurnAgentBuilder to generate conversations with multiple
|
|
833
|
+
user-agent interaction turns. Each turn can involve different tools
|
|
834
|
+
and builds on previous context.
|
|
835
|
+
"""
|
|
836
|
+
|
|
837
|
+
async def generate(
|
|
838
|
+
self,
|
|
839
|
+
topic_prompt: str,
|
|
840
|
+
error_feedback: str | None = None, # noqa: ARG002
|
|
841
|
+
) -> Conversation:
|
|
842
|
+
"""Generate multi-turn agent conversation using ReAct loop.
|
|
843
|
+
|
|
844
|
+
Args:
|
|
845
|
+
topic_prompt: Topic or scenario to generate conversation about
|
|
846
|
+
error_feedback: Unused, kept for interface consistency with ConversationBuilder
|
|
847
|
+
|
|
848
|
+
Returns:
|
|
849
|
+
Complete multi-turn Conversation
|
|
850
|
+
|
|
851
|
+
Raises:
|
|
852
|
+
ValueError: If generation fails or config is invalid
|
|
853
|
+
"""
|
|
854
|
+
try:
|
|
855
|
+
# Initialize Spin session if configured
|
|
856
|
+
await self._ensure_spin_session()
|
|
857
|
+
|
|
858
|
+
# Determine number of turns (from config range)
|
|
859
|
+
num_turns = random.randint(self.config.min_turns, self.config.max_turns) # noqa: S311 # nosec
|
|
860
|
+
|
|
861
|
+
# Track conversation context
|
|
862
|
+
turns: list[AgentTurnData] = []
|
|
863
|
+
all_messages: list[ChatMessage] = []
|
|
864
|
+
|
|
865
|
+
# Reset duplicate tracking for this conversation
|
|
866
|
+
self._seen_tool_signatures.clear()
|
|
867
|
+
|
|
868
|
+
# Generate scenario overview
|
|
869
|
+
scenario = await self._generate_scenario(topic_prompt, num_turns)
|
|
870
|
+
|
|
871
|
+
for turn_idx in range(num_turns):
|
|
872
|
+
# Generate this turn using ReAct loop
|
|
873
|
+
turn_data = await self._generate_turn(turn_idx, scenario, all_messages)
|
|
874
|
+
turns.append(turn_data)
|
|
875
|
+
|
|
876
|
+
# Accumulate messages for context
|
|
877
|
+
all_messages.extend(
|
|
878
|
+
[
|
|
879
|
+
turn_data.user_message,
|
|
880
|
+
turn_data.agent_response,
|
|
881
|
+
]
|
|
882
|
+
)
|
|
883
|
+
|
|
884
|
+
# Count total tool calls so far
|
|
885
|
+
total_tool_calls = sum(len(t.tool_calls) for t in turns)
|
|
886
|
+
|
|
887
|
+
# Check if we should conclude early
|
|
888
|
+
if turn_idx >= self.config.min_turns - 1 and await self._should_conclude_early(
|
|
889
|
+
all_messages, scenario, turn_idx + 1, total_tool_calls
|
|
890
|
+
):
|
|
891
|
+
break
|
|
892
|
+
|
|
893
|
+
# Assemble into complete conversation
|
|
894
|
+
return self._build_multi_turn_conversation(turns, scenario, topic_prompt)
|
|
895
|
+
finally:
|
|
896
|
+
# Always cleanup Spin session
|
|
897
|
+
await self._cleanup_spin_session()
|
|
898
|
+
|
|
899
|
+
async def _generate_scenario(self, topic_prompt: str, num_turns: int) -> str:
|
|
900
|
+
"""Generate a multi-turn scenario description.
|
|
901
|
+
|
|
902
|
+
Args:
|
|
903
|
+
topic_prompt: Original topic
|
|
904
|
+
num_turns: Number of turns to plan for
|
|
905
|
+
|
|
906
|
+
Returns:
|
|
907
|
+
Scenario description that requires multiple interactions
|
|
908
|
+
"""
|
|
909
|
+
tools_info = self._format_tools_for_prompt()
|
|
910
|
+
|
|
911
|
+
prompt = (
|
|
912
|
+
f"Generate a realistic scenario for this topic that requires {num_turns} user-agent interaction turns:\n"
|
|
913
|
+
f"{topic_prompt}\n\n"
|
|
914
|
+
f"Available tools:\n"
|
|
915
|
+
f"{tools_info}\n\n"
|
|
916
|
+
f"The scenario MUST:\n"
|
|
917
|
+
f"- Require at least {num_turns} distinct tool calls across different turns\n"
|
|
918
|
+
f"- Have tool dependencies (e.g., read before modify, search before create, fetch before analyze)\n"
|
|
919
|
+
f"- Build progressively - each turn depends on results from previous turns\n"
|
|
920
|
+
f"- NOT be completable in a single turn\n\n"
|
|
921
|
+
f"Example structure for a {num_turns}-turn scenario:\n"
|
|
922
|
+
f"- Turn 1: User asks to find/read/search something\n"
|
|
923
|
+
f"- Turn 2: User asks to modify/create based on what was found\n"
|
|
924
|
+
f"- Turn 3+: User asks to verify, take action, or build further on previous results\n\n"
|
|
925
|
+
f"Keep it brief (2-3 sentences) but ensure multi-step complexity with clear tool dependencies."
|
|
926
|
+
)
|
|
927
|
+
|
|
928
|
+
# Always use non-streaming for reliable structured output
|
|
929
|
+
response = await self.llm.generate_async(
|
|
930
|
+
prompt=prompt,
|
|
931
|
+
schema=Scenario,
|
|
932
|
+
max_tokens=self.config.max_tokens,
|
|
933
|
+
temperature=self.config.temperature,
|
|
934
|
+
)
|
|
935
|
+
|
|
936
|
+
# Fire-and-forget: simulate streaming for TUI preview (non-blocking)
|
|
937
|
+
simulate_stream(
|
|
938
|
+
self.progress_reporter,
|
|
939
|
+
response.model_dump_json(),
|
|
940
|
+
source="scenario_gen",
|
|
941
|
+
)
|
|
942
|
+
|
|
943
|
+
return response.description
|
|
944
|
+
|
|
945
|
+
async def _generate_turn(
|
|
946
|
+
self,
|
|
947
|
+
turn_idx: int,
|
|
948
|
+
scenario: str,
|
|
949
|
+
previous_messages: list[ChatMessage],
|
|
950
|
+
) -> AgentTurnData:
|
|
951
|
+
"""Generate a single turn of the conversation using ReAct loop.
|
|
952
|
+
|
|
953
|
+
Args:
|
|
954
|
+
turn_idx: Index of this turn (0-based)
|
|
955
|
+
scenario: Overall scenario description
|
|
956
|
+
previous_messages: Messages from previous turns
|
|
957
|
+
|
|
958
|
+
Returns:
|
|
959
|
+
Complete turn data with step-by-step structure preserved
|
|
960
|
+
"""
|
|
961
|
+
# Build context from previous messages
|
|
962
|
+
context_text = self._format_message_context(previous_messages)
|
|
963
|
+
|
|
964
|
+
# Generate user message for this turn
|
|
965
|
+
user_message = await self._generate_turn_user_message(turn_idx, scenario, context_text)
|
|
966
|
+
|
|
967
|
+
# ReAct loop for this turn - preserve step structure
|
|
968
|
+
steps_with_results: list[StepWithResults] = []
|
|
969
|
+
all_steps: list[AgentStep] = [] # For passing to next step generation
|
|
970
|
+
all_tool_results: list[ToolExecution] = [] # For passing to next step generation
|
|
971
|
+
max_steps = getattr(self.config, "max_agent_steps", 5)
|
|
972
|
+
|
|
973
|
+
for step_num in range(max_steps):
|
|
974
|
+
if self.progress_reporter:
|
|
975
|
+
self.progress_reporter.emit_step_start(
|
|
976
|
+
f"Turn {turn_idx + 1}, ReAct step {step_num + 1}/{max_steps}"
|
|
977
|
+
)
|
|
978
|
+
|
|
979
|
+
# Generate next step based on observations so far
|
|
980
|
+
step = await self._generate_next_step_with_context(
|
|
981
|
+
user_message,
|
|
982
|
+
all_steps,
|
|
983
|
+
all_tool_results,
|
|
984
|
+
context_text,
|
|
985
|
+
)
|
|
986
|
+
all_steps.append(step)
|
|
987
|
+
|
|
988
|
+
# Check if agent is done
|
|
989
|
+
if step.is_final or not step.tool_calls:
|
|
990
|
+
# Add final step with no results
|
|
991
|
+
steps_with_results.append(StepWithResults(step=step, results=[]))
|
|
992
|
+
if self.progress_reporter:
|
|
993
|
+
self.progress_reporter.emit_step_complete(
|
|
994
|
+
f"Turn {turn_idx + 1}: Agent concluded after {step_num + 1} steps"
|
|
995
|
+
)
|
|
996
|
+
break
|
|
997
|
+
|
|
998
|
+
# Execute THIS step's tools via Spin
|
|
999
|
+
step_results = await self._execute_step_tools(step.tool_calls)
|
|
1000
|
+
|
|
1001
|
+
# Store step with its results
|
|
1002
|
+
steps_with_results.append(StepWithResults(step=step, results=step_results))
|
|
1003
|
+
all_tool_results.extend(step_results)
|
|
1004
|
+
|
|
1005
|
+
if self.progress_reporter:
|
|
1006
|
+
self.progress_reporter.emit_step_complete(
|
|
1007
|
+
f"Turn {turn_idx + 1}: Executed {len(step.tool_calls)} tool(s)"
|
|
1008
|
+
)
|
|
1009
|
+
|
|
1010
|
+
# Generate agent response based on all observations
|
|
1011
|
+
agent_response = await self._generate_agent_conclusion(
|
|
1012
|
+
user_message, all_steps, all_tool_results, context=context_text
|
|
1013
|
+
)
|
|
1014
|
+
|
|
1015
|
+
return AgentTurnData(
|
|
1016
|
+
user_message=user_message,
|
|
1017
|
+
steps_with_results=steps_with_results,
|
|
1018
|
+
agent_response=agent_response,
|
|
1019
|
+
)
|
|
1020
|
+
|
|
1021
|
+
async def _generate_next_step_with_context(
|
|
1022
|
+
self,
|
|
1023
|
+
user_message: ChatMessage,
|
|
1024
|
+
previous_steps: list[AgentStep],
|
|
1025
|
+
previous_results: list[ToolExecution],
|
|
1026
|
+
conversation_context: str,
|
|
1027
|
+
) -> AgentStep:
|
|
1028
|
+
"""Generate the next ReAct step with conversation context.
|
|
1029
|
+
|
|
1030
|
+
Similar to _generate_next_step but includes multi-turn conversation context.
|
|
1031
|
+
"""
|
|
1032
|
+
tools_info = self._format_tools_for_prompt()
|
|
1033
|
+
history = self._format_step_history(previous_steps, previous_results)
|
|
1034
|
+
|
|
1035
|
+
prompt_parts = [
|
|
1036
|
+
"## Conversation Context",
|
|
1037
|
+
conversation_context if conversation_context else "(No previous conversation)",
|
|
1038
|
+
"",
|
|
1039
|
+
"## Current User Request",
|
|
1040
|
+
user_message.content or "",
|
|
1041
|
+
"",
|
|
1042
|
+
"## Available Tools",
|
|
1043
|
+
tools_info,
|
|
1044
|
+
"",
|
|
1045
|
+
"## Previous Actions & Results (this turn)",
|
|
1046
|
+
history if history else "None yet - this is your first action for this turn.",
|
|
1047
|
+
"",
|
|
1048
|
+
"## Instructions",
|
|
1049
|
+
"Based on the conversation context and what you've observed so far, decide your next action:",
|
|
1050
|
+
"- If you need more information, specify tool_calls for THIS step only",
|
|
1051
|
+
"- If you have enough information to answer, set is_final=true and leave tool_calls empty",
|
|
1052
|
+
"- IMPORTANT: Do NOT call write/modify operations until you've confirmed current state via read operations",
|
|
1053
|
+
"- Tool arguments must use concrete values (no placeholders like '<user_input>' or null)",
|
|
1054
|
+
"",
|
|
1055
|
+
"What is your next step?",
|
|
1056
|
+
]
|
|
1057
|
+
|
|
1058
|
+
prompt = "\n".join(prompt_parts)
|
|
1059
|
+
|
|
1060
|
+
# Always use non-streaming for reliable structured output
|
|
1061
|
+
response = await self.llm.generate_async(
|
|
1062
|
+
prompt=prompt,
|
|
1063
|
+
schema=AgentStep,
|
|
1064
|
+
max_tokens=self.config.max_tokens,
|
|
1065
|
+
temperature=self.config.temperature,
|
|
1066
|
+
)
|
|
1067
|
+
|
|
1068
|
+
# Fire-and-forget: simulate streaming for TUI preview (non-blocking)
|
|
1069
|
+
simulate_stream(
|
|
1070
|
+
self.progress_reporter,
|
|
1071
|
+
response.model_dump_json(),
|
|
1072
|
+
source="agent_step_mt",
|
|
1073
|
+
)
|
|
1074
|
+
|
|
1075
|
+
return response
|
|
1076
|
+
|
|
1077
|
+
async def _generate_turn_user_message(
|
|
1078
|
+
self,
|
|
1079
|
+
turn_idx: int,
|
|
1080
|
+
scenario: str,
|
|
1081
|
+
context: str,
|
|
1082
|
+
) -> ChatMessage:
|
|
1083
|
+
"""Generate user message for a specific turn.
|
|
1084
|
+
|
|
1085
|
+
Args:
|
|
1086
|
+
turn_idx: Turn index
|
|
1087
|
+
scenario: Overall scenario
|
|
1088
|
+
context: Previous conversation context
|
|
1089
|
+
|
|
1090
|
+
Returns:
|
|
1091
|
+
User message for this turn
|
|
1092
|
+
"""
|
|
1093
|
+
turn_guidance = {
|
|
1094
|
+
0: "Start with the initial request or question",
|
|
1095
|
+
1: "Request a follow-up action or ask for more information",
|
|
1096
|
+
2: "Request another related action or verify results",
|
|
1097
|
+
3: "Final request or verification",
|
|
1098
|
+
}
|
|
1099
|
+
|
|
1100
|
+
guidance = turn_guidance.get(turn_idx, "Continue the conversation naturally")
|
|
1101
|
+
|
|
1102
|
+
prompt = f"""Scenario: {scenario}
|
|
1103
|
+
|
|
1104
|
+
Previous conversation:
|
|
1105
|
+
{context if context else "(No previous conversation)"}
|
|
1106
|
+
|
|
1107
|
+
Generate the user's message for turn {turn_idx + 1}.
|
|
1108
|
+
Guidance: {guidance}
|
|
1109
|
+
|
|
1110
|
+
The message should reference or build upon previous conversation if applicable.
|
|
1111
|
+
Keep it concise and natural."""
|
|
1112
|
+
|
|
1113
|
+
# Always use non-streaming for reliable structured output
|
|
1114
|
+
response = await self.llm.generate_async(
|
|
1115
|
+
prompt=prompt,
|
|
1116
|
+
schema=UserQuestion,
|
|
1117
|
+
max_tokens=self.config.max_tokens,
|
|
1118
|
+
temperature=self.config.temperature,
|
|
1119
|
+
)
|
|
1120
|
+
|
|
1121
|
+
# Fire-and-forget: simulate streaming for TUI preview (non-blocking)
|
|
1122
|
+
simulate_stream(
|
|
1123
|
+
self.progress_reporter,
|
|
1124
|
+
response.model_dump_json(),
|
|
1125
|
+
source=f"turn_{turn_idx}_user",
|
|
1126
|
+
turn=turn_idx + 1,
|
|
1127
|
+
)
|
|
1128
|
+
|
|
1129
|
+
return ChatMessage(role="user", content=response.content)
|
|
1130
|
+
|
|
1131
|
+
async def _should_conclude_early(
|
|
1132
|
+
self, messages: list[ChatMessage], scenario: str, current_turn: int, total_tool_calls: int
|
|
1133
|
+
) -> bool:
|
|
1134
|
+
"""Determine if conversation should conclude before max_turns.
|
|
1135
|
+
|
|
1136
|
+
Args:
|
|
1137
|
+
messages: All messages so far
|
|
1138
|
+
scenario: Original scenario
|
|
1139
|
+
current_turn: Current turn number
|
|
1140
|
+
total_tool_calls: Total number of tool calls made so far
|
|
1141
|
+
|
|
1142
|
+
Returns:
|
|
1143
|
+
True if conversation should end
|
|
1144
|
+
"""
|
|
1145
|
+
# Don't conclude early if we haven't met the minimum tool calls requirement
|
|
1146
|
+
if total_tool_calls < self.config.min_tool_calls:
|
|
1147
|
+
return False
|
|
1148
|
+
|
|
1149
|
+
# Format conversation so far
|
|
1150
|
+
conversation_text = self._format_message_context(messages)
|
|
1151
|
+
|
|
1152
|
+
prompt = f"""Scenario: {scenario}
|
|
1153
|
+
|
|
1154
|
+
Conversation so far (after {current_turn} turns):
|
|
1155
|
+
{conversation_text}
|
|
1156
|
+
|
|
1157
|
+
Is the user's original task/goal from the scenario fully completed?
|
|
1158
|
+
- True: Task is complete, conversation can end naturally
|
|
1159
|
+
- False: Task incomplete, more turns needed"""
|
|
1160
|
+
|
|
1161
|
+
response = await self.llm.generate_async(
|
|
1162
|
+
prompt=prompt,
|
|
1163
|
+
schema=ConclusionDecision,
|
|
1164
|
+
max_tokens=100,
|
|
1165
|
+
temperature=0.3,
|
|
1166
|
+
)
|
|
1167
|
+
|
|
1168
|
+
return response.should_conclude
|
|
1169
|
+
|
|
1170
|
+
def _format_message_context(self, messages: list[ChatMessage]) -> str:
|
|
1171
|
+
"""Format messages as readable context.
|
|
1172
|
+
|
|
1173
|
+
Args:
|
|
1174
|
+
messages: List of chat messages
|
|
1175
|
+
|
|
1176
|
+
Returns:
|
|
1177
|
+
Formatted string of messages
|
|
1178
|
+
"""
|
|
1179
|
+
if not messages:
|
|
1180
|
+
return ""
|
|
1181
|
+
|
|
1182
|
+
lines = []
|
|
1183
|
+
for msg in messages:
|
|
1184
|
+
lines.append(f"{msg.role}: {msg.content}")
|
|
1185
|
+
|
|
1186
|
+
return "\n".join(lines)
|
|
1187
|
+
|
|
1188
|
+
def _build_multi_turn_conversation(
|
|
1189
|
+
self, turns: list[AgentTurnData], scenario: str, topic_prompt: str = ""
|
|
1190
|
+
) -> Conversation:
|
|
1191
|
+
"""Assemble multi-turn conversation from turn data.
|
|
1192
|
+
|
|
1193
|
+
Preserves ReAct step-by-step structure: each step's tool calls become
|
|
1194
|
+
a separate assistant message followed by tool responses. This ensures
|
|
1195
|
+
training data shows the agent making decisions AFTER observing results.
|
|
1196
|
+
|
|
1197
|
+
Args:
|
|
1198
|
+
turns: List of turn data
|
|
1199
|
+
scenario: Scenario description
|
|
1200
|
+
topic_prompt: Topic used to generate this conversation (for metadata)
|
|
1201
|
+
|
|
1202
|
+
Returns:
|
|
1203
|
+
Complete Conversation object
|
|
1204
|
+
"""
|
|
1205
|
+
messages = []
|
|
1206
|
+
|
|
1207
|
+
# Don't add system message for agent mode - it interferes with tool calling
|
|
1208
|
+
# The system prompt teaches models to explain tool usage instead of executing tools
|
|
1209
|
+
# For tool calling, the tool definitions themselves serve as instructions
|
|
1210
|
+
|
|
1211
|
+
# Collect all reasoning steps and tool executions
|
|
1212
|
+
all_reasoning: list[ReasoningStep] = []
|
|
1213
|
+
all_executions: list[ToolExecution] = []
|
|
1214
|
+
|
|
1215
|
+
# Add messages from each turn in correct order:
|
|
1216
|
+
# For each turn: user -> [step: assistant(tool_calls) -> tool(responses)]* -> assistant(final)
|
|
1217
|
+
for turn in turns:
|
|
1218
|
+
# User message
|
|
1219
|
+
messages.append(turn.user_message)
|
|
1220
|
+
|
|
1221
|
+
# Process each ReAct step separately to preserve step-by-step structure
|
|
1222
|
+
# This is critical: agent should see results from step N before deciding step N+1
|
|
1223
|
+
for step_with_results in turn.steps_with_results:
|
|
1224
|
+
step = step_with_results.step
|
|
1225
|
+
step_results = step_with_results.results
|
|
1226
|
+
|
|
1227
|
+
# Skip steps with no tool calls (e.g., final "is_final=true" step)
|
|
1228
|
+
if not step.tool_calls:
|
|
1229
|
+
continue
|
|
1230
|
+
|
|
1231
|
+
# Build tool_calls for THIS step only
|
|
1232
|
+
step_tool_calls: list[ToolCall] = []
|
|
1233
|
+
step_tool_call_ids: list[str] = []
|
|
1234
|
+
for result in step_results:
|
|
1235
|
+
tool_call_id = generate_tool_call_id()
|
|
1236
|
+
step_tool_call_ids.append(tool_call_id)
|
|
1237
|
+
step_tool_calls.append(result.to_tool_call(tool_call_id))
|
|
1238
|
+
|
|
1239
|
+
# Assistant message with tool_calls for this step
|
|
1240
|
+
messages.append(
|
|
1241
|
+
ChatMessage(
|
|
1242
|
+
role="assistant",
|
|
1243
|
+
content="",
|
|
1244
|
+
tool_calls=step_tool_calls,
|
|
1245
|
+
)
|
|
1246
|
+
)
|
|
1247
|
+
|
|
1248
|
+
# Tool response messages for this step
|
|
1249
|
+
for idx, result in enumerate(step_results):
|
|
1250
|
+
messages.append(
|
|
1251
|
+
ChatMessage(
|
|
1252
|
+
role="tool",
|
|
1253
|
+
content=result.result,
|
|
1254
|
+
tool_call_id=step_tool_call_ids[idx],
|
|
1255
|
+
)
|
|
1256
|
+
)
|
|
1257
|
+
|
|
1258
|
+
# Accumulate executions for this step
|
|
1259
|
+
all_executions.extend(step_results)
|
|
1260
|
+
|
|
1261
|
+
# Final assistant response with the answer for this turn
|
|
1262
|
+
messages.append(turn.agent_response)
|
|
1263
|
+
|
|
1264
|
+
# Accumulate reasoning across all turns
|
|
1265
|
+
all_reasoning.extend(turn.reasoning_steps)
|
|
1266
|
+
|
|
1267
|
+
# Build tool context (executions only - tools are in top-level 'tools' field)
|
|
1268
|
+
tool_context = ToolContext(
|
|
1269
|
+
executions=all_executions,
|
|
1270
|
+
)
|
|
1271
|
+
|
|
1272
|
+
# Build reasoning trace
|
|
1273
|
+
reasoning_trace = ReasoningTrace(
|
|
1274
|
+
style=self.config.reasoning_style or "agent", # type: ignore
|
|
1275
|
+
content=all_reasoning,
|
|
1276
|
+
)
|
|
1277
|
+
|
|
1278
|
+
# Build agent context
|
|
1279
|
+
agent_context = AgentContext(
|
|
1280
|
+
mode="multi_turn",
|
|
1281
|
+
planning_trace=scenario,
|
|
1282
|
+
execution_summary=f"Completed {len(turns)}-turn conversation",
|
|
1283
|
+
)
|
|
1284
|
+
|
|
1285
|
+
# Build metadata
|
|
1286
|
+
metadata = {
|
|
1287
|
+
"conversation_type": "chain_of_thought" if reasoning_trace else "basic",
|
|
1288
|
+
"topic": topic_prompt if topic_prompt else "general",
|
|
1289
|
+
}
|
|
1290
|
+
|
|
1291
|
+
# Insert system message if configured
|
|
1292
|
+
self._insert_system_message_if_configured(messages)
|
|
1293
|
+
|
|
1294
|
+
# Convert tools to OpenAI format
|
|
1295
|
+
tools_openai = [tool.to_openai() for tool in self.tool_registry.tools]
|
|
1296
|
+
|
|
1297
|
+
return Conversation(
|
|
1298
|
+
messages=messages,
|
|
1299
|
+
reasoning=reasoning_trace,
|
|
1300
|
+
tool_context=tool_context,
|
|
1301
|
+
tools=tools_openai,
|
|
1302
|
+
agent_context=agent_context,
|
|
1303
|
+
metadata=metadata,
|
|
1304
|
+
)
|