ursa-ai 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ursa/__init__.py +3 -0
- ursa/agents/__init__.py +32 -0
- ursa/agents/acquisition_agents.py +812 -0
- ursa/agents/arxiv_agent.py +429 -0
- ursa/agents/base.py +728 -0
- ursa/agents/chat_agent.py +60 -0
- ursa/agents/code_review_agent.py +341 -0
- ursa/agents/execution_agent.py +915 -0
- ursa/agents/hypothesizer_agent.py +614 -0
- ursa/agents/lammps_agent.py +465 -0
- ursa/agents/mp_agent.py +204 -0
- ursa/agents/optimization_agent.py +410 -0
- ursa/agents/planning_agent.py +219 -0
- ursa/agents/rag_agent.py +304 -0
- ursa/agents/recall_agent.py +54 -0
- ursa/agents/websearch_agent.py +196 -0
- ursa/cli/__init__.py +363 -0
- ursa/cli/hitl.py +516 -0
- ursa/cli/hitl_api.py +75 -0
- ursa/observability/metrics_charts.py +1279 -0
- ursa/observability/metrics_io.py +11 -0
- ursa/observability/metrics_session.py +750 -0
- ursa/observability/pricing.json +97 -0
- ursa/observability/pricing.py +321 -0
- ursa/observability/timing.py +1466 -0
- ursa/prompt_library/__init__.py +0 -0
- ursa/prompt_library/code_review_prompts.py +51 -0
- ursa/prompt_library/execution_prompts.py +50 -0
- ursa/prompt_library/hypothesizer_prompts.py +17 -0
- ursa/prompt_library/literature_prompts.py +11 -0
- ursa/prompt_library/optimization_prompts.py +131 -0
- ursa/prompt_library/planning_prompts.py +79 -0
- ursa/prompt_library/websearch_prompts.py +131 -0
- ursa/tools/__init__.py +0 -0
- ursa/tools/feasibility_checker.py +114 -0
- ursa/tools/feasibility_tools.py +1075 -0
- ursa/tools/run_command.py +27 -0
- ursa/tools/write_code.py +42 -0
- ursa/util/__init__.py +0 -0
- ursa/util/diff_renderer.py +128 -0
- ursa/util/helperFunctions.py +142 -0
- ursa/util/logo_generator.py +625 -0
- ursa/util/memory_logger.py +183 -0
- ursa/util/optimization_schema.py +78 -0
- ursa/util/parse.py +405 -0
- ursa_ai-0.9.1.dist-info/METADATA +304 -0
- ursa_ai-0.9.1.dist-info/RECORD +51 -0
- ursa_ai-0.9.1.dist-info/WHEEL +5 -0
- ursa_ai-0.9.1.dist-info/entry_points.txt +2 -0
- ursa_ai-0.9.1.dist-info/licenses/LICENSE +8 -0
- ursa_ai-0.9.1.dist-info/top_level.txt +1 -0
ursa/cli/hitl.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import platform
|
|
3
|
+
import sqlite3
|
|
4
|
+
from cmd import Cmd
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from functools import cached_property
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Callable, Optional
|
|
9
|
+
|
|
10
|
+
import httpx
|
|
11
|
+
from langchain.chat_models import init_chat_model
|
|
12
|
+
from langchain.embeddings import init_embeddings
|
|
13
|
+
from langchain_core.messages import HumanMessage
|
|
14
|
+
from langgraph.checkpoint.sqlite import SqliteSaver
|
|
15
|
+
from rich.console import Console
|
|
16
|
+
from rich.markdown import Markdown
|
|
17
|
+
from rich.theme import Theme
|
|
18
|
+
from typer import Typer
|
|
19
|
+
|
|
20
|
+
from ursa.agents import (
|
|
21
|
+
ArxivAgent,
|
|
22
|
+
ChatAgent,
|
|
23
|
+
ExecutionAgent,
|
|
24
|
+
HypothesizerAgent,
|
|
25
|
+
PlanningAgent,
|
|
26
|
+
RecallAgent,
|
|
27
|
+
WebSearchAgent,
|
|
28
|
+
)
|
|
29
|
+
from ursa.util.memory_logger import AgentMemory
|
|
30
|
+
|
|
31
|
+
app = Typer()
|
|
32
|
+
|
|
33
|
+
ursa_banner = r"""
|
|
34
|
+
__ ________________ _
|
|
35
|
+
/ / / / ___/ ___/ __ `/
|
|
36
|
+
/ /_/ / / (__ ) /_/ /
|
|
37
|
+
\__,_/_/ /____/\__,_/
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def make_console():
|
|
42
|
+
return Console(
|
|
43
|
+
theme=Theme({
|
|
44
|
+
"success": "green",
|
|
45
|
+
"error": "bold red",
|
|
46
|
+
"dim": "grey50",
|
|
47
|
+
"warn": "yellow",
|
|
48
|
+
"emph": "bold cyan",
|
|
49
|
+
})
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class HITL:
|
|
55
|
+
workspace: Path
|
|
56
|
+
llm_model_name: str
|
|
57
|
+
llm_base_url: Optional[str]
|
|
58
|
+
llm_api_key: Optional[str]
|
|
59
|
+
max_completion_tokens: int
|
|
60
|
+
emb_model_name: Optional[str]
|
|
61
|
+
emb_base_url: Optional[str]
|
|
62
|
+
emb_api_key: Optional[str]
|
|
63
|
+
share_key: bool
|
|
64
|
+
thread_id: str
|
|
65
|
+
safe_codes: list[str]
|
|
66
|
+
arxiv_summarize: bool
|
|
67
|
+
arxiv_process_images: bool
|
|
68
|
+
arxiv_max_results: int
|
|
69
|
+
arxiv_database_path: Optional[Path]
|
|
70
|
+
arxiv_summaries_path: Optional[Path]
|
|
71
|
+
arxiv_vectorstore_path: Optional[Path]
|
|
72
|
+
arxiv_download_papers: bool
|
|
73
|
+
ssl_verify_llm: bool
|
|
74
|
+
ssl_verify_emb: bool
|
|
75
|
+
|
|
76
|
+
def _make_kwargs(self, **kwargs):
|
|
77
|
+
# NOTE: This is required instead of setting to None because of
|
|
78
|
+
# strangeness in init_chat_model.
|
|
79
|
+
return {
|
|
80
|
+
key: value for key, value in kwargs.items() if value is not None
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
def get_path(self, path: Optional[Path], default_subdir: str) -> str:
|
|
84
|
+
if path is None:
|
|
85
|
+
return str(self.workspace / default_subdir)
|
|
86
|
+
return str(path)
|
|
87
|
+
|
|
88
|
+
def __post_init__(self):
|
|
89
|
+
self.workspace.mkdir(parents=True, exist_ok=True)
|
|
90
|
+
|
|
91
|
+
# Specify API key only once and share for llm and embedder.
|
|
92
|
+
if self.share_key:
|
|
93
|
+
match self.llm_api_key, self.emb_api_key:
|
|
94
|
+
case None, None:
|
|
95
|
+
raise ValueError(
|
|
96
|
+
"When sharing API keys, both llm_api_key and emb_api_key "
|
|
97
|
+
"cannot be empty!"
|
|
98
|
+
)
|
|
99
|
+
case str(), str():
|
|
100
|
+
raise ValueError(
|
|
101
|
+
"When sharing API keys, do not supply both llm_api_key and "
|
|
102
|
+
"emb_api_key."
|
|
103
|
+
)
|
|
104
|
+
case None, str():
|
|
105
|
+
self.llm_api_key = self.emb_api_key
|
|
106
|
+
case str(), None:
|
|
107
|
+
self.emb_api_key = self.llm_api_key
|
|
108
|
+
|
|
109
|
+
self.model = init_chat_model(
|
|
110
|
+
model=self.llm_model_name,
|
|
111
|
+
max_completion_tokens=self.max_completion_tokens,
|
|
112
|
+
**self._make_kwargs(
|
|
113
|
+
http_client=None
|
|
114
|
+
if self.ssl_verify_llm
|
|
115
|
+
else httpx.Client(verify=False),
|
|
116
|
+
base_url=self.llm_base_url,
|
|
117
|
+
api_key=self.llm_api_key,
|
|
118
|
+
),
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
self.embedding = (
|
|
122
|
+
init_embeddings(
|
|
123
|
+
model=self.emb_model_name,
|
|
124
|
+
**self._make_kwargs(
|
|
125
|
+
http_client=None
|
|
126
|
+
if self.ssl_verify_emb
|
|
127
|
+
else httpx.Client(verify=False),
|
|
128
|
+
base_url=self.emb_base_url,
|
|
129
|
+
api_key=self.emb_api_key,
|
|
130
|
+
),
|
|
131
|
+
)
|
|
132
|
+
if self.emb_model_name
|
|
133
|
+
else None
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
self.memory = (
|
|
137
|
+
AgentMemory(
|
|
138
|
+
embedding_model=self.embedding,
|
|
139
|
+
path=str(self.workspace / "memory"),
|
|
140
|
+
)
|
|
141
|
+
if self.embedding
|
|
142
|
+
else None
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
self.last_agent_result = ""
|
|
146
|
+
self.arxiv_state = []
|
|
147
|
+
self.chatter_state = {"messages": []}
|
|
148
|
+
self.executor_state = {}
|
|
149
|
+
self.hypothesizer_state = {}
|
|
150
|
+
self.planner_state = {}
|
|
151
|
+
self.websearcher_state = []
|
|
152
|
+
|
|
153
|
+
def update_last_agent_result(self, result: str):
|
|
154
|
+
self.last_agent_result = result
|
|
155
|
+
|
|
156
|
+
@cached_property
|
|
157
|
+
def arxiv_agent(self) -> ArxivAgent:
|
|
158
|
+
return ArxivAgent(
|
|
159
|
+
llm=self.model,
|
|
160
|
+
summarize=self.arxiv_summarize,
|
|
161
|
+
process_images=self.arxiv_process_images,
|
|
162
|
+
max_results=self.arxiv_max_results,
|
|
163
|
+
# rag_embedding=self.embedding,
|
|
164
|
+
database_path=self.get_path(
|
|
165
|
+
self.arxiv_database_path, "arxiv_downloaded_papers"
|
|
166
|
+
),
|
|
167
|
+
summaries_path=self.get_path(
|
|
168
|
+
self.arxiv_summaries_path, "arxiv_generated_summaries"
|
|
169
|
+
),
|
|
170
|
+
vectorstore_path=self.get_path(
|
|
171
|
+
self.arxiv_vectorstore_path, "arxiv_vectorstores"
|
|
172
|
+
),
|
|
173
|
+
download=self.arxiv_download_papers,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
@cached_property
|
|
177
|
+
def chatter(self) -> ChatAgent:
|
|
178
|
+
edb_path = self.workspace / "checkpoint.db"
|
|
179
|
+
edb_path.parent.mkdir(parents=True, exist_ok=True)
|
|
180
|
+
econn = sqlite3.connect(str(edb_path), check_same_thread=False)
|
|
181
|
+
self.chatter_checkpointer = SqliteSaver(econn)
|
|
182
|
+
return ChatAgent(
|
|
183
|
+
llm=self.model,
|
|
184
|
+
checkpointer=self.chatter_checkpointer,
|
|
185
|
+
thread_id=self.thread_id + "_chatter",
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
@cached_property
|
|
189
|
+
def executor(self) -> ExecutionAgent:
|
|
190
|
+
edb_path = self.workspace / "checkpoint.db"
|
|
191
|
+
edb_path.parent.mkdir(parents=True, exist_ok=True)
|
|
192
|
+
econn = sqlite3.connect(str(edb_path), check_same_thread=False)
|
|
193
|
+
self.executor_checkpointer = SqliteSaver(econn)
|
|
194
|
+
return ExecutionAgent(
|
|
195
|
+
llm=self.model,
|
|
196
|
+
checkpointer=self.executor_checkpointer,
|
|
197
|
+
agent_memory=self.memory,
|
|
198
|
+
thread_id=self.thread_id + "_executor",
|
|
199
|
+
safe_codes=self.safe_codes,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
@cached_property
|
|
203
|
+
def hypothesizer(self) -> HypothesizerAgent:
|
|
204
|
+
edb_path = self.workspace / "checkpoint.db"
|
|
205
|
+
edb_path.parent.mkdir(parents=True, exist_ok=True)
|
|
206
|
+
econn = sqlite3.connect(str(edb_path), check_same_thread=False)
|
|
207
|
+
self.executor_checkpointer = SqliteSaver(econn)
|
|
208
|
+
return HypothesizerAgent(
|
|
209
|
+
llm=self.model,
|
|
210
|
+
checkpointer=self.executor_checkpointer,
|
|
211
|
+
thread_id=self.thread_id + "_hypothesizer",
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
@cached_property
|
|
215
|
+
def planner(self) -> PlanningAgent:
|
|
216
|
+
pdb_path = Path(self.workspace) / "checkpoint.db"
|
|
217
|
+
pdb_path.parent.mkdir(parents=True, exist_ok=True)
|
|
218
|
+
pconn = sqlite3.connect(str(pdb_path), check_same_thread=False)
|
|
219
|
+
self.planner_checkpointer = SqliteSaver(pconn)
|
|
220
|
+
return PlanningAgent(
|
|
221
|
+
llm=self.model,
|
|
222
|
+
checkpointer=self.planner_checkpointer,
|
|
223
|
+
thread_id=self.thread_id + "_planner",
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
@cached_property
|
|
227
|
+
def websearcher(self) -> WebSearchAgent:
|
|
228
|
+
rdb_path = Path(self.workspace) / "checkpoint.db"
|
|
229
|
+
rdb_path.parent.mkdir(parents=True, exist_ok=True)
|
|
230
|
+
rconn = sqlite3.connect(str(rdb_path), check_same_thread=False)
|
|
231
|
+
self.websearcher_checkpointer = SqliteSaver(rconn)
|
|
232
|
+
|
|
233
|
+
return WebSearchAgent(
|
|
234
|
+
llm=self.model,
|
|
235
|
+
max_results=10,
|
|
236
|
+
database_path="web_db",
|
|
237
|
+
summaries_path="web_summaries",
|
|
238
|
+
checkpointer=self.websearcher_checkpointer,
|
|
239
|
+
thread_id=self.thread_id + "_websearch",
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
@cached_property
|
|
243
|
+
def rememberer(self) -> RecallAgent:
|
|
244
|
+
return (
|
|
245
|
+
RecallAgent(llm=self.model, memory=self.memory)
|
|
246
|
+
if self.memory
|
|
247
|
+
else None
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
def run_arxiv(self, prompt: str) -> str:
|
|
251
|
+
llm_search_query = self.model.invoke(
|
|
252
|
+
f"The user stated {prompt}. Generate between 1 and 8 words for a search query to address the users need. Return only the words to search."
|
|
253
|
+
).content
|
|
254
|
+
print("Searching ArXiv for ", llm_search_query)
|
|
255
|
+
|
|
256
|
+
if isinstance(llm_search_query, str):
|
|
257
|
+
arxiv_result = self.arxiv_agent.invoke(
|
|
258
|
+
arxiv_search_query=llm_search_query,
|
|
259
|
+
context=prompt,
|
|
260
|
+
)
|
|
261
|
+
self.arxiv_state.append(arxiv_result)
|
|
262
|
+
self.update_last_agent_result(arxiv_result)
|
|
263
|
+
return f"[ArXiv Agent Output]:\n {self.last_agent_result}"
|
|
264
|
+
else:
|
|
265
|
+
raise RuntimeError("Unexpected error while running ArxivAgent!")
|
|
266
|
+
|
|
267
|
+
def run_executor(self, prompt: str) -> str:
|
|
268
|
+
if "messages" in self.executor_state and isinstance(
|
|
269
|
+
self.executor_state["messages"], list
|
|
270
|
+
):
|
|
271
|
+
self.executor_state["messages"].append(
|
|
272
|
+
HumanMessage(
|
|
273
|
+
f"The last agent output was: {self.last_agent_result}\n"
|
|
274
|
+
f"The user stated: {prompt}"
|
|
275
|
+
)
|
|
276
|
+
)
|
|
277
|
+
executor_state = self.executor.invoke(
|
|
278
|
+
self.executor_state,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
if isinstance(
|
|
282
|
+
content := executor_state["messages"][-1].content, str
|
|
283
|
+
):
|
|
284
|
+
self.update_last_agent_result(content)
|
|
285
|
+
else:
|
|
286
|
+
raise TypeError(
|
|
287
|
+
f"content is supposed to have type str! Instead, it is {content}"
|
|
288
|
+
)
|
|
289
|
+
else:
|
|
290
|
+
self.executor_state = dict(
|
|
291
|
+
workspace=self.workspace,
|
|
292
|
+
messages=[
|
|
293
|
+
HumanMessage(
|
|
294
|
+
f"The last agent output was: {self.last_agent_result}\n The user stated: {prompt}"
|
|
295
|
+
)
|
|
296
|
+
],
|
|
297
|
+
)
|
|
298
|
+
self.executor_state = self.executor.invoke(
|
|
299
|
+
self.executor_state,
|
|
300
|
+
)
|
|
301
|
+
self.update_last_agent_result(
|
|
302
|
+
self.executor_state["messages"][-1].content
|
|
303
|
+
)
|
|
304
|
+
return f"[Executor Agent Output]:\n {self.last_agent_result}"
|
|
305
|
+
|
|
306
|
+
def run_rememberer(self, prompt: str) -> str:
|
|
307
|
+
memory_output = self.rememberer.invoke(prompt) if self.memory else None
|
|
308
|
+
return f"[Rememberer Output]:\n {memory_output}"
|
|
309
|
+
|
|
310
|
+
def run_chatter(self, prompt: str) -> str:
|
|
311
|
+
self.chatter_state["messages"].append(
|
|
312
|
+
HumanMessage(
|
|
313
|
+
content=f"The last agent output was: {self.last_agent_result}\n The user stated: {prompt}"
|
|
314
|
+
)
|
|
315
|
+
)
|
|
316
|
+
self.chatter_state = self.chatter.invoke(
|
|
317
|
+
self.chatter_state,
|
|
318
|
+
)
|
|
319
|
+
chat_output = self.chatter_state["messages"][-1]
|
|
320
|
+
|
|
321
|
+
if not isinstance(chat_output.content, str):
|
|
322
|
+
raise TypeError(
|
|
323
|
+
f"chat_output is not a str! Instead, it is: {type(chat_output.content)}."
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
self.update_last_agent_result(chat_output.content)
|
|
327
|
+
# return f"[{self.model.model_name}]: {self.last_agent_result}"
|
|
328
|
+
return f"{self.last_agent_result}"
|
|
329
|
+
|
|
330
|
+
def run_hypothesizer(self, prompt: str) -> str:
|
|
331
|
+
question = f"The last agent output was: {self.last_agent_result}\n\nThe user stated: {prompt}"
|
|
332
|
+
|
|
333
|
+
self.hypothesizer_state = self.hypothesizer.invoke(
|
|
334
|
+
prompt=question,
|
|
335
|
+
max_iterations=2,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
solution = self.hypothesizer_state.get(
|
|
339
|
+
"solution", "Hypothesizer failed to return a solution"
|
|
340
|
+
)
|
|
341
|
+
self.update_last_agent_result(solution)
|
|
342
|
+
return f"[Hypothesizer Agent Output]:\n {self.last_agent_result}"
|
|
343
|
+
|
|
344
|
+
def run_planner(self, prompt: str) -> str:
|
|
345
|
+
self.planner_state.setdefault("messages", [])
|
|
346
|
+
self.planner_state["messages"].append(
|
|
347
|
+
HumanMessage(
|
|
348
|
+
f"The last agent output was: {self.last_agent_result}\n"
|
|
349
|
+
f"The user stated: {prompt}"
|
|
350
|
+
)
|
|
351
|
+
)
|
|
352
|
+
self.planner_state = self.planner.invoke(
|
|
353
|
+
self.planner_state,
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
plan = "\n\n\n".join(
|
|
357
|
+
f"## {step['id']} -- {step['name']}\n\n"
|
|
358
|
+
+ "\n\n".join(
|
|
359
|
+
f"* {key}\n * {value}" for key, value in step.items()
|
|
360
|
+
)
|
|
361
|
+
for step in self.planner_state["plan_steps"]
|
|
362
|
+
)
|
|
363
|
+
self.update_last_agent_result(plan)
|
|
364
|
+
return f"[Planner Agent Output]:\n {self.last_agent_result}"
|
|
365
|
+
|
|
366
|
+
def run_websearcher(self, prompt: str) -> str:
|
|
367
|
+
llm_search_query = self.model.invoke(
|
|
368
|
+
f"The user stated {prompt}. Generate between 1 and 8 words for a search query to address the users need. Return only the words to search."
|
|
369
|
+
).content
|
|
370
|
+
print("Searching Web for ", llm_search_query)
|
|
371
|
+
if isinstance(llm_search_query, str):
|
|
372
|
+
web_result = self.websearcher.invoke(
|
|
373
|
+
query=llm_search_query,
|
|
374
|
+
context=prompt,
|
|
375
|
+
)
|
|
376
|
+
self.websearcher_state.append(web_result)
|
|
377
|
+
self.update_last_agent_result(web_result)
|
|
378
|
+
return f"[WebSearch Agent Output]:\n {self.last_agent_result}"
|
|
379
|
+
else:
|
|
380
|
+
raise RuntimeError("Unexpected error while running WebSearchAgent!")
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
class UrsaRepl(Cmd):
|
|
384
|
+
console = make_console()
|
|
385
|
+
exit_message: str = "[dim]Exiting ursa..."
|
|
386
|
+
_help_message: str = "[dim]For help, type: ? or help. Exit with Ctrl+d."
|
|
387
|
+
prompt: str = "ursa> "
|
|
388
|
+
|
|
389
|
+
def get_input(self, msg: str, end: str = "", **kwargs):
|
|
390
|
+
# NOTE: Printing in rich with Prompt somehow gets removed when
|
|
391
|
+
# backspacing. This is a workaround that captures the print output and
|
|
392
|
+
# converts it to the proper string format for your terminal.
|
|
393
|
+
with self.console.capture() as capture:
|
|
394
|
+
self.console.print(msg, end=end, **kwargs)
|
|
395
|
+
return input(capture.get())
|
|
396
|
+
|
|
397
|
+
def __init__(self, hitl: HITL, **kwargs):
|
|
398
|
+
self.hitl = hitl
|
|
399
|
+
super().__init__(**kwargs)
|
|
400
|
+
|
|
401
|
+
def show(self, msg: str, markdown: bool = True, **kwargs):
|
|
402
|
+
self.console.print(Markdown(msg) if markdown else msg, **kwargs)
|
|
403
|
+
|
|
404
|
+
def default(self, prompt: str):
|
|
405
|
+
with self.console.status("Generating response"):
|
|
406
|
+
response = self.hitl.run_chatter(prompt)
|
|
407
|
+
self.show(response)
|
|
408
|
+
|
|
409
|
+
def postcmd(self, stop: bool, line: str):
|
|
410
|
+
print()
|
|
411
|
+
return stop
|
|
412
|
+
|
|
413
|
+
def do_exit(self, _: str):
|
|
414
|
+
"""Exit shell."""
|
|
415
|
+
self.show(self.exit_message, markdown=False)
|
|
416
|
+
return True
|
|
417
|
+
|
|
418
|
+
def do_EOF(self, _: str):
|
|
419
|
+
"""Exit on Ctrl+D."""
|
|
420
|
+
self.show(self.exit_message, markdown=False)
|
|
421
|
+
return True
|
|
422
|
+
|
|
423
|
+
def do_clear(self, _: str):
|
|
424
|
+
"""Clear the screen. Same as pressing Ctrl+L."""
|
|
425
|
+
os.system("cls" if platform.system() == "Windows" else "clear")
|
|
426
|
+
|
|
427
|
+
def emptyline(self):
|
|
428
|
+
"""Do nothing when an empty line is entered"""
|
|
429
|
+
pass
|
|
430
|
+
|
|
431
|
+
def run_agent(self, agent: str, run: Callable[[str], str]):
|
|
432
|
+
# prompt = self.get_input(f"Enter your prompt for [emph]{agent}[/]: ")
|
|
433
|
+
prompt = input(f"Enter your prompt for {agent}: ")
|
|
434
|
+
with self.console.status("Generating response"):
|
|
435
|
+
return run(prompt)
|
|
436
|
+
|
|
437
|
+
def do_arxiv(self, _: str):
|
|
438
|
+
"""Run ArxivAgent"""
|
|
439
|
+
self.show(self.run_agent("Arxiv Agent", self.hitl.run_arxiv))
|
|
440
|
+
|
|
441
|
+
def do_plan(self, _: str):
|
|
442
|
+
"""Run PlanningAgent"""
|
|
443
|
+
self.show(self.run_agent("Planning Agent", self.hitl.run_planner))
|
|
444
|
+
|
|
445
|
+
def do_execute(self, _: str):
|
|
446
|
+
"""Run ExecutionAgent"""
|
|
447
|
+
self.show(self.run_agent("Execution Agent", self.hitl.run_executor))
|
|
448
|
+
|
|
449
|
+
def do_web(self, _: str):
|
|
450
|
+
"""Run WebSearchAgent"""
|
|
451
|
+
self.show(self.run_agent("Websearch Agent", self.hitl.run_websearcher))
|
|
452
|
+
|
|
453
|
+
def do_recall(self, _: str):
|
|
454
|
+
"""Run RecallAgent"""
|
|
455
|
+
self.show(self.run_agent("Recall Agent", self.hitl.run_rememberer))
|
|
456
|
+
|
|
457
|
+
def do_hypothesize(self, _: str):
|
|
458
|
+
"""Run HypothesizerAgent"""
|
|
459
|
+
self.show(
|
|
460
|
+
self.run_agent("Hypothesizer Agent", self.hitl.run_hypothesizer)
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
def run(self):
|
|
464
|
+
"""Handle Ctrl+C to avoid quitting the program"""
|
|
465
|
+
# Print intro only once.
|
|
466
|
+
self.show(f"[magenta]{ursa_banner}", markdown=False)
|
|
467
|
+
self.show(self._help_message, markdown=False)
|
|
468
|
+
|
|
469
|
+
while True:
|
|
470
|
+
try:
|
|
471
|
+
self.cmdloop()
|
|
472
|
+
break # Allows breaking out of loop if EOF is triggered.
|
|
473
|
+
except KeyboardInterrupt:
|
|
474
|
+
print(
|
|
475
|
+
"\n(Interrupted) Press Ctrl+D to exit or continue typing."
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
def do_models(self, _: str):
|
|
479
|
+
"""List models and base urls"""
|
|
480
|
+
llm_provider, llm_name = get_provider_and_model(
|
|
481
|
+
self.hitl.llm_model_name
|
|
482
|
+
)
|
|
483
|
+
self.show(
|
|
484
|
+
f"[dim]*[/] LLM: [emph]{llm_name} "
|
|
485
|
+
f"[dim]{self.hitl.llm_base_url or llm_provider}",
|
|
486
|
+
markdown=False,
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
emb_provider, emb_name = get_provider_and_model(
|
|
490
|
+
self.hitl.emb_model_name
|
|
491
|
+
)
|
|
492
|
+
self.show(
|
|
493
|
+
f"[dim]*[/] Embedding Model: [emph]{emb_name} "
|
|
494
|
+
f"[dim]{self.hitl.emb_base_url or emb_provider}",
|
|
495
|
+
markdown=False,
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
def get_provider_and_model(model_str: Optional[str]):
|
|
500
|
+
if model_str is None:
|
|
501
|
+
return "none", "none"
|
|
502
|
+
|
|
503
|
+
if ":" in model_str:
|
|
504
|
+
provider, model = model_str.split(":", 1)
|
|
505
|
+
else:
|
|
506
|
+
provider = "openai"
|
|
507
|
+
model = model_str
|
|
508
|
+
|
|
509
|
+
return provider, model
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
# TODO:
|
|
513
|
+
# * Add option to swap models in REPL
|
|
514
|
+
# * Add option for seed setting via flags
|
|
515
|
+
# * Name change: --llm-model-name -> llm
|
|
516
|
+
# * Name change: --emb-model-name -> emb
|
ursa/cli/hitl_api.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from typing import Annotated, Literal
|
|
2
|
+
|
|
3
|
+
from fastapi import Depends, FastAPI, HTTPException, Request
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
|
|
6
|
+
from ursa import __version__
|
|
7
|
+
|
|
8
|
+
mcp_app = FastAPI(
|
|
9
|
+
title="URSA Server",
|
|
10
|
+
description="Micro-service for hosting URSA to integrate as an MCP tool.",
|
|
11
|
+
version=__version__,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class QueryRequest(BaseModel):
|
|
16
|
+
agent: Literal[
|
|
17
|
+
"arxiv", "plan", "execute", "web", "recall", "chat", "hypothesize"
|
|
18
|
+
]
|
|
19
|
+
query: Annotated[
|
|
20
|
+
str,
|
|
21
|
+
Field(examples=["Write the first 1000 prime numbers to a text file."]),
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class QueryResponse(BaseModel):
|
|
26
|
+
response: str
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def get_hitl(req: Request):
|
|
30
|
+
# Single, pre-created instance set by the CLI (see below)
|
|
31
|
+
return req.app.state.hitl
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@mcp_app.post("/run", response_model=QueryResponse)
|
|
35
|
+
def run_ursa(req: QueryRequest, hitl=Depends(get_hitl)):
|
|
36
|
+
"""
|
|
37
|
+
Queries the URSA Agentic AI Workflow to request that one of the URSA Agents
|
|
38
|
+
address a query. The available agents are:
|
|
39
|
+
ArxivAgent: Search for papers on ArXiv and summarize them in the context of the query
|
|
40
|
+
PlanningAgent: Builds a structured step-by-step plan to attempt to solve the users problem
|
|
41
|
+
ExecuteAgent: Runs a ReAct agent to write/edit code and run commands to attempt to solve the user query
|
|
42
|
+
WebSearchAgent: Search the web for information on a query and summarize the results given that context
|
|
43
|
+
RecallAgent: Perform RAG on previous ExecutionAgent steps saved in a memory database
|
|
44
|
+
HypothesizerAgent: Perform detailed reasoning to propose an approach to solve a given user problem/query
|
|
45
|
+
ChatAgent: Query the hosted LLM as a straightforward chatbot.
|
|
46
|
+
|
|
47
|
+
Arguments:
|
|
48
|
+
agent: str, one of: arxiv, plan, execute, web, recall, hypothesize, or chat. Directs the query to the corresponding agent
|
|
49
|
+
query: str, query to send to the requested agent for processing
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
response: str, summary of the agent output. The Execute agent may also write code and generate artifacts in the ursa_mcp workspace
|
|
53
|
+
"""
|
|
54
|
+
try:
|
|
55
|
+
match req.agent:
|
|
56
|
+
case "arxiv":
|
|
57
|
+
response = hitl.run_arxiv(req.query)
|
|
58
|
+
case "plan":
|
|
59
|
+
response = hitl.run_planner(req.query)
|
|
60
|
+
case "execute":
|
|
61
|
+
response = hitl.run_executor(req.query)
|
|
62
|
+
case "web":
|
|
63
|
+
response = hitl.run_websearcher(req.query)
|
|
64
|
+
case "recall":
|
|
65
|
+
response = hitl.run_rememberer(req.query)
|
|
66
|
+
case "hypothesize":
|
|
67
|
+
response = hitl.run_hypothesizer(req.query)
|
|
68
|
+
case "chat":
|
|
69
|
+
response = hitl.run_chatter(req.query)
|
|
70
|
+
case _:
|
|
71
|
+
response = f"Agent '{req.agent}' not found."
|
|
72
|
+
return QueryResponse(response=response)
|
|
73
|
+
except Exception as exc:
|
|
74
|
+
# Surface a readable error message for upstream agents
|
|
75
|
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|