tactus 0.31.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tactus/__init__.py +49 -0
- tactus/adapters/__init__.py +9 -0
- tactus/adapters/broker_log.py +76 -0
- tactus/adapters/cli_hitl.py +189 -0
- tactus/adapters/cli_log.py +223 -0
- tactus/adapters/cost_collector_log.py +56 -0
- tactus/adapters/file_storage.py +367 -0
- tactus/adapters/http_callback_log.py +109 -0
- tactus/adapters/ide_log.py +71 -0
- tactus/adapters/lua_tools.py +336 -0
- tactus/adapters/mcp.py +289 -0
- tactus/adapters/mcp_manager.py +196 -0
- tactus/adapters/memory.py +53 -0
- tactus/adapters/plugins.py +419 -0
- tactus/backends/http_backend.py +58 -0
- tactus/backends/model_backend.py +35 -0
- tactus/backends/pytorch_backend.py +110 -0
- tactus/broker/__init__.py +12 -0
- tactus/broker/client.py +247 -0
- tactus/broker/protocol.py +183 -0
- tactus/broker/server.py +1123 -0
- tactus/broker/stdio.py +12 -0
- tactus/cli/__init__.py +7 -0
- tactus/cli/app.py +2245 -0
- tactus/cli/commands/__init__.py +0 -0
- tactus/core/__init__.py +32 -0
- tactus/core/config_manager.py +790 -0
- tactus/core/dependencies/__init__.py +14 -0
- tactus/core/dependencies/registry.py +180 -0
- tactus/core/dsl_stubs.py +2117 -0
- tactus/core/exceptions.py +66 -0
- tactus/core/execution_context.py +480 -0
- tactus/core/lua_sandbox.py +508 -0
- tactus/core/message_history_manager.py +236 -0
- tactus/core/mocking.py +286 -0
- tactus/core/output_validator.py +291 -0
- tactus/core/registry.py +499 -0
- tactus/core/runtime.py +2907 -0
- tactus/core/template_resolver.py +142 -0
- tactus/core/yaml_parser.py +301 -0
- tactus/docker/Dockerfile +61 -0
- tactus/docker/entrypoint.sh +69 -0
- tactus/dspy/__init__.py +39 -0
- tactus/dspy/agent.py +1144 -0
- tactus/dspy/broker_lm.py +181 -0
- tactus/dspy/config.py +212 -0
- tactus/dspy/history.py +196 -0
- tactus/dspy/module.py +405 -0
- tactus/dspy/prediction.py +318 -0
- tactus/dspy/signature.py +185 -0
- tactus/formatting/__init__.py +7 -0
- tactus/formatting/formatter.py +437 -0
- tactus/ide/__init__.py +9 -0
- tactus/ide/coding_assistant.py +343 -0
- tactus/ide/server.py +2223 -0
- tactus/primitives/__init__.py +49 -0
- tactus/primitives/control.py +168 -0
- tactus/primitives/file.py +229 -0
- tactus/primitives/handles.py +378 -0
- tactus/primitives/host.py +94 -0
- tactus/primitives/human.py +342 -0
- tactus/primitives/json.py +189 -0
- tactus/primitives/log.py +187 -0
- tactus/primitives/message_history.py +157 -0
- tactus/primitives/model.py +163 -0
- tactus/primitives/procedure.py +564 -0
- tactus/primitives/procedure_callable.py +318 -0
- tactus/primitives/retry.py +155 -0
- tactus/primitives/session.py +152 -0
- tactus/primitives/state.py +182 -0
- tactus/primitives/step.py +209 -0
- tactus/primitives/system.py +93 -0
- tactus/primitives/tool.py +375 -0
- tactus/primitives/tool_handle.py +279 -0
- tactus/primitives/toolset.py +229 -0
- tactus/protocols/__init__.py +38 -0
- tactus/protocols/chat_recorder.py +81 -0
- tactus/protocols/config.py +97 -0
- tactus/protocols/cost.py +31 -0
- tactus/protocols/hitl.py +71 -0
- tactus/protocols/log_handler.py +27 -0
- tactus/protocols/models.py +355 -0
- tactus/protocols/result.py +33 -0
- tactus/protocols/storage.py +90 -0
- tactus/providers/__init__.py +13 -0
- tactus/providers/base.py +92 -0
- tactus/providers/bedrock.py +117 -0
- tactus/providers/google.py +105 -0
- tactus/providers/openai.py +98 -0
- tactus/sandbox/__init__.py +63 -0
- tactus/sandbox/config.py +171 -0
- tactus/sandbox/container_runner.py +1099 -0
- tactus/sandbox/docker_manager.py +433 -0
- tactus/sandbox/entrypoint.py +227 -0
- tactus/sandbox/protocol.py +213 -0
- tactus/stdlib/__init__.py +10 -0
- tactus/stdlib/io/__init__.py +13 -0
- tactus/stdlib/io/csv.py +88 -0
- tactus/stdlib/io/excel.py +136 -0
- tactus/stdlib/io/file.py +90 -0
- tactus/stdlib/io/fs.py +154 -0
- tactus/stdlib/io/hdf5.py +121 -0
- tactus/stdlib/io/json.py +109 -0
- tactus/stdlib/io/parquet.py +83 -0
- tactus/stdlib/io/tsv.py +88 -0
- tactus/stdlib/loader.py +274 -0
- tactus/stdlib/tac/tactus/tools/done.tac +33 -0
- tactus/stdlib/tac/tactus/tools/log.tac +50 -0
- tactus/testing/README.md +273 -0
- tactus/testing/__init__.py +61 -0
- tactus/testing/behave_integration.py +380 -0
- tactus/testing/context.py +486 -0
- tactus/testing/eval_models.py +114 -0
- tactus/testing/evaluation_runner.py +222 -0
- tactus/testing/evaluators.py +634 -0
- tactus/testing/events.py +94 -0
- tactus/testing/gherkin_parser.py +134 -0
- tactus/testing/mock_agent.py +315 -0
- tactus/testing/mock_dependencies.py +234 -0
- tactus/testing/mock_hitl.py +171 -0
- tactus/testing/mock_registry.py +168 -0
- tactus/testing/mock_tools.py +133 -0
- tactus/testing/models.py +115 -0
- tactus/testing/pydantic_eval_runner.py +508 -0
- tactus/testing/steps/__init__.py +13 -0
- tactus/testing/steps/builtin.py +902 -0
- tactus/testing/steps/custom.py +69 -0
- tactus/testing/steps/registry.py +68 -0
- tactus/testing/test_runner.py +489 -0
- tactus/tracing/__init__.py +5 -0
- tactus/tracing/trace_manager.py +417 -0
- tactus/utils/__init__.py +1 -0
- tactus/utils/cost_calculator.py +72 -0
- tactus/utils/model_pricing.py +132 -0
- tactus/utils/safe_file_library.py +502 -0
- tactus/utils/safe_libraries.py +234 -0
- tactus/validation/LuaLexerBase.py +66 -0
- tactus/validation/LuaParserBase.py +23 -0
- tactus/validation/README.md +224 -0
- tactus/validation/__init__.py +7 -0
- tactus/validation/error_listener.py +21 -0
- tactus/validation/generated/LuaLexer.interp +231 -0
- tactus/validation/generated/LuaLexer.py +5548 -0
- tactus/validation/generated/LuaLexer.tokens +124 -0
- tactus/validation/generated/LuaLexerBase.py +66 -0
- tactus/validation/generated/LuaParser.interp +173 -0
- tactus/validation/generated/LuaParser.py +6439 -0
- tactus/validation/generated/LuaParser.tokens +124 -0
- tactus/validation/generated/LuaParserBase.py +23 -0
- tactus/validation/generated/LuaParserVisitor.py +118 -0
- tactus/validation/generated/__init__.py +7 -0
- tactus/validation/grammar/LuaLexer.g4 +123 -0
- tactus/validation/grammar/LuaParser.g4 +178 -0
- tactus/validation/semantic_visitor.py +817 -0
- tactus/validation/validator.py +157 -0
- tactus-0.31.2.dist-info/METADATA +1809 -0
- tactus-0.31.2.dist-info/RECORD +160 -0
- tactus-0.31.2.dist-info/WHEEL +4 -0
- tactus-0.31.2.dist-info/entry_points.txt +2 -0
- tactus-0.31.2.dist-info/licenses/LICENSE +21 -0
tactus/dspy/agent.py
ADDED
|
@@ -0,0 +1,1144 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DSPy-based Agent implementation for Tactus.
|
|
3
|
+
|
|
4
|
+
This module provides an Agent implementation built on top of DSPy primitives
|
|
5
|
+
(Module, Signature, History, Prediction). It maintains the same external API
|
|
6
|
+
as the original pydantic_ai-based Agent while using DSPy for LLM interactions.
|
|
7
|
+
|
|
8
|
+
The Agent uses:
|
|
9
|
+
- Configurable DSPy module (default: Predict for simple pass-through, or ChainOfThought for reasoning)
|
|
10
|
+
- History for conversation management
|
|
11
|
+
- Tool handling similar to DSPy's ReAct pattern
|
|
12
|
+
- Unified mocking via Mocks {} primitive
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
from typing import Any, Dict, List, Optional
|
|
17
|
+
|
|
18
|
+
from tactus.dspy.history import TactusHistory, create_history
|
|
19
|
+
from tactus.dspy.module import TactusModule, create_module
|
|
20
|
+
from tactus.dspy.prediction import TactusPrediction, wrap_prediction
|
|
21
|
+
from tactus.protocols.cost import CostStats, UsageStats
|
|
22
|
+
from tactus.protocols.result import TactusResult
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DSPyAgentHandle:
|
|
28
|
+
"""
|
|
29
|
+
A DSPy-based Agent handle that provides the callable interface.
|
|
30
|
+
|
|
31
|
+
This is a drop-in replacement for the pydantic_ai AgentHandle,
|
|
32
|
+
using DSPy primitives for LLM interactions.
|
|
33
|
+
|
|
34
|
+
Example usage in Lua:
|
|
35
|
+
worker = Agent {
|
|
36
|
+
system_prompt = "You are a helpful assistant",
|
|
37
|
+
tools = {search, calculator}
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
-- Call the agent directly
|
|
41
|
+
worker()
|
|
42
|
+
worker({message = input.query})
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
name: str,
|
|
48
|
+
system_prompt: str = "",
|
|
49
|
+
model: Optional[str] = None,
|
|
50
|
+
provider: Optional[str] = None,
|
|
51
|
+
tools: Optional[List[Any]] = None,
|
|
52
|
+
toolsets: Optional[List[str]] = None,
|
|
53
|
+
input_schema: Optional[Dict[str, Any]] = None,
|
|
54
|
+
output_schema: Optional[Dict[str, Any]] = None,
|
|
55
|
+
temperature: float = 0.7,
|
|
56
|
+
max_tokens: Optional[int] = None,
|
|
57
|
+
model_type: Optional[str] = None,
|
|
58
|
+
module: str = "Raw",
|
|
59
|
+
initial_message: Optional[str] = None,
|
|
60
|
+
registry: Any = None,
|
|
61
|
+
mock_manager: Any = None,
|
|
62
|
+
log_handler: Any = None,
|
|
63
|
+
disable_streaming: bool = False,
|
|
64
|
+
**kwargs: Any,
|
|
65
|
+
):
|
|
66
|
+
"""
|
|
67
|
+
Initialize a DSPy-based Agent.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
name: Agent name (used for tracking/logging)
|
|
71
|
+
system_prompt: System prompt for the agent
|
|
72
|
+
model: Model name (in LiteLLM format, e.g., "openai/gpt-4o")
|
|
73
|
+
provider: Provider name (deprecated, use model instead)
|
|
74
|
+
tools: List of tools available to the agent
|
|
75
|
+
toolsets: List of toolset names to include
|
|
76
|
+
input_schema: Optional input schema for validation (default: {message: string})
|
|
77
|
+
output_schema: Optional output schema for validation (default: {response: string})
|
|
78
|
+
temperature: Model temperature (default: 0.7)
|
|
79
|
+
max_tokens: Maximum tokens for response
|
|
80
|
+
model_type: Model type for DSPy (e.g., "chat", "responses" for reasoning models)
|
|
81
|
+
module: DSPy module type to use (default: "Raw", case-insensitive). Options:
|
|
82
|
+
- "Raw": Minimal formatting, direct LM calls (lowest token overhead)
|
|
83
|
+
- "Predict": Simple pass-through prediction (no reasoning traces)
|
|
84
|
+
- "ChainOfThought": Adds step-by-step reasoning before response
|
|
85
|
+
initial_message: Initial message to send on first turn if no inject
|
|
86
|
+
registry: Optional Registry instance for accessing mocks
|
|
87
|
+
mock_manager: Optional MockManager instance for checking mocks
|
|
88
|
+
log_handler: Optional log handler for emitting streaming events
|
|
89
|
+
disable_streaming: If True, disable streaming even when log_handler is present
|
|
90
|
+
**kwargs: Additional configuration
|
|
91
|
+
"""
|
|
92
|
+
self.name = name
|
|
93
|
+
self.system_prompt = system_prompt
|
|
94
|
+
self.model = model
|
|
95
|
+
self.provider = provider
|
|
96
|
+
self.tools = tools or []
|
|
97
|
+
self.toolsets = toolsets or []
|
|
98
|
+
# Default input schema: {message: string}
|
|
99
|
+
self.input_schema = input_schema or {"message": {"type": "string", "required": False}}
|
|
100
|
+
# Default output schema: {response: string}
|
|
101
|
+
self.output_schema = output_schema or {"response": {"type": "string", "required": False}}
|
|
102
|
+
self.temperature = temperature
|
|
103
|
+
self.max_tokens = max_tokens
|
|
104
|
+
self.model_type = model_type
|
|
105
|
+
self.module = module
|
|
106
|
+
self.initial_message = initial_message
|
|
107
|
+
self.registry = registry
|
|
108
|
+
self.mock_manager = mock_manager
|
|
109
|
+
self.log_handler = log_handler
|
|
110
|
+
self.disable_streaming = disable_streaming
|
|
111
|
+
self.kwargs = kwargs
|
|
112
|
+
|
|
113
|
+
# Initialize conversation history
|
|
114
|
+
self._history = create_history()
|
|
115
|
+
|
|
116
|
+
# Track conversation state
|
|
117
|
+
self._turn_count = 0
|
|
118
|
+
|
|
119
|
+
# Cumulative cost/usage stats (monotonic across turns)
|
|
120
|
+
self._cumulative_usage = UsageStats()
|
|
121
|
+
self._cumulative_cost = CostStats()
|
|
122
|
+
|
|
123
|
+
# Build the internal DSPy module
|
|
124
|
+
self._module = self._build_module()
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def usage(self) -> UsageStats:
|
|
128
|
+
"""Return cumulative token usage incurred by this agent so far."""
|
|
129
|
+
return self._cumulative_usage
|
|
130
|
+
|
|
131
|
+
def cost(self) -> CostStats:
|
|
132
|
+
"""Return cumulative cost incurred by this agent so far."""
|
|
133
|
+
return self._cumulative_cost
|
|
134
|
+
|
|
135
|
+
def _add_usage_and_cost(self, usage_stats: UsageStats, cost_stats: CostStats) -> None:
|
|
136
|
+
"""Accumulate a per-call UsageStats/CostStats into agent totals."""
|
|
137
|
+
self._cumulative_usage.prompt_tokens += usage_stats.prompt_tokens
|
|
138
|
+
self._cumulative_usage.completion_tokens += usage_stats.completion_tokens
|
|
139
|
+
self._cumulative_usage.total_tokens += usage_stats.total_tokens
|
|
140
|
+
|
|
141
|
+
self._cumulative_cost.total_cost += cost_stats.total_cost
|
|
142
|
+
self._cumulative_cost.prompt_cost += cost_stats.prompt_cost
|
|
143
|
+
self._cumulative_cost.completion_cost += cost_stats.completion_cost
|
|
144
|
+
|
|
145
|
+
# Preserve "latest known" model/provider for introspection
|
|
146
|
+
if cost_stats.model:
|
|
147
|
+
self._cumulative_cost.model = cost_stats.model
|
|
148
|
+
if cost_stats.provider:
|
|
149
|
+
self._cumulative_cost.provider = cost_stats.provider
|
|
150
|
+
|
|
151
|
+
def _extract_last_call_stats(self) -> tuple[UsageStats, CostStats]:
|
|
152
|
+
"""
|
|
153
|
+
Extract usage+cost from DSPy's LM history for the most recent call.
|
|
154
|
+
|
|
155
|
+
Returns zeroed stats if no LM history is available (e.g., mocked calls).
|
|
156
|
+
"""
|
|
157
|
+
import dspy
|
|
158
|
+
|
|
159
|
+
# Default to zero (e.g., mocks or no LM configured)
|
|
160
|
+
usage_stats = UsageStats()
|
|
161
|
+
cost_stats = CostStats()
|
|
162
|
+
|
|
163
|
+
lm = dspy.settings.lm
|
|
164
|
+
if lm is None or not hasattr(lm, "history") or not lm.history:
|
|
165
|
+
return usage_stats, cost_stats
|
|
166
|
+
|
|
167
|
+
last_call = lm.history[-1]
|
|
168
|
+
|
|
169
|
+
# Usage
|
|
170
|
+
usage = last_call.get("usage", {}) or {}
|
|
171
|
+
prompt_tokens = int(usage.get("prompt_tokens", 0) or 0)
|
|
172
|
+
completion_tokens = int(usage.get("completion_tokens", 0) or 0)
|
|
173
|
+
total_tokens = int(usage.get("total_tokens", 0) or 0)
|
|
174
|
+
usage_stats = UsageStats(
|
|
175
|
+
prompt_tokens=prompt_tokens,
|
|
176
|
+
completion_tokens=completion_tokens,
|
|
177
|
+
total_tokens=total_tokens,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Model/provider
|
|
181
|
+
model = last_call.get("model", self.model or None)
|
|
182
|
+
provider = None
|
|
183
|
+
if model and "/" in str(model):
|
|
184
|
+
provider = str(model).split("/")[0]
|
|
185
|
+
|
|
186
|
+
# Cost: prefer the history value, fallback to hidden params / LiteLLM calc
|
|
187
|
+
total_cost = last_call.get("cost")
|
|
188
|
+
if total_cost is None:
|
|
189
|
+
response = last_call.get("response")
|
|
190
|
+
if response and hasattr(response, "_hidden_params"):
|
|
191
|
+
total_cost = response._hidden_params.get("response_cost")
|
|
192
|
+
|
|
193
|
+
if total_cost is None and total_tokens > 0:
|
|
194
|
+
try:
|
|
195
|
+
# We already have token counts, so compute cost from tokens to avoid relying
|
|
196
|
+
# on provider-specific response object shapes.
|
|
197
|
+
from litellm.cost_calculator import cost_per_token
|
|
198
|
+
|
|
199
|
+
prompt_cost, completion_cost = cost_per_token(
|
|
200
|
+
model=str(model) if model is not None else "",
|
|
201
|
+
prompt_tokens=prompt_tokens,
|
|
202
|
+
completion_tokens=completion_tokens,
|
|
203
|
+
call_type="completion",
|
|
204
|
+
)
|
|
205
|
+
total_cost = float(prompt_cost) + float(completion_cost)
|
|
206
|
+
except Exception as e:
|
|
207
|
+
logger.warning(f"[COST] Agent '{self.name}': failed to calculate cost: {e}")
|
|
208
|
+
total_cost = 0.0
|
|
209
|
+
elif total_cost is None:
|
|
210
|
+
total_cost = 0.0
|
|
211
|
+
|
|
212
|
+
total_cost = float(total_cost or 0.0)
|
|
213
|
+
|
|
214
|
+
# Approximate prompt/completion split by token ratio
|
|
215
|
+
if total_tokens > 0 and total_cost > 0:
|
|
216
|
+
prompt_cost = total_cost * (prompt_tokens / total_tokens)
|
|
217
|
+
completion_cost = total_cost * (completion_tokens / total_tokens)
|
|
218
|
+
else:
|
|
219
|
+
prompt_cost = 0.0
|
|
220
|
+
completion_cost = 0.0
|
|
221
|
+
|
|
222
|
+
cost_stats = CostStats(
|
|
223
|
+
total_cost=total_cost,
|
|
224
|
+
prompt_cost=prompt_cost,
|
|
225
|
+
completion_cost=completion_cost,
|
|
226
|
+
model=str(model) if model is not None else None,
|
|
227
|
+
provider=provider,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
return usage_stats, cost_stats
|
|
231
|
+
|
|
232
|
+
def _prediction_to_value(self, prediction: TactusPrediction) -> Any:
|
|
233
|
+
"""
|
|
234
|
+
Convert a Prediction into a stable `result.output`.
|
|
235
|
+
|
|
236
|
+
Default behavior:
|
|
237
|
+
- Prefer the `response` field when present (string)
|
|
238
|
+
- Otherwise fall back to `prediction.message`
|
|
239
|
+
- If an output schema is configured, attempt to parse JSON into a dict/list
|
|
240
|
+
- If multiple output fields exist, return a dict (excluding internal fields)
|
|
241
|
+
"""
|
|
242
|
+
try:
|
|
243
|
+
data = prediction.data()
|
|
244
|
+
except Exception:
|
|
245
|
+
data = {}
|
|
246
|
+
|
|
247
|
+
filtered = {k: v for k, v in data.items() if k not in {"tool_calls"}}
|
|
248
|
+
|
|
249
|
+
if "response" in filtered and isinstance(filtered["response"], str) and len(filtered) <= 1:
|
|
250
|
+
text = filtered["response"]
|
|
251
|
+
else:
|
|
252
|
+
text = prediction.message
|
|
253
|
+
|
|
254
|
+
# If output schema is configured, prefer structured JSON when possible
|
|
255
|
+
if self.output_schema and isinstance(text, str) and text.strip():
|
|
256
|
+
import json
|
|
257
|
+
|
|
258
|
+
try:
|
|
259
|
+
parsed = json.loads(text)
|
|
260
|
+
return parsed
|
|
261
|
+
except Exception:
|
|
262
|
+
pass
|
|
263
|
+
|
|
264
|
+
# If multiple non-internal output fields exist, return structured dict
|
|
265
|
+
if len(filtered) > 1:
|
|
266
|
+
return filtered
|
|
267
|
+
|
|
268
|
+
if len(filtered) == 1:
|
|
269
|
+
return next(iter(filtered.values()))
|
|
270
|
+
|
|
271
|
+
return text
|
|
272
|
+
|
|
273
|
+
def _wrap_as_result(
|
|
274
|
+
self, prediction: TactusPrediction, usage_stats: UsageStats, cost_stats: CostStats
|
|
275
|
+
) -> TactusResult:
|
|
276
|
+
"""Wrap a Prediction into the standard TactusResult."""
|
|
277
|
+
return TactusResult(
|
|
278
|
+
output=self._prediction_to_value(prediction),
|
|
279
|
+
usage=usage_stats,
|
|
280
|
+
cost_stats=cost_stats,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
def _module_to_strategy(self, module: str) -> str:
|
|
284
|
+
"""
|
|
285
|
+
Map DSPy module name to internal strategy name.
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
module: DSPy module name (e.g., "Predict", "ChainOfThought")
|
|
289
|
+
|
|
290
|
+
Returns:
|
|
291
|
+
Internal strategy name for create_module()
|
|
292
|
+
|
|
293
|
+
Raises:
|
|
294
|
+
ValueError: If module name is not recognized
|
|
295
|
+
"""
|
|
296
|
+
mapping = {
|
|
297
|
+
"predict": "predict",
|
|
298
|
+
"chainofthought": "chain_of_thought",
|
|
299
|
+
"raw": "raw",
|
|
300
|
+
# Future modules can be added here:
|
|
301
|
+
# "react": "react",
|
|
302
|
+
# "programofthought": "program_of_thought",
|
|
303
|
+
}
|
|
304
|
+
strategy = mapping.get(module.lower())
|
|
305
|
+
if strategy is None:
|
|
306
|
+
raise ValueError(f"Unknown module '{module}'. Supported: {list(mapping.keys())}")
|
|
307
|
+
return strategy
|
|
308
|
+
|
|
309
|
+
def _build_module(self) -> TactusModule:
|
|
310
|
+
"""Build the internal DSPy module for this agent."""
|
|
311
|
+
# Create a signature for agent turns
|
|
312
|
+
# Input: system_prompt, history, user_message, available_tools
|
|
313
|
+
# Output: response and tool_calls (if tools are needed)
|
|
314
|
+
# Include tools in the signature if they're available
|
|
315
|
+
if self.tools or self.toolsets:
|
|
316
|
+
signature = (
|
|
317
|
+
"system_prompt, history, user_message, available_tools -> response, tool_calls"
|
|
318
|
+
)
|
|
319
|
+
else:
|
|
320
|
+
signature = "system_prompt, history, user_message -> response"
|
|
321
|
+
|
|
322
|
+
return create_module(
|
|
323
|
+
f"{self.name}_module",
|
|
324
|
+
{
|
|
325
|
+
"signature": signature,
|
|
326
|
+
"strategy": self._module_to_strategy(self.module),
|
|
327
|
+
},
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
def _should_stream(self) -> bool:
|
|
331
|
+
"""
|
|
332
|
+
Determine if streaming should be enabled for this agent.
|
|
333
|
+
|
|
334
|
+
Streaming is enabled when:
|
|
335
|
+
- log_handler is available (for emitting events)
|
|
336
|
+
- log_handler supports streaming events
|
|
337
|
+
- disable_streaming is False
|
|
338
|
+
- No structured output schema (streaming only works with plain text)
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
True if streaming should be enabled
|
|
342
|
+
"""
|
|
343
|
+
# Must have log_handler to emit streaming events
|
|
344
|
+
if self.log_handler is None:
|
|
345
|
+
logger.debug(f"[STREAMING] Agent '{self.name}': no log_handler, streaming disabled")
|
|
346
|
+
return False
|
|
347
|
+
|
|
348
|
+
# Allow log handlers to opt out of streaming (e.g., cost-only collectors)
|
|
349
|
+
supports_streaming = getattr(self.log_handler, "supports_streaming", True)
|
|
350
|
+
if not supports_streaming:
|
|
351
|
+
logger.debug(
|
|
352
|
+
f"[STREAMING] Agent '{self.name}': log_handler supports_streaming=False, streaming disabled"
|
|
353
|
+
)
|
|
354
|
+
return False
|
|
355
|
+
|
|
356
|
+
# Respect explicit disable flag
|
|
357
|
+
if self.disable_streaming:
|
|
358
|
+
logger.debug(
|
|
359
|
+
f"[STREAMING] Agent '{self.name}': disable_streaming=True, streaming disabled"
|
|
360
|
+
)
|
|
361
|
+
return False
|
|
362
|
+
|
|
363
|
+
# Note: We intentionally allow streaming even with output_schema.
|
|
364
|
+
# Streaming (UI feedback) and validation (post-processing) are orthogonal.
|
|
365
|
+
# Stream raw text to UI during generation, then validate after completion.
|
|
366
|
+
|
|
367
|
+
logger.info(f"[STREAMING] Agent '{self.name}': streaming ENABLED")
|
|
368
|
+
return True
|
|
369
|
+
|
|
370
|
+
def _emit_cost_event(self) -> None:
|
|
371
|
+
"""
|
|
372
|
+
Emit a CostEvent based on the most recent LLM call in the LM history.
|
|
373
|
+
|
|
374
|
+
Extracts usage and cost information from DSPy's LM history and emits
|
|
375
|
+
a CostEvent for tracking in the IDE.
|
|
376
|
+
"""
|
|
377
|
+
if self.log_handler is None:
|
|
378
|
+
return
|
|
379
|
+
|
|
380
|
+
import dspy
|
|
381
|
+
from tactus.protocols.models import CostEvent
|
|
382
|
+
|
|
383
|
+
# Get the current LM
|
|
384
|
+
lm = dspy.settings.lm
|
|
385
|
+
if lm is None or not hasattr(lm, "history") or not lm.history:
|
|
386
|
+
logger.debug(f"[COST] Agent '{self.name}': no LM history available")
|
|
387
|
+
return
|
|
388
|
+
|
|
389
|
+
# Get the most recent call
|
|
390
|
+
last_call = lm.history[-1]
|
|
391
|
+
|
|
392
|
+
# Extract usage information
|
|
393
|
+
usage = last_call.get("usage", {})
|
|
394
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
|
395
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
|
396
|
+
total_tokens = usage.get("total_tokens", 0)
|
|
397
|
+
|
|
398
|
+
# Extract cost information
|
|
399
|
+
total_cost = last_call.get("cost")
|
|
400
|
+
logger.debug(f"[COST] Agent '{self.name}': raw cost from history = {total_cost}")
|
|
401
|
+
|
|
402
|
+
# If cost is None (happens with streamify()), calculate it using LiteLLM
|
|
403
|
+
if total_cost is None:
|
|
404
|
+
response = last_call.get("response")
|
|
405
|
+
if response and hasattr(response, "_hidden_params"):
|
|
406
|
+
total_cost = response._hidden_params.get("response_cost")
|
|
407
|
+
logger.debug(f"[COST] Agent '{self.name}': cost from _hidden_params = {total_cost}")
|
|
408
|
+
|
|
409
|
+
# If still None, calculate manually using litellm.completion_cost
|
|
410
|
+
if total_cost is None and response:
|
|
411
|
+
try:
|
|
412
|
+
import litellm
|
|
413
|
+
|
|
414
|
+
total_cost = litellm.completion_cost(completion_response=response)
|
|
415
|
+
logger.debug(f"[COST] Agent '{self.name}': calculated cost = {total_cost}")
|
|
416
|
+
except Exception as e:
|
|
417
|
+
logger.warning(f"[COST] Agent '{self.name}': failed to calculate cost: {e}")
|
|
418
|
+
total_cost = 0.0
|
|
419
|
+
elif total_cost is None:
|
|
420
|
+
total_cost = 0.0
|
|
421
|
+
logger.warning(f"[COST] Agent '{self.name}': no cost information available")
|
|
422
|
+
|
|
423
|
+
# Calculate per-token costs (approximate)
|
|
424
|
+
# Note: LiteLLM provides total cost, we can approximate prompt/completion split
|
|
425
|
+
# based on token ratios
|
|
426
|
+
if total_tokens > 0 and total_cost > 0:
|
|
427
|
+
prompt_cost = total_cost * (prompt_tokens / total_tokens)
|
|
428
|
+
completion_cost = total_cost * (completion_tokens / total_tokens)
|
|
429
|
+
else:
|
|
430
|
+
prompt_cost = 0.0
|
|
431
|
+
completion_cost = 0.0
|
|
432
|
+
|
|
433
|
+
# Extract duration from response metadata
|
|
434
|
+
response = last_call.get("response")
|
|
435
|
+
duration_ms = None
|
|
436
|
+
if response and hasattr(response, "_hidden_params"):
|
|
437
|
+
duration_ms = response._hidden_params.get("_response_ms")
|
|
438
|
+
|
|
439
|
+
# Extract model info
|
|
440
|
+
model = last_call.get("model", self.model or "unknown")
|
|
441
|
+
|
|
442
|
+
# Parse provider from model string (e.g., "openai/gpt-4o" -> "openai")
|
|
443
|
+
provider = "unknown"
|
|
444
|
+
if "/" in str(model):
|
|
445
|
+
provider = str(model).split("/")[0]
|
|
446
|
+
|
|
447
|
+
# Create and emit cost event
|
|
448
|
+
cost_event = CostEvent(
|
|
449
|
+
agent_name=self.name,
|
|
450
|
+
model=model,
|
|
451
|
+
provider=provider,
|
|
452
|
+
prompt_tokens=prompt_tokens,
|
|
453
|
+
completion_tokens=completion_tokens,
|
|
454
|
+
total_tokens=total_tokens,
|
|
455
|
+
prompt_cost=prompt_cost,
|
|
456
|
+
completion_cost=completion_cost,
|
|
457
|
+
total_cost=total_cost,
|
|
458
|
+
duration_ms=duration_ms,
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
self.log_handler.log(cost_event)
|
|
462
|
+
logger.info(f"[COST] Agent '{self.name}': ${total_cost:.6f} ({total_tokens} tokens)")
|
|
463
|
+
|
|
464
|
+
def _turn_with_streaming(
|
|
465
|
+
self,
|
|
466
|
+
opts: Dict[str, Any],
|
|
467
|
+
prompt_context: Dict[str, Any],
|
|
468
|
+
) -> TactusResult:
|
|
469
|
+
"""
|
|
470
|
+
Execute an agent turn with streaming enabled.
|
|
471
|
+
|
|
472
|
+
Uses DSPy's streamify() to wrap the module for streaming output.
|
|
473
|
+
Runs in a separate thread to avoid event loop conflicts.
|
|
474
|
+
Chunks are emitted as AgentStreamChunkEvent for real-time display in the UI.
|
|
475
|
+
|
|
476
|
+
Args:
|
|
477
|
+
opts: Turn options
|
|
478
|
+
prompt_context: Prepared prompt context for the module
|
|
479
|
+
|
|
480
|
+
Returns:
|
|
481
|
+
TactusResult with value, usage, and cost_stats
|
|
482
|
+
"""
|
|
483
|
+
import asyncio
|
|
484
|
+
import threading
|
|
485
|
+
import queue
|
|
486
|
+
from tactus.protocols.models import AgentTurnEvent, AgentStreamChunkEvent
|
|
487
|
+
|
|
488
|
+
logger.info(f"[STREAMING] Agent '{self.name}' starting streaming turn")
|
|
489
|
+
|
|
490
|
+
# Emit turn started event so the UI shows a loading indicator
|
|
491
|
+
self.log_handler.log(
|
|
492
|
+
AgentTurnEvent(
|
|
493
|
+
agent_name=self.name,
|
|
494
|
+
stage="started",
|
|
495
|
+
)
|
|
496
|
+
)
|
|
497
|
+
logger.info(f"[STREAMING] Agent '{self.name}' emitted AgentTurnEvent(started)")
|
|
498
|
+
|
|
499
|
+
# Queue for passing chunks from streaming thread to main thread
|
|
500
|
+
chunk_queue = queue.Queue()
|
|
501
|
+
result_holder = {"result": None, "error": None}
|
|
502
|
+
|
|
503
|
+
def run_streaming_in_thread():
|
|
504
|
+
"""Run DSPy streaming in a separate thread with its own event loop."""
|
|
505
|
+
import dspy as dspy_thread # Import in thread context
|
|
506
|
+
|
|
507
|
+
async def async_streaming():
|
|
508
|
+
"""Async function that runs the streaming module."""
|
|
509
|
+
try:
|
|
510
|
+
# Create a streaming version of the module using DSPy's streamify
|
|
511
|
+
# NOTE: streamify() automatically enables streaming on the LM
|
|
512
|
+
# We do NOT need to use settings.context(stream=True) - that actually breaks it!
|
|
513
|
+
streaming_module = dspy_thread.streamify(self._module.module)
|
|
514
|
+
logger.info(f"[STREAMING] Agent '{self.name}' created streaming module")
|
|
515
|
+
|
|
516
|
+
# Call the streaming module - it returns an async generator
|
|
517
|
+
stream = streaming_module(**prompt_context)
|
|
518
|
+
|
|
519
|
+
chunk_count = 0
|
|
520
|
+
async for value in stream:
|
|
521
|
+
chunk_count += 1
|
|
522
|
+
value_type = type(value).__name__
|
|
523
|
+
|
|
524
|
+
# Check for final Prediction first
|
|
525
|
+
if isinstance(value, dspy_thread.Prediction):
|
|
526
|
+
# Final prediction - this is the result
|
|
527
|
+
logger.info(
|
|
528
|
+
f"[STREAMING] Agent '{self.name}' received final Prediction"
|
|
529
|
+
)
|
|
530
|
+
result_holder["result"] = value
|
|
531
|
+
# Check for ModelResponseStream (the actual streaming chunks!)
|
|
532
|
+
elif hasattr(value, "choices") and value.choices:
|
|
533
|
+
delta = value.choices[0].delta
|
|
534
|
+
if hasattr(delta, "content") and delta.content:
|
|
535
|
+
logger.info(
|
|
536
|
+
f"[STREAMING] Agent '{self.name}' chunk #{chunk_count}: '{delta.content}'"
|
|
537
|
+
)
|
|
538
|
+
chunk_queue.put(("chunk", delta.content))
|
|
539
|
+
# String chunks (shouldn't happen with DSPy but handle it anyway)
|
|
540
|
+
elif isinstance(value, str):
|
|
541
|
+
logger.info(
|
|
542
|
+
f"[STREAMING] Agent '{self.name}' got STRING chunk, len={len(value)}"
|
|
543
|
+
)
|
|
544
|
+
if value:
|
|
545
|
+
chunk_queue.put(("chunk", value))
|
|
546
|
+
else:
|
|
547
|
+
logger.warning(
|
|
548
|
+
f"[STREAMING] Agent '{self.name}' got unexpected type: {value_type}"
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
logger.info(
|
|
552
|
+
f"[STREAMING] Agent '{self.name}' stream finished, processed {chunk_count} values"
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
except Exception as e:
|
|
556
|
+
logger.error(f"[STREAMING] Agent '{self.name}' error: {e}", exc_info=True)
|
|
557
|
+
result_holder["error"] = e
|
|
558
|
+
finally:
|
|
559
|
+
# Signal end of stream
|
|
560
|
+
chunk_queue.put(("done", None))
|
|
561
|
+
|
|
562
|
+
# Run the async function in this thread's new event loop
|
|
563
|
+
asyncio.run(async_streaming())
|
|
564
|
+
|
|
565
|
+
# Start streaming in a separate thread
|
|
566
|
+
streaming_thread = threading.Thread(target=run_streaming_in_thread, daemon=True)
|
|
567
|
+
streaming_thread.start()
|
|
568
|
+
|
|
569
|
+
# Consume chunks from the queue and emit events in the main thread
|
|
570
|
+
accumulated_text = ""
|
|
571
|
+
emitted_count = 0
|
|
572
|
+
logger.info(f"[STREAMING] Agent '{self.name}' consuming chunks from queue")
|
|
573
|
+
|
|
574
|
+
while True:
|
|
575
|
+
try:
|
|
576
|
+
msg_type, msg_data = chunk_queue.get(timeout=120.0) # 2 minute timeout
|
|
577
|
+
if msg_type == "done":
|
|
578
|
+
break
|
|
579
|
+
elif msg_type == "chunk" and msg_data:
|
|
580
|
+
accumulated_text += msg_data
|
|
581
|
+
emitted_count += 1
|
|
582
|
+
event = AgentStreamChunkEvent(
|
|
583
|
+
agent_name=self.name,
|
|
584
|
+
chunk_text=msg_data,
|
|
585
|
+
accumulated_text=accumulated_text,
|
|
586
|
+
)
|
|
587
|
+
logger.info(
|
|
588
|
+
f"[STREAMING] Agent '{self.name}' emitting chunk {emitted_count}, len={len(msg_data)}"
|
|
589
|
+
)
|
|
590
|
+
self.log_handler.log(event)
|
|
591
|
+
except queue.Empty:
|
|
592
|
+
logger.warning(f"[STREAMING] Agent '{self.name}' timeout waiting for chunks")
|
|
593
|
+
break
|
|
594
|
+
|
|
595
|
+
# Wait for thread to complete
|
|
596
|
+
streaming_thread.join(timeout=5.0)
|
|
597
|
+
|
|
598
|
+
logger.info(f"[STREAMING] Agent '{self.name}' finished, emitted {emitted_count} events")
|
|
599
|
+
|
|
600
|
+
# Check for errors
|
|
601
|
+
if result_holder["error"] is not None:
|
|
602
|
+
raise result_holder["error"]
|
|
603
|
+
|
|
604
|
+
# If streaming failed to produce a result, fall back to non-streaming
|
|
605
|
+
if result_holder["result"] is None:
|
|
606
|
+
logger.warning(f"Streaming produced no result for agent '{self.name}', falling back")
|
|
607
|
+
return self._turn_without_streaming(opts, prompt_context)
|
|
608
|
+
|
|
609
|
+
# Track new messages for this turn
|
|
610
|
+
new_messages = []
|
|
611
|
+
|
|
612
|
+
# Determine user message
|
|
613
|
+
user_message = opts.get("message")
|
|
614
|
+
if self._turn_count == 1 and not user_message and self.initial_message:
|
|
615
|
+
user_message = self.initial_message
|
|
616
|
+
|
|
617
|
+
# Add user message to new_messages if present
|
|
618
|
+
if user_message:
|
|
619
|
+
user_msg = {"role": "user", "content": user_message}
|
|
620
|
+
new_messages.append(user_msg)
|
|
621
|
+
self._history.add(user_msg)
|
|
622
|
+
|
|
623
|
+
# Add assistant response to new_messages
|
|
624
|
+
if hasattr(result_holder["result"], "response"):
|
|
625
|
+
assistant_msg = {"role": "assistant", "content": result_holder["result"].response}
|
|
626
|
+
new_messages.append(assistant_msg)
|
|
627
|
+
self._history.add(assistant_msg)
|
|
628
|
+
|
|
629
|
+
# Wrap the result with message tracking
|
|
630
|
+
wrapped_result = wrap_prediction(
|
|
631
|
+
result_holder["result"],
|
|
632
|
+
new_messages=new_messages,
|
|
633
|
+
all_messages=self._history.get(),
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
# Handle tool calls if present
|
|
637
|
+
if hasattr(wrapped_result, "tool_calls") and wrapped_result.tool_calls:
|
|
638
|
+
tool_primitive = getattr(self, "_tool_primitive", None)
|
|
639
|
+
if tool_primitive and "done" in str(wrapped_result.tool_calls).lower():
|
|
640
|
+
reason = (
|
|
641
|
+
wrapped_result.response
|
|
642
|
+
if hasattr(wrapped_result, "response")
|
|
643
|
+
else "Task completed"
|
|
644
|
+
)
|
|
645
|
+
logger.info(f"Recording done tool call with reason: {reason}")
|
|
646
|
+
tool_primitive.record_call(
|
|
647
|
+
"done",
|
|
648
|
+
{"reason": reason},
|
|
649
|
+
{"status": "completed", "reason": reason, "tool": "done"},
|
|
650
|
+
agent_name=self.name,
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
# Emit turn completed event
|
|
654
|
+
self.log_handler.log(
|
|
655
|
+
AgentTurnEvent(
|
|
656
|
+
agent_name=self.name,
|
|
657
|
+
stage="completed",
|
|
658
|
+
)
|
|
659
|
+
)
|
|
660
|
+
logger.info(f"[STREAMING] Agent '{self.name}' emitted AgentTurnEvent(completed)")
|
|
661
|
+
|
|
662
|
+
# Extract usage and cost stats
|
|
663
|
+
usage_stats, cost_stats = self._extract_last_call_stats()
|
|
664
|
+
|
|
665
|
+
# Emit cost event with usage and cost information
|
|
666
|
+
self._emit_cost_event()
|
|
667
|
+
|
|
668
|
+
# Wrap as TactusResult with value, usage, and cost
|
|
669
|
+
return self._wrap_as_result(wrapped_result, usage_stats, cost_stats)
|
|
670
|
+
|
|
671
|
+
def _turn_without_streaming(
|
|
672
|
+
self,
|
|
673
|
+
opts: Dict[str, Any],
|
|
674
|
+
prompt_context: Dict[str, Any],
|
|
675
|
+
) -> TactusResult:
|
|
676
|
+
"""
|
|
677
|
+
Execute an agent turn without streaming.
|
|
678
|
+
|
|
679
|
+
This is the standard execution path that waits for the full response.
|
|
680
|
+
|
|
681
|
+
Args:
|
|
682
|
+
opts: Turn options
|
|
683
|
+
prompt_context: Prepared prompt context for the module
|
|
684
|
+
|
|
685
|
+
Returns:
|
|
686
|
+
TactusResult with value, usage, and cost_stats
|
|
687
|
+
"""
|
|
688
|
+
# Execute the module
|
|
689
|
+
dspy_result = self._module.module(**prompt_context)
|
|
690
|
+
|
|
691
|
+
# Track new messages for this turn
|
|
692
|
+
new_messages = []
|
|
693
|
+
|
|
694
|
+
# Determine user message
|
|
695
|
+
user_message = opts.get("message")
|
|
696
|
+
if self._turn_count == 1 and not user_message and self.initial_message:
|
|
697
|
+
user_message = self.initial_message
|
|
698
|
+
|
|
699
|
+
# Add user message to new_messages if present
|
|
700
|
+
if user_message:
|
|
701
|
+
user_msg = {"role": "user", "content": user_message}
|
|
702
|
+
new_messages.append(user_msg)
|
|
703
|
+
self._history.add(user_msg)
|
|
704
|
+
|
|
705
|
+
# Add assistant response to new_messages
|
|
706
|
+
if hasattr(dspy_result, "response"):
|
|
707
|
+
assistant_msg = {"role": "assistant", "content": dspy_result.response}
|
|
708
|
+
new_messages.append(assistant_msg)
|
|
709
|
+
self._history.add(assistant_msg)
|
|
710
|
+
|
|
711
|
+
# Wrap the result with message tracking
|
|
712
|
+
wrapped_result = wrap_prediction(
|
|
713
|
+
dspy_result,
|
|
714
|
+
new_messages=new_messages,
|
|
715
|
+
all_messages=self._history.get(),
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
# Handle tool calls if present
|
|
719
|
+
if hasattr(wrapped_result, "tool_calls") and wrapped_result.tool_calls:
|
|
720
|
+
tool_primitive = getattr(self, "_tool_primitive", None)
|
|
721
|
+
if tool_primitive and "done" in str(wrapped_result.tool_calls).lower():
|
|
722
|
+
reason = (
|
|
723
|
+
wrapped_result.response
|
|
724
|
+
if hasattr(wrapped_result, "response")
|
|
725
|
+
else "Task completed"
|
|
726
|
+
)
|
|
727
|
+
logger.info(f"Recording done tool call with reason: {reason}")
|
|
728
|
+
tool_primitive.record_call(
|
|
729
|
+
"done",
|
|
730
|
+
{"reason": reason},
|
|
731
|
+
{"status": "completed", "reason": reason, "tool": "done"},
|
|
732
|
+
agent_name=self.name,
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
# Extract usage and cost stats
|
|
736
|
+
usage_stats, cost_stats = self._extract_last_call_stats()
|
|
737
|
+
|
|
738
|
+
# Emit cost event with usage and cost information
|
|
739
|
+
self._emit_cost_event()
|
|
740
|
+
|
|
741
|
+
# Wrap as TactusResult with value, usage, and cost
|
|
742
|
+
return self._wrap_as_result(wrapped_result, usage_stats, cost_stats)
|
|
743
|
+
|
|
744
|
+
def __call__(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
|
745
|
+
"""
|
|
746
|
+
Execute an agent turn using the callable interface.
|
|
747
|
+
|
|
748
|
+
This is the unified callable interface that allows:
|
|
749
|
+
result = worker({message = "Hello"})
|
|
750
|
+
|
|
751
|
+
Args:
|
|
752
|
+
inputs: Input dict with fields matching input_schema.
|
|
753
|
+
Default field 'message' is used as the user message.
|
|
754
|
+
Additional fields are passed as context.
|
|
755
|
+
Can also include per-turn overrides like:
|
|
756
|
+
- tools: List[Any] - Tool/toolset references and toolset expressions to use
|
|
757
|
+
- temperature: float - Override temperature
|
|
758
|
+
- max_tokens: int - Override max_tokens
|
|
759
|
+
|
|
760
|
+
Returns:
|
|
761
|
+
Result object with response and other fields
|
|
762
|
+
|
|
763
|
+
Example (Lua):
|
|
764
|
+
result = worker({message = "Process this task"})
|
|
765
|
+
print(result.response)
|
|
766
|
+
"""
|
|
767
|
+
logger.debug(f"Agent '{self.name}' invoked via __call__()")
|
|
768
|
+
# Convenience: allow shorthand string calls in Lua:
|
|
769
|
+
# worker("Hello") == worker({message = "Hello"})
|
|
770
|
+
if isinstance(inputs, str):
|
|
771
|
+
inputs = {"message": inputs}
|
|
772
|
+
|
|
773
|
+
inputs = inputs or {}
|
|
774
|
+
|
|
775
|
+
# Convert Lua table to dict if needed
|
|
776
|
+
if hasattr(inputs, "items"):
|
|
777
|
+
try:
|
|
778
|
+
inputs = dict(inputs.items())
|
|
779
|
+
except (AttributeError, TypeError):
|
|
780
|
+
pass
|
|
781
|
+
|
|
782
|
+
# Extract message field (the main input)
|
|
783
|
+
message = inputs.get("message")
|
|
784
|
+
|
|
785
|
+
# Build turn options (keeping per-turn overrides like tools, temperature, etc.)
|
|
786
|
+
opts = {}
|
|
787
|
+
if message:
|
|
788
|
+
opts["message"] = message
|
|
789
|
+
|
|
790
|
+
# Pass remaining fields - some are per-turn overrides, others are context
|
|
791
|
+
override_keys = {"tools", "temperature", "max_tokens"}
|
|
792
|
+
for key in override_keys:
|
|
793
|
+
if key in inputs:
|
|
794
|
+
opts[key] = inputs[key]
|
|
795
|
+
|
|
796
|
+
# Everything else goes into context
|
|
797
|
+
context = {k: v for k, v in inputs.items() if k not in ({"message"} | override_keys)}
|
|
798
|
+
if context:
|
|
799
|
+
opts["context"] = context
|
|
800
|
+
|
|
801
|
+
# Execute the turn (inlined from old turn() method)
|
|
802
|
+
self._turn_count += 1
|
|
803
|
+
logger.debug(f"Agent '{self.name}' turn {self._turn_count}")
|
|
804
|
+
|
|
805
|
+
# Check for mock first (before any LLM calls)
|
|
806
|
+
if self.mock_manager and self.registry:
|
|
807
|
+
mock_response = self._get_mock_response(opts)
|
|
808
|
+
if mock_response is not None:
|
|
809
|
+
logger.debug(f"Agent '{self.name}' returning mock response")
|
|
810
|
+
return mock_response
|
|
811
|
+
|
|
812
|
+
# Auto-configure LM if not already configured
|
|
813
|
+
from tactus.dspy.config import get_current_lm, configure_lm
|
|
814
|
+
|
|
815
|
+
if get_current_lm() is None and self.model:
|
|
816
|
+
# Convert model format from "provider:model" to "provider/model" for LiteLLM
|
|
817
|
+
# Only replace the FIRST colon (provider separator), not all colons
|
|
818
|
+
# Bedrock model IDs like "us.anthropic.claude-haiku-4-5-20251001-v1:0" have a version suffix
|
|
819
|
+
model_for_litellm = self.model.replace(":", "/", 1) if ":" in self.model else self.model
|
|
820
|
+
logger.info(f"Auto-configuring DSPy LM with model: {model_for_litellm}")
|
|
821
|
+
|
|
822
|
+
# Build kwargs for configure_lm
|
|
823
|
+
config_kwargs = {}
|
|
824
|
+
if self.temperature is not None:
|
|
825
|
+
config_kwargs["temperature"] = self.temperature
|
|
826
|
+
if self.max_tokens is not None:
|
|
827
|
+
config_kwargs["max_tokens"] = self.max_tokens
|
|
828
|
+
if self.model_type is not None:
|
|
829
|
+
config_kwargs["model_type"] = self.model_type
|
|
830
|
+
|
|
831
|
+
configure_lm(model_for_litellm, **config_kwargs)
|
|
832
|
+
|
|
833
|
+
# Extract options
|
|
834
|
+
user_message = opts.get("message")
|
|
835
|
+
|
|
836
|
+
# Use initial_message on first turn if no inject provided
|
|
837
|
+
if self._turn_count == 1 and not user_message and self.initial_message:
|
|
838
|
+
user_message = self.initial_message
|
|
839
|
+
|
|
840
|
+
context = opts.get("context")
|
|
841
|
+
|
|
842
|
+
# Build the prompt context
|
|
843
|
+
prompt_context = {
|
|
844
|
+
"system_prompt": self.system_prompt,
|
|
845
|
+
"history": self._history.to_dspy(),
|
|
846
|
+
"user_message": user_message or "",
|
|
847
|
+
}
|
|
848
|
+
|
|
849
|
+
# Add available tools if agent has them
|
|
850
|
+
if self.tools or self.toolsets:
|
|
851
|
+
# Format tools for the prompt
|
|
852
|
+
tool_descriptions = []
|
|
853
|
+
if self.toolsets:
|
|
854
|
+
# Convert toolsets to strings if they're not already
|
|
855
|
+
toolset_names = [str(ts) if not isinstance(ts, str) else ts for ts in self.toolsets]
|
|
856
|
+
tool_descriptions.append(f"Available toolsets: {', '.join(toolset_names)}")
|
|
857
|
+
tool_descriptions.append(
|
|
858
|
+
"Use the 'done' tool with a 'reason' parameter to complete the task."
|
|
859
|
+
)
|
|
860
|
+
prompt_context["available_tools"] = (
|
|
861
|
+
"\n".join(tool_descriptions) if tool_descriptions else "No tools available"
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
# Add any injected context (user_message is already in prompt_context)
|
|
865
|
+
if context:
|
|
866
|
+
prompt_context["context"] = context
|
|
867
|
+
|
|
868
|
+
# Check if we should use streaming
|
|
869
|
+
if self._should_stream():
|
|
870
|
+
logger.debug(f"Agent '{self.name}' using streaming mode")
|
|
871
|
+
return self._turn_with_streaming(opts, prompt_context)
|
|
872
|
+
|
|
873
|
+
# Non-streaming execution
|
|
874
|
+
logger.debug(f"Agent '{self.name}' using non-streaming mode")
|
|
875
|
+
|
|
876
|
+
try:
|
|
877
|
+
return self._turn_without_streaming(opts, prompt_context)
|
|
878
|
+
except Exception as e:
|
|
879
|
+
logger.error(f"Agent '{self.name}' turn failed: {e}")
|
|
880
|
+
raise
|
|
881
|
+
|
|
882
|
+
def _get_mock_response(self, opts: Dict[str, Any]) -> Optional[TactusPrediction]:
|
|
883
|
+
"""
|
|
884
|
+
Check if this agent has a mock configured and return mock response.
|
|
885
|
+
|
|
886
|
+
Agent mocks are stored in registry.agent_mocks (not registry.mocks which is for tools).
|
|
887
|
+
Agent mock configs specify tool_calls, message, data, and usage.
|
|
888
|
+
|
|
889
|
+
Args:
|
|
890
|
+
opts: The turn options
|
|
891
|
+
|
|
892
|
+
Returns:
|
|
893
|
+
TactusPrediction if mocked, None otherwise
|
|
894
|
+
"""
|
|
895
|
+
agent_name = self.name
|
|
896
|
+
|
|
897
|
+
# Check if agent has a mock in the registry (agent_mocks, not mocks)
|
|
898
|
+
if not self.registry or agent_name not in self.registry.agent_mocks:
|
|
899
|
+
return None
|
|
900
|
+
|
|
901
|
+
# Get agent mock config from registry.agent_mocks
|
|
902
|
+
mock_config = self.registry.agent_mocks[agent_name]
|
|
903
|
+
|
|
904
|
+
temporal_turns = getattr(mock_config, "temporal", None) or []
|
|
905
|
+
if temporal_turns:
|
|
906
|
+
injected = opts.get("message")
|
|
907
|
+
|
|
908
|
+
selected_turn = None
|
|
909
|
+
if injected is not None:
|
|
910
|
+
for turn in temporal_turns:
|
|
911
|
+
if isinstance(turn, dict) and turn.get("when_message") == injected:
|
|
912
|
+
selected_turn = turn
|
|
913
|
+
break
|
|
914
|
+
|
|
915
|
+
if selected_turn is None:
|
|
916
|
+
idx = self._turn_count - 1 # 1-indexed turns
|
|
917
|
+
if idx < 0:
|
|
918
|
+
idx = 0
|
|
919
|
+
if idx >= len(temporal_turns):
|
|
920
|
+
idx = len(temporal_turns) - 1
|
|
921
|
+
selected_turn = temporal_turns[idx]
|
|
922
|
+
|
|
923
|
+
turn = selected_turn
|
|
924
|
+
if isinstance(turn, dict):
|
|
925
|
+
message = turn.get("message", mock_config.message)
|
|
926
|
+
tool_calls = turn.get("tool_calls", mock_config.tool_calls)
|
|
927
|
+
data = turn.get("data", mock_config.data)
|
|
928
|
+
else:
|
|
929
|
+
message = mock_config.message
|
|
930
|
+
tool_calls = mock_config.tool_calls
|
|
931
|
+
data = mock_config.data
|
|
932
|
+
else:
|
|
933
|
+
message = mock_config.message
|
|
934
|
+
tool_calls = mock_config.tool_calls
|
|
935
|
+
data = mock_config.data
|
|
936
|
+
|
|
937
|
+
# Convert AgentMockConfig to format expected by _wrap_mock_response.
|
|
938
|
+
# Important: we do NOT embed `data`/`usage` inside the prediction output by default.
|
|
939
|
+
# The canonical agent payload is `result.output`:
|
|
940
|
+
# - If the agent has an explicit output schema, we allow structured output via `data`.
|
|
941
|
+
# - Otherwise, `result.output` is the plain response string.
|
|
942
|
+
mock_data = {
|
|
943
|
+
"response": message,
|
|
944
|
+
"tool_calls": tool_calls,
|
|
945
|
+
}
|
|
946
|
+
|
|
947
|
+
if self.output_schema and data:
|
|
948
|
+
mock_data["data"] = data
|
|
949
|
+
|
|
950
|
+
try:
|
|
951
|
+
return self._wrap_mock_response(mock_data, opts)
|
|
952
|
+
except Exception:
|
|
953
|
+
# If wrapping throws an error, let it propagate
|
|
954
|
+
raise
|
|
955
|
+
|
|
956
|
+
def _wrap_mock_response(self, mock_data: Dict[str, Any], opts: Dict[str, Any]) -> TactusResult:
|
|
957
|
+
"""
|
|
958
|
+
Wrap mock data as a TactusResult.
|
|
959
|
+
|
|
960
|
+
Also handles special mock behaviors like recording done tool calls.
|
|
961
|
+
|
|
962
|
+
Args:
|
|
963
|
+
mock_data: The mock response data. Can contain either 'message' or 'response'
|
|
964
|
+
field for the text response. If 'message' is present and 'response'
|
|
965
|
+
is not, it will be normalized to 'response' to match the agent's
|
|
966
|
+
output signature.
|
|
967
|
+
opts: The turn options
|
|
968
|
+
|
|
969
|
+
Returns:
|
|
970
|
+
TactusResult with value, usage, and cost_stats (zeroed for mocks)
|
|
971
|
+
"""
|
|
972
|
+
from tactus.dspy.prediction import create_prediction
|
|
973
|
+
|
|
974
|
+
response_text = None
|
|
975
|
+
if "response" in mock_data and isinstance(mock_data.get("response"), str):
|
|
976
|
+
response_text = mock_data["response"]
|
|
977
|
+
elif "message" in mock_data and isinstance(mock_data.get("message"), str):
|
|
978
|
+
response_text = mock_data["message"]
|
|
979
|
+
else:
|
|
980
|
+
response_text = ""
|
|
981
|
+
|
|
982
|
+
# Track new messages for this turn
|
|
983
|
+
new_messages = []
|
|
984
|
+
|
|
985
|
+
# Determine user message
|
|
986
|
+
user_message = opts.get("message")
|
|
987
|
+
if self._turn_count == 1 and not user_message and self.initial_message:
|
|
988
|
+
user_message = self.initial_message
|
|
989
|
+
|
|
990
|
+
# Add user message to new_messages if present
|
|
991
|
+
if user_message:
|
|
992
|
+
user_msg = {"role": "user", "content": user_message}
|
|
993
|
+
new_messages.append(user_msg)
|
|
994
|
+
self._history.add(user_msg)
|
|
995
|
+
|
|
996
|
+
# Add assistant response to new_messages
|
|
997
|
+
if response_text:
|
|
998
|
+
assistant_msg = {"role": "assistant", "content": response_text}
|
|
999
|
+
new_messages.append(assistant_msg)
|
|
1000
|
+
self._history.add(assistant_msg)
|
|
1001
|
+
|
|
1002
|
+
prediction_fields: Dict[str, Any] = {}
|
|
1003
|
+
|
|
1004
|
+
tool_calls_list = mock_data.get("tool_calls", [])
|
|
1005
|
+
if tool_calls_list:
|
|
1006
|
+
prediction_fields["tool_calls"] = tool_calls_list
|
|
1007
|
+
|
|
1008
|
+
# If the agent has an explicit output schema, allow structured output via mock `data`.
|
|
1009
|
+
# Otherwise default to plain string output.
|
|
1010
|
+
data = mock_data.get("data")
|
|
1011
|
+
if self.output_schema and isinstance(data, dict) and data:
|
|
1012
|
+
prediction_fields.update(data)
|
|
1013
|
+
else:
|
|
1014
|
+
prediction_fields["response"] = response_text
|
|
1015
|
+
|
|
1016
|
+
# Add message tracking to prediction
|
|
1017
|
+
prediction_fields["__new_messages__"] = new_messages
|
|
1018
|
+
prediction_fields["__all_messages__"] = self._history.get()
|
|
1019
|
+
|
|
1020
|
+
# Create prediction from normalized mock data
|
|
1021
|
+
result = create_prediction(**prediction_fields)
|
|
1022
|
+
|
|
1023
|
+
# Record all tool calls from the mock
|
|
1024
|
+
# This allows mocks to trigger Tool.called(...) behavior
|
|
1025
|
+
# Use getattr since _tool_primitive is set externally by runtime
|
|
1026
|
+
tool_primitive = getattr(self, "_tool_primitive", None)
|
|
1027
|
+
if tool_calls_list and tool_primitive:
|
|
1028
|
+
if isinstance(tool_calls_list, list):
|
|
1029
|
+
for tool_call in tool_calls_list:
|
|
1030
|
+
if isinstance(tool_call, dict) and "tool" in tool_call:
|
|
1031
|
+
tool_name = tool_call["tool"]
|
|
1032
|
+
tool_args = tool_call.get("args", {})
|
|
1033
|
+
|
|
1034
|
+
# For done tool, extract reason for result
|
|
1035
|
+
if tool_name == "done":
|
|
1036
|
+
reason = tool_args.get(
|
|
1037
|
+
"reason",
|
|
1038
|
+
response_text or "Task completed (mocked)",
|
|
1039
|
+
)
|
|
1040
|
+
tool_result = {"status": "completed", "reason": reason, "tool": "done"}
|
|
1041
|
+
else:
|
|
1042
|
+
# For other tools, use a generic result
|
|
1043
|
+
tool_result = {"tool": tool_name, "args": tool_args}
|
|
1044
|
+
|
|
1045
|
+
logger.debug(f"Mock recording {tool_name} tool call")
|
|
1046
|
+
tool_primitive.record_call(
|
|
1047
|
+
tool_name,
|
|
1048
|
+
tool_args,
|
|
1049
|
+
tool_result,
|
|
1050
|
+
agent_name=self.name,
|
|
1051
|
+
)
|
|
1052
|
+
|
|
1053
|
+
# Return as TactusResult with zeroed usage/cost (mocks don't incur costs)
|
|
1054
|
+
return self._wrap_as_result(result, UsageStats(), CostStats())
|
|
1055
|
+
|
|
1056
|
+
def clear_history(self) -> None:
|
|
1057
|
+
"""Clear the conversation history."""
|
|
1058
|
+
self._history.clear()
|
|
1059
|
+
self._turn_count = 0
|
|
1060
|
+
|
|
1061
|
+
def get_history(self) -> List[Dict[str, Any]]:
|
|
1062
|
+
"""Get the conversation history."""
|
|
1063
|
+
return self._history.get()
|
|
1064
|
+
|
|
1065
|
+
@property
|
|
1066
|
+
def history(self) -> TactusHistory:
|
|
1067
|
+
"""Get the history object."""
|
|
1068
|
+
return self._history
|
|
1069
|
+
|
|
1070
|
+
|
|
1071
|
+
def create_dspy_agent(
|
|
1072
|
+
name: str,
|
|
1073
|
+
config: Dict[str, Any],
|
|
1074
|
+
registry: Any = None,
|
|
1075
|
+
mock_manager: Any = None,
|
|
1076
|
+
) -> DSPyAgentHandle:
|
|
1077
|
+
"""
|
|
1078
|
+
Create a DSPy-based Agent from configuration.
|
|
1079
|
+
|
|
1080
|
+
This is the main entry point for creating DSPy agents.
|
|
1081
|
+
|
|
1082
|
+
Args:
|
|
1083
|
+
name: Agent name
|
|
1084
|
+
config: Configuration dict with:
|
|
1085
|
+
- system_prompt: System prompt
|
|
1086
|
+
- model: Model name (LiteLLM format)
|
|
1087
|
+
- tools: List of tools
|
|
1088
|
+
- toolsets: List of toolset names
|
|
1089
|
+
- module: DSPy module type (default: "Predict"). Options: "Predict", "ChainOfThought"
|
|
1090
|
+
- Other optional configuration
|
|
1091
|
+
registry: Optional Registry instance for accessing mocks
|
|
1092
|
+
mock_manager: Optional MockManager instance for checking mocks
|
|
1093
|
+
|
|
1094
|
+
Returns:
|
|
1095
|
+
A DSPyAgentHandle instance
|
|
1096
|
+
|
|
1097
|
+
Raises:
|
|
1098
|
+
ValueError: If no LM is configured (either via config or globally)
|
|
1099
|
+
"""
|
|
1100
|
+
# Check if LM is configured either in config or globally
|
|
1101
|
+
from tactus.dspy.config import get_current_lm
|
|
1102
|
+
|
|
1103
|
+
if not config.get("model") and not get_current_lm():
|
|
1104
|
+
raise ValueError("LM not configured. Please configure an LM before creating an agent.")
|
|
1105
|
+
|
|
1106
|
+
return DSPyAgentHandle(
|
|
1107
|
+
name=name,
|
|
1108
|
+
system_prompt=config.get("system_prompt", ""),
|
|
1109
|
+
model=config.get("model"),
|
|
1110
|
+
provider=config.get("provider"),
|
|
1111
|
+
tools=config.get("tools", []),
|
|
1112
|
+
toolsets=config.get("toolsets", []),
|
|
1113
|
+
output_schema=config.get("output_schema") or config.get("output"),
|
|
1114
|
+
temperature=config.get("temperature", 0.7),
|
|
1115
|
+
max_tokens=config.get("max_tokens"),
|
|
1116
|
+
model_type=config.get("model_type"),
|
|
1117
|
+
module=config.get("module", "Raw"),
|
|
1118
|
+
initial_message=config.get("initial_message"),
|
|
1119
|
+
registry=registry,
|
|
1120
|
+
mock_manager=mock_manager,
|
|
1121
|
+
log_handler=config.get("log_handler"),
|
|
1122
|
+
disable_streaming=config.get("disable_streaming", False),
|
|
1123
|
+
**{
|
|
1124
|
+
k: v
|
|
1125
|
+
for k, v in config.items()
|
|
1126
|
+
if k
|
|
1127
|
+
not in [
|
|
1128
|
+
"system_prompt",
|
|
1129
|
+
"model",
|
|
1130
|
+
"provider",
|
|
1131
|
+
"tools",
|
|
1132
|
+
"toolsets",
|
|
1133
|
+
"output_schema",
|
|
1134
|
+
"output",
|
|
1135
|
+
"temperature",
|
|
1136
|
+
"max_tokens",
|
|
1137
|
+
"model_type",
|
|
1138
|
+
"module",
|
|
1139
|
+
"initial_message",
|
|
1140
|
+
"log_handler",
|
|
1141
|
+
"disable_streaming",
|
|
1142
|
+
]
|
|
1143
|
+
},
|
|
1144
|
+
)
|