ursa-ai 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ursa/__init__.py +3 -0
- ursa/agents/__init__.py +32 -0
- ursa/agents/acquisition_agents.py +812 -0
- ursa/agents/arxiv_agent.py +429 -0
- ursa/agents/base.py +728 -0
- ursa/agents/chat_agent.py +60 -0
- ursa/agents/code_review_agent.py +341 -0
- ursa/agents/execution_agent.py +915 -0
- ursa/agents/hypothesizer_agent.py +614 -0
- ursa/agents/lammps_agent.py +465 -0
- ursa/agents/mp_agent.py +204 -0
- ursa/agents/optimization_agent.py +410 -0
- ursa/agents/planning_agent.py +219 -0
- ursa/agents/rag_agent.py +304 -0
- ursa/agents/recall_agent.py +54 -0
- ursa/agents/websearch_agent.py +196 -0
- ursa/cli/__init__.py +363 -0
- ursa/cli/hitl.py +516 -0
- ursa/cli/hitl_api.py +75 -0
- ursa/observability/metrics_charts.py +1279 -0
- ursa/observability/metrics_io.py +11 -0
- ursa/observability/metrics_session.py +750 -0
- ursa/observability/pricing.json +97 -0
- ursa/observability/pricing.py +321 -0
- ursa/observability/timing.py +1466 -0
- ursa/prompt_library/__init__.py +0 -0
- ursa/prompt_library/code_review_prompts.py +51 -0
- ursa/prompt_library/execution_prompts.py +50 -0
- ursa/prompt_library/hypothesizer_prompts.py +17 -0
- ursa/prompt_library/literature_prompts.py +11 -0
- ursa/prompt_library/optimization_prompts.py +131 -0
- ursa/prompt_library/planning_prompts.py +79 -0
- ursa/prompt_library/websearch_prompts.py +131 -0
- ursa/tools/__init__.py +0 -0
- ursa/tools/feasibility_checker.py +114 -0
- ursa/tools/feasibility_tools.py +1075 -0
- ursa/tools/run_command.py +27 -0
- ursa/tools/write_code.py +42 -0
- ursa/util/__init__.py +0 -0
- ursa/util/diff_renderer.py +128 -0
- ursa/util/helperFunctions.py +142 -0
- ursa/util/logo_generator.py +625 -0
- ursa/util/memory_logger.py +183 -0
- ursa/util/optimization_schema.py +78 -0
- ursa/util/parse.py +405 -0
- ursa_ai-0.9.1.dist-info/METADATA +304 -0
- ursa_ai-0.9.1.dist-info/RECORD +51 -0
- ursa_ai-0.9.1.dist-info/WHEEL +5 -0
- ursa_ai-0.9.1.dist-info/entry_points.txt +2 -0
- ursa_ai-0.9.1.dist-info/licenses/LICENSE +8 -0
- ursa_ai-0.9.1.dist-info/top_level.txt +1 -0
ursa/agents/base.py
ADDED
|
@@ -0,0 +1,728 @@
|
|
|
1
|
+
"""Base agent class providing telemetry, configuration, and execution abstractions.
|
|
2
|
+
|
|
3
|
+
This module defines the BaseAgent abstract class, which serves as the foundation for all
|
|
4
|
+
agent implementations in the Ursa framework. It provides:
|
|
5
|
+
|
|
6
|
+
- Standardized initialization with LLM configuration
|
|
7
|
+
- Telemetry and metrics collection
|
|
8
|
+
- Thread and checkpoint management
|
|
9
|
+
- Input normalization and validation
|
|
10
|
+
- Execution flow control with invoke/stream methods
|
|
11
|
+
- Graph integration utilities for LangGraph compatibility
|
|
12
|
+
- Runtime enforcement of the agent interface contract
|
|
13
|
+
|
|
14
|
+
Agents built on this base class benefit from consistent behavior, observability, and
|
|
15
|
+
integration capabilities while only needing to implement the core _invoke method.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import re
|
|
19
|
+
from abc import ABC, abstractmethod
|
|
20
|
+
from typing import (
|
|
21
|
+
Any,
|
|
22
|
+
Callable,
|
|
23
|
+
Iterator,
|
|
24
|
+
Mapping,
|
|
25
|
+
Optional,
|
|
26
|
+
Sequence,
|
|
27
|
+
Union,
|
|
28
|
+
final,
|
|
29
|
+
)
|
|
30
|
+
from uuid import uuid4
|
|
31
|
+
|
|
32
|
+
from langchain.chat_models import BaseChatModel
|
|
33
|
+
from langchain_core.load import dumps
|
|
34
|
+
from langchain_core.messages import HumanMessage
|
|
35
|
+
from langchain_core.runnables import (
|
|
36
|
+
RunnableLambda,
|
|
37
|
+
)
|
|
38
|
+
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
39
|
+
from langgraph.graph import StateGraph
|
|
40
|
+
|
|
41
|
+
from ursa.observability.timing import (
|
|
42
|
+
Telemetry, # for timing / telemetry / metrics
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
InputLike = Union[str, Mapping[str, Any]]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _to_snake(s: str) -> str:
|
|
49
|
+
"""Convert a string to snake_case format.
|
|
50
|
+
|
|
51
|
+
This function transforms various string formats (CamelCase, PascalCase, etc.) into
|
|
52
|
+
snake_case. It handles special cases like acronyms at the beginning of strings
|
|
53
|
+
(e.g., "RAGAgent" becomes "rag_agent") and replaces hyphens and spaces with
|
|
54
|
+
underscores.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
s: The input string to convert to snake_case.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
The snake_case version of the input string.
|
|
61
|
+
"""
|
|
62
|
+
s = re.sub(
|
|
63
|
+
r"^([A-Z]{2,})([A-Z][a-z])",
|
|
64
|
+
lambda m: m.group(1)[0] + m.group(1)[1:].lower() + m.group(2),
|
|
65
|
+
str(s),
|
|
66
|
+
) # RAGAgent -> RagAgent
|
|
67
|
+
s = re.sub(r"(?<!^)(?=[A-Z])", "_", s) # CamelCase -> snake_case
|
|
68
|
+
s = s.replace("-", "_").replace(" ", "_")
|
|
69
|
+
return s.lower()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class BaseAgent(ABC):
|
|
73
|
+
"""Abstract base class for all agent implementations in the Ursa framework.
|
|
74
|
+
|
|
75
|
+
BaseAgent provides a standardized foundation for building LLM-powered agents with
|
|
76
|
+
built-in telemetry, configuration management, and execution flow control. It handles
|
|
77
|
+
common tasks like input normalization, thread management, metrics collection, and
|
|
78
|
+
LangGraph integration.
|
|
79
|
+
|
|
80
|
+
Subclasses only need to implement the _invoke method to define their core
|
|
81
|
+
functionality, while inheriting standardized invocation patterns, telemetry, and
|
|
82
|
+
graph integration capabilities. The class enforces a consistent interface through
|
|
83
|
+
runtime checks that prevent subclasses from overriding critical methods like
|
|
84
|
+
invoke().
|
|
85
|
+
|
|
86
|
+
The agent supports both direct invocation with inputs and streaming responses, with
|
|
87
|
+
automatic tracking of token usage, execution time, and other metrics. It also
|
|
88
|
+
provides utilities for integrating with LangGraph through node wrapping and
|
|
89
|
+
configuration.
|
|
90
|
+
|
|
91
|
+
Subclass Inheritance Guidelines:
|
|
92
|
+
- Must Override: _invoke() - Define your agent's core functionality
|
|
93
|
+
- Can Override: _stream() - Enable streaming support
|
|
94
|
+
_normalize_inputs() - Customize input handling
|
|
95
|
+
Various helper methods (_default_node_tags, _as_runnable, etc.)
|
|
96
|
+
- Never Override: invoke() - Final method with runtime enforcement
|
|
97
|
+
stream() - Handles telemetry and delegates to _stream
|
|
98
|
+
__call__() - Delegates to invoke
|
|
99
|
+
Other public methods (build_config, write_state, add_node)
|
|
100
|
+
|
|
101
|
+
To create a custom agent, inherit from this class and implement the _invoke method:
|
|
102
|
+
|
|
103
|
+
```python
|
|
104
|
+
class MyAgent(BaseAgent):
|
|
105
|
+
def _invoke(self, inputs: Mapping[str, Any], **config: Any) -> Any:
|
|
106
|
+
# Process inputs and return results
|
|
107
|
+
...
|
|
108
|
+
```
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
# This will be shared across all BaseAgent instances.
|
|
112
|
+
_invoke_depth: int = 0
|
|
113
|
+
|
|
114
|
+
_TELEMETRY_KW = {
|
|
115
|
+
"raw_debug",
|
|
116
|
+
"save_json",
|
|
117
|
+
"metrics_path",
|
|
118
|
+
"save_raw_snapshot",
|
|
119
|
+
"save_raw_records",
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
_CONTROL_KW = {"config", "recursion_limit", "tags", "metadata", "callbacks"}
|
|
123
|
+
|
|
124
|
+
def __init__(
|
|
125
|
+
self,
|
|
126
|
+
llm: BaseChatModel,
|
|
127
|
+
checkpointer: Optional[BaseCheckpointSaver] = None,
|
|
128
|
+
enable_metrics: bool = True,
|
|
129
|
+
metrics_dir: str = "ursa_metrics", # dir to save metrics, with a default
|
|
130
|
+
autosave_metrics: bool = True,
|
|
131
|
+
thread_id: Optional[str] = None,
|
|
132
|
+
):
|
|
133
|
+
self.llm = llm
|
|
134
|
+
"""Initializes the base agent with a language model and optional configurations.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
llm: a BaseChatModel instance.
|
|
138
|
+
checkpointer: Optional checkpoint saver for persisting agent state.
|
|
139
|
+
enable_metrics: Whether to collect performance and usage metrics.
|
|
140
|
+
metrics_dir: Directory path where metrics will be saved.
|
|
141
|
+
autosave_metrics: Whether to automatically save metrics to disk.
|
|
142
|
+
thread_id: Unique identifier for this agent instance. Generated if not
|
|
143
|
+
provided.
|
|
144
|
+
"""
|
|
145
|
+
self.thread_id = thread_id or uuid4().hex
|
|
146
|
+
self.checkpointer = checkpointer
|
|
147
|
+
self.telemetry = Telemetry(
|
|
148
|
+
enable=enable_metrics,
|
|
149
|
+
output_dir=metrics_dir,
|
|
150
|
+
save_json_default=autosave_metrics,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def name(self) -> str:
|
|
155
|
+
"""Agent name."""
|
|
156
|
+
return self.__class__.__name__
|
|
157
|
+
|
|
158
|
+
def add_node(
|
|
159
|
+
self,
|
|
160
|
+
graph: StateGraph,
|
|
161
|
+
f: Callable[..., Mapping[str, Any]],
|
|
162
|
+
node_name: Optional[str] = None,
|
|
163
|
+
agent_name: Optional[str] = None,
|
|
164
|
+
) -> StateGraph:
|
|
165
|
+
"""Add a node to the state graph with token usage tracking.
|
|
166
|
+
|
|
167
|
+
This method adds a function as a node to the state graph, wrapping it to track
|
|
168
|
+
token usage during execution. The node is identified by either the provided
|
|
169
|
+
node_name or the function's name.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
graph: The StateGraph to add the node to.
|
|
173
|
+
f: The function to add as a node. Should return a mapping of string keys to
|
|
174
|
+
any values.
|
|
175
|
+
node_name: Optional name for the node. If not provided, the function's name
|
|
176
|
+
will be used.
|
|
177
|
+
agent_name: Optional agent name for tracking. If not provided, the agent's
|
|
178
|
+
name in snake_case will be used.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
The updated StateGraph with the new node added.
|
|
182
|
+
"""
|
|
183
|
+
_node_name = node_name or f.__name__
|
|
184
|
+
_agent_name = agent_name or _to_snake(self.name)
|
|
185
|
+
wrapped_node = self._wrap_node(f, _node_name, _agent_name)
|
|
186
|
+
|
|
187
|
+
return graph.add_node(_node_name, wrapped_node)
|
|
188
|
+
|
|
189
|
+
def write_state(self, filename: str, state: dict) -> None:
|
|
190
|
+
"""Writes agent state to a JSON file.
|
|
191
|
+
|
|
192
|
+
Serializes the provided state dictionary to JSON format and writes it to the
|
|
193
|
+
specified file. The JSON is written with non-ASCII characters preserved.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
filename: Path to the file where state will be written.
|
|
197
|
+
state: Dictionary containing the agent state to be serialized.
|
|
198
|
+
"""
|
|
199
|
+
json_state = dumps(state, ensure_ascii=False)
|
|
200
|
+
with open(filename, "w") as f:
|
|
201
|
+
f.write(json_state)
|
|
202
|
+
|
|
203
|
+
def build_config(self, **overrides) -> dict:
|
|
204
|
+
"""Constructs a config dictionary for agent operations with telemetry support.
|
|
205
|
+
|
|
206
|
+
This method creates a standardized configuration dictionary that includes thread
|
|
207
|
+
identification, telemetry callbacks, and other metadata needed for agent
|
|
208
|
+
operations. The configuration can be customized through override parameters.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
**overrides: Optional configuration overrides that can include keys like
|
|
212
|
+
'recursion_limit', 'configurable', 'metadata', 'tags', etc.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
dict: A complete configuration dictionary with all necessary parameters.
|
|
216
|
+
"""
|
|
217
|
+
# Create the base configuration with essential fields.
|
|
218
|
+
base = {
|
|
219
|
+
"configurable": {"thread_id": self.thread_id},
|
|
220
|
+
"metadata": {
|
|
221
|
+
"thread_id": self.thread_id,
|
|
222
|
+
"telemetry_run_id": self.telemetry.context.get("run_id"),
|
|
223
|
+
},
|
|
224
|
+
"tags": [self.name],
|
|
225
|
+
"callbacks": self.telemetry.callbacks,
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
# Try to determine the model name from either direct or nested attributes
|
|
229
|
+
model_name = getattr(self, "llm_model", None) or getattr(
|
|
230
|
+
getattr(self, "llm", None), "model", None
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
# Add model name to metadata if available
|
|
234
|
+
if model_name:
|
|
235
|
+
base["metadata"]["model"] = model_name
|
|
236
|
+
|
|
237
|
+
# Handle configurable dictionary overrides by merging with base configurable
|
|
238
|
+
if "configurable" in overrides and isinstance(
|
|
239
|
+
overrides["configurable"], dict
|
|
240
|
+
):
|
|
241
|
+
base["configurable"].update(overrides.pop("configurable"))
|
|
242
|
+
|
|
243
|
+
# Handle metadata dictionary overrides by merging with base metadata
|
|
244
|
+
if "metadata" in overrides and isinstance(overrides["metadata"], dict):
|
|
245
|
+
base["metadata"].update(overrides.pop("metadata"))
|
|
246
|
+
|
|
247
|
+
# Merge tags from caller-provided overrides, avoid duplicates
|
|
248
|
+
if "tags" in overrides and isinstance(overrides["tags"], list):
|
|
249
|
+
base["tags"] = base["tags"] + [
|
|
250
|
+
t for t in overrides.pop("tags") if t not in base["tags"]
|
|
251
|
+
]
|
|
252
|
+
|
|
253
|
+
# Apply any remaining overrides directly to the base configuration
|
|
254
|
+
base.update(overrides)
|
|
255
|
+
|
|
256
|
+
return base
|
|
257
|
+
|
|
258
|
+
def _invoke_engine(
|
|
259
|
+
self,
|
|
260
|
+
invoke_method,
|
|
261
|
+
inputs: Optional[InputLike] = None,
|
|
262
|
+
raw_debug: bool = False,
|
|
263
|
+
save_json: Optional[bool] = None,
|
|
264
|
+
metrics_path: Optional[str] = None,
|
|
265
|
+
save_raw_snapshot: Optional[bool] = None,
|
|
266
|
+
save_raw_records: Optional[bool] = None,
|
|
267
|
+
config: Optional[dict] = None,
|
|
268
|
+
**kwargs: Any,
|
|
269
|
+
):
|
|
270
|
+
BaseAgent._invoke_depth += 1
|
|
271
|
+
|
|
272
|
+
try:
|
|
273
|
+
# Start telemetry tracking for the top-level invocation
|
|
274
|
+
if BaseAgent._invoke_depth == 1:
|
|
275
|
+
self.telemetry.begin_run(
|
|
276
|
+
agent=self.name, thread_id=self.thread_id
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# Handle the case where inputs are provided as keyword arguments
|
|
280
|
+
if inputs is None:
|
|
281
|
+
# Separate kwargs into input parameters and control parameters
|
|
282
|
+
kw_inputs: dict[str, Any] = {}
|
|
283
|
+
control_kwargs: dict[str, Any] = {}
|
|
284
|
+
for k, v in kwargs.items():
|
|
285
|
+
if k in self._TELEMETRY_KW or k in self._CONTROL_KW:
|
|
286
|
+
control_kwargs[k] = v
|
|
287
|
+
else:
|
|
288
|
+
kw_inputs[k] = v
|
|
289
|
+
inputs = kw_inputs
|
|
290
|
+
|
|
291
|
+
# Only control kwargs remain for further processing
|
|
292
|
+
kwargs = control_kwargs
|
|
293
|
+
|
|
294
|
+
# Handle the case where inputs are provided as a positional argument
|
|
295
|
+
else:
|
|
296
|
+
# Ensure no ambiguous keyword arguments are present
|
|
297
|
+
for k in kwargs.keys():
|
|
298
|
+
if not (k in self._TELEMETRY_KW or k in self._CONTROL_KW):
|
|
299
|
+
raise TypeError(
|
|
300
|
+
f"Unexpected keyword argument '{k}'. "
|
|
301
|
+
"Pass inputs as a single mapping or omit the positional "
|
|
302
|
+
"inputs and pass them as keyword arguments."
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
# Allow subclasses to normalize or transform the input format
|
|
306
|
+
normalized = self._normalize_inputs(inputs)
|
|
307
|
+
|
|
308
|
+
# Delegate to the subclass implementation with the normalized inputs
|
|
309
|
+
# and any control parameters
|
|
310
|
+
return invoke_method(normalized, config=config, **kwargs)
|
|
311
|
+
|
|
312
|
+
finally:
|
|
313
|
+
# Clean up the invocation depth tracking
|
|
314
|
+
BaseAgent._invoke_depth -= 1
|
|
315
|
+
|
|
316
|
+
# For the top-level invocation, finalize telemetry and generate outputs
|
|
317
|
+
if BaseAgent._invoke_depth == 0:
|
|
318
|
+
self.telemetry.render(
|
|
319
|
+
raw=raw_debug,
|
|
320
|
+
save_json=save_json,
|
|
321
|
+
filepath=metrics_path,
|
|
322
|
+
save_raw_snapshot=save_raw_snapshot,
|
|
323
|
+
save_raw_records=save_raw_records,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
# NOTE: The `invoke` method uses the PEP 570 `/,*` notation to explicitly state which
|
|
327
|
+
# arguments can and cannot be passed as positional or keyword arguments.
|
|
328
|
+
@final
|
|
329
|
+
def invoke(
|
|
330
|
+
self,
|
|
331
|
+
inputs: Optional[InputLike] = None,
|
|
332
|
+
/,
|
|
333
|
+
*,
|
|
334
|
+
raw_debug: bool = False,
|
|
335
|
+
save_json: Optional[bool] = None,
|
|
336
|
+
metrics_path: Optional[str] = None,
|
|
337
|
+
save_raw_snapshot: Optional[bool] = None,
|
|
338
|
+
save_raw_records: Optional[bool] = None,
|
|
339
|
+
config: Optional[dict] = None,
|
|
340
|
+
**kwargs: Any,
|
|
341
|
+
) -> Any:
|
|
342
|
+
"""Executes the agent with the provided inputs and configuration.
|
|
343
|
+
|
|
344
|
+
This is the main entry point for agent execution. It handles input normalization,
|
|
345
|
+
telemetry tracking, and proper execution context management. The method supports
|
|
346
|
+
flexible input formats - either as a positional argument or as keyword arguments.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
inputs: Optional positional input to the agent. If provided, all non-control
|
|
350
|
+
keyword arguments will be rejected to avoid ambiguity.
|
|
351
|
+
raw_debug: If True, displays raw telemetry data for debugging purposes.
|
|
352
|
+
save_json: If True, saves telemetry data as JSON.
|
|
353
|
+
metrics_path: Optional file path where telemetry metrics should be saved.
|
|
354
|
+
save_raw_snapshot: If True, saves a raw snapshot of the telemetry data.
|
|
355
|
+
save_raw_records: If True, saves raw telemetry records.
|
|
356
|
+
config: Optional configuration dictionary to override default settings.
|
|
357
|
+
**kwargs: Additional keyword arguments that can be either:
|
|
358
|
+
- Input parameters (when no positional input is provided)
|
|
359
|
+
- Control parameters recognized by the agent
|
|
360
|
+
|
|
361
|
+
Returns:
|
|
362
|
+
The result of the agent's execution.
|
|
363
|
+
|
|
364
|
+
Raises:
|
|
365
|
+
TypeError: If both positional inputs and non-control keyword arguments are
|
|
366
|
+
provided simultaneously.
|
|
367
|
+
"""
|
|
368
|
+
return self._invoke_engine(
|
|
369
|
+
invoke_method=self._invoke,
|
|
370
|
+
inputs=inputs,
|
|
371
|
+
raw_debug=raw_debug,
|
|
372
|
+
save_json=save_json,
|
|
373
|
+
metrics_path=metrics_path,
|
|
374
|
+
save_raw_snapshot=save_raw_snapshot,
|
|
375
|
+
save_raw_records=save_raw_records,
|
|
376
|
+
config=config,
|
|
377
|
+
**kwargs,
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
# NOTE: The `ainvoke` method uses the PEP 570 `/,*` notation to explicitly state which
|
|
381
|
+
# arguments can and cannot be passed as positional or keyword arguments.
|
|
382
|
+
@final
|
|
383
|
+
def ainvoke(
|
|
384
|
+
self,
|
|
385
|
+
inputs: Optional[InputLike] = None,
|
|
386
|
+
/,
|
|
387
|
+
*,
|
|
388
|
+
raw_debug: bool = False,
|
|
389
|
+
save_json: Optional[bool] = None,
|
|
390
|
+
metrics_path: Optional[str] = None,
|
|
391
|
+
save_raw_snapshot: Optional[bool] = None,
|
|
392
|
+
save_raw_records: Optional[bool] = None,
|
|
393
|
+
config: Optional[dict] = None,
|
|
394
|
+
**kwargs: Any,
|
|
395
|
+
) -> Any:
|
|
396
|
+
"""Asynchrnously executes the agent with the provided inputs and configuration.
|
|
397
|
+
|
|
398
|
+
(Async version of `invoke`.)
|
|
399
|
+
|
|
400
|
+
This is the main entry point for agent execution. It handles input normalization,
|
|
401
|
+
telemetry tracking, and proper execution context management. The method supports
|
|
402
|
+
flexible input formats - either as a positional argument or as keyword arguments.
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
inputs: Optional positional input to the agent. If provided, all non-control
|
|
406
|
+
keyword arguments will be rejected to avoid ambiguity.
|
|
407
|
+
raw_debug: If True, displays raw telemetry data for debugging purposes.
|
|
408
|
+
save_json: If True, saves telemetry data as JSON.
|
|
409
|
+
metrics_path: Optional file path where telemetry metrics should be saved.
|
|
410
|
+
save_raw_snapshot: If True, saves a raw snapshot of the telemetry data.
|
|
411
|
+
save_raw_records: If True, saves raw telemetry records.
|
|
412
|
+
config: Optional configuration dictionary to override default settings.
|
|
413
|
+
**kwargs: Additional keyword arguments that can be either:
|
|
414
|
+
- Input parameters (when no positional input is provided)
|
|
415
|
+
- Control parameters recognized by the agent
|
|
416
|
+
|
|
417
|
+
Returns:
|
|
418
|
+
The result of the agent's execution.
|
|
419
|
+
|
|
420
|
+
Raises:
|
|
421
|
+
TypeError: If both positional inputs and non-control keyword arguments are
|
|
422
|
+
provided simultaneously.
|
|
423
|
+
"""
|
|
424
|
+
return self._invoke_engine(
|
|
425
|
+
invoke_method=self._ainvoke,
|
|
426
|
+
inputs=inputs,
|
|
427
|
+
raw_debug=raw_debug,
|
|
428
|
+
save_json=save_json,
|
|
429
|
+
metrics_path=metrics_path,
|
|
430
|
+
save_raw_snapshot=save_raw_snapshot,
|
|
431
|
+
save_raw_records=save_raw_records,
|
|
432
|
+
config=config,
|
|
433
|
+
**kwargs,
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
def _normalize_inputs(self, inputs: InputLike) -> Mapping[str, Any]:
|
|
437
|
+
"""Normalizes various input formats into a standardized mapping.
|
|
438
|
+
|
|
439
|
+
This method converts different input types into a consistent dictionary format
|
|
440
|
+
that can be processed by the agent. String inputs are wrapped as messages, while
|
|
441
|
+
mappings are passed through unchanged.
|
|
442
|
+
|
|
443
|
+
Args:
|
|
444
|
+
inputs: The input to normalize. Can be a string (which will be converted to a
|
|
445
|
+
message) or a mapping (which will be returned as-is).
|
|
446
|
+
|
|
447
|
+
Returns:
|
|
448
|
+
A mapping containing the normalized inputs, with keys appropriate for agent
|
|
449
|
+
processing.
|
|
450
|
+
|
|
451
|
+
Raises:
|
|
452
|
+
TypeError: If the input type is not supported (neither string nor mapping).
|
|
453
|
+
"""
|
|
454
|
+
if isinstance(inputs, str):
|
|
455
|
+
# Adjust to your message type
|
|
456
|
+
return {"messages": [HumanMessage(content=inputs)]}
|
|
457
|
+
if isinstance(inputs, Mapping):
|
|
458
|
+
return inputs
|
|
459
|
+
raise TypeError(f"Unsupported input type: {type(inputs)}")
|
|
460
|
+
|
|
461
|
+
@abstractmethod
|
|
462
|
+
def _invoke(self, inputs: Mapping[str, Any], **config: Any) -> Any:
|
|
463
|
+
"""Subclasses implement the actual work against normalized inputs."""
|
|
464
|
+
...
|
|
465
|
+
|
|
466
|
+
def _ainvoke(self, inputs: Mapping[str, Any], **config: Any) -> Any:
|
|
467
|
+
"""Subclasses implement the actual work against normalized inputs."""
|
|
468
|
+
...
|
|
469
|
+
|
|
470
|
+
def __call__(self, inputs: InputLike, /, **kwargs: Any) -> Any:
|
|
471
|
+
"""Specify calling behavior for class instance."""
|
|
472
|
+
return self.invoke(inputs, **kwargs)
|
|
473
|
+
|
|
474
|
+
# Runtime enforcement: forbid subclasses from overriding invoke
|
|
475
|
+
def __init_subclass__(cls, **kwargs):
|
|
476
|
+
"""Ensure subclass does not override key method."""
|
|
477
|
+
super().__init_subclass__(**kwargs)
|
|
478
|
+
if "invoke" in cls.__dict__:
|
|
479
|
+
err_msg = (
|
|
480
|
+
f"{cls.__name__} must not override BaseAgent.invoke(); "
|
|
481
|
+
"implement _invoke() only."
|
|
482
|
+
)
|
|
483
|
+
raise TypeError(err_msg)
|
|
484
|
+
|
|
485
|
+
def stream(
|
|
486
|
+
self,
|
|
487
|
+
inputs: InputLike,
|
|
488
|
+
config: Any | None = None, # allow positional/keyword like LangGraph
|
|
489
|
+
/,
|
|
490
|
+
*,
|
|
491
|
+
raw_debug: bool = False,
|
|
492
|
+
save_json: bool | None = None,
|
|
493
|
+
metrics_path: str | None = None,
|
|
494
|
+
save_raw_snapshot: bool | None = None,
|
|
495
|
+
save_raw_records: bool | None = None,
|
|
496
|
+
**kwargs: Any,
|
|
497
|
+
) -> Iterator[Any]:
|
|
498
|
+
"""Streams agent responses with telemetry tracking.
|
|
499
|
+
|
|
500
|
+
This method serves as the public streaming entry point for agent interactions.
|
|
501
|
+
It wraps the actual streaming implementation with telemetry tracking to capture
|
|
502
|
+
metrics and debugging information.
|
|
503
|
+
|
|
504
|
+
Args:
|
|
505
|
+
inputs: The input to process, which will be normalized internally.
|
|
506
|
+
config: Optional configuration for the agent, compatible with LangGraph
|
|
507
|
+
positional/keyword argument style.
|
|
508
|
+
raw_debug: If True, renders raw debug information in telemetry output.
|
|
509
|
+
save_json: If True, saves telemetry data as JSON.
|
|
510
|
+
metrics_path: Optional file path where metrics should be saved.
|
|
511
|
+
save_raw_snapshot: If True, saves raw snapshot data in telemetry.
|
|
512
|
+
save_raw_records: If True, saves raw record data in telemetry.
|
|
513
|
+
**kwargs: Additional keyword arguments passed to the streaming
|
|
514
|
+
implementation.
|
|
515
|
+
|
|
516
|
+
Returns:
|
|
517
|
+
An iterator yielding the agent's responses.
|
|
518
|
+
|
|
519
|
+
Note:
|
|
520
|
+
This method tracks invocation depth to properly handle nested agent calls
|
|
521
|
+
and ensure telemetry is only rendered once at the top level.
|
|
522
|
+
"""
|
|
523
|
+
# Track invocation depth to handle nested agent calls
|
|
524
|
+
BaseAgent._invoke_depth += 1
|
|
525
|
+
|
|
526
|
+
try:
|
|
527
|
+
# Start telemetry tracking for top-level invocations only
|
|
528
|
+
if BaseAgent._invoke_depth == 1:
|
|
529
|
+
self.telemetry.begin_run(
|
|
530
|
+
agent=self.name, thread_id=self.thread_id
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
# Normalize inputs and delegate to the actual streaming implementation
|
|
534
|
+
normalized = self._normalize_inputs(inputs)
|
|
535
|
+
yield from self._stream(normalized, config=config, **kwargs)
|
|
536
|
+
|
|
537
|
+
finally:
|
|
538
|
+
# Decrement invocation depth when exiting
|
|
539
|
+
BaseAgent._invoke_depth -= 1
|
|
540
|
+
|
|
541
|
+
# Render telemetry data only for top-level invocations
|
|
542
|
+
if BaseAgent._invoke_depth == 0:
|
|
543
|
+
self.telemetry.render(
|
|
544
|
+
raw=raw_debug,
|
|
545
|
+
save_json=save_json,
|
|
546
|
+
filepath=metrics_path,
|
|
547
|
+
save_raw_snapshot=save_raw_snapshot,
|
|
548
|
+
save_raw_records=save_raw_records,
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
def _stream(
|
|
552
|
+
self,
|
|
553
|
+
inputs: Mapping[str, Any],
|
|
554
|
+
*,
|
|
555
|
+
config: Any | None = None,
|
|
556
|
+
**kwargs: Any,
|
|
557
|
+
) -> Iterator[Any]:
|
|
558
|
+
"""Subclass method to be overwritten for streaming implementation."""
|
|
559
|
+
raise NotImplementedError(
|
|
560
|
+
f"{self.name} does not support streaming. "
|
|
561
|
+
"Override _stream(...) in your agent to enable it."
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
def _default_node_tags(
|
|
565
|
+
self, name: str, extra: Sequence[str] | None = None
|
|
566
|
+
) -> list[str]:
|
|
567
|
+
"""Generate default tags for a graph node.
|
|
568
|
+
|
|
569
|
+
Args:
|
|
570
|
+
name: The name of the node.
|
|
571
|
+
extra: Optional sequence of additional tags to include.
|
|
572
|
+
|
|
573
|
+
Returns:
|
|
574
|
+
list[str]: A list of tags for the node, including the agent name, 'graph',
|
|
575
|
+
the node name, and any extra tags provided.
|
|
576
|
+
"""
|
|
577
|
+
# Start with standard tags: agent name, graph indicator, and node name
|
|
578
|
+
tags = [self.name, "graph", name]
|
|
579
|
+
|
|
580
|
+
# Add any extra tags if provided
|
|
581
|
+
if extra:
|
|
582
|
+
tags.extend(extra)
|
|
583
|
+
|
|
584
|
+
return tags
|
|
585
|
+
|
|
586
|
+
def _as_runnable(self, fn: Any):
|
|
587
|
+
"""Convert a function to a runnable if it isn't already.
|
|
588
|
+
|
|
589
|
+
Args:
|
|
590
|
+
fn: The function or object to convert to a runnable.
|
|
591
|
+
|
|
592
|
+
Returns:
|
|
593
|
+
A runnable object that can be used in the graph. If the input is already
|
|
594
|
+
runnable (has .with_config and .invoke methods), it's returned as is.
|
|
595
|
+
Otherwise, it's wrapped in a RunnableLambda.
|
|
596
|
+
"""
|
|
597
|
+
# Check if the function already has the required runnable interface
|
|
598
|
+
# If so, return it as is; otherwise wrap it in a RunnableLambda
|
|
599
|
+
return (
|
|
600
|
+
fn
|
|
601
|
+
if hasattr(fn, "with_config") and hasattr(fn, "invoke")
|
|
602
|
+
else RunnableLambda(fn)
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
def _node_cfg(self, name: str, *extra_tags: str) -> dict:
|
|
606
|
+
"""Build a consistent configuration for a node/runnable.
|
|
607
|
+
|
|
608
|
+
Creates a configuration dict that can be reapplied after operations like
|
|
609
|
+
.map(), subgraph compile, etc.
|
|
610
|
+
|
|
611
|
+
Args:
|
|
612
|
+
name: The name of the node.
|
|
613
|
+
*extra_tags: Additional tags to include in the node configuration.
|
|
614
|
+
|
|
615
|
+
Returns:
|
|
616
|
+
dict: A configuration dictionary with run_name, tags, and metadata.
|
|
617
|
+
"""
|
|
618
|
+
# Determine the namespace - use first extra tag if available, otherwise
|
|
619
|
+
# convert agent name to snake_case
|
|
620
|
+
ns = extra_tags[0] if extra_tags else _to_snake(self.name)
|
|
621
|
+
|
|
622
|
+
# Combine all tags: agent name, graph indicator, node name, and any extra tags
|
|
623
|
+
tags = [self.name, "graph", name, *extra_tags]
|
|
624
|
+
|
|
625
|
+
# Return the complete configuration dictionary
|
|
626
|
+
return dict(
|
|
627
|
+
run_name="node", # keep "node:" prefixing in the timer
|
|
628
|
+
tags=tags,
|
|
629
|
+
metadata={
|
|
630
|
+
"langgraph_node": name,
|
|
631
|
+
"ursa_ns": ns,
|
|
632
|
+
"ursa_agent": self.name,
|
|
633
|
+
},
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
def ns(self, runnable_or_fn, name: str, *extra_tags: str):
|
|
637
|
+
"""Return a runnable with node configuration applied.
|
|
638
|
+
|
|
639
|
+
Applies the agent's node configuration to a runnable or callable. This method
|
|
640
|
+
should be called again after operations like .map() or subgraph .compile() as
|
|
641
|
+
these operations may drop configuration.
|
|
642
|
+
|
|
643
|
+
Args:
|
|
644
|
+
runnable_or_fn: A runnable or callable to configure.
|
|
645
|
+
name: The name to assign to this node.
|
|
646
|
+
*extra_tags: Additional tags to apply to the node.
|
|
647
|
+
|
|
648
|
+
Returns:
|
|
649
|
+
A configured runnable with the agent's node configuration applied.
|
|
650
|
+
"""
|
|
651
|
+
# Convert input to a runnable if it's not already one
|
|
652
|
+
r = self._as_runnable(runnable_or_fn)
|
|
653
|
+
# Apply node configuration and return the configured runnable
|
|
654
|
+
return r.with_config(**self._node_cfg(name, *extra_tags))
|
|
655
|
+
|
|
656
|
+
def _wrap_node(self, fn_or_runnable, name: str, *extra_tags: str):
|
|
657
|
+
"""Wrap a function or runnable as a node in the graph.
|
|
658
|
+
|
|
659
|
+
This is a convenience wrapper around the ns() method.
|
|
660
|
+
|
|
661
|
+
Args:
|
|
662
|
+
fn_or_runnable: A function or runnable to wrap as a node.
|
|
663
|
+
name: The name to assign to this node.
|
|
664
|
+
*extra_tags: Additional tags to apply to the node.
|
|
665
|
+
|
|
666
|
+
Returns:
|
|
667
|
+
A configured runnable with the agent's node configuration applied.
|
|
668
|
+
"""
|
|
669
|
+
return self.ns(fn_or_runnable, name, *extra_tags)
|
|
670
|
+
|
|
671
|
+
def _wrap_cond(self, fn: Any, name: str, *extra_tags: str):
|
|
672
|
+
"""Wrap a conditional function as a routing node in the graph.
|
|
673
|
+
|
|
674
|
+
Creates a runnable lambda with routing-specific configuration.
|
|
675
|
+
|
|
676
|
+
Args:
|
|
677
|
+
fn: The conditional function to wrap.
|
|
678
|
+
name: The name of the routing node.
|
|
679
|
+
*extra_tags: Additional tags to apply to the node.
|
|
680
|
+
|
|
681
|
+
Returns:
|
|
682
|
+
A configured RunnableLambda with routing-specific metadata.
|
|
683
|
+
"""
|
|
684
|
+
# Use the first extra tag as namespace, or fall back to agent name in snake_case
|
|
685
|
+
ns = extra_tags[0] if extra_tags else _to_snake(self.name)
|
|
686
|
+
|
|
687
|
+
# Create and return a configured RunnableLambda for routing
|
|
688
|
+
return RunnableLambda(fn).with_config(
|
|
689
|
+
run_name="node",
|
|
690
|
+
tags=[
|
|
691
|
+
self.name,
|
|
692
|
+
"graph",
|
|
693
|
+
f"route:{name}",
|
|
694
|
+
*extra_tags,
|
|
695
|
+
],
|
|
696
|
+
metadata={
|
|
697
|
+
"langgraph_node": f"route:{name}",
|
|
698
|
+
"ursa_ns": ns,
|
|
699
|
+
"ursa_agent": self.name,
|
|
700
|
+
},
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
def _named(self, runnable: Any, name: str, *extra_tags: str):
|
|
704
|
+
"""Apply a specific name and configuration to a runnable.
|
|
705
|
+
|
|
706
|
+
Configures a runnable with a specific name and the agent's metadata.
|
|
707
|
+
|
|
708
|
+
Args:
|
|
709
|
+
runnable: The runnable to configure.
|
|
710
|
+
name: The name to assign to this runnable.
|
|
711
|
+
*extra_tags: Additional tags to apply to the runnable.
|
|
712
|
+
|
|
713
|
+
Returns:
|
|
714
|
+
A configured runnable with the specified name and agent metadata.
|
|
715
|
+
"""
|
|
716
|
+
# Use the first extra tag as namespace, or fall back to agent name in snake_case
|
|
717
|
+
ns = extra_tags[0] if extra_tags else _to_snake(self.name)
|
|
718
|
+
|
|
719
|
+
# Apply configuration and return the configured runnable
|
|
720
|
+
return runnable.with_config(
|
|
721
|
+
run_name=name,
|
|
722
|
+
tags=[self.name, "graph", name, *extra_tags],
|
|
723
|
+
metadata={
|
|
724
|
+
"langgraph_node": name,
|
|
725
|
+
"ursa_ns": ns,
|
|
726
|
+
"ursa_agent": self.name,
|
|
727
|
+
},
|
|
728
|
+
)
|