ursa-ai 0.7.0rc1__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ursa-ai might be problematic. Click here for more details.

@@ -119,7 +119,7 @@ def remove_surrogates(text: str) -> str:
119
119
  return re.sub(r"[\ud800-\udfff]", "", text)
120
120
 
121
121
 
122
- class ArxivAgent(BaseAgent):
122
+ class ArxivAgentLegacy(BaseAgent):
123
123
  def __init__(
124
124
  self,
125
125
  llm: str | BaseChatModel = "openai/o3-mini",
ursa/agents/base.py CHANGED
@@ -1,3 +1,20 @@
1
+ """Base agent class providing telemetry, configuration, and execution abstractions.
2
+
3
+ This module defines the BaseAgent abstract class, which serves as the foundation for all
4
+ agent implementations in the Ursa framework. It provides:
5
+
6
+ - Standardized initialization with LLM configuration
7
+ - Telemetry and metrics collection
8
+ - Thread and checkpoint management
9
+ - Input normalization and validation
10
+ - Execution flow control with invoke/stream methods
11
+ - Graph integration utilities for LangGraph compatibility
12
+ - Runtime enforcement of the agent interface contract
13
+
14
+ Agents built on this base class benefit from consistent behavior, observability, and
15
+ integration capabilities while only needing to implement the core _invoke method.
16
+ """
17
+
1
18
  import re
2
19
  from abc import ABC, abstractmethod
3
20
  from contextvars import ContextVar
@@ -15,6 +32,7 @@ from uuid import uuid4
15
32
 
16
33
  from langchain_core.language_models.chat_models import BaseChatModel
17
34
  from langchain_core.load import dumps
35
+ from langchain_core.messages import HumanMessage
18
36
  from langchain_core.runnables import (
19
37
  RunnableLambda,
20
38
  )
@@ -31,6 +49,19 @@ _INVOKE_DEPTH = ContextVar("_INVOKE_DEPTH", default=0)
31
49
 
32
50
 
33
51
  def _to_snake(s: str) -> str:
52
+ """Convert a string to snake_case format.
53
+
54
+ This function transforms various string formats (CamelCase, PascalCase, etc.) into
55
+ snake_case. It handles special cases like acronyms at the beginning of strings
56
+ (e.g., "RAGAgent" becomes "rag_agent") and replaces hyphens and spaces with
57
+ underscores.
58
+
59
+ Args:
60
+ s: The input string to convert to snake_case.
61
+
62
+ Returns:
63
+ The snake_case version of the input string.
64
+ """
34
65
  s = re.sub(
35
66
  r"^([A-Z]{2,})([A-Z][a-z])",
36
67
  lambda m: m.group(1)[0] + m.group(1)[1:].lower() + m.group(2),
@@ -42,8 +73,53 @@ def _to_snake(s: str) -> str:
42
73
 
43
74
 
44
75
  class BaseAgent(ABC):
45
- # llm: BaseChatModel
46
- # llm_with_tools: Runnable[LanguageModelInput, BaseMessage]
76
+ """Abstract base class for all agent implementations in the Ursa framework.
77
+
78
+ BaseAgent provides a standardized foundation for building LLM-powered agents with
79
+ built-in telemetry, configuration management, and execution flow control. It handles
80
+ common tasks like input normalization, thread management, metrics collection, and
81
+ LangGraph integration.
82
+
83
+ Subclasses only need to implement the _invoke method to define their core
84
+ functionality, while inheriting standardized invocation patterns, telemetry, and
85
+ graph integration capabilities. The class enforces a consistent interface through
86
+ runtime checks that prevent subclasses from overriding critical methods like
87
+ invoke().
88
+
89
+ The agent supports both direct invocation with inputs and streaming responses, with
90
+ automatic tracking of token usage, execution time, and other metrics. It also
91
+ provides utilities for integrating with LangGraph through node wrapping and
92
+ configuration.
93
+
94
+ Subclass Inheritance Guidelines:
95
+ - Must Override: _invoke() - Define your agent's core functionality
96
+ - Can Override: _stream() - Enable streaming support
97
+ _normalize_inputs() - Customize input handling
98
+ Various helper methods (_default_node_tags, _as_runnable, etc.)
99
+ - Never Override: invoke() - Final method with runtime enforcement
100
+ stream() - Handles telemetry and delegates to _stream
101
+ __call__() - Delegates to invoke
102
+ Other public methods (build_config, write_state, add_node)
103
+
104
+ To create a custom agent, inherit from this class and implement the _invoke method:
105
+
106
+ ```python
107
+ class MyAgent(BaseAgent):
108
+ def _invoke(self, inputs: Mapping[str, Any], **config: Any) -> Any:
109
+ # Process inputs and return results
110
+ ...
111
+ ```
112
+ """
113
+
114
+ _TELEMETRY_KW = {
115
+ "raw_debug",
116
+ "save_json",
117
+ "metrics_path",
118
+ "save_raw_snapshot",
119
+ "save_raw_records",
120
+ }
121
+
122
+ _CONTROL_KW = {"config", "recursion_limit", "tags", "metadata", "callbacks"}
47
123
 
48
124
  def __init__(
49
125
  self,
@@ -55,6 +131,19 @@ class BaseAgent(ABC):
55
131
  thread_id: Optional[str] = None,
56
132
  **kwargs,
57
133
  ):
134
+ """Initializes the base agent with a language model and optional configurations.
135
+
136
+ Args:
137
+ llm: Either a string in the format "provider/model" or a BaseChatModel
138
+ instance.
139
+ checkpointer: Optional checkpoint saver for persisting agent state.
140
+ enable_metrics: Whether to collect performance and usage metrics.
141
+ metrics_dir: Directory path where metrics will be saved.
142
+ autosave_metrics: Whether to automatically save metrics to disk.
143
+ thread_id: Unique identifier for this agent instance. Generated if not
144
+ provided.
145
+ **kwargs: Additional keyword arguments passed to the LLM initialization.
146
+ """
58
147
  match llm:
59
148
  case BaseChatModel():
60
149
  self.llm = llm
@@ -70,7 +159,8 @@ class BaseAgent(ABC):
70
159
 
71
160
  case _:
72
161
  raise TypeError(
73
- "llm argument must be a string with the provider and model, or a BaseChatModel instance."
162
+ "llm argument must be a string with the provider and model, or a "
163
+ "BaseChatModel instance."
74
164
  )
75
165
 
76
166
  self.thread_id = thread_id or uuid4().hex
@@ -93,93 +183,105 @@ class BaseAgent(ABC):
93
183
  node_name: Optional[str] = None,
94
184
  agent_name: Optional[str] = None,
95
185
  ) -> StateGraph:
96
- """Add node to graph.
97
-
98
- This is used to track token usage and is simply the following.
99
-
100
- ```python
101
- _node_name = node_name or f.__name__
102
- return graph.add_node(
103
- _node_name, self._wrap_node(f, _node_name, self.name)
104
- )
105
- ```
186
+ """Add a node to the state graph with token usage tracking.
187
+
188
+ This method adds a function as a node to the state graph, wrapping it to track
189
+ token usage during execution. The node is identified by either the provided
190
+ node_name or the function's name.
191
+
192
+ Args:
193
+ graph: The StateGraph to add the node to.
194
+ f: The function to add as a node. Should return a mapping of string keys to
195
+ any values.
196
+ node_name: Optional name for the node. If not provided, the function's name
197
+ will be used.
198
+ agent_name: Optional agent name for tracking. If not provided, the agent's
199
+ name in snake_case will be used.
200
+
201
+ Returns:
202
+ The updated StateGraph with the new node added.
106
203
  """
107
204
  _node_name = node_name or f.__name__
108
205
  _agent_name = agent_name or _to_snake(self.name)
109
206
  wrapped_node = self._wrap_node(f, _node_name, _agent_name)
207
+
110
208
  return graph.add_node(_node_name, wrapped_node)
111
209
 
112
- def write_state(self, filename, state):
210
+ def write_state(self, filename: str, state: dict) -> None:
211
+ """Writes agent state to a JSON file.
212
+
213
+ Serializes the provided state dictionary to JSON format and writes it to the
214
+ specified file. The JSON is written with non-ASCII characters preserved.
215
+
216
+ Args:
217
+ filename: Path to the file where state will be written.
218
+ state: Dictionary containing the agent state to be serialized.
219
+ """
113
220
  json_state = dumps(state, ensure_ascii=False)
114
221
  with open(filename, "w") as f:
115
222
  f.write(json_state)
116
223
 
117
- # BaseAgent
118
224
  def build_config(self, **overrides) -> dict:
225
+ """Constructs a config dictionary for agent operations with telemetry support.
226
+
227
+ This method creates a standardized configuration dictionary that includes thread
228
+ identification, telemetry callbacks, and other metadata needed for agent
229
+ operations. The configuration can be customized through override parameters.
230
+
231
+ Args:
232
+ **overrides: Optional configuration overrides that can include keys like
233
+ 'recursion_limit', 'configurable', 'metadata', 'tags', etc.
234
+
235
+ Returns:
236
+ dict: A complete configuration dictionary with all necessary parameters.
119
237
  """
120
- Build a config dict that includes telemetry callbacks and the thread_id.
121
- You can pass overrides like recursion_limit=..., configurable={...}, etc.
122
- """
238
+ # Create the base configuration with essential fields.
123
239
  base = {
124
240
  "configurable": {"thread_id": self.thread_id},
125
241
  "metadata": {
126
242
  "thread_id": self.thread_id,
127
243
  "telemetry_run_id": self.telemetry.context.get("run_id"),
128
244
  },
129
- # "configurable": {
130
- # "thread_id": getattr(self, "thread_id", "default")
131
- # },
132
- # "metadata": {
133
- # "thread_id": getattr(self, "thread_id", "default"),
134
- # "telemetry_run_id": self.telemetry.context.get("run_id"),
135
- # },
136
245
  "tags": [self.name],
137
246
  "callbacks": self.telemetry.callbacks,
138
247
  }
139
- # include model name when we can
248
+
249
+ # Try to determine the model name from either direct or nested attributes
140
250
  model_name = getattr(self, "llm_model", None) or getattr(
141
251
  getattr(self, "llm", None), "model", None
142
252
  )
253
+
254
+ # Add model name to metadata if available
143
255
  if model_name:
144
256
  base["metadata"]["model"] = model_name
145
257
 
258
+ # Handle configurable dictionary overrides by merging with base configurable
146
259
  if "configurable" in overrides and isinstance(
147
260
  overrides["configurable"], dict
148
261
  ):
149
262
  base["configurable"].update(overrides.pop("configurable"))
263
+
264
+ # Handle metadata dictionary overrides by merging with base metadata
150
265
  if "metadata" in overrides and isinstance(overrides["metadata"], dict):
151
266
  base["metadata"].update(overrides.pop("metadata"))
152
- # merge tags if caller provides them
267
+
268
+ # Merge tags from caller-provided overrides, avoid duplicates
153
269
  if "tags" in overrides and isinstance(overrides["tags"], list):
154
270
  base["tags"] = base["tags"] + [
155
271
  t for t in overrides.pop("tags") if t not in base["tags"]
156
272
  ]
273
+
274
+ # Apply any remaining overrides directly to the base configuration
157
275
  base.update(overrides)
158
- return base
159
276
 
160
- # agents will invoke like this:
161
- # planning_output = planner.invoke(
162
- # {"messages": [HumanMessage(content=problem)]},
163
- # config={
164
- # "recursion_limit": 999_999,
165
- # "configurable": {"thread_id": planner.thread_id},
166
- # },
167
- # )
168
- # they can also, separately, override these defaults about metrics
169
- # keys that are NOT inputs; they should not be folded into the inputs mapping
170
- _TELEMETRY_KW = {
171
- "raw_debug",
172
- "save_json",
173
- "metrics_path",
174
- "save_raw_snapshot",
175
- "save_raw_records",
176
- }
177
- _CONTROL_KW = {"config", "recursion_limit", "tags", "metadata", "callbacks"}
277
+ return base
178
278
 
279
+ # NOTE: The `invoke` method uses the PEP 570 `/,*` notation to explicitly state which
280
+ # arguments can and cannot be passed as positional or keyword arguments.
179
281
  @final
180
282
  def invoke(
181
283
  self,
182
- inputs: Optional[InputLike] = None, # sentinel
284
+ inputs: Optional[InputLike] = None,
183
285
  /,
184
286
  *,
185
287
  raw_debug: bool = False,
@@ -188,18 +290,47 @@ class BaseAgent(ABC):
188
290
  save_raw_snapshot: Optional[bool] = None,
189
291
  save_raw_records: Optional[bool] = None,
190
292
  config: Optional[dict] = None,
191
- **kwargs: Any, # may contain inputs (keyword-inputs) and/or control kw
293
+ **kwargs: Any,
192
294
  ) -> Any:
295
+ """Executes the agent with the provided inputs and configuration.
296
+
297
+ This is the main entry point for agent execution. It handles input normalization,
298
+ telemetry tracking, and proper execution context management. The method supports
299
+ flexible input formats - either as a positional argument or as keyword arguments.
300
+
301
+ Args:
302
+ inputs: Optional positional input to the agent. If provided, all non-control
303
+ keyword arguments will be rejected to avoid ambiguity.
304
+ raw_debug: If True, displays raw telemetry data for debugging purposes.
305
+ save_json: If True, saves telemetry data as JSON.
306
+ metrics_path: Optional file path where telemetry metrics should be saved.
307
+ save_raw_snapshot: If True, saves a raw snapshot of the telemetry data.
308
+ save_raw_records: If True, saves raw telemetry records.
309
+ config: Optional configuration dictionary to override default settings.
310
+ **kwargs: Additional keyword arguments that can be either:
311
+ - Input parameters (when no positional input is provided)
312
+ - Control parameters recognized by the agent
313
+
314
+ Returns:
315
+ The result of the agent's execution.
316
+
317
+ Raises:
318
+ TypeError: If both positional inputs and non-control keyword arguments are
319
+ provided simultaneously.
320
+ """
321
+ # Track invocation depth to manage nested agent calls
193
322
  depth = _INVOKE_DEPTH.get()
194
323
  _INVOKE_DEPTH.set(depth + 1)
195
324
  try:
325
+ # Start telemetry tracking for the top-level invocation
196
326
  if depth == 0:
197
327
  self.telemetry.begin_run(
198
328
  agent=self.name, thread_id=self.thread_id
199
329
  )
200
330
 
201
- # If no positional inputs were provided, split kwargs into inputs vs control
331
+ # Handle the case where inputs are provided as keyword arguments
202
332
  if inputs is None:
333
+ # Separate kwargs into input parameters and control parameters
203
334
  kw_inputs: dict[str, Any] = {}
204
335
  control_kwargs: dict[str, Any] = {}
205
336
  for k, v in kwargs.items():
@@ -208,11 +339,13 @@ class BaseAgent(ABC):
208
339
  else:
209
340
  kw_inputs[k] = v
210
341
  inputs = kw_inputs
211
- kwargs = control_kwargs # only control kwargs remain
212
342
 
213
- # If both positional inputs and extra unknown kwargs-as-inputs are given, forbid merging
343
+ # Only control kwargs remain for further processing
344
+ kwargs = control_kwargs
345
+
346
+ # Handle the case where inputs are provided as a positional argument
214
347
  else:
215
- # keep only control kwargs; anything else would be ambiguous
348
+ # Ensure no ambiguous keyword arguments are present
216
349
  for k in kwargs.keys():
217
350
  if not (k in self._TELEMETRY_KW or k in self._CONTROL_KW):
218
351
  raise TypeError(
@@ -221,15 +354,19 @@ class BaseAgent(ABC):
221
354
  "inputs and pass them as keyword arguments."
222
355
  )
223
356
 
224
- # subclasses may translate keys
357
+ # Allow subclasses to normalize or transform the input format
225
358
  normalized = self._normalize_inputs(inputs)
226
359
 
227
- # forward config + any control kwargs (e.g., recursion_limit) to the agent
360
+ # Delegate to the subclass implementation with the normalized inputs
361
+ # and any control parameters
228
362
  return self._invoke(normalized, config=config, **kwargs)
229
363
 
230
364
  finally:
365
+ # Clean up the invocation depth tracking
231
366
  new_depth = _INVOKE_DEPTH.get() - 1
232
367
  _INVOKE_DEPTH.set(new_depth)
368
+
369
+ # For the top-level invocation, finalize telemetry and generate outputs
233
370
  if new_depth == 0:
234
371
  self.telemetry.render(
235
372
  raw=raw_debug,
@@ -240,10 +377,25 @@ class BaseAgent(ABC):
240
377
  )
241
378
 
242
379
  def _normalize_inputs(self, inputs: InputLike) -> Mapping[str, Any]:
380
+ """Normalizes various input formats into a standardized mapping.
381
+
382
+ This method converts different input types into a consistent dictionary format
383
+ that can be processed by the agent. String inputs are wrapped as messages, while
384
+ mappings are passed through unchanged.
385
+
386
+ Args:
387
+ inputs: The input to normalize. Can be a string (which will be converted to a
388
+ message) or a mapping (which will be returned as-is).
389
+
390
+ Returns:
391
+ A mapping containing the normalized inputs, with keys appropriate for agent
392
+ processing.
393
+
394
+ Raises:
395
+ TypeError: If the input type is not supported (neither string nor mapping).
396
+ """
243
397
  if isinstance(inputs, str):
244
398
  # Adjust to your message type
245
- from langchain_core.messages import HumanMessage
246
-
247
399
  return {"messages": [HumanMessage(content=inputs)]}
248
400
  if isinstance(inputs, Mapping):
249
401
  return inputs
@@ -255,15 +407,19 @@ class BaseAgent(ABC):
255
407
  ...
256
408
 
257
409
  def __call__(self, inputs: InputLike, /, **kwargs: Any) -> Any:
410
+ """Specify calling behavior for class instance."""
258
411
  return self.invoke(inputs, **kwargs)
259
412
 
260
413
  # Runtime enforcement: forbid subclasses from overriding invoke
261
414
  def __init_subclass__(cls, **kwargs):
415
+ """Ensure subclass does not override key method."""
262
416
  super().__init_subclass__(**kwargs)
263
417
  if "invoke" in cls.__dict__:
264
- raise TypeError(
265
- f"{cls.__name__} must not override BaseAgent.invoke(); implement _invoke() only."
418
+ err_msg = (
419
+ f"{cls.__name__} must not override BaseAgent.invoke(); "
420
+ "implement _invoke() only."
266
421
  )
422
+ raise TypeError(err_msg)
267
423
 
268
424
  def stream(
269
425
  self,
@@ -278,19 +434,52 @@ class BaseAgent(ABC):
278
434
  save_raw_records: bool | None = None,
279
435
  **kwargs: Any,
280
436
  ) -> Iterator[Any]:
281
- """Public streaming entry point. Telemetry-wrapped."""
437
+ """Streams agent responses with telemetry tracking.
438
+
439
+ This method serves as the public streaming entry point for agent interactions.
440
+ It wraps the actual streaming implementation with telemetry tracking to capture
441
+ metrics and debugging information.
442
+
443
+ Args:
444
+ inputs: The input to process, which will be normalized internally.
445
+ config: Optional configuration for the agent, compatible with LangGraph
446
+ positional/keyword argument style.
447
+ raw_debug: If True, renders raw debug information in telemetry output.
448
+ save_json: If True, saves telemetry data as JSON.
449
+ metrics_path: Optional file path where metrics should be saved.
450
+ save_raw_snapshot: If True, saves raw snapshot data in telemetry.
451
+ save_raw_records: If True, saves raw record data in telemetry.
452
+ **kwargs: Additional keyword arguments passed to the streaming
453
+ implementation.
454
+
455
+ Returns:
456
+ An iterator yielding the agent's responses.
457
+
458
+ Note:
459
+ This method tracks invocation depth to properly handle nested agent calls
460
+ and ensure telemetry is only rendered once at the top level.
461
+ """
462
+ # Track invocation depth to handle nested agent calls
282
463
  depth = _INVOKE_DEPTH.get()
283
464
  _INVOKE_DEPTH.set(depth + 1)
465
+
284
466
  try:
467
+ # Start telemetry tracking for top-level invocations only
285
468
  if depth == 0:
286
469
  self.telemetry.begin_run(
287
470
  agent=self.name, thread_id=self.thread_id
288
471
  )
472
+
473
+ # Normalize inputs and delegate to the actual streaming implementation
289
474
  normalized = self._normalize_inputs(inputs)
290
475
  yield from self._stream(normalized, config=config, **kwargs)
476
+
291
477
  finally:
478
+ # Decrement invocation depth when exiting
292
479
  new_depth = _INVOKE_DEPTH.get() - 1
293
480
  _INVOKE_DEPTH.set(new_depth)
481
+
482
+ # Render telemetry data only for top-level invocations
294
483
  if new_depth == 0:
295
484
  self.telemetry.render(
296
485
  raw=raw_debug,
@@ -307,48 +496,47 @@ class BaseAgent(ABC):
307
496
  config: Any | None = None,
308
497
  **kwargs: Any,
309
498
  ) -> Iterator[Any]:
499
+ """Subclass method to be overwritten for streaming implementation."""
310
500
  raise NotImplementedError(
311
501
  f"{self.name} does not support streaming. "
312
502
  "Override _stream(...) in your agent to enable it."
313
503
  )
314
504
 
315
- # def run(
316
- # self,
317
- # *args,
318
- # raw_debug: bool = False,
319
- # save_json: bool | None = None,
320
- # metrics_path: str | None = None,
321
- # save_raw_snapshot: bool | None = None,
322
- # save_raw_records: bool | None = None,
323
- # **kwargs
324
- # ):
325
- # try:
326
- # self.telemetry.begin_run(agent=self.name, thread_id=self.thread_id)
327
- # result = self._run_impl(*args, **kwargs)
328
- # return result
329
- # finally:
330
- # print(self.telemetry.render(
331
- # raw=raw_debug,
332
- # save_json=save_json,
333
- # filepath=metrics_path,
334
- # save_raw_snapshot=save_raw_snapshot,
335
- # save_raw_records=save_raw_records,
336
- # ))
337
-
338
- # @abstractmethod
339
- # def _run_impl(self, *args, **kwargs):
340
- # raise NotImplementedError("Agents must implement _run_impl")
341
-
342
505
  def _default_node_tags(
343
506
  self, name: str, extra: Sequence[str] | None = None
344
507
  ) -> list[str]:
508
+ """Generate default tags for a graph node.
509
+
510
+ Args:
511
+ name: The name of the node.
512
+ extra: Optional sequence of additional tags to include.
513
+
514
+ Returns:
515
+ list[str]: A list of tags for the node, including the agent name, 'graph',
516
+ the node name, and any extra tags provided.
517
+ """
518
+ # Start with standard tags: agent name, graph indicator, and node name
345
519
  tags = [self.name, "graph", name]
520
+
521
+ # Add any extra tags if provided
346
522
  if extra:
347
523
  tags.extend(extra)
524
+
348
525
  return tags
349
526
 
350
527
  def _as_runnable(self, fn: Any):
351
- # If it's already runnable (has .with_config/.invoke), return it; else wrap
528
+ """Convert a function to a runnable if it isn't already.
529
+
530
+ Args:
531
+ fn: The function or object to convert to a runnable.
532
+
533
+ Returns:
534
+ A runnable object that can be used in the graph. If the input is already
535
+ runnable (has .with_config and .invoke methods), it's returned as is.
536
+ Otherwise, it's wrapped in a RunnableLambda.
537
+ """
538
+ # Check if the function already has the required runnable interface
539
+ # If so, return it as is; otherwise wrap it in a RunnableLambda
352
540
  return (
353
541
  fn
354
542
  if hasattr(fn, "with_config") and hasattr(fn, "invoke")
@@ -356,11 +544,28 @@ class BaseAgent(ABC):
356
544
  )
357
545
 
358
546
  def _node_cfg(self, name: str, *extra_tags: str) -> dict:
359
- """Build a consistent config for a node/runnable so we can reapply it after .map(), subgraph compile, etc."""
547
+ """Build a consistent configuration for a node/runnable.
548
+
549
+ Creates a configuration dict that can be reapplied after operations like
550
+ .map(), subgraph compile, etc.
551
+
552
+ Args:
553
+ name: The name of the node.
554
+ *extra_tags: Additional tags to include in the node configuration.
555
+
556
+ Returns:
557
+ dict: A configuration dictionary with run_name, tags, and metadata.
558
+ """
559
+ # Determine the namespace - use first extra tag if available, otherwise
560
+ # convert agent name to snake_case
360
561
  ns = extra_tags[0] if extra_tags else _to_snake(self.name)
562
+
563
+ # Combine all tags: agent name, graph indicator, node name, and any extra tags
361
564
  tags = [self.name, "graph", name, *extra_tags]
565
+
566
+ # Return the complete configuration dictionary
362
567
  return dict(
363
- run_name="node", # keep "node:" prefixing in the timer; don't fight Rich labels here
568
+ run_name="node", # keep "node:" prefixing in the timer
364
569
  tags=tags,
365
570
  metadata={
366
571
  "langgraph_node": name,
@@ -370,16 +575,57 @@ class BaseAgent(ABC):
370
575
  )
371
576
 
372
577
  def ns(self, runnable_or_fn, name: str, *extra_tags: str):
373
- """Return a runnable with our node config applied. Safe to call on callables or runnables.
374
- IMPORTANT: call this AGAIN after .map() / subgraph .compile() (they often drop config)."""
578
+ """Return a runnable with node configuration applied.
579
+
580
+ Applies the agent's node configuration to a runnable or callable. This method
581
+ should be called again after operations like .map() or subgraph .compile() as
582
+ these operations may drop configuration.
583
+
584
+ Args:
585
+ runnable_or_fn: A runnable or callable to configure.
586
+ name: The name to assign to this node.
587
+ *extra_tags: Additional tags to apply to the node.
588
+
589
+ Returns:
590
+ A configured runnable with the agent's node configuration applied.
591
+ """
592
+ # Convert input to a runnable if it's not already one
375
593
  r = self._as_runnable(runnable_or_fn)
594
+ # Apply node configuration and return the configured runnable
376
595
  return r.with_config(**self._node_cfg(name, *extra_tags))
377
596
 
378
597
  def _wrap_node(self, fn_or_runnable, name: str, *extra_tags: str):
598
+ """Wrap a function or runnable as a node in the graph.
599
+
600
+ This is a convenience wrapper around the ns() method.
601
+
602
+ Args:
603
+ fn_or_runnable: A function or runnable to wrap as a node.
604
+ name: The name to assign to this node.
605
+ *extra_tags: Additional tags to apply to the node.
606
+
607
+ Returns:
608
+ A configured runnable with the agent's node configuration applied.
609
+ """
379
610
  return self.ns(fn_or_runnable, name, *extra_tags)
380
611
 
381
612
  def _wrap_cond(self, fn: Any, name: str, *extra_tags: str):
613
+ """Wrap a conditional function as a routing node in the graph.
614
+
615
+ Creates a runnable lambda with routing-specific configuration.
616
+
617
+ Args:
618
+ fn: The conditional function to wrap.
619
+ name: The name of the routing node.
620
+ *extra_tags: Additional tags to apply to the node.
621
+
622
+ Returns:
623
+ A configured RunnableLambda with routing-specific metadata.
624
+ """
625
+ # Use the first extra tag as namespace, or fall back to agent name in snake_case
382
626
  ns = extra_tags[0] if extra_tags else _to_snake(self.name)
627
+
628
+ # Create and return a configured RunnableLambda for routing
383
629
  return RunnableLambda(fn).with_config(
384
630
  run_name="node",
385
631
  tags=[
@@ -396,7 +642,22 @@ class BaseAgent(ABC):
396
642
  )
397
643
 
398
644
  def _named(self, runnable: Any, name: str, *extra_tags: str):
645
+ """Apply a specific name and configuration to a runnable.
646
+
647
+ Configures a runnable with a specific name and the agent's metadata.
648
+
649
+ Args:
650
+ runnable: The runnable to configure.
651
+ name: The name to assign to this runnable.
652
+ *extra_tags: Additional tags to apply to the runnable.
653
+
654
+ Returns:
655
+ A configured runnable with the specified name and agent metadata.
656
+ """
657
+ # Use the first extra tag as namespace, or fall back to agent name in snake_case
399
658
  ns = extra_tags[0] if extra_tags else _to_snake(self.name)
659
+
660
+ # Apply configuration and return the configured runnable
400
661
  return runnable.with_config(
401
662
  run_name=name,
402
663
  tags=[self.name, "graph", name, *extra_tags],