ursa-ai 0.4.2__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ursa-ai might be problematic. Click here for more details.

ursa/__init__.py ADDED
File without changes
ursa/agents/__init__.py CHANGED
@@ -14,6 +14,8 @@ from .lammps_agent import LammpsState as LammpsState
14
14
  from .mp_agent import MaterialsProjectAgent as MaterialsProjectAgent
15
15
  from .planning_agent import PlanningAgent as PlanningAgent
16
16
  from .planning_agent import PlanningState as PlanningState
17
+ from .rag_agent import RAGAgent as RAGAgent
18
+ from .rag_agent import RAGState as RAGState
17
19
  from .recall_agent import RecallAgent as RecallAgent
18
20
  from .websearch_agent import WebSearchAgent as WebSearchAgent
19
21
  from .websearch_agent import WebSearchState as WebSearchState
@@ -1,17 +1,16 @@
1
1
  import base64
2
2
  import os
3
3
  import re
4
- import statistics
5
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
6
5
  from io import BytesIO
6
+ from typing import Any, Mapping
7
7
  from urllib.parse import quote
8
8
 
9
9
  import feedparser
10
10
  import pymupdf
11
11
  import requests
12
- from langchain.text_splitter import RecursiveCharacterTextSplitter
13
- from langchain_chroma import Chroma
14
12
  from langchain_community.document_loaders import PyPDFLoader
13
+ from langchain_core.language_models import BaseChatModel
15
14
  from langchain_core.output_parsers import StrOutputParser
16
15
  from langchain_core.prompts import ChatPromptTemplate
17
16
  from langgraph.graph import StateGraph
@@ -19,16 +18,14 @@ from PIL import Image
19
18
  from tqdm import tqdm
20
19
  from typing_extensions import List, TypedDict
21
20
 
22
- from .base import BaseAgent
21
+ from ursa.agents.base import BaseAgent
22
+ from ursa.agents.rag_agent import RAGAgent
23
23
 
24
24
  try:
25
25
  from openai import OpenAI
26
26
  except Exception:
27
27
  pass
28
28
 
29
- # embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
30
- # embeddings = OpenAIEmbeddings()
31
-
32
29
 
33
30
  class PaperMetadata(TypedDict):
34
31
  arxiv_id: str
@@ -125,7 +122,7 @@ def remove_surrogates(text: str) -> str:
125
122
  class ArxivAgent(BaseAgent):
126
123
  def __init__(
127
124
  self,
128
- llm="openai/o3-mini",
125
+ llm: str | BaseChatModel = "openai/o3-mini",
129
126
  summarize: bool = True,
130
127
  process_images=True,
131
128
  max_results: int = 3,
@@ -146,7 +143,7 @@ class ArxivAgent(BaseAgent):
146
143
  self.download_papers = download_papers
147
144
  self.rag_embedding = rag_embedding
148
145
 
149
- self.graph = self._build_graph()
146
+ self._action = self._build_graph()
150
147
 
151
148
  os.makedirs(self.database_path, exist_ok=True)
152
149
 
@@ -242,27 +239,6 @@ class ArxivAgent(BaseAgent):
242
239
  papers = self._fetch_papers(state["query"])
243
240
  return {**state, "papers": papers}
244
241
 
245
- def _get_or_build_vectorstore(self, paper_text: str, arxiv_id: str):
246
- os.makedirs(self.vectorstore_path, exist_ok=True)
247
-
248
- persist_directory = os.path.join(self.vectorstore_path, arxiv_id)
249
-
250
- if os.path.exists(persist_directory):
251
- vectorstore = Chroma(
252
- persist_directory=persist_directory,
253
- embedding_function=self.rag_embedding,
254
- )
255
- else:
256
- splitter = RecursiveCharacterTextSplitter(
257
- chunk_size=1000, chunk_overlap=200
258
- )
259
- docs = splitter.create_documents([paper_text])
260
- vectorstore = Chroma.from_documents(
261
- docs, self.rag_embedding, persist_directory=persist_directory
262
- )
263
-
264
- return vectorstore.as_retriever(search_kwargs={"k": 5})
265
-
266
242
  def _summarize_node(self, state: PaperState) -> PaperState:
267
243
  prompt = ChatPromptTemplate.from_template("""
268
244
  You are a scientific assistant responsible for summarizing extracts from research papers, in the context of the following task: {context}
@@ -285,35 +261,13 @@ class ArxivAgent(BaseAgent):
285
261
 
286
262
  try:
287
263
  cleaned_text = remove_surrogates(paper["full_text"])
288
- if self.rag_embedding:
289
- retriever = self._get_or_build_vectorstore(
290
- cleaned_text, arxiv_id
291
- )
292
-
293
- relevant_docs_with_scores = (
294
- retriever.vectorstore.similarity_search_with_score(
295
- state["context"], k=5
296
- )
297
- )
298
-
299
- if relevant_docs_with_scores:
300
- score = sum([
301
- s for _, s in relevant_docs_with_scores
302
- ]) / len(relevant_docs_with_scores)
303
- relevancy_scores[i] = abs(1.0 - score)
304
- else:
305
- relevancy_scores[i] = 0.0
306
-
307
- retrieved_content = "\n\n".join([
308
- doc.page_content for doc, _ in relevant_docs_with_scores
309
- ])
310
- else:
311
- retrieved_content = cleaned_text
312
-
313
- summary = chain.invoke({
314
- "retrieved_content": retrieved_content,
315
- "context": state["context"],
316
- })
264
+ summary = chain.invoke(
265
+ {
266
+ "retrieved_content": cleaned_text,
267
+ "context": state["context"],
268
+ },
269
+ config=self.build_config(tags=["arxiv", "summarize_each"]),
270
+ )
317
271
 
318
272
  except Exception as e:
319
273
  summary = f"Error summarizing paper: {e}"
@@ -346,15 +300,20 @@ class ArxivAgent(BaseAgent):
346
300
  i, result = future.result()
347
301
  summaries[i] = result
348
302
 
349
- if self.rag_embedding:
350
- print(f"\nMax Relevancy Score: {max(relevancy_scores)}")
351
- print(f"Min Relevancy Score: {min(relevancy_scores)}")
352
- print(
353
- f"Median Relevancy Score: {statistics.median(relevancy_scores)}\n"
354
- )
355
-
356
303
  return {**state, "summaries": summaries}
357
304
 
305
+ def _rag_node(self, state: PaperState) -> PaperState:
306
+ new_state = state.copy()
307
+ rag_agent = RAGAgent(
308
+ llm=self.llm,
309
+ embedding=self.rag_embedding,
310
+ database_path=self.database_path,
311
+ )
312
+ new_state["final_summary"] = rag_agent.invoke(context=state["context"])[
313
+ "summary"
314
+ ]
315
+ return new_state
316
+
358
317
  def _aggregate_node(self, state: PaperState) -> PaperState:
359
318
  summaries = state["summaries"]
360
319
  papers = state["papers"]
@@ -389,10 +348,13 @@ class ArxivAgent(BaseAgent):
389
348
 
390
349
  chain = prompt | self.llm | StrOutputParser()
391
350
 
392
- final_summary = chain.invoke({
393
- "Summaries": combined,
394
- "context": state["context"],
395
- })
351
+ final_summary = chain.invoke(
352
+ {
353
+ "Summaries": combined,
354
+ "context": state["context"],
355
+ },
356
+ config=self.build_config(tags=["arxiv", "aggregate"]),
357
+ )
396
358
 
397
359
  with open(self.summaries_path + "/final_summary.txt", "w") as f:
398
360
  f.write(final_summary)
@@ -400,42 +362,69 @@ class ArxivAgent(BaseAgent):
400
362
  return {**state, "final_summary": final_summary}
401
363
 
402
364
  def _build_graph(self):
403
- builder = StateGraph(PaperState)
404
- builder.add_node("fetch_papers", self._fetch_node)
365
+ graph = StateGraph(PaperState)
405
366
 
367
+ self.add_node(graph, self._fetch_node)
406
368
  if self.summarize:
407
- builder.add_node("summarize_each", self._summarize_node)
408
- builder.add_node("aggregate", self._aggregate_node)
369
+ if self.rag_embedding:
370
+ self.add_node(graph, self._rag_node)
371
+ graph.set_entry_point("_fetch_node")
372
+ graph.add_edge("_fetch_node", "_rag_node")
373
+ graph.set_finish_point("_rag_node")
374
+ else:
375
+ self.add_node(graph, self._summarize_node)
376
+ self.add_node(graph, self._aggregate_node)
377
+
378
+ graph.set_entry_point("_fetch_node")
379
+ graph.add_edge("_fetch_node", "_summarize_node")
380
+ graph.add_edge("_summarize_node", "_aggregate_node")
381
+ graph.set_finish_point("_aggregate_node")
382
+ else:
383
+ graph.set_entry_point("_fetch_node")
384
+ graph.set_finish_point("_fetch_node")
409
385
 
410
- builder.set_entry_point("fetch_papers")
411
- builder.add_edge("fetch_papers", "summarize_each")
412
- builder.add_edge("summarize_each", "aggregate")
413
- builder.set_finish_point("aggregate")
386
+ return graph.compile(checkpointer=self.checkpointer)
414
387
 
415
- else:
416
- builder.set_entry_point("fetch_papers")
417
- builder.set_finish_point("fetch_papers")
388
+ def _invoke(
389
+ self,
390
+ inputs: Mapping[str, Any],
391
+ *,
392
+ summarize: bool | None = None,
393
+ recursion_limit: int = 1000,
394
+ **_,
395
+ ) -> str:
396
+ config = self.build_config(
397
+ recursion_limit=recursion_limit, tags=["graph"]
398
+ )
418
399
 
419
- graph = builder.compile()
420
- return graph
400
+ # this seems dumb, but it's b/c sometimes we had referred to the value as
401
+ # 'query' other times as 'arxiv_search_query' so trying to keep it compatible
402
+ # aliasing: accept arxiv_search_query -> query
403
+ if "query" not in inputs:
404
+ if "arxiv_search_query" in inputs:
405
+ # make a shallow copy and rename the key
406
+ inputs = dict(inputs)
407
+ inputs["query"] = inputs.pop("arxiv_search_query")
408
+ else:
409
+ raise KeyError(
410
+ "Missing 'query' in inputs (alias 'arxiv_search_query' also accepted)."
411
+ )
421
412
 
422
- def run(self, arxiv_search_query: str, context: str) -> str:
423
- result = self.graph.invoke({
424
- "query": arxiv_search_query,
425
- "context": context,
426
- })
413
+ result = self._action.invoke(inputs, config)
427
414
 
428
- if self.summarize:
429
- return result.get("final_summary", "No summary generated.")
430
- else:
431
- return "\n\nFinished Fetching papers!"
415
+ use_summary = self.summarize if summarize is None else summarize
432
416
 
417
+ return (
418
+ result.get("final_summary", "No summary generated.")
419
+ if use_summary
420
+ else "\n\nFinished Fetching papers!"
421
+ )
433
422
 
434
- if __name__ == "__main__":
435
- agent = ArxivAgent()
436
- result = agent.run(
437
- arxiv_search_query="Experimental Constraints on neutron star radius",
438
- context="What are the constraints on the neutron star radius and what uncertainties are there on the constraints?",
439
- )
440
423
 
441
- print(result)
424
+ # NOTE: Run test in `tests/agents/test_arxiv_agent/test_arxiv_agent.py` via:
425
+ #
426
+ # pytest -s tests/agents/test_arxiv_agent
427
+ #
428
+ # OR
429
+ #
430
+ # uv run pytest -s tests/agents/test_arxiv_agent
ursa/agents/base.py CHANGED
@@ -1,10 +1,47 @@
1
+ import re
2
+ from abc import ABC, abstractmethod
3
+ from contextvars import ContextVar
4
+ from typing import (
5
+ Any,
6
+ Callable,
7
+ Iterator,
8
+ Mapping,
9
+ Optional,
10
+ Sequence,
11
+ Union,
12
+ final,
13
+ )
14
+ from uuid import uuid4
15
+
1
16
  from langchain_core.language_models.chat_models import BaseChatModel
2
17
  from langchain_core.load import dumps
18
+ from langchain_core.runnables import (
19
+ RunnableLambda,
20
+ )
3
21
  from langchain_litellm import ChatLiteLLM
4
22
  from langgraph.checkpoint.base import BaseCheckpointSaver
23
+ from langgraph.graph import StateGraph
24
+
25
+ from ursa.observability.timing import (
26
+ Telemetry, # for timing / telemetry / metrics
27
+ )
28
+
29
+ InputLike = Union[str, Mapping[str, Any]]
30
+ _INVOKE_DEPTH = ContextVar("_INVOKE_DEPTH", default=0)
5
31
 
6
32
 
7
- class BaseAgent:
33
+ def _to_snake(s: str) -> str:
34
+ s = re.sub(
35
+ r"^([A-Z]{2,})([A-Z][a-z])",
36
+ lambda m: m.group(1)[0] + m.group(1)[1:].lower() + m.group(2),
37
+ str(s),
38
+ ) # RAGAgent -> RagAgent
39
+ s = re.sub(r"(?<!^)(?=[A-Z])", "_", s) # CamelCase -> snake_case
40
+ s = s.replace("-", "_").replace(" ", "_")
41
+ return s.lower()
42
+
43
+
44
+ class BaseAgent(ABC):
8
45
  # llm: BaseChatModel
9
46
  # llm_with_tools: Runnable[LanguageModelInput, BaseMessage]
10
47
 
@@ -12,6 +49,10 @@ class BaseAgent:
12
49
  self,
13
50
  llm: str | BaseChatModel,
14
51
  checkpointer: BaseCheckpointSaver = None,
52
+ enable_metrics: bool = False, # default to enabling metrics
53
+ metrics_dir: str = ".ursa_metrics", # dir to save metrics, with a default
54
+ autosave_metrics: bool = True,
55
+ thread_id: Optional[str] = None,
15
56
  **kwargs,
16
57
  ):
17
58
  match llm:
@@ -32,10 +73,336 @@ class BaseAgent:
32
73
  "llm argument must be a string with the provider and model, or a BaseChatModel instance."
33
74
  )
34
75
 
76
+ self.thread_id = thread_id or uuid4().hex
35
77
  self.checkpointer = checkpointer
36
- self.thread_id = self.__class__.__name__
78
+ self.telemetry = Telemetry(
79
+ enable=enable_metrics,
80
+ output_dir=metrics_dir,
81
+ save_json_default=autosave_metrics,
82
+ )
83
+
84
+ @property
85
+ def name(self) -> str:
86
+ """Agent name."""
87
+ return self.__class__.__name__
88
+
89
+ def add_node(
90
+ self,
91
+ graph: StateGraph,
92
+ f: Callable[..., Mapping[str, Any]],
93
+ node_name: Optional[str] = None,
94
+ agent_name: Optional[str] = None,
95
+ ) -> StateGraph:
96
+ """Add node to graph.
97
+
98
+ This is used to track token usage and is simply the following.
99
+
100
+ ```python
101
+ _node_name = node_name or f.__name__
102
+ return graph.add_node(
103
+ _node_name, self._wrap_node(f, _node_name, self.name)
104
+ )
105
+ ```
106
+ """
107
+ _node_name = node_name or f.__name__
108
+ _agent_name = agent_name or _to_snake(self.name)
109
+ wrapped_node = self._wrap_node(f, _node_name, _agent_name)
110
+ return graph.add_node(_node_name, wrapped_node)
37
111
 
38
112
  def write_state(self, filename, state):
39
113
  json_state = dumps(state, ensure_ascii=False)
40
114
  with open(filename, "w") as f:
41
115
  f.write(json_state)
116
+
117
+ # BaseAgent
118
+ def build_config(self, **overrides) -> dict:
119
+ """
120
+ Build a config dict that includes telemetry callbacks and the thread_id.
121
+ You can pass overrides like recursion_limit=..., configurable={...}, etc.
122
+ """
123
+ base = {
124
+ "configurable": {"thread_id": self.thread_id},
125
+ "metadata": {
126
+ "thread_id": self.thread_id,
127
+ "telemetry_run_id": self.telemetry.context.get("run_id"),
128
+ },
129
+ # "configurable": {
130
+ # "thread_id": getattr(self, "thread_id", "default")
131
+ # },
132
+ # "metadata": {
133
+ # "thread_id": getattr(self, "thread_id", "default"),
134
+ # "telemetry_run_id": self.telemetry.context.get("run_id"),
135
+ # },
136
+ "tags": [self.name],
137
+ "callbacks": self.telemetry.callbacks,
138
+ }
139
+ # include model name when we can
140
+ model_name = getattr(self, "llm_model", None) or getattr(
141
+ getattr(self, "llm", None), "model", None
142
+ )
143
+ if model_name:
144
+ base["metadata"]["model"] = model_name
145
+
146
+ if "configurable" in overrides and isinstance(
147
+ overrides["configurable"], dict
148
+ ):
149
+ base["configurable"].update(overrides.pop("configurable"))
150
+ if "metadata" in overrides and isinstance(overrides["metadata"], dict):
151
+ base["metadata"].update(overrides.pop("metadata"))
152
+ # merge tags if caller provides them
153
+ if "tags" in overrides and isinstance(overrides["tags"], list):
154
+ base["tags"] = base["tags"] + [
155
+ t for t in overrides.pop("tags") if t not in base["tags"]
156
+ ]
157
+ base.update(overrides)
158
+ return base
159
+
160
+ # agents will invoke like this:
161
+ # planning_output = planner.invoke(
162
+ # {"messages": [HumanMessage(content=problem)]},
163
+ # config={
164
+ # "recursion_limit": 999_999,
165
+ # "configurable": {"thread_id": planner.thread_id},
166
+ # },
167
+ # )
168
+ # they can also, separately, override these defaults about metrics
169
+ # keys that are NOT inputs; they should not be folded into the inputs mapping
170
+ _TELEMETRY_KW = {
171
+ "raw_debug",
172
+ "save_json",
173
+ "metrics_path",
174
+ "save_raw_snapshot",
175
+ "save_raw_records",
176
+ }
177
+ _CONTROL_KW = {"config", "recursion_limit", "tags", "metadata", "callbacks"}
178
+
179
+ @final
180
+ def invoke(
181
+ self,
182
+ inputs: Optional[InputLike] = None, # sentinel
183
+ /,
184
+ *,
185
+ raw_debug: bool = False,
186
+ save_json: Optional[bool] = None,
187
+ metrics_path: Optional[str] = None,
188
+ save_raw_snapshot: Optional[bool] = None,
189
+ save_raw_records: Optional[bool] = None,
190
+ config: Optional[dict] = None,
191
+ **kwargs: Any, # may contain inputs (keyword-inputs) and/or control kw
192
+ ) -> Any:
193
+ depth = _INVOKE_DEPTH.get()
194
+ _INVOKE_DEPTH.set(depth + 1)
195
+ try:
196
+ if depth == 0:
197
+ self.telemetry.begin_run(
198
+ agent=self.name, thread_id=self.thread_id
199
+ )
200
+
201
+ # If no positional inputs were provided, split kwargs into inputs vs control
202
+ if inputs is None:
203
+ kw_inputs: dict[str, Any] = {}
204
+ control_kwargs: dict[str, Any] = {}
205
+ for k, v in kwargs.items():
206
+ if k in self._TELEMETRY_KW or k in self._CONTROL_KW:
207
+ control_kwargs[k] = v
208
+ else:
209
+ kw_inputs[k] = v
210
+ inputs = kw_inputs
211
+ kwargs = control_kwargs # only control kwargs remain
212
+
213
+ # If both positional inputs and extra unknown kwargs-as-inputs are given, forbid merging
214
+ else:
215
+ # keep only control kwargs; anything else would be ambiguous
216
+ for k in kwargs.keys():
217
+ if not (k in self._TELEMETRY_KW or k in self._CONTROL_KW):
218
+ raise TypeError(
219
+ f"Unexpected keyword argument '{k}'. "
220
+ "Pass inputs as a single mapping or omit the positional "
221
+ "inputs and pass them as keyword arguments."
222
+ )
223
+
224
+ # subclasses may translate keys
225
+ normalized = self._normalize_inputs(inputs)
226
+
227
+ # forward config + any control kwargs (e.g., recursion_limit) to the agent
228
+ return self._invoke(normalized, config=config, **kwargs)
229
+
230
+ finally:
231
+ new_depth = _INVOKE_DEPTH.get() - 1
232
+ _INVOKE_DEPTH.set(new_depth)
233
+ if new_depth == 0:
234
+ self.telemetry.render(
235
+ raw=raw_debug,
236
+ save_json=save_json,
237
+ filepath=metrics_path,
238
+ save_raw_snapshot=save_raw_snapshot,
239
+ save_raw_records=save_raw_records,
240
+ )
241
+
242
+ def _normalize_inputs(self, inputs: InputLike) -> Mapping[str, Any]:
243
+ if isinstance(inputs, str):
244
+ # Adjust to your message type
245
+ from langchain_core.messages import HumanMessage
246
+
247
+ return {"messages": [HumanMessage(content=inputs)]}
248
+ if isinstance(inputs, Mapping):
249
+ return inputs
250
+ raise TypeError(f"Unsupported input type: {type(inputs)}")
251
+
252
+ @abstractmethod
253
+ def _invoke(self, inputs: Mapping[str, Any], **config: Any) -> Any:
254
+ """Subclasses implement the actual work against normalized inputs."""
255
+ ...
256
+
257
+ def __call__(self, inputs: InputLike, /, **kwargs: Any) -> Any:
258
+ return self.invoke(inputs, **kwargs)
259
+
260
+ # Runtime enforcement: forbid subclasses from overriding invoke
261
+ def __init_subclass__(cls, **kwargs):
262
+ super().__init_subclass__(**kwargs)
263
+ if "invoke" in cls.__dict__:
264
+ raise TypeError(
265
+ f"{cls.__name__} must not override BaseAgent.invoke(); implement _invoke() only."
266
+ )
267
+
268
+ def stream(
269
+ self,
270
+ inputs: InputLike,
271
+ config: Any | None = None, # allow positional/keyword like LangGraph
272
+ /,
273
+ *,
274
+ raw_debug: bool = False,
275
+ save_json: bool | None = None,
276
+ metrics_path: str | None = None,
277
+ save_raw_snapshot: bool | None = None,
278
+ save_raw_records: bool | None = None,
279
+ **kwargs: Any,
280
+ ) -> Iterator[Any]:
281
+ """Public streaming entry point. Telemetry-wrapped."""
282
+ depth = _INVOKE_DEPTH.get()
283
+ _INVOKE_DEPTH.set(depth + 1)
284
+ try:
285
+ if depth == 0:
286
+ self.telemetry.begin_run(
287
+ agent=self.name, thread_id=self.thread_id
288
+ )
289
+ normalized = self._normalize_inputs(inputs)
290
+ yield from self._stream(normalized, config=config, **kwargs)
291
+ finally:
292
+ new_depth = _INVOKE_DEPTH.get() - 1
293
+ _INVOKE_DEPTH.set(new_depth)
294
+ if new_depth == 0:
295
+ self.telemetry.render(
296
+ raw=raw_debug,
297
+ save_json=save_json,
298
+ filepath=metrics_path,
299
+ save_raw_snapshot=save_raw_snapshot,
300
+ save_raw_records=save_raw_records,
301
+ )
302
+
303
+ def _stream(
304
+ self,
305
+ inputs: Mapping[str, Any],
306
+ *,
307
+ config: Any | None = None,
308
+ **kwargs: Any,
309
+ ) -> Iterator[Any]:
310
+ raise NotImplementedError(
311
+ f"{self.name} does not support streaming. "
312
+ "Override _stream(...) in your agent to enable it."
313
+ )
314
+
315
+ # def run(
316
+ # self,
317
+ # *args,
318
+ # raw_debug: bool = False,
319
+ # save_json: bool | None = None,
320
+ # metrics_path: str | None = None,
321
+ # save_raw_snapshot: bool | None = None,
322
+ # save_raw_records: bool | None = None,
323
+ # **kwargs
324
+ # ):
325
+ # try:
326
+ # self.telemetry.begin_run(agent=self.name, thread_id=self.thread_id)
327
+ # result = self._run_impl(*args, **kwargs)
328
+ # return result
329
+ # finally:
330
+ # print(self.telemetry.render(
331
+ # raw=raw_debug,
332
+ # save_json=save_json,
333
+ # filepath=metrics_path,
334
+ # save_raw_snapshot=save_raw_snapshot,
335
+ # save_raw_records=save_raw_records,
336
+ # ))
337
+
338
+ # @abstractmethod
339
+ # def _run_impl(self, *args, **kwargs):
340
+ # raise NotImplementedError("Agents must implement _run_impl")
341
+
342
+ def _default_node_tags(
343
+ self, name: str, extra: Sequence[str] | None = None
344
+ ) -> list[str]:
345
+ tags = [self.name, "graph", name]
346
+ if extra:
347
+ tags.extend(extra)
348
+ return tags
349
+
350
+ def _as_runnable(self, fn: Any):
351
+ # If it's already runnable (has .with_config/.invoke), return it; else wrap
352
+ return (
353
+ fn
354
+ if hasattr(fn, "with_config") and hasattr(fn, "invoke")
355
+ else RunnableLambda(fn)
356
+ )
357
+
358
+ def _node_cfg(self, name: str, *extra_tags: str) -> dict:
359
+ """Build a consistent config for a node/runnable so we can reapply it after .map(), subgraph compile, etc."""
360
+ ns = extra_tags[0] if extra_tags else _to_snake(self.name)
361
+ tags = [self.name, "graph", name, *extra_tags]
362
+ return dict(
363
+ run_name="node", # keep "node:" prefixing in the timer; don't fight Rich labels here
364
+ tags=tags,
365
+ metadata={
366
+ "langgraph_node": name,
367
+ "ursa_ns": ns,
368
+ "ursa_agent": self.name,
369
+ },
370
+ )
371
+
372
+ def ns(self, runnable_or_fn, name: str, *extra_tags: str):
373
+ """Return a runnable with our node config applied. Safe to call on callables or runnables.
374
+ IMPORTANT: call this AGAIN after .map() / subgraph .compile() (they often drop config)."""
375
+ r = self._as_runnable(runnable_or_fn)
376
+ return r.with_config(**self._node_cfg(name, *extra_tags))
377
+
378
+ def _wrap_node(self, fn_or_runnable, name: str, *extra_tags: str):
379
+ return self.ns(fn_or_runnable, name, *extra_tags)
380
+
381
+ def _wrap_cond(self, fn: Any, name: str, *extra_tags: str):
382
+ ns = extra_tags[0] if extra_tags else _to_snake(self.name)
383
+ return RunnableLambda(fn).with_config(
384
+ run_name="node",
385
+ tags=[
386
+ self.name,
387
+ "graph",
388
+ f"route:{name}",
389
+ *extra_tags,
390
+ ],
391
+ metadata={
392
+ "langgraph_node": f"route:{name}",
393
+ "ursa_ns": ns,
394
+ "ursa_agent": self.name,
395
+ },
396
+ )
397
+
398
+ def _named(self, runnable: Any, name: str, *extra_tags: str):
399
+ ns = extra_tags[0] if extra_tags else _to_snake(self.name)
400
+ return runnable.with_config(
401
+ run_name=name,
402
+ tags=[self.name, "graph", name, *extra_tags],
403
+ metadata={
404
+ "langgraph_node": name,
405
+ "ursa_ns": ns,
406
+ "ursa_agent": self.name,
407
+ },
408
+ )
@@ -270,7 +270,9 @@ def read_file(filename: str, state: Annotated[dict, InjectedState]):
270
270
 
271
271
 
272
272
  @tool
273
- def write_file(code: str, filename: str, state: Annotated[dict, InjectedState]):
273
+ def write_file(
274
+ code: str, filename: str, state: Annotated[dict, InjectedState]
275
+ ) -> str:
274
276
  """
275
277
  Writes text to a file in the given workspace as requested.
276
278