ursa-ai 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ursa-ai might be problematic. Click here for more details.
- ursa/__init__.py +0 -0
- ursa/agents/arxiv_agent.py +77 -47
- ursa/agents/base.py +369 -2
- ursa/agents/code_review_agent.py +3 -1
- ursa/agents/execution_agent.py +92 -48
- ursa/agents/hypothesizer_agent.py +39 -42
- ursa/agents/lammps_agent.py +51 -29
- ursa/agents/mp_agent.py +45 -20
- ursa/agents/optimization_agent.py +405 -0
- ursa/agents/planning_agent.py +63 -28
- ursa/agents/rag_agent.py +75 -44
- ursa/agents/recall_agent.py +35 -5
- ursa/agents/websearch_agent.py +44 -54
- ursa/cli/__init__.py +127 -0
- ursa/cli/hitl.py +426 -0
- ursa/observability/pricing.py +319 -0
- ursa/observability/timing.py +1441 -0
- ursa/prompt_library/__init__.py +0 -0
- ursa/prompt_library/execution_prompts.py +7 -0
- ursa/prompt_library/optimization_prompts.py +131 -0
- ursa/tools/__init__.py +0 -0
- ursa/tools/feasibility_checker.py +114 -0
- ursa/tools/feasibility_tools.py +1075 -0
- ursa/tools/write_code.py +1 -1
- ursa/util/__init__.py +0 -0
- ursa/util/helperFunctions.py +142 -0
- ursa/util/optimization_schema.py +78 -0
- ursa/util/parse.py +1 -1
- {ursa_ai-0.5.0.dist-info → ursa_ai-0.6.0.dist-info}/METADATA +123 -4
- ursa_ai-0.6.0.dist-info/RECORD +43 -0
- ursa_ai-0.6.0.dist-info/entry_points.txt +2 -0
- ursa_ai-0.5.0.dist-info/RECORD +0 -28
- {ursa_ai-0.5.0.dist-info → ursa_ai-0.6.0.dist-info}/WHEEL +0 -0
- {ursa_ai-0.5.0.dist-info → ursa_ai-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {ursa_ai-0.5.0.dist-info → ursa_ai-0.6.0.dist-info}/top_level.txt +0 -0
ursa/__init__.py
ADDED
|
File without changes
|
ursa/agents/arxiv_agent.py
CHANGED
|
@@ -3,12 +3,14 @@ import os
|
|
|
3
3
|
import re
|
|
4
4
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
5
5
|
from io import BytesIO
|
|
6
|
+
from typing import Any, Mapping
|
|
6
7
|
from urllib.parse import quote
|
|
7
8
|
|
|
8
9
|
import feedparser
|
|
9
10
|
import pymupdf
|
|
10
11
|
import requests
|
|
11
12
|
from langchain_community.document_loaders import PyPDFLoader
|
|
13
|
+
from langchain_core.language_models import BaseChatModel
|
|
12
14
|
from langchain_core.output_parsers import StrOutputParser
|
|
13
15
|
from langchain_core.prompts import ChatPromptTemplate
|
|
14
16
|
from langgraph.graph import StateGraph
|
|
@@ -16,8 +18,8 @@ from PIL import Image
|
|
|
16
18
|
from tqdm import tqdm
|
|
17
19
|
from typing_extensions import List, TypedDict
|
|
18
20
|
|
|
19
|
-
from .base import BaseAgent
|
|
20
|
-
from .rag_agent import RAGAgent
|
|
21
|
+
from ursa.agents.base import BaseAgent
|
|
22
|
+
from ursa.agents.rag_agent import RAGAgent
|
|
21
23
|
|
|
22
24
|
try:
|
|
23
25
|
from openai import OpenAI
|
|
@@ -120,7 +122,7 @@ def remove_surrogates(text: str) -> str:
|
|
|
120
122
|
class ArxivAgent(BaseAgent):
|
|
121
123
|
def __init__(
|
|
122
124
|
self,
|
|
123
|
-
llm="openai/o3-mini",
|
|
125
|
+
llm: str | BaseChatModel = "openai/o3-mini",
|
|
124
126
|
summarize: bool = True,
|
|
125
127
|
process_images=True,
|
|
126
128
|
max_results: int = 3,
|
|
@@ -141,7 +143,7 @@ class ArxivAgent(BaseAgent):
|
|
|
141
143
|
self.download_papers = download_papers
|
|
142
144
|
self.rag_embedding = rag_embedding
|
|
143
145
|
|
|
144
|
-
self.
|
|
146
|
+
self._action = self._build_graph()
|
|
145
147
|
|
|
146
148
|
os.makedirs(self.database_path, exist_ok=True)
|
|
147
149
|
|
|
@@ -259,10 +261,13 @@ class ArxivAgent(BaseAgent):
|
|
|
259
261
|
|
|
260
262
|
try:
|
|
261
263
|
cleaned_text = remove_surrogates(paper["full_text"])
|
|
262
|
-
summary = chain.invoke(
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
264
|
+
summary = chain.invoke(
|
|
265
|
+
{
|
|
266
|
+
"retrieved_content": cleaned_text,
|
|
267
|
+
"context": state["context"],
|
|
268
|
+
},
|
|
269
|
+
config=self.build_config(tags=["arxiv", "summarize_each"]),
|
|
270
|
+
)
|
|
266
271
|
|
|
267
272
|
except Exception as e:
|
|
268
273
|
summary = f"Error summarizing paper: {e}"
|
|
@@ -304,7 +309,9 @@ class ArxivAgent(BaseAgent):
|
|
|
304
309
|
embedding=self.rag_embedding,
|
|
305
310
|
database_path=self.database_path,
|
|
306
311
|
)
|
|
307
|
-
new_state["final_summary"] = rag_agent.
|
|
312
|
+
new_state["final_summary"] = rag_agent.invoke(context=state["context"])[
|
|
313
|
+
"summary"
|
|
314
|
+
]
|
|
308
315
|
return new_state
|
|
309
316
|
|
|
310
317
|
def _aggregate_node(self, state: PaperState) -> PaperState:
|
|
@@ -341,10 +348,13 @@ class ArxivAgent(BaseAgent):
|
|
|
341
348
|
|
|
342
349
|
chain = prompt | self.llm | StrOutputParser()
|
|
343
350
|
|
|
344
|
-
final_summary = chain.invoke(
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
351
|
+
final_summary = chain.invoke(
|
|
352
|
+
{
|
|
353
|
+
"Summaries": combined,
|
|
354
|
+
"context": state["context"],
|
|
355
|
+
},
|
|
356
|
+
config=self.build_config(tags=["arxiv", "aggregate"]),
|
|
357
|
+
)
|
|
348
358
|
|
|
349
359
|
with open(self.summaries_path + "/final_summary.txt", "w") as f:
|
|
350
360
|
f.write(final_summary)
|
|
@@ -352,49 +362,69 @@ class ArxivAgent(BaseAgent):
|
|
|
352
362
|
return {**state, "final_summary": final_summary}
|
|
353
363
|
|
|
354
364
|
def _build_graph(self):
|
|
355
|
-
|
|
356
|
-
builder.add_node("fetch_papers", self._fetch_node)
|
|
365
|
+
graph = StateGraph(PaperState)
|
|
357
366
|
|
|
367
|
+
self.add_node(graph, self._fetch_node)
|
|
358
368
|
if self.summarize:
|
|
359
369
|
if self.rag_embedding:
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
builder.set_finish_point("rag_summarize")
|
|
370
|
+
self.add_node(graph, self._rag_node)
|
|
371
|
+
graph.set_entry_point("_fetch_node")
|
|
372
|
+
graph.add_edge("_fetch_node", "_rag_node")
|
|
373
|
+
graph.set_finish_point("_rag_node")
|
|
365
374
|
else:
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
builder.set_entry_point("fetch_papers")
|
|
370
|
-
builder.add_edge("fetch_papers", "summarize_each")
|
|
371
|
-
builder.add_edge("summarize_each", "aggregate")
|
|
372
|
-
builder.set_finish_point("aggregate")
|
|
375
|
+
self.add_node(graph, self._summarize_node)
|
|
376
|
+
self.add_node(graph, self._aggregate_node)
|
|
373
377
|
|
|
378
|
+
graph.set_entry_point("_fetch_node")
|
|
379
|
+
graph.add_edge("_fetch_node", "_summarize_node")
|
|
380
|
+
graph.add_edge("_summarize_node", "_aggregate_node")
|
|
381
|
+
graph.set_finish_point("_aggregate_node")
|
|
374
382
|
else:
|
|
375
|
-
|
|
376
|
-
|
|
383
|
+
graph.set_entry_point("_fetch_node")
|
|
384
|
+
graph.set_finish_point("_fetch_node")
|
|
377
385
|
|
|
378
|
-
graph
|
|
379
|
-
return graph
|
|
386
|
+
return graph.compile(checkpointer=self.checkpointer)
|
|
380
387
|
|
|
381
|
-
def
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
388
|
+
def _invoke(
|
|
389
|
+
self,
|
|
390
|
+
inputs: Mapping[str, Any],
|
|
391
|
+
*,
|
|
392
|
+
summarize: bool | None = None,
|
|
393
|
+
recursion_limit: int = 1000,
|
|
394
|
+
**_,
|
|
395
|
+
) -> str:
|
|
396
|
+
config = self.build_config(
|
|
397
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
398
|
+
)
|
|
386
399
|
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
400
|
+
# this seems dumb, but it's b/c sometimes we had referred to the value as
|
|
401
|
+
# 'query' other times as 'arxiv_search_query' so trying to keep it compatible
|
|
402
|
+
# aliasing: accept arxiv_search_query -> query
|
|
403
|
+
if "query" not in inputs:
|
|
404
|
+
if "arxiv_search_query" in inputs:
|
|
405
|
+
# make a shallow copy and rename the key
|
|
406
|
+
inputs = dict(inputs)
|
|
407
|
+
inputs["query"] = inputs.pop("arxiv_search_query")
|
|
408
|
+
else:
|
|
409
|
+
raise KeyError(
|
|
410
|
+
"Missing 'query' in inputs (alias 'arxiv_search_query' also accepted)."
|
|
411
|
+
)
|
|
391
412
|
|
|
413
|
+
result = self._action.invoke(inputs, config)
|
|
414
|
+
|
|
415
|
+
use_summary = self.summarize if summarize is None else summarize
|
|
416
|
+
|
|
417
|
+
return (
|
|
418
|
+
result.get("final_summary", "No summary generated.")
|
|
419
|
+
if use_summary
|
|
420
|
+
else "\n\nFinished Fetching papers!"
|
|
421
|
+
)
|
|
392
422
|
|
|
393
|
-
if __name__ == "__main__":
|
|
394
|
-
agent = ArxivAgent()
|
|
395
|
-
result = agent.run(
|
|
396
|
-
arxiv_search_query="Experimental Constraints on neutron star radius",
|
|
397
|
-
context="What are the constraints on the neutron star radius and what uncertainties are there on the constraints?",
|
|
398
|
-
)
|
|
399
423
|
|
|
400
|
-
|
|
424
|
+
# NOTE: Run test in `tests/agents/test_arxiv_agent/test_arxiv_agent.py` via:
|
|
425
|
+
#
|
|
426
|
+
# pytest -s tests/agents/test_arxiv_agent
|
|
427
|
+
#
|
|
428
|
+
# OR
|
|
429
|
+
#
|
|
430
|
+
# uv run pytest -s tests/agents/test_arxiv_agent
|
ursa/agents/base.py
CHANGED
|
@@ -1,10 +1,47 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from contextvars import ContextVar
|
|
4
|
+
from typing import (
|
|
5
|
+
Any,
|
|
6
|
+
Callable,
|
|
7
|
+
Iterator,
|
|
8
|
+
Mapping,
|
|
9
|
+
Optional,
|
|
10
|
+
Sequence,
|
|
11
|
+
Union,
|
|
12
|
+
final,
|
|
13
|
+
)
|
|
14
|
+
from uuid import uuid4
|
|
15
|
+
|
|
1
16
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
2
17
|
from langchain_core.load import dumps
|
|
18
|
+
from langchain_core.runnables import (
|
|
19
|
+
RunnableLambda,
|
|
20
|
+
)
|
|
3
21
|
from langchain_litellm import ChatLiteLLM
|
|
4
22
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
23
|
+
from langgraph.graph import StateGraph
|
|
24
|
+
|
|
25
|
+
from ursa.observability.timing import (
|
|
26
|
+
Telemetry, # for timing / telemetry / metrics
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
InputLike = Union[str, Mapping[str, Any]]
|
|
30
|
+
_INVOKE_DEPTH = ContextVar("_INVOKE_DEPTH", default=0)
|
|
5
31
|
|
|
6
32
|
|
|
7
|
-
|
|
33
|
+
def _to_snake(s: str) -> str:
|
|
34
|
+
s = re.sub(
|
|
35
|
+
r"^([A-Z]{2,})([A-Z][a-z])",
|
|
36
|
+
lambda m: m.group(1)[0] + m.group(1)[1:].lower() + m.group(2),
|
|
37
|
+
str(s),
|
|
38
|
+
) # RAGAgent -> RagAgent
|
|
39
|
+
s = re.sub(r"(?<!^)(?=[A-Z])", "_", s) # CamelCase -> snake_case
|
|
40
|
+
s = s.replace("-", "_").replace(" ", "_")
|
|
41
|
+
return s.lower()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class BaseAgent(ABC):
|
|
8
45
|
# llm: BaseChatModel
|
|
9
46
|
# llm_with_tools: Runnable[LanguageModelInput, BaseMessage]
|
|
10
47
|
|
|
@@ -12,6 +49,10 @@ class BaseAgent:
|
|
|
12
49
|
self,
|
|
13
50
|
llm: str | BaseChatModel,
|
|
14
51
|
checkpointer: BaseCheckpointSaver = None,
|
|
52
|
+
enable_metrics: bool = False, # default to enabling metrics
|
|
53
|
+
metrics_dir: str = ".ursa_metrics", # dir to save metrics, with a default
|
|
54
|
+
autosave_metrics: bool = True,
|
|
55
|
+
thread_id: Optional[str] = None,
|
|
15
56
|
**kwargs,
|
|
16
57
|
):
|
|
17
58
|
match llm:
|
|
@@ -32,10 +73,336 @@ class BaseAgent:
|
|
|
32
73
|
"llm argument must be a string with the provider and model, or a BaseChatModel instance."
|
|
33
74
|
)
|
|
34
75
|
|
|
76
|
+
self.thread_id = thread_id or uuid4().hex
|
|
35
77
|
self.checkpointer = checkpointer
|
|
36
|
-
self.
|
|
78
|
+
self.telemetry = Telemetry(
|
|
79
|
+
enable=enable_metrics,
|
|
80
|
+
output_dir=metrics_dir,
|
|
81
|
+
save_json_default=autosave_metrics,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def name(self) -> str:
|
|
86
|
+
"""Agent name."""
|
|
87
|
+
return self.__class__.__name__
|
|
88
|
+
|
|
89
|
+
def add_node(
|
|
90
|
+
self,
|
|
91
|
+
graph: StateGraph,
|
|
92
|
+
f: Callable[..., Mapping[str, Any]],
|
|
93
|
+
node_name: Optional[str] = None,
|
|
94
|
+
agent_name: Optional[str] = None,
|
|
95
|
+
) -> StateGraph:
|
|
96
|
+
"""Add node to graph.
|
|
97
|
+
|
|
98
|
+
This is used to track token usage and is simply the following.
|
|
99
|
+
|
|
100
|
+
```python
|
|
101
|
+
_node_name = node_name or f.__name__
|
|
102
|
+
return graph.add_node(
|
|
103
|
+
_node_name, self._wrap_node(f, _node_name, self.name)
|
|
104
|
+
)
|
|
105
|
+
```
|
|
106
|
+
"""
|
|
107
|
+
_node_name = node_name or f.__name__
|
|
108
|
+
_agent_name = agent_name or _to_snake(self.name)
|
|
109
|
+
wrapped_node = self._wrap_node(f, _node_name, _agent_name)
|
|
110
|
+
return graph.add_node(_node_name, wrapped_node)
|
|
37
111
|
|
|
38
112
|
def write_state(self, filename, state):
|
|
39
113
|
json_state = dumps(state, ensure_ascii=False)
|
|
40
114
|
with open(filename, "w") as f:
|
|
41
115
|
f.write(json_state)
|
|
116
|
+
|
|
117
|
+
# BaseAgent
|
|
118
|
+
def build_config(self, **overrides) -> dict:
|
|
119
|
+
"""
|
|
120
|
+
Build a config dict that includes telemetry callbacks and the thread_id.
|
|
121
|
+
You can pass overrides like recursion_limit=..., configurable={...}, etc.
|
|
122
|
+
"""
|
|
123
|
+
base = {
|
|
124
|
+
"configurable": {"thread_id": self.thread_id},
|
|
125
|
+
"metadata": {
|
|
126
|
+
"thread_id": self.thread_id,
|
|
127
|
+
"telemetry_run_id": self.telemetry.context.get("run_id"),
|
|
128
|
+
},
|
|
129
|
+
# "configurable": {
|
|
130
|
+
# "thread_id": getattr(self, "thread_id", "default")
|
|
131
|
+
# },
|
|
132
|
+
# "metadata": {
|
|
133
|
+
# "thread_id": getattr(self, "thread_id", "default"),
|
|
134
|
+
# "telemetry_run_id": self.telemetry.context.get("run_id"),
|
|
135
|
+
# },
|
|
136
|
+
"tags": [self.name],
|
|
137
|
+
"callbacks": self.telemetry.callbacks,
|
|
138
|
+
}
|
|
139
|
+
# include model name when we can
|
|
140
|
+
model_name = getattr(self, "llm_model", None) or getattr(
|
|
141
|
+
getattr(self, "llm", None), "model", None
|
|
142
|
+
)
|
|
143
|
+
if model_name:
|
|
144
|
+
base["metadata"]["model"] = model_name
|
|
145
|
+
|
|
146
|
+
if "configurable" in overrides and isinstance(
|
|
147
|
+
overrides["configurable"], dict
|
|
148
|
+
):
|
|
149
|
+
base["configurable"].update(overrides.pop("configurable"))
|
|
150
|
+
if "metadata" in overrides and isinstance(overrides["metadata"], dict):
|
|
151
|
+
base["metadata"].update(overrides.pop("metadata"))
|
|
152
|
+
# merge tags if caller provides them
|
|
153
|
+
if "tags" in overrides and isinstance(overrides["tags"], list):
|
|
154
|
+
base["tags"] = base["tags"] + [
|
|
155
|
+
t for t in overrides.pop("tags") if t not in base["tags"]
|
|
156
|
+
]
|
|
157
|
+
base.update(overrides)
|
|
158
|
+
return base
|
|
159
|
+
|
|
160
|
+
# agents will invoke like this:
|
|
161
|
+
# planning_output = planner.invoke(
|
|
162
|
+
# {"messages": [HumanMessage(content=problem)]},
|
|
163
|
+
# config={
|
|
164
|
+
# "recursion_limit": 999_999,
|
|
165
|
+
# "configurable": {"thread_id": planner.thread_id},
|
|
166
|
+
# },
|
|
167
|
+
# )
|
|
168
|
+
# they can also, separately, override these defaults about metrics
|
|
169
|
+
# keys that are NOT inputs; they should not be folded into the inputs mapping
|
|
170
|
+
_TELEMETRY_KW = {
|
|
171
|
+
"raw_debug",
|
|
172
|
+
"save_json",
|
|
173
|
+
"metrics_path",
|
|
174
|
+
"save_raw_snapshot",
|
|
175
|
+
"save_raw_records",
|
|
176
|
+
}
|
|
177
|
+
_CONTROL_KW = {"config", "recursion_limit", "tags", "metadata", "callbacks"}
|
|
178
|
+
|
|
179
|
+
@final
|
|
180
|
+
def invoke(
|
|
181
|
+
self,
|
|
182
|
+
inputs: Optional[InputLike] = None, # sentinel
|
|
183
|
+
/,
|
|
184
|
+
*,
|
|
185
|
+
raw_debug: bool = False,
|
|
186
|
+
save_json: Optional[bool] = None,
|
|
187
|
+
metrics_path: Optional[str] = None,
|
|
188
|
+
save_raw_snapshot: Optional[bool] = None,
|
|
189
|
+
save_raw_records: Optional[bool] = None,
|
|
190
|
+
config: Optional[dict] = None,
|
|
191
|
+
**kwargs: Any, # may contain inputs (keyword-inputs) and/or control kw
|
|
192
|
+
) -> Any:
|
|
193
|
+
depth = _INVOKE_DEPTH.get()
|
|
194
|
+
_INVOKE_DEPTH.set(depth + 1)
|
|
195
|
+
try:
|
|
196
|
+
if depth == 0:
|
|
197
|
+
self.telemetry.begin_run(
|
|
198
|
+
agent=self.name, thread_id=self.thread_id
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# If no positional inputs were provided, split kwargs into inputs vs control
|
|
202
|
+
if inputs is None:
|
|
203
|
+
kw_inputs: dict[str, Any] = {}
|
|
204
|
+
control_kwargs: dict[str, Any] = {}
|
|
205
|
+
for k, v in kwargs.items():
|
|
206
|
+
if k in self._TELEMETRY_KW or k in self._CONTROL_KW:
|
|
207
|
+
control_kwargs[k] = v
|
|
208
|
+
else:
|
|
209
|
+
kw_inputs[k] = v
|
|
210
|
+
inputs = kw_inputs
|
|
211
|
+
kwargs = control_kwargs # only control kwargs remain
|
|
212
|
+
|
|
213
|
+
# If both positional inputs and extra unknown kwargs-as-inputs are given, forbid merging
|
|
214
|
+
else:
|
|
215
|
+
# keep only control kwargs; anything else would be ambiguous
|
|
216
|
+
for k in kwargs.keys():
|
|
217
|
+
if not (k in self._TELEMETRY_KW or k in self._CONTROL_KW):
|
|
218
|
+
raise TypeError(
|
|
219
|
+
f"Unexpected keyword argument '{k}'. "
|
|
220
|
+
"Pass inputs as a single mapping or omit the positional "
|
|
221
|
+
"inputs and pass them as keyword arguments."
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# subclasses may translate keys
|
|
225
|
+
normalized = self._normalize_inputs(inputs)
|
|
226
|
+
|
|
227
|
+
# forward config + any control kwargs (e.g., recursion_limit) to the agent
|
|
228
|
+
return self._invoke(normalized, config=config, **kwargs)
|
|
229
|
+
|
|
230
|
+
finally:
|
|
231
|
+
new_depth = _INVOKE_DEPTH.get() - 1
|
|
232
|
+
_INVOKE_DEPTH.set(new_depth)
|
|
233
|
+
if new_depth == 0:
|
|
234
|
+
self.telemetry.render(
|
|
235
|
+
raw=raw_debug,
|
|
236
|
+
save_json=save_json,
|
|
237
|
+
filepath=metrics_path,
|
|
238
|
+
save_raw_snapshot=save_raw_snapshot,
|
|
239
|
+
save_raw_records=save_raw_records,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
def _normalize_inputs(self, inputs: InputLike) -> Mapping[str, Any]:
|
|
243
|
+
if isinstance(inputs, str):
|
|
244
|
+
# Adjust to your message type
|
|
245
|
+
from langchain_core.messages import HumanMessage
|
|
246
|
+
|
|
247
|
+
return {"messages": [HumanMessage(content=inputs)]}
|
|
248
|
+
if isinstance(inputs, Mapping):
|
|
249
|
+
return inputs
|
|
250
|
+
raise TypeError(f"Unsupported input type: {type(inputs)}")
|
|
251
|
+
|
|
252
|
+
@abstractmethod
|
|
253
|
+
def _invoke(self, inputs: Mapping[str, Any], **config: Any) -> Any:
|
|
254
|
+
"""Subclasses implement the actual work against normalized inputs."""
|
|
255
|
+
...
|
|
256
|
+
|
|
257
|
+
def __call__(self, inputs: InputLike, /, **kwargs: Any) -> Any:
|
|
258
|
+
return self.invoke(inputs, **kwargs)
|
|
259
|
+
|
|
260
|
+
# Runtime enforcement: forbid subclasses from overriding invoke
|
|
261
|
+
def __init_subclass__(cls, **kwargs):
|
|
262
|
+
super().__init_subclass__(**kwargs)
|
|
263
|
+
if "invoke" in cls.__dict__:
|
|
264
|
+
raise TypeError(
|
|
265
|
+
f"{cls.__name__} must not override BaseAgent.invoke(); implement _invoke() only."
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
def stream(
|
|
269
|
+
self,
|
|
270
|
+
inputs: InputLike,
|
|
271
|
+
config: Any | None = None, # allow positional/keyword like LangGraph
|
|
272
|
+
/,
|
|
273
|
+
*,
|
|
274
|
+
raw_debug: bool = False,
|
|
275
|
+
save_json: bool | None = None,
|
|
276
|
+
metrics_path: str | None = None,
|
|
277
|
+
save_raw_snapshot: bool | None = None,
|
|
278
|
+
save_raw_records: bool | None = None,
|
|
279
|
+
**kwargs: Any,
|
|
280
|
+
) -> Iterator[Any]:
|
|
281
|
+
"""Public streaming entry point. Telemetry-wrapped."""
|
|
282
|
+
depth = _INVOKE_DEPTH.get()
|
|
283
|
+
_INVOKE_DEPTH.set(depth + 1)
|
|
284
|
+
try:
|
|
285
|
+
if depth == 0:
|
|
286
|
+
self.telemetry.begin_run(
|
|
287
|
+
agent=self.name, thread_id=self.thread_id
|
|
288
|
+
)
|
|
289
|
+
normalized = self._normalize_inputs(inputs)
|
|
290
|
+
yield from self._stream(normalized, config=config, **kwargs)
|
|
291
|
+
finally:
|
|
292
|
+
new_depth = _INVOKE_DEPTH.get() - 1
|
|
293
|
+
_INVOKE_DEPTH.set(new_depth)
|
|
294
|
+
if new_depth == 0:
|
|
295
|
+
self.telemetry.render(
|
|
296
|
+
raw=raw_debug,
|
|
297
|
+
save_json=save_json,
|
|
298
|
+
filepath=metrics_path,
|
|
299
|
+
save_raw_snapshot=save_raw_snapshot,
|
|
300
|
+
save_raw_records=save_raw_records,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
def _stream(
|
|
304
|
+
self,
|
|
305
|
+
inputs: Mapping[str, Any],
|
|
306
|
+
*,
|
|
307
|
+
config: Any | None = None,
|
|
308
|
+
**kwargs: Any,
|
|
309
|
+
) -> Iterator[Any]:
|
|
310
|
+
raise NotImplementedError(
|
|
311
|
+
f"{self.name} does not support streaming. "
|
|
312
|
+
"Override _stream(...) in your agent to enable it."
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
# def run(
|
|
316
|
+
# self,
|
|
317
|
+
# *args,
|
|
318
|
+
# raw_debug: bool = False,
|
|
319
|
+
# save_json: bool | None = None,
|
|
320
|
+
# metrics_path: str | None = None,
|
|
321
|
+
# save_raw_snapshot: bool | None = None,
|
|
322
|
+
# save_raw_records: bool | None = None,
|
|
323
|
+
# **kwargs
|
|
324
|
+
# ):
|
|
325
|
+
# try:
|
|
326
|
+
# self.telemetry.begin_run(agent=self.name, thread_id=self.thread_id)
|
|
327
|
+
# result = self._run_impl(*args, **kwargs)
|
|
328
|
+
# return result
|
|
329
|
+
# finally:
|
|
330
|
+
# print(self.telemetry.render(
|
|
331
|
+
# raw=raw_debug,
|
|
332
|
+
# save_json=save_json,
|
|
333
|
+
# filepath=metrics_path,
|
|
334
|
+
# save_raw_snapshot=save_raw_snapshot,
|
|
335
|
+
# save_raw_records=save_raw_records,
|
|
336
|
+
# ))
|
|
337
|
+
|
|
338
|
+
# @abstractmethod
|
|
339
|
+
# def _run_impl(self, *args, **kwargs):
|
|
340
|
+
# raise NotImplementedError("Agents must implement _run_impl")
|
|
341
|
+
|
|
342
|
+
def _default_node_tags(
|
|
343
|
+
self, name: str, extra: Sequence[str] | None = None
|
|
344
|
+
) -> list[str]:
|
|
345
|
+
tags = [self.name, "graph", name]
|
|
346
|
+
if extra:
|
|
347
|
+
tags.extend(extra)
|
|
348
|
+
return tags
|
|
349
|
+
|
|
350
|
+
def _as_runnable(self, fn: Any):
|
|
351
|
+
# If it's already runnable (has .with_config/.invoke), return it; else wrap
|
|
352
|
+
return (
|
|
353
|
+
fn
|
|
354
|
+
if hasattr(fn, "with_config") and hasattr(fn, "invoke")
|
|
355
|
+
else RunnableLambda(fn)
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
def _node_cfg(self, name: str, *extra_tags: str) -> dict:
|
|
359
|
+
"""Build a consistent config for a node/runnable so we can reapply it after .map(), subgraph compile, etc."""
|
|
360
|
+
ns = extra_tags[0] if extra_tags else _to_snake(self.name)
|
|
361
|
+
tags = [self.name, "graph", name, *extra_tags]
|
|
362
|
+
return dict(
|
|
363
|
+
run_name="node", # keep "node:" prefixing in the timer; don't fight Rich labels here
|
|
364
|
+
tags=tags,
|
|
365
|
+
metadata={
|
|
366
|
+
"langgraph_node": name,
|
|
367
|
+
"ursa_ns": ns,
|
|
368
|
+
"ursa_agent": self.name,
|
|
369
|
+
},
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
def ns(self, runnable_or_fn, name: str, *extra_tags: str):
|
|
373
|
+
"""Return a runnable with our node config applied. Safe to call on callables or runnables.
|
|
374
|
+
IMPORTANT: call this AGAIN after .map() / subgraph .compile() (they often drop config)."""
|
|
375
|
+
r = self._as_runnable(runnable_or_fn)
|
|
376
|
+
return r.with_config(**self._node_cfg(name, *extra_tags))
|
|
377
|
+
|
|
378
|
+
def _wrap_node(self, fn_or_runnable, name: str, *extra_tags: str):
|
|
379
|
+
return self.ns(fn_or_runnable, name, *extra_tags)
|
|
380
|
+
|
|
381
|
+
def _wrap_cond(self, fn: Any, name: str, *extra_tags: str):
|
|
382
|
+
ns = extra_tags[0] if extra_tags else _to_snake(self.name)
|
|
383
|
+
return RunnableLambda(fn).with_config(
|
|
384
|
+
run_name="node",
|
|
385
|
+
tags=[
|
|
386
|
+
self.name,
|
|
387
|
+
"graph",
|
|
388
|
+
f"route:{name}",
|
|
389
|
+
*extra_tags,
|
|
390
|
+
],
|
|
391
|
+
metadata={
|
|
392
|
+
"langgraph_node": f"route:{name}",
|
|
393
|
+
"ursa_ns": ns,
|
|
394
|
+
"ursa_agent": self.name,
|
|
395
|
+
},
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
def _named(self, runnable: Any, name: str, *extra_tags: str):
|
|
399
|
+
ns = extra_tags[0] if extra_tags else _to_snake(self.name)
|
|
400
|
+
return runnable.with_config(
|
|
401
|
+
run_name=name,
|
|
402
|
+
tags=[self.name, "graph", name, *extra_tags],
|
|
403
|
+
metadata={
|
|
404
|
+
"langgraph_node": name,
|
|
405
|
+
"ursa_ns": ns,
|
|
406
|
+
"ursa_agent": self.name,
|
|
407
|
+
},
|
|
408
|
+
)
|
ursa/agents/code_review_agent.py
CHANGED
|
@@ -270,7 +270,9 @@ def read_file(filename: str, state: Annotated[dict, InjectedState]):
|
|
|
270
270
|
|
|
271
271
|
|
|
272
272
|
@tool
|
|
273
|
-
def write_file(
|
|
273
|
+
def write_file(
|
|
274
|
+
code: str, filename: str, state: Annotated[dict, InjectedState]
|
|
275
|
+
) -> str:
|
|
274
276
|
"""
|
|
275
277
|
Writes text to a file in the given workspace as requested.
|
|
276
278
|
|