ursa-ai 0.7.0rc2__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ursa-ai might be problematic. Click here for more details.
- ursa/agents/__init__.py +13 -2
- ursa/agents/acquisition_agents.py +812 -0
- ursa/agents/arxiv_agent.py +1 -1
- ursa/agents/base.py +352 -91
- ursa/agents/chat_agent.py +58 -0
- ursa/agents/execution_agent.py +506 -260
- ursa/agents/lammps_agent.py +81 -31
- ursa/agents/planning_agent.py +27 -2
- ursa/agents/websearch_agent.py +2 -2
- ursa/cli/__init__.py +5 -1
- ursa/cli/hitl.py +46 -34
- ursa/observability/pricing.json +85 -0
- ursa/observability/pricing.py +20 -18
- ursa/util/parse.py +316 -0
- {ursa_ai-0.7.0rc2.dist-info → ursa_ai-0.7.1.dist-info}/METADATA +5 -1
- {ursa_ai-0.7.0rc2.dist-info → ursa_ai-0.7.1.dist-info}/RECORD +20 -17
- {ursa_ai-0.7.0rc2.dist-info → ursa_ai-0.7.1.dist-info}/WHEEL +0 -0
- {ursa_ai-0.7.0rc2.dist-info → ursa_ai-0.7.1.dist-info}/entry_points.txt +0 -0
- {ursa_ai-0.7.0rc2.dist-info → ursa_ai-0.7.1.dist-info}/licenses/LICENSE +0 -0
- {ursa_ai-0.7.0rc2.dist-info → ursa_ai-0.7.1.dist-info}/top_level.txt +0 -0
ursa/agents/arxiv_agent.py
CHANGED
ursa/agents/base.py
CHANGED
|
@@ -1,3 +1,20 @@
|
|
|
1
|
+
"""Base agent class providing telemetry, configuration, and execution abstractions.
|
|
2
|
+
|
|
3
|
+
This module defines the BaseAgent abstract class, which serves as the foundation for all
|
|
4
|
+
agent implementations in the Ursa framework. It provides:
|
|
5
|
+
|
|
6
|
+
- Standardized initialization with LLM configuration
|
|
7
|
+
- Telemetry and metrics collection
|
|
8
|
+
- Thread and checkpoint management
|
|
9
|
+
- Input normalization and validation
|
|
10
|
+
- Execution flow control with invoke/stream methods
|
|
11
|
+
- Graph integration utilities for LangGraph compatibility
|
|
12
|
+
- Runtime enforcement of the agent interface contract
|
|
13
|
+
|
|
14
|
+
Agents built on this base class benefit from consistent behavior, observability, and
|
|
15
|
+
integration capabilities while only needing to implement the core _invoke method.
|
|
16
|
+
"""
|
|
17
|
+
|
|
1
18
|
import re
|
|
2
19
|
from abc import ABC, abstractmethod
|
|
3
20
|
from contextvars import ContextVar
|
|
@@ -15,6 +32,7 @@ from uuid import uuid4
|
|
|
15
32
|
|
|
16
33
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
17
34
|
from langchain_core.load import dumps
|
|
35
|
+
from langchain_core.messages import HumanMessage
|
|
18
36
|
from langchain_core.runnables import (
|
|
19
37
|
RunnableLambda,
|
|
20
38
|
)
|
|
@@ -31,6 +49,19 @@ _INVOKE_DEPTH = ContextVar("_INVOKE_DEPTH", default=0)
|
|
|
31
49
|
|
|
32
50
|
|
|
33
51
|
def _to_snake(s: str) -> str:
|
|
52
|
+
"""Convert a string to snake_case format.
|
|
53
|
+
|
|
54
|
+
This function transforms various string formats (CamelCase, PascalCase, etc.) into
|
|
55
|
+
snake_case. It handles special cases like acronyms at the beginning of strings
|
|
56
|
+
(e.g., "RAGAgent" becomes "rag_agent") and replaces hyphens and spaces with
|
|
57
|
+
underscores.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
s: The input string to convert to snake_case.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
The snake_case version of the input string.
|
|
64
|
+
"""
|
|
34
65
|
s = re.sub(
|
|
35
66
|
r"^([A-Z]{2,})([A-Z][a-z])",
|
|
36
67
|
lambda m: m.group(1)[0] + m.group(1)[1:].lower() + m.group(2),
|
|
@@ -42,8 +73,53 @@ def _to_snake(s: str) -> str:
|
|
|
42
73
|
|
|
43
74
|
|
|
44
75
|
class BaseAgent(ABC):
|
|
45
|
-
|
|
46
|
-
|
|
76
|
+
"""Abstract base class for all agent implementations in the Ursa framework.
|
|
77
|
+
|
|
78
|
+
BaseAgent provides a standardized foundation for building LLM-powered agents with
|
|
79
|
+
built-in telemetry, configuration management, and execution flow control. It handles
|
|
80
|
+
common tasks like input normalization, thread management, metrics collection, and
|
|
81
|
+
LangGraph integration.
|
|
82
|
+
|
|
83
|
+
Subclasses only need to implement the _invoke method to define their core
|
|
84
|
+
functionality, while inheriting standardized invocation patterns, telemetry, and
|
|
85
|
+
graph integration capabilities. The class enforces a consistent interface through
|
|
86
|
+
runtime checks that prevent subclasses from overriding critical methods like
|
|
87
|
+
invoke().
|
|
88
|
+
|
|
89
|
+
The agent supports both direct invocation with inputs and streaming responses, with
|
|
90
|
+
automatic tracking of token usage, execution time, and other metrics. It also
|
|
91
|
+
provides utilities for integrating with LangGraph through node wrapping and
|
|
92
|
+
configuration.
|
|
93
|
+
|
|
94
|
+
Subclass Inheritance Guidelines:
|
|
95
|
+
- Must Override: _invoke() - Define your agent's core functionality
|
|
96
|
+
- Can Override: _stream() - Enable streaming support
|
|
97
|
+
_normalize_inputs() - Customize input handling
|
|
98
|
+
Various helper methods (_default_node_tags, _as_runnable, etc.)
|
|
99
|
+
- Never Override: invoke() - Final method with runtime enforcement
|
|
100
|
+
stream() - Handles telemetry and delegates to _stream
|
|
101
|
+
__call__() - Delegates to invoke
|
|
102
|
+
Other public methods (build_config, write_state, add_node)
|
|
103
|
+
|
|
104
|
+
To create a custom agent, inherit from this class and implement the _invoke method:
|
|
105
|
+
|
|
106
|
+
```python
|
|
107
|
+
class MyAgent(BaseAgent):
|
|
108
|
+
def _invoke(self, inputs: Mapping[str, Any], **config: Any) -> Any:
|
|
109
|
+
# Process inputs and return results
|
|
110
|
+
...
|
|
111
|
+
```
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
_TELEMETRY_KW = {
|
|
115
|
+
"raw_debug",
|
|
116
|
+
"save_json",
|
|
117
|
+
"metrics_path",
|
|
118
|
+
"save_raw_snapshot",
|
|
119
|
+
"save_raw_records",
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
_CONTROL_KW = {"config", "recursion_limit", "tags", "metadata", "callbacks"}
|
|
47
123
|
|
|
48
124
|
def __init__(
|
|
49
125
|
self,
|
|
@@ -55,6 +131,19 @@ class BaseAgent(ABC):
|
|
|
55
131
|
thread_id: Optional[str] = None,
|
|
56
132
|
**kwargs,
|
|
57
133
|
):
|
|
134
|
+
"""Initializes the base agent with a language model and optional configurations.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
llm: Either a string in the format "provider/model" or a BaseChatModel
|
|
138
|
+
instance.
|
|
139
|
+
checkpointer: Optional checkpoint saver for persisting agent state.
|
|
140
|
+
enable_metrics: Whether to collect performance and usage metrics.
|
|
141
|
+
metrics_dir: Directory path where metrics will be saved.
|
|
142
|
+
autosave_metrics: Whether to automatically save metrics to disk.
|
|
143
|
+
thread_id: Unique identifier for this agent instance. Generated if not
|
|
144
|
+
provided.
|
|
145
|
+
**kwargs: Additional keyword arguments passed to the LLM initialization.
|
|
146
|
+
"""
|
|
58
147
|
match llm:
|
|
59
148
|
case BaseChatModel():
|
|
60
149
|
self.llm = llm
|
|
@@ -70,7 +159,8 @@ class BaseAgent(ABC):
|
|
|
70
159
|
|
|
71
160
|
case _:
|
|
72
161
|
raise TypeError(
|
|
73
|
-
"llm argument must be a string with the provider and model, or a
|
|
162
|
+
"llm argument must be a string with the provider and model, or a "
|
|
163
|
+
"BaseChatModel instance."
|
|
74
164
|
)
|
|
75
165
|
|
|
76
166
|
self.thread_id = thread_id or uuid4().hex
|
|
@@ -93,93 +183,105 @@ class BaseAgent(ABC):
|
|
|
93
183
|
node_name: Optional[str] = None,
|
|
94
184
|
agent_name: Optional[str] = None,
|
|
95
185
|
) -> StateGraph:
|
|
96
|
-
"""Add node to graph.
|
|
97
|
-
|
|
98
|
-
This
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
186
|
+
"""Add a node to the state graph with token usage tracking.
|
|
187
|
+
|
|
188
|
+
This method adds a function as a node to the state graph, wrapping it to track
|
|
189
|
+
token usage during execution. The node is identified by either the provided
|
|
190
|
+
node_name or the function's name.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
graph: The StateGraph to add the node to.
|
|
194
|
+
f: The function to add as a node. Should return a mapping of string keys to
|
|
195
|
+
any values.
|
|
196
|
+
node_name: Optional name for the node. If not provided, the function's name
|
|
197
|
+
will be used.
|
|
198
|
+
agent_name: Optional agent name for tracking. If not provided, the agent's
|
|
199
|
+
name in snake_case will be used.
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
The updated StateGraph with the new node added.
|
|
106
203
|
"""
|
|
107
204
|
_node_name = node_name or f.__name__
|
|
108
205
|
_agent_name = agent_name or _to_snake(self.name)
|
|
109
206
|
wrapped_node = self._wrap_node(f, _node_name, _agent_name)
|
|
207
|
+
|
|
110
208
|
return graph.add_node(_node_name, wrapped_node)
|
|
111
209
|
|
|
112
|
-
def write_state(self, filename, state):
|
|
210
|
+
def write_state(self, filename: str, state: dict) -> None:
|
|
211
|
+
"""Writes agent state to a JSON file.
|
|
212
|
+
|
|
213
|
+
Serializes the provided state dictionary to JSON format and writes it to the
|
|
214
|
+
specified file. The JSON is written with non-ASCII characters preserved.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
filename: Path to the file where state will be written.
|
|
218
|
+
state: Dictionary containing the agent state to be serialized.
|
|
219
|
+
"""
|
|
113
220
|
json_state = dumps(state, ensure_ascii=False)
|
|
114
221
|
with open(filename, "w") as f:
|
|
115
222
|
f.write(json_state)
|
|
116
223
|
|
|
117
|
-
# BaseAgent
|
|
118
224
|
def build_config(self, **overrides) -> dict:
|
|
225
|
+
"""Constructs a config dictionary for agent operations with telemetry support.
|
|
226
|
+
|
|
227
|
+
This method creates a standardized configuration dictionary that includes thread
|
|
228
|
+
identification, telemetry callbacks, and other metadata needed for agent
|
|
229
|
+
operations. The configuration can be customized through override parameters.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
**overrides: Optional configuration overrides that can include keys like
|
|
233
|
+
'recursion_limit', 'configurable', 'metadata', 'tags', etc.
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
dict: A complete configuration dictionary with all necessary parameters.
|
|
119
237
|
"""
|
|
120
|
-
|
|
121
|
-
You can pass overrides like recursion_limit=..., configurable={...}, etc.
|
|
122
|
-
"""
|
|
238
|
+
# Create the base configuration with essential fields.
|
|
123
239
|
base = {
|
|
124
240
|
"configurable": {"thread_id": self.thread_id},
|
|
125
241
|
"metadata": {
|
|
126
242
|
"thread_id": self.thread_id,
|
|
127
243
|
"telemetry_run_id": self.telemetry.context.get("run_id"),
|
|
128
244
|
},
|
|
129
|
-
# "configurable": {
|
|
130
|
-
# "thread_id": getattr(self, "thread_id", "default")
|
|
131
|
-
# },
|
|
132
|
-
# "metadata": {
|
|
133
|
-
# "thread_id": getattr(self, "thread_id", "default"),
|
|
134
|
-
# "telemetry_run_id": self.telemetry.context.get("run_id"),
|
|
135
|
-
# },
|
|
136
245
|
"tags": [self.name],
|
|
137
246
|
"callbacks": self.telemetry.callbacks,
|
|
138
247
|
}
|
|
139
|
-
|
|
248
|
+
|
|
249
|
+
# Try to determine the model name from either direct or nested attributes
|
|
140
250
|
model_name = getattr(self, "llm_model", None) or getattr(
|
|
141
251
|
getattr(self, "llm", None), "model", None
|
|
142
252
|
)
|
|
253
|
+
|
|
254
|
+
# Add model name to metadata if available
|
|
143
255
|
if model_name:
|
|
144
256
|
base["metadata"]["model"] = model_name
|
|
145
257
|
|
|
258
|
+
# Handle configurable dictionary overrides by merging with base configurable
|
|
146
259
|
if "configurable" in overrides and isinstance(
|
|
147
260
|
overrides["configurable"], dict
|
|
148
261
|
):
|
|
149
262
|
base["configurable"].update(overrides.pop("configurable"))
|
|
263
|
+
|
|
264
|
+
# Handle metadata dictionary overrides by merging with base metadata
|
|
150
265
|
if "metadata" in overrides and isinstance(overrides["metadata"], dict):
|
|
151
266
|
base["metadata"].update(overrides.pop("metadata"))
|
|
152
|
-
|
|
267
|
+
|
|
268
|
+
# Merge tags from caller-provided overrides, avoid duplicates
|
|
153
269
|
if "tags" in overrides and isinstance(overrides["tags"], list):
|
|
154
270
|
base["tags"] = base["tags"] + [
|
|
155
271
|
t for t in overrides.pop("tags") if t not in base["tags"]
|
|
156
272
|
]
|
|
273
|
+
|
|
274
|
+
# Apply any remaining overrides directly to the base configuration
|
|
157
275
|
base.update(overrides)
|
|
158
|
-
return base
|
|
159
276
|
|
|
160
|
-
|
|
161
|
-
# planning_output = planner.invoke(
|
|
162
|
-
# {"messages": [HumanMessage(content=problem)]},
|
|
163
|
-
# config={
|
|
164
|
-
# "recursion_limit": 999_999,
|
|
165
|
-
# "configurable": {"thread_id": planner.thread_id},
|
|
166
|
-
# },
|
|
167
|
-
# )
|
|
168
|
-
# they can also, separately, override these defaults about metrics
|
|
169
|
-
# keys that are NOT inputs; they should not be folded into the inputs mapping
|
|
170
|
-
_TELEMETRY_KW = {
|
|
171
|
-
"raw_debug",
|
|
172
|
-
"save_json",
|
|
173
|
-
"metrics_path",
|
|
174
|
-
"save_raw_snapshot",
|
|
175
|
-
"save_raw_records",
|
|
176
|
-
}
|
|
177
|
-
_CONTROL_KW = {"config", "recursion_limit", "tags", "metadata", "callbacks"}
|
|
277
|
+
return base
|
|
178
278
|
|
|
279
|
+
# NOTE: The `invoke` method uses the PEP 570 `/,*` notation to explicitly state which
|
|
280
|
+
# arguments can and cannot be passed as positional or keyword arguments.
|
|
179
281
|
@final
|
|
180
282
|
def invoke(
|
|
181
283
|
self,
|
|
182
|
-
inputs: Optional[InputLike] = None,
|
|
284
|
+
inputs: Optional[InputLike] = None,
|
|
183
285
|
/,
|
|
184
286
|
*,
|
|
185
287
|
raw_debug: bool = False,
|
|
@@ -188,18 +290,47 @@ class BaseAgent(ABC):
|
|
|
188
290
|
save_raw_snapshot: Optional[bool] = None,
|
|
189
291
|
save_raw_records: Optional[bool] = None,
|
|
190
292
|
config: Optional[dict] = None,
|
|
191
|
-
**kwargs: Any,
|
|
293
|
+
**kwargs: Any,
|
|
192
294
|
) -> Any:
|
|
295
|
+
"""Executes the agent with the provided inputs and configuration.
|
|
296
|
+
|
|
297
|
+
This is the main entry point for agent execution. It handles input normalization,
|
|
298
|
+
telemetry tracking, and proper execution context management. The method supports
|
|
299
|
+
flexible input formats - either as a positional argument or as keyword arguments.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
inputs: Optional positional input to the agent. If provided, all non-control
|
|
303
|
+
keyword arguments will be rejected to avoid ambiguity.
|
|
304
|
+
raw_debug: If True, displays raw telemetry data for debugging purposes.
|
|
305
|
+
save_json: If True, saves telemetry data as JSON.
|
|
306
|
+
metrics_path: Optional file path where telemetry metrics should be saved.
|
|
307
|
+
save_raw_snapshot: If True, saves a raw snapshot of the telemetry data.
|
|
308
|
+
save_raw_records: If True, saves raw telemetry records.
|
|
309
|
+
config: Optional configuration dictionary to override default settings.
|
|
310
|
+
**kwargs: Additional keyword arguments that can be either:
|
|
311
|
+
- Input parameters (when no positional input is provided)
|
|
312
|
+
- Control parameters recognized by the agent
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
The result of the agent's execution.
|
|
316
|
+
|
|
317
|
+
Raises:
|
|
318
|
+
TypeError: If both positional inputs and non-control keyword arguments are
|
|
319
|
+
provided simultaneously.
|
|
320
|
+
"""
|
|
321
|
+
# Track invocation depth to manage nested agent calls
|
|
193
322
|
depth = _INVOKE_DEPTH.get()
|
|
194
323
|
_INVOKE_DEPTH.set(depth + 1)
|
|
195
324
|
try:
|
|
325
|
+
# Start telemetry tracking for the top-level invocation
|
|
196
326
|
if depth == 0:
|
|
197
327
|
self.telemetry.begin_run(
|
|
198
328
|
agent=self.name, thread_id=self.thread_id
|
|
199
329
|
)
|
|
200
330
|
|
|
201
|
-
#
|
|
331
|
+
# Handle the case where inputs are provided as keyword arguments
|
|
202
332
|
if inputs is None:
|
|
333
|
+
# Separate kwargs into input parameters and control parameters
|
|
203
334
|
kw_inputs: dict[str, Any] = {}
|
|
204
335
|
control_kwargs: dict[str, Any] = {}
|
|
205
336
|
for k, v in kwargs.items():
|
|
@@ -208,11 +339,13 @@ class BaseAgent(ABC):
|
|
|
208
339
|
else:
|
|
209
340
|
kw_inputs[k] = v
|
|
210
341
|
inputs = kw_inputs
|
|
211
|
-
kwargs = control_kwargs # only control kwargs remain
|
|
212
342
|
|
|
213
|
-
|
|
343
|
+
# Only control kwargs remain for further processing
|
|
344
|
+
kwargs = control_kwargs
|
|
345
|
+
|
|
346
|
+
# Handle the case where inputs are provided as a positional argument
|
|
214
347
|
else:
|
|
215
|
-
#
|
|
348
|
+
# Ensure no ambiguous keyword arguments are present
|
|
216
349
|
for k in kwargs.keys():
|
|
217
350
|
if not (k in self._TELEMETRY_KW or k in self._CONTROL_KW):
|
|
218
351
|
raise TypeError(
|
|
@@ -221,15 +354,19 @@ class BaseAgent(ABC):
|
|
|
221
354
|
"inputs and pass them as keyword arguments."
|
|
222
355
|
)
|
|
223
356
|
|
|
224
|
-
# subclasses
|
|
357
|
+
# Allow subclasses to normalize or transform the input format
|
|
225
358
|
normalized = self._normalize_inputs(inputs)
|
|
226
359
|
|
|
227
|
-
#
|
|
360
|
+
# Delegate to the subclass implementation with the normalized inputs
|
|
361
|
+
# and any control parameters
|
|
228
362
|
return self._invoke(normalized, config=config, **kwargs)
|
|
229
363
|
|
|
230
364
|
finally:
|
|
365
|
+
# Clean up the invocation depth tracking
|
|
231
366
|
new_depth = _INVOKE_DEPTH.get() - 1
|
|
232
367
|
_INVOKE_DEPTH.set(new_depth)
|
|
368
|
+
|
|
369
|
+
# For the top-level invocation, finalize telemetry and generate outputs
|
|
233
370
|
if new_depth == 0:
|
|
234
371
|
self.telemetry.render(
|
|
235
372
|
raw=raw_debug,
|
|
@@ -240,10 +377,25 @@ class BaseAgent(ABC):
|
|
|
240
377
|
)
|
|
241
378
|
|
|
242
379
|
def _normalize_inputs(self, inputs: InputLike) -> Mapping[str, Any]:
|
|
380
|
+
"""Normalizes various input formats into a standardized mapping.
|
|
381
|
+
|
|
382
|
+
This method converts different input types into a consistent dictionary format
|
|
383
|
+
that can be processed by the agent. String inputs are wrapped as messages, while
|
|
384
|
+
mappings are passed through unchanged.
|
|
385
|
+
|
|
386
|
+
Args:
|
|
387
|
+
inputs: The input to normalize. Can be a string (which will be converted to a
|
|
388
|
+
message) or a mapping (which will be returned as-is).
|
|
389
|
+
|
|
390
|
+
Returns:
|
|
391
|
+
A mapping containing the normalized inputs, with keys appropriate for agent
|
|
392
|
+
processing.
|
|
393
|
+
|
|
394
|
+
Raises:
|
|
395
|
+
TypeError: If the input type is not supported (neither string nor mapping).
|
|
396
|
+
"""
|
|
243
397
|
if isinstance(inputs, str):
|
|
244
398
|
# Adjust to your message type
|
|
245
|
-
from langchain_core.messages import HumanMessage
|
|
246
|
-
|
|
247
399
|
return {"messages": [HumanMessage(content=inputs)]}
|
|
248
400
|
if isinstance(inputs, Mapping):
|
|
249
401
|
return inputs
|
|
@@ -255,15 +407,19 @@ class BaseAgent(ABC):
|
|
|
255
407
|
...
|
|
256
408
|
|
|
257
409
|
def __call__(self, inputs: InputLike, /, **kwargs: Any) -> Any:
|
|
410
|
+
"""Specify calling behavior for class instance."""
|
|
258
411
|
return self.invoke(inputs, **kwargs)
|
|
259
412
|
|
|
260
413
|
# Runtime enforcement: forbid subclasses from overriding invoke
|
|
261
414
|
def __init_subclass__(cls, **kwargs):
|
|
415
|
+
"""Ensure subclass does not override key method."""
|
|
262
416
|
super().__init_subclass__(**kwargs)
|
|
263
417
|
if "invoke" in cls.__dict__:
|
|
264
|
-
|
|
265
|
-
f"{cls.__name__} must not override BaseAgent.invoke();
|
|
418
|
+
err_msg = (
|
|
419
|
+
f"{cls.__name__} must not override BaseAgent.invoke(); "
|
|
420
|
+
"implement _invoke() only."
|
|
266
421
|
)
|
|
422
|
+
raise TypeError(err_msg)
|
|
267
423
|
|
|
268
424
|
def stream(
|
|
269
425
|
self,
|
|
@@ -278,19 +434,52 @@ class BaseAgent(ABC):
|
|
|
278
434
|
save_raw_records: bool | None = None,
|
|
279
435
|
**kwargs: Any,
|
|
280
436
|
) -> Iterator[Any]:
|
|
281
|
-
"""
|
|
437
|
+
"""Streams agent responses with telemetry tracking.
|
|
438
|
+
|
|
439
|
+
This method serves as the public streaming entry point for agent interactions.
|
|
440
|
+
It wraps the actual streaming implementation with telemetry tracking to capture
|
|
441
|
+
metrics and debugging information.
|
|
442
|
+
|
|
443
|
+
Args:
|
|
444
|
+
inputs: The input to process, which will be normalized internally.
|
|
445
|
+
config: Optional configuration for the agent, compatible with LangGraph
|
|
446
|
+
positional/keyword argument style.
|
|
447
|
+
raw_debug: If True, renders raw debug information in telemetry output.
|
|
448
|
+
save_json: If True, saves telemetry data as JSON.
|
|
449
|
+
metrics_path: Optional file path where metrics should be saved.
|
|
450
|
+
save_raw_snapshot: If True, saves raw snapshot data in telemetry.
|
|
451
|
+
save_raw_records: If True, saves raw record data in telemetry.
|
|
452
|
+
**kwargs: Additional keyword arguments passed to the streaming
|
|
453
|
+
implementation.
|
|
454
|
+
|
|
455
|
+
Returns:
|
|
456
|
+
An iterator yielding the agent's responses.
|
|
457
|
+
|
|
458
|
+
Note:
|
|
459
|
+
This method tracks invocation depth to properly handle nested agent calls
|
|
460
|
+
and ensure telemetry is only rendered once at the top level.
|
|
461
|
+
"""
|
|
462
|
+
# Track invocation depth to handle nested agent calls
|
|
282
463
|
depth = _INVOKE_DEPTH.get()
|
|
283
464
|
_INVOKE_DEPTH.set(depth + 1)
|
|
465
|
+
|
|
284
466
|
try:
|
|
467
|
+
# Start telemetry tracking for top-level invocations only
|
|
285
468
|
if depth == 0:
|
|
286
469
|
self.telemetry.begin_run(
|
|
287
470
|
agent=self.name, thread_id=self.thread_id
|
|
288
471
|
)
|
|
472
|
+
|
|
473
|
+
# Normalize inputs and delegate to the actual streaming implementation
|
|
289
474
|
normalized = self._normalize_inputs(inputs)
|
|
290
475
|
yield from self._stream(normalized, config=config, **kwargs)
|
|
476
|
+
|
|
291
477
|
finally:
|
|
478
|
+
# Decrement invocation depth when exiting
|
|
292
479
|
new_depth = _INVOKE_DEPTH.get() - 1
|
|
293
480
|
_INVOKE_DEPTH.set(new_depth)
|
|
481
|
+
|
|
482
|
+
# Render telemetry data only for top-level invocations
|
|
294
483
|
if new_depth == 0:
|
|
295
484
|
self.telemetry.render(
|
|
296
485
|
raw=raw_debug,
|
|
@@ -307,48 +496,47 @@ class BaseAgent(ABC):
|
|
|
307
496
|
config: Any | None = None,
|
|
308
497
|
**kwargs: Any,
|
|
309
498
|
) -> Iterator[Any]:
|
|
499
|
+
"""Subclass method to be overwritten for streaming implementation."""
|
|
310
500
|
raise NotImplementedError(
|
|
311
501
|
f"{self.name} does not support streaming. "
|
|
312
502
|
"Override _stream(...) in your agent to enable it."
|
|
313
503
|
)
|
|
314
504
|
|
|
315
|
-
# def run(
|
|
316
|
-
# self,
|
|
317
|
-
# *args,
|
|
318
|
-
# raw_debug: bool = False,
|
|
319
|
-
# save_json: bool | None = None,
|
|
320
|
-
# metrics_path: str | None = None,
|
|
321
|
-
# save_raw_snapshot: bool | None = None,
|
|
322
|
-
# save_raw_records: bool | None = None,
|
|
323
|
-
# **kwargs
|
|
324
|
-
# ):
|
|
325
|
-
# try:
|
|
326
|
-
# self.telemetry.begin_run(agent=self.name, thread_id=self.thread_id)
|
|
327
|
-
# result = self._run_impl(*args, **kwargs)
|
|
328
|
-
# return result
|
|
329
|
-
# finally:
|
|
330
|
-
# print(self.telemetry.render(
|
|
331
|
-
# raw=raw_debug,
|
|
332
|
-
# save_json=save_json,
|
|
333
|
-
# filepath=metrics_path,
|
|
334
|
-
# save_raw_snapshot=save_raw_snapshot,
|
|
335
|
-
# save_raw_records=save_raw_records,
|
|
336
|
-
# ))
|
|
337
|
-
|
|
338
|
-
# @abstractmethod
|
|
339
|
-
# def _run_impl(self, *args, **kwargs):
|
|
340
|
-
# raise NotImplementedError("Agents must implement _run_impl")
|
|
341
|
-
|
|
342
505
|
def _default_node_tags(
|
|
343
506
|
self, name: str, extra: Sequence[str] | None = None
|
|
344
507
|
) -> list[str]:
|
|
508
|
+
"""Generate default tags for a graph node.
|
|
509
|
+
|
|
510
|
+
Args:
|
|
511
|
+
name: The name of the node.
|
|
512
|
+
extra: Optional sequence of additional tags to include.
|
|
513
|
+
|
|
514
|
+
Returns:
|
|
515
|
+
list[str]: A list of tags for the node, including the agent name, 'graph',
|
|
516
|
+
the node name, and any extra tags provided.
|
|
517
|
+
"""
|
|
518
|
+
# Start with standard tags: agent name, graph indicator, and node name
|
|
345
519
|
tags = [self.name, "graph", name]
|
|
520
|
+
|
|
521
|
+
# Add any extra tags if provided
|
|
346
522
|
if extra:
|
|
347
523
|
tags.extend(extra)
|
|
524
|
+
|
|
348
525
|
return tags
|
|
349
526
|
|
|
350
527
|
def _as_runnable(self, fn: Any):
|
|
351
|
-
|
|
528
|
+
"""Convert a function to a runnable if it isn't already.
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
fn: The function or object to convert to a runnable.
|
|
532
|
+
|
|
533
|
+
Returns:
|
|
534
|
+
A runnable object that can be used in the graph. If the input is already
|
|
535
|
+
runnable (has .with_config and .invoke methods), it's returned as is.
|
|
536
|
+
Otherwise, it's wrapped in a RunnableLambda.
|
|
537
|
+
"""
|
|
538
|
+
# Check if the function already has the required runnable interface
|
|
539
|
+
# If so, return it as is; otherwise wrap it in a RunnableLambda
|
|
352
540
|
return (
|
|
353
541
|
fn
|
|
354
542
|
if hasattr(fn, "with_config") and hasattr(fn, "invoke")
|
|
@@ -356,11 +544,28 @@ class BaseAgent(ABC):
|
|
|
356
544
|
)
|
|
357
545
|
|
|
358
546
|
def _node_cfg(self, name: str, *extra_tags: str) -> dict:
|
|
359
|
-
"""Build a consistent
|
|
547
|
+
"""Build a consistent configuration for a node/runnable.
|
|
548
|
+
|
|
549
|
+
Creates a configuration dict that can be reapplied after operations like
|
|
550
|
+
.map(), subgraph compile, etc.
|
|
551
|
+
|
|
552
|
+
Args:
|
|
553
|
+
name: The name of the node.
|
|
554
|
+
*extra_tags: Additional tags to include in the node configuration.
|
|
555
|
+
|
|
556
|
+
Returns:
|
|
557
|
+
dict: A configuration dictionary with run_name, tags, and metadata.
|
|
558
|
+
"""
|
|
559
|
+
# Determine the namespace - use first extra tag if available, otherwise
|
|
560
|
+
# convert agent name to snake_case
|
|
360
561
|
ns = extra_tags[0] if extra_tags else _to_snake(self.name)
|
|
562
|
+
|
|
563
|
+
# Combine all tags: agent name, graph indicator, node name, and any extra tags
|
|
361
564
|
tags = [self.name, "graph", name, *extra_tags]
|
|
565
|
+
|
|
566
|
+
# Return the complete configuration dictionary
|
|
362
567
|
return dict(
|
|
363
|
-
run_name="node", # keep "node:" prefixing in the timer
|
|
568
|
+
run_name="node", # keep "node:" prefixing in the timer
|
|
364
569
|
tags=tags,
|
|
365
570
|
metadata={
|
|
366
571
|
"langgraph_node": name,
|
|
@@ -370,16 +575,57 @@ class BaseAgent(ABC):
|
|
|
370
575
|
)
|
|
371
576
|
|
|
372
577
|
def ns(self, runnable_or_fn, name: str, *extra_tags: str):
|
|
373
|
-
"""Return a runnable with
|
|
374
|
-
|
|
578
|
+
"""Return a runnable with node configuration applied.
|
|
579
|
+
|
|
580
|
+
Applies the agent's node configuration to a runnable or callable. This method
|
|
581
|
+
should be called again after operations like .map() or subgraph .compile() as
|
|
582
|
+
these operations may drop configuration.
|
|
583
|
+
|
|
584
|
+
Args:
|
|
585
|
+
runnable_or_fn: A runnable or callable to configure.
|
|
586
|
+
name: The name to assign to this node.
|
|
587
|
+
*extra_tags: Additional tags to apply to the node.
|
|
588
|
+
|
|
589
|
+
Returns:
|
|
590
|
+
A configured runnable with the agent's node configuration applied.
|
|
591
|
+
"""
|
|
592
|
+
# Convert input to a runnable if it's not already one
|
|
375
593
|
r = self._as_runnable(runnable_or_fn)
|
|
594
|
+
# Apply node configuration and return the configured runnable
|
|
376
595
|
return r.with_config(**self._node_cfg(name, *extra_tags))
|
|
377
596
|
|
|
378
597
|
def _wrap_node(self, fn_or_runnable, name: str, *extra_tags: str):
|
|
598
|
+
"""Wrap a function or runnable as a node in the graph.
|
|
599
|
+
|
|
600
|
+
This is a convenience wrapper around the ns() method.
|
|
601
|
+
|
|
602
|
+
Args:
|
|
603
|
+
fn_or_runnable: A function or runnable to wrap as a node.
|
|
604
|
+
name: The name to assign to this node.
|
|
605
|
+
*extra_tags: Additional tags to apply to the node.
|
|
606
|
+
|
|
607
|
+
Returns:
|
|
608
|
+
A configured runnable with the agent's node configuration applied.
|
|
609
|
+
"""
|
|
379
610
|
return self.ns(fn_or_runnable, name, *extra_tags)
|
|
380
611
|
|
|
381
612
|
def _wrap_cond(self, fn: Any, name: str, *extra_tags: str):
|
|
613
|
+
"""Wrap a conditional function as a routing node in the graph.
|
|
614
|
+
|
|
615
|
+
Creates a runnable lambda with routing-specific configuration.
|
|
616
|
+
|
|
617
|
+
Args:
|
|
618
|
+
fn: The conditional function to wrap.
|
|
619
|
+
name: The name of the routing node.
|
|
620
|
+
*extra_tags: Additional tags to apply to the node.
|
|
621
|
+
|
|
622
|
+
Returns:
|
|
623
|
+
A configured RunnableLambda with routing-specific metadata.
|
|
624
|
+
"""
|
|
625
|
+
# Use the first extra tag as namespace, or fall back to agent name in snake_case
|
|
382
626
|
ns = extra_tags[0] if extra_tags else _to_snake(self.name)
|
|
627
|
+
|
|
628
|
+
# Create and return a configured RunnableLambda for routing
|
|
383
629
|
return RunnableLambda(fn).with_config(
|
|
384
630
|
run_name="node",
|
|
385
631
|
tags=[
|
|
@@ -396,7 +642,22 @@ class BaseAgent(ABC):
|
|
|
396
642
|
)
|
|
397
643
|
|
|
398
644
|
def _named(self, runnable: Any, name: str, *extra_tags: str):
|
|
645
|
+
"""Apply a specific name and configuration to a runnable.
|
|
646
|
+
|
|
647
|
+
Configures a runnable with a specific name and the agent's metadata.
|
|
648
|
+
|
|
649
|
+
Args:
|
|
650
|
+
runnable: The runnable to configure.
|
|
651
|
+
name: The name to assign to this runnable.
|
|
652
|
+
*extra_tags: Additional tags to apply to the runnable.
|
|
653
|
+
|
|
654
|
+
Returns:
|
|
655
|
+
A configured runnable with the specified name and agent metadata.
|
|
656
|
+
"""
|
|
657
|
+
# Use the first extra tag as namespace, or fall back to agent name in snake_case
|
|
399
658
|
ns = extra_tags[0] if extra_tags else _to_snake(self.name)
|
|
659
|
+
|
|
660
|
+
# Apply configuration and return the configured runnable
|
|
400
661
|
return runnable.with_config(
|
|
401
662
|
run_name=name,
|
|
402
663
|
tags=[self.name, "graph", name, *extra_tags],
|