ursa-ai 0.4.2__py3-none-any.whl → 0.6.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ursa-ai might be problematic. Click here for more details.

ursa/cli/hitl.py ADDED
@@ -0,0 +1,426 @@
1
+ import os
2
+ import platform
3
+ import sqlite3
4
+ from cmd import Cmd
5
+ from dataclasses import dataclass
6
+ from functools import cached_property
7
+ from pathlib import Path
8
+ from typing import Callable, Optional
9
+
10
+ import httpx
11
+ from langchain_core.messages import HumanMessage
12
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
13
+ from langgraph.checkpoint.sqlite import SqliteSaver
14
+ from pydantic import SecretStr
15
+ from rich.console import Console
16
+ from rich.markdown import Markdown
17
+ from rich.theme import Theme
18
+ from typer import Typer
19
+
20
+ from ursa.agents import (
21
+ ArxivAgent,
22
+ ExecutionAgent,
23
+ PlanningAgent,
24
+ RecallAgent,
25
+ WebSearchAgent,
26
+ )
27
+ from ursa.util.memory_logger import AgentMemory
28
+
29
+ app = Typer()
30
+
31
+ ursa_banner = r"""
32
+ __ ________________ _
33
+ / / / / ___/ ___/ __ `/
34
+ / /_/ / / (__ ) /_/ /
35
+ \__,_/_/ /____/\__,_/
36
+ """
37
+
38
+
39
+ def make_console():
40
+ return Console(
41
+ theme=Theme({
42
+ "success": "green",
43
+ "error": "bold red",
44
+ "dim": "grey50",
45
+ "warn": "yellow",
46
+ "emph": "bold cyan",
47
+ })
48
+ )
49
+
50
+
51
+ def wrap_api_key(api_key: Optional[str]) -> Optional[SecretStr]:
52
+ return None if api_key is None else SecretStr(api_key)
53
+
54
+
55
+ @dataclass
56
+ class HITL:
57
+ workspace: Path
58
+ llm_model_name: str
59
+ llm_base_url: str
60
+ llm_api_key: Optional[str]
61
+ max_completion_tokens: int
62
+ emb_model_name: str
63
+ emb_base_url: str
64
+ emb_api_key: Optional[str]
65
+ share_key: bool
66
+ arxiv_summarize: bool
67
+ arxiv_process_images: bool
68
+ arxiv_max_results: int
69
+ arxiv_database_path: Optional[Path]
70
+ arxiv_summaries_path: Optional[Path]
71
+ arxiv_vectorstore_path: Optional[Path]
72
+ arxiv_download_papers: bool
73
+ ssl_verify: bool
74
+
75
+ def get_path(self, path: Optional[Path], default_subdir: str) -> str:
76
+ if path is None:
77
+ return str(self.workspace / default_subdir)
78
+ return str(path)
79
+
80
+ def __post_init__(self):
81
+ self.workspace.mkdir(parents=True, exist_ok=True)
82
+
83
+ # Specify API key only once and share for llm and embedder.
84
+ if self.share_key:
85
+ match self.llm_api_key, self.emb_api_key:
86
+ case None, None:
87
+ raise ValueError(
88
+ "When sharing API keys, both llm_api_key and emb_api_key "
89
+ "cannot be empty!"
90
+ )
91
+ case str(), str():
92
+ raise ValueError(
93
+ "When sharing API keys, do not supply both llm_api_key and "
94
+ "emb_api_key."
95
+ )
96
+ case None, str():
97
+ self.llm_api_key = self.emb_api_key
98
+ case str(), None:
99
+ self.emb_api_key = self.llm_api_key
100
+
101
+ llm_api_secret = wrap_api_key(self.llm_api_key)
102
+ emb_api_secret = wrap_api_key(self.emb_api_key)
103
+
104
+ self.model = ChatOpenAI(
105
+ model=self.llm_model_name,
106
+ max_completion_tokens=self.max_completion_tokens,
107
+ base_url=self.llm_base_url,
108
+ api_key=llm_api_secret,
109
+ http_client=None if self.ssl_verify else httpx.Client(verify=False),
110
+ )
111
+
112
+ self.embedding = OpenAIEmbeddings(
113
+ model=self.emb_model_name,
114
+ base_url=self.emb_base_url,
115
+ api_key=emb_api_secret,
116
+ http_client=None if self.ssl_verify else httpx.Client(verify=False),
117
+ )
118
+
119
+ self.memory = AgentMemory(
120
+ embedding_model=self.embedding, path=str(self.workspace / "memory")
121
+ )
122
+
123
+ self.last_agent_result = ""
124
+ self.arxiv_state = []
125
+ self.executor_state = {}
126
+ self.planner_state = {}
127
+ self.websearcher_state = {}
128
+
129
+ def update_last_agent_result(self, result: str):
130
+ self.last_agent_result = result
131
+
132
+ @cached_property
133
+ def arxiv_agent(self) -> ArxivAgent:
134
+ return ArxivAgent(
135
+ llm=self.model,
136
+ summarize=self.arxiv_summarize,
137
+ process_images=self.arxiv_process_images,
138
+ max_results=self.arxiv_max_results,
139
+ # rag_embedding=self.embedding,
140
+ database_path=self.get_path(
141
+ self.arxiv_database_path, "arxiv_downloaded_papers"
142
+ ),
143
+ summaries_path=self.get_path(
144
+ self.arxiv_summaries_path, "arxiv_generated_summaries"
145
+ ),
146
+ vectorstore_path=self.get_path(
147
+ self.arxiv_vectorstore_path, "arxiv_vectorstores"
148
+ ),
149
+ download_papers=self.arxiv_download_papers,
150
+ )
151
+
152
+ @cached_property
153
+ def executor(self) -> ExecutionAgent:
154
+ edb_path = self.workspace / "executor_checkpoint.db"
155
+ edb_path.parent.mkdir(parents=True, exist_ok=True)
156
+ econn = sqlite3.connect(str(edb_path), check_same_thread=False)
157
+ self.executor_checkpointer = SqliteSaver(econn)
158
+ return ExecutionAgent(
159
+ llm=self.model,
160
+ checkpointer=self.executor_checkpointer,
161
+ agent_memory=self.memory,
162
+ )
163
+
164
+ @cached_property
165
+ def planner(self) -> PlanningAgent:
166
+ pdb_path = Path(self.workspace) / "planner_checkpoint.db"
167
+ pdb_path.parent.mkdir(parents=True, exist_ok=True)
168
+ pconn = sqlite3.connect(str(pdb_path), check_same_thread=False)
169
+ self.planner_checkpointer = SqliteSaver(pconn)
170
+ return PlanningAgent(
171
+ llm=self.model,
172
+ checkpointer=self.planner_checkpointer,
173
+ )
174
+
175
+ @cached_property
176
+ def websearcher(self) -> WebSearchAgent:
177
+ rdb_path = Path(self.workspace) / "websearcher_checkpoint.db"
178
+ rdb_path.parent.mkdir(parents=True, exist_ok=True)
179
+ rconn = sqlite3.connect(str(rdb_path), check_same_thread=False)
180
+ self.websearcher_checkpointer = SqliteSaver(rconn)
181
+
182
+ return WebSearchAgent(
183
+ llm=self.model,
184
+ checkpointer=self.websearcher_checkpointer,
185
+ )
186
+
187
+ @cached_property
188
+ def rememberer(self) -> RecallAgent:
189
+ return RecallAgent(llm=self.model, memory=self.memory)
190
+
191
+ def run_arvix(self, prompt: str) -> str:
192
+ llm_search_query = self.model.invoke(
193
+ f"The user stated {prompt}. Generate between 1 and 8 words for a search query to address the users need. Return only the words to search."
194
+ ).content
195
+ print("Searching ArXiv for ", llm_search_query)
196
+
197
+ if isinstance(llm_search_query, str):
198
+ arxiv_result = self.arxiv_agent.invoke(
199
+ arxiv_search_query=llm_search_query,
200
+ context=prompt,
201
+ )
202
+ self.arxiv_state.append(arxiv_result)
203
+ self.update_last_agent_result(arxiv_result)
204
+ return f"[ArXiv Agent Output]:\n {self.last_agent_result}"
205
+ else:
206
+ raise RuntimeError("Unexpected error while running ArxivAgent!")
207
+
208
+ def run_executor(self, prompt: str) -> str:
209
+ if "messages" in self.executor_state and isinstance(
210
+ self.executor_state["messages"], list
211
+ ):
212
+ self.executor_state["messages"].append(
213
+ HumanMessage(
214
+ f"The last agent output was: {self.last_agent_result}\n"
215
+ f"The user stated: {prompt}"
216
+ )
217
+ )
218
+ executor_state = self.executor.invoke(
219
+ self.executor_state,
220
+ )
221
+
222
+ if isinstance(
223
+ content := executor_state["messages"][-1].content, str
224
+ ):
225
+ self.update_last_agent_result(content)
226
+ else:
227
+ raise TypeError(
228
+ f"content is supposed to have type str! Instead, it is {content}"
229
+ )
230
+ else:
231
+ self.executor_state = dict(
232
+ workspace=self.workspace,
233
+ messages=[
234
+ HumanMessage(
235
+ f"The last agent output was: {self.last_agent_result}\n The user stated: {prompt}"
236
+ )
237
+ ],
238
+ )
239
+ self.executor_state = self.executor.invoke(
240
+ self.executor_state,
241
+ )
242
+ self.update_last_agent_result(
243
+ self.executor_state["messages"][-1].content
244
+ )
245
+ return f"[Executor Agent Output]:\n {self.last_agent_result}"
246
+
247
+ def run_rememberer(self, prompt: str) -> str:
248
+ memory_output = self.rememberer.remember(prompt)
249
+ return f"[Rememberer Output]:\n {memory_output}"
250
+
251
+ def run_chatter(self, prompt: str) -> str:
252
+ chat_output = self.model.invoke(
253
+ f"The last agent output was: {self.last_agent_result}\n The user stated: {prompt}"
254
+ )
255
+
256
+ if not isinstance(chat_output.content, str):
257
+ raise TypeError(
258
+ f"chat_output is not a str! Instead, it is: {chat_output}."
259
+ )
260
+
261
+ self.update_last_agent_result(chat_output.content)
262
+ # return f"[{self.model.model_name}]: {self.last_agent_result}"
263
+ return f"{self.last_agent_result}"
264
+
265
+ def run_planner(self, prompt: str) -> str:
266
+ self.planner_state.setdefault("messages", [])
267
+ self.planner_state["messages"].append(
268
+ HumanMessage(
269
+ f"The last agent output was: {self.last_agent_result}\n"
270
+ f"The user stated: {prompt}"
271
+ )
272
+ )
273
+ self.planner_state = self.planner.invoke(
274
+ self.planner_state,
275
+ )
276
+
277
+ plan = "\n\n\n".join(
278
+ f"## {step['id']} -- {step['name']}\n\n"
279
+ + "\n\n".join(
280
+ f"* {key}\n * {value}" for key, value in step.items()
281
+ )
282
+ for step in self.planner_state["plan_steps"]
283
+ )
284
+ self.update_last_agent_result(plan)
285
+ return f"[Planner Agent Output]:\n {self.last_agent_result}"
286
+
287
+ def run_websearcher(self, prompt: str) -> str:
288
+ if self.websearcher_state:
289
+ self.websearcher_state["messages"].append(
290
+ HumanMessage(
291
+ f"The last agent output was: {self.last_agent_result}\n"
292
+ f"The user stated: {prompt}"
293
+ )
294
+ )
295
+ self.websearcher_state = self.websearcher.invoke(
296
+ self.websearcher_state,
297
+ )
298
+ self.update_last_agent_result(
299
+ self.websearcher_state["messages"][-1].content
300
+ )
301
+ else:
302
+ self.websearcher_state = {
303
+ "messages": [
304
+ HumanMessage(
305
+ f"The last agent output was: {self.last_agent_result}\n"
306
+ f"The user stated: {prompt}"
307
+ )
308
+ ]
309
+ }
310
+ self.websearcher_state = self.websearcher.invoke(
311
+ self.websearcher_state,
312
+ )
313
+ self.update_last_agent_result(
314
+ self.websearcher_state["messages"][-1].content
315
+ )
316
+ return f"[Planner Agent Output]:\n {self.last_agent_result}"
317
+
318
+
319
+ class UrsaRepl(Cmd):
320
+ console = make_console()
321
+ exit_message: str = "[dim]Exiting ursa..."
322
+ _help_message: str = "[dim]For help, type: ? or help. Exit with Ctrl+d."
323
+ prompt: str = "ursa> "
324
+
325
+ def get_input(self, msg: str, end: str = "", **kwargs):
326
+ # NOTE: Printing in rich with Prompt somehow gets removed when
327
+ # backspacing. This is a workaround that captures the print output and
328
+ # converts it to the proper string format for your terminal.
329
+ with self.console.capture() as capture:
330
+ self.console.print(msg, end=end, **kwargs)
331
+ return input(capture.get())
332
+
333
+ def __init__(self, hitl: HITL, **kwargs):
334
+ self.hitl = hitl
335
+ super().__init__(**kwargs)
336
+
337
+ def show(self, msg: str, markdown: bool = True, **kwargs):
338
+ self.console.print(Markdown(msg) if markdown else msg, **kwargs)
339
+
340
+ def default(self, prompt: str):
341
+ with self.console.status("Generating response"):
342
+ response = self.hitl.run_chatter(prompt)
343
+ self.show(response)
344
+
345
+ def postcmd(self, stop: bool, line: str):
346
+ print()
347
+ return stop
348
+
349
+ def do_exit(self, _: str):
350
+ """Exit shell."""
351
+ self.show(self.exit_message, markdown=False)
352
+ return True
353
+
354
+ def do_EOF(self, _: str):
355
+ """Exit on Ctrl+D."""
356
+ self.show(self.exit_message, markdown=False)
357
+ return True
358
+
359
+ def do_clear(self, _: str):
360
+ """Clear the screen. Same as pressing Ctrl+L."""
361
+ os.system("cls" if platform.system() == "Windows" else "clear")
362
+
363
+ def emptyline(self):
364
+ """Do nothing when an empty line is entered"""
365
+ pass
366
+
367
+ def run_agent(self, agent: str, run: Callable[[str], str]):
368
+ # prompt = self.get_input(f"Enter your prompt for [emph]{agent}[/]: ")
369
+ prompt = input(f"Enter your prompt for {agent}: ")
370
+ with self.console.status("Generating response"):
371
+ return run(prompt)
372
+
373
+ def do_arxiv(self, _: str):
374
+ """Run ArxivAgent"""
375
+ self.show(self.run_agent("Arxiv Agent", self.hitl.run_arvix))
376
+
377
+ def do_plan(self, _: str):
378
+ """Run PlanningAgent"""
379
+ self.show(self.run_agent("Planning Agent", self.hitl.run_planner))
380
+
381
+ def do_execute(self, _: str):
382
+ """Run ExecutionAgent"""
383
+ self.show(self.run_agent("Execution Agent", self.hitl.run_executor))
384
+
385
+ def do_web(self, _: str):
386
+ """Run WebSearchAgent"""
387
+ self.show(self.run_agent("Websearch Agent", self.hitl.run_websearcher))
388
+
389
+ def do_recall(self, _: str):
390
+ """Run RecallAgent"""
391
+ self.show(self.run_agent("Recall Agent", self.hitl.run_rememberer))
392
+
393
+ def run(self):
394
+ """Handle Ctrl+C to avoid quitting the program"""
395
+ # Print intro only once.
396
+ self.show(f"[magenta]{ursa_banner}", markdown=False)
397
+ self.show(self._help_message, markdown=False)
398
+
399
+ while True:
400
+ try:
401
+ self.cmdloop()
402
+ break # Allows breaking out of loop if EOF is triggered.
403
+ except KeyboardInterrupt:
404
+ print(
405
+ "\n(Interrupted) Press Ctrl+D to exit or continue typing."
406
+ )
407
+
408
+ def do_models(self, _: str):
409
+ """List models and base urls"""
410
+ self.show(
411
+ f"[dim]*[/] LLM: [emph]{self.hitl.model.model_name} "
412
+ f"[dim]{self.hitl.llm_base_url}",
413
+ markdown=False,
414
+ )
415
+ self.show(
416
+ f"[dim]*[/] Embedding Model: [emph]{self.hitl.embedding.model} "
417
+ f"[dim]{self.hitl.emb_base_url}",
418
+ markdown=False,
419
+ )
420
+
421
+
422
+ # TODO:
423
+ # * Add option to swap models in REPL
424
+ # * Add option for seed setting via flags
425
+ # * Name change: --llm-model-name -> llm
426
+ # * Name change: --emb-model-name -> emb