ursa-ai 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. ursa/__init__.py +3 -0
  2. ursa/agents/__init__.py +32 -0
  3. ursa/agents/acquisition_agents.py +812 -0
  4. ursa/agents/arxiv_agent.py +429 -0
  5. ursa/agents/base.py +728 -0
  6. ursa/agents/chat_agent.py +60 -0
  7. ursa/agents/code_review_agent.py +341 -0
  8. ursa/agents/execution_agent.py +915 -0
  9. ursa/agents/hypothesizer_agent.py +614 -0
  10. ursa/agents/lammps_agent.py +465 -0
  11. ursa/agents/mp_agent.py +204 -0
  12. ursa/agents/optimization_agent.py +410 -0
  13. ursa/agents/planning_agent.py +219 -0
  14. ursa/agents/rag_agent.py +304 -0
  15. ursa/agents/recall_agent.py +54 -0
  16. ursa/agents/websearch_agent.py +196 -0
  17. ursa/cli/__init__.py +363 -0
  18. ursa/cli/hitl.py +516 -0
  19. ursa/cli/hitl_api.py +75 -0
  20. ursa/observability/metrics_charts.py +1279 -0
  21. ursa/observability/metrics_io.py +11 -0
  22. ursa/observability/metrics_session.py +750 -0
  23. ursa/observability/pricing.json +97 -0
  24. ursa/observability/pricing.py +321 -0
  25. ursa/observability/timing.py +1466 -0
  26. ursa/prompt_library/__init__.py +0 -0
  27. ursa/prompt_library/code_review_prompts.py +51 -0
  28. ursa/prompt_library/execution_prompts.py +50 -0
  29. ursa/prompt_library/hypothesizer_prompts.py +17 -0
  30. ursa/prompt_library/literature_prompts.py +11 -0
  31. ursa/prompt_library/optimization_prompts.py +131 -0
  32. ursa/prompt_library/planning_prompts.py +79 -0
  33. ursa/prompt_library/websearch_prompts.py +131 -0
  34. ursa/tools/__init__.py +0 -0
  35. ursa/tools/feasibility_checker.py +114 -0
  36. ursa/tools/feasibility_tools.py +1075 -0
  37. ursa/tools/run_command.py +27 -0
  38. ursa/tools/write_code.py +42 -0
  39. ursa/util/__init__.py +0 -0
  40. ursa/util/diff_renderer.py +128 -0
  41. ursa/util/helperFunctions.py +142 -0
  42. ursa/util/logo_generator.py +625 -0
  43. ursa/util/memory_logger.py +183 -0
  44. ursa/util/optimization_schema.py +78 -0
  45. ursa/util/parse.py +405 -0
  46. ursa_ai-0.9.1.dist-info/METADATA +304 -0
  47. ursa_ai-0.9.1.dist-info/RECORD +51 -0
  48. ursa_ai-0.9.1.dist-info/WHEEL +5 -0
  49. ursa_ai-0.9.1.dist-info/entry_points.txt +2 -0
  50. ursa_ai-0.9.1.dist-info/licenses/LICENSE +8 -0
  51. ursa_ai-0.9.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,429 @@
1
+ import base64
2
+ import os
3
+ import re
4
+ from concurrent.futures import ThreadPoolExecutor, as_completed
5
+ from io import BytesIO
6
+ from typing import Any, Mapping, TypedDict
7
+ from urllib.parse import quote
8
+
9
+ import feedparser
10
+ import pymupdf
11
+ import requests
12
+ from langchain.chat_models import BaseChatModel
13
+ from langchain_community.document_loaders import PyPDFLoader
14
+ from langchain_core.output_parsers import StrOutputParser
15
+ from langchain_core.prompts import ChatPromptTemplate
16
+ from langgraph.graph import StateGraph
17
+ from PIL import Image
18
+ from tqdm import tqdm
19
+
20
+ from ursa.agents.base import BaseAgent
21
+ from ursa.agents.rag_agent import RAGAgent
22
+
23
+ try:
24
+ from openai import OpenAI
25
+ except Exception:
26
+ pass
27
+
28
+
29
+ class PaperMetadata(TypedDict):
30
+ arxiv_id: str
31
+ full_text: str
32
+
33
+
34
+ class PaperState(TypedDict, total=False):
35
+ query: str
36
+ context: str
37
+ papers: list[PaperMetadata]
38
+ summaries: list[str]
39
+ final_summary: str
40
+
41
+
42
+ def describe_image(image: Image.Image) -> str:
43
+ if "OpenAI" not in globals():
44
+ print(
45
+ "Vision transformer for summarizing images currently only implemented for OpenAI API."
46
+ )
47
+ return ""
48
+ client = OpenAI()
49
+
50
+ buffered = BytesIO()
51
+ image.save(buffered, format="PNG")
52
+ img_base64 = base64.b64encode(buffered.getvalue()).decode()
53
+
54
+ response = client.chat.completions.create(
55
+ model="gpt-4-vision-preview",
56
+ messages=[
57
+ {
58
+ "role": "system",
59
+ "content": "You are a scientific assistant who explains plots and scientific diagrams.",
60
+ },
61
+ {
62
+ "role": "user",
63
+ "content": [
64
+ {
65
+ "type": "text",
66
+ "text": "Describe this scientific image or plot in detail.",
67
+ },
68
+ {
69
+ "type": "image_url",
70
+ "image_url": {
71
+ "url": f"data:image/png;base64,{img_base64}"
72
+ },
73
+ },
74
+ ],
75
+ },
76
+ ],
77
+ max_tokens=500,
78
+ )
79
+ return response.choices[0].message.content.strip()
80
+
81
+
82
+ def extract_and_describe_images(
83
+ pdf_path: str, max_images: int = 5
84
+ ) -> list[str]:
85
+ doc = pymupdf.open(pdf_path)
86
+ descriptions = []
87
+ image_count = 0
88
+
89
+ for page_index in range(len(doc)):
90
+ if image_count >= max_images:
91
+ break
92
+ page = doc[page_index]
93
+ images = page.get_images(full=True)
94
+
95
+ for img_index, img in enumerate(images):
96
+ if image_count >= max_images:
97
+ break
98
+ xref = img[0]
99
+ base_image = doc.extract_image(xref)
100
+ image_bytes = base_image["image"]
101
+ image = Image.open(BytesIO(image_bytes))
102
+
103
+ try:
104
+ desc = describe_image(image)
105
+ descriptions.append(
106
+ f"Page {page_index + 1}, Image {img_index + 1}: {desc}"
107
+ )
108
+ except Exception as e:
109
+ descriptions.append(
110
+ f"Page {page_index + 1}, Image {img_index + 1}: [Error: {e}]"
111
+ )
112
+ image_count += 1
113
+
114
+ return descriptions
115
+
116
+
117
+ def remove_surrogates(text: str) -> str:
118
+ return re.sub(r"[\ud800-\udfff]", "", text)
119
+
120
+
121
+ class ArxivAgentLegacy(BaseAgent):
122
+ def __init__(
123
+ self,
124
+ llm: BaseChatModel,
125
+ summarize: bool = True,
126
+ process_images=True,
127
+ max_results: int = 3,
128
+ download_papers: bool = True,
129
+ rag_embedding=None,
130
+ database_path="arxiv_papers",
131
+ summaries_path="arxiv_generated_summaries",
132
+ vectorstore_path="arxiv_vectorstores",
133
+ **kwargs,
134
+ ):
135
+ super().__init__(llm, **kwargs)
136
+ self.summarize = summarize
137
+ self.process_images = process_images
138
+ self.max_results = max_results
139
+ self.database_path = database_path
140
+ self.summaries_path = summaries_path
141
+ self.vectorstore_path = vectorstore_path
142
+ self.download_papers = download_papers
143
+ self.rag_embedding = rag_embedding
144
+
145
+ self._action = self._build_graph()
146
+
147
+ os.makedirs(self.database_path, exist_ok=True)
148
+
149
+ os.makedirs(self.summaries_path, exist_ok=True)
150
+
151
+ def _fetch_papers(self, query: str) -> list[PaperMetadata]:
152
+ if self.download_papers:
153
+ encoded_query = quote(query)
154
+ url = f"http://export.arxiv.org/api/query?search_query=all:{encoded_query}&start=0&max_results={self.max_results}"
155
+ # print(f"URL is {url}") # if verbose
156
+ entries = []
157
+ try:
158
+ response = requests.get(url, timeout=10)
159
+ response.raise_for_status()
160
+
161
+ feed = feedparser.parse(response.content)
162
+ # print(f"parsed response status is {feed.status}") # if verbose
163
+ entries = feed.entries
164
+ if feed.bozo:
165
+ raise Exception("Feed from arXiv looks like garbage =(")
166
+ except requests.exceptions.Timeout:
167
+ print("Request timed out while fetching papers.")
168
+ except requests.exceptions.RequestException as e:
169
+ print(f"Request error encountered while fetching papers: {e}")
170
+ except ValueError as ve:
171
+ print(f"Value error occurred while fetching papers: {ve}")
172
+ except Exception as e:
173
+ print(
174
+ f"An unexpected error occurred while fetching papers: {e}"
175
+ )
176
+
177
+ for i, entry in enumerate(entries):
178
+ full_id = entry.id.split("/abs/")[-1]
179
+ arxiv_id = full_id.split("/")[-1]
180
+ title = entry.title.strip()
181
+ # authors = ", ".join(author.name for author in entry.authors)
182
+ pdf_url = f"https://arxiv.org/pdf/{full_id}.pdf"
183
+ pdf_filename = os.path.join(
184
+ self.database_path, f"{arxiv_id}.pdf"
185
+ )
186
+
187
+ if os.path.exists(pdf_filename):
188
+ print(
189
+ f"Paper # {i + 1}, Title: {title}, already exists in database"
190
+ )
191
+ else:
192
+ print(f"Downloading paper # {i + 1}, Title: {title}")
193
+ response = requests.get(pdf_url)
194
+ with open(pdf_filename, "wb") as f:
195
+ f.write(response.content)
196
+
197
+ papers = []
198
+
199
+ pdf_files = [
200
+ f
201
+ for f in os.listdir(self.database_path)
202
+ if f.lower().endswith(".pdf")
203
+ ]
204
+
205
+ for i, pdf_filename in enumerate(pdf_files):
206
+ full_text = ""
207
+ arxiv_id = pdf_filename.split(".pdf")[0]
208
+ vec_save_loc = self.vectorstore_path + "/" + arxiv_id
209
+
210
+ if self.summarize and not os.path.exists(vec_save_loc):
211
+ try:
212
+ loader = PyPDFLoader(
213
+ os.path.join(self.database_path, pdf_filename)
214
+ )
215
+ pages = loader.load()
216
+ full_text = "\n".join([p.page_content for p in pages])
217
+
218
+ if self.process_images:
219
+ image_descriptions = extract_and_describe_images(
220
+ os.path.join(self.database_path, pdf_filename)
221
+ )
222
+ full_text += (
223
+ "\n\n[Image Interpretations]\n"
224
+ + "\n".join(image_descriptions)
225
+ )
226
+
227
+ except Exception as e:
228
+ full_text = f"Error loading paper: {e}"
229
+
230
+ papers.append({
231
+ "arxiv_id": arxiv_id,
232
+ "full_text": full_text,
233
+ })
234
+
235
+ return papers
236
+
237
+ def _fetch_node(self, state: PaperState) -> PaperState:
238
+ papers = self._fetch_papers(state["query"])
239
+ return {**state, "papers": papers}
240
+
241
+ def _summarize_node(self, state: PaperState) -> PaperState:
242
+ prompt = ChatPromptTemplate.from_template("""
243
+ You are a scientific assistant responsible for summarizing extracts from research papers, in the context of the following task: {context}
244
+
245
+ Summarize the retrieved scientific content below.
246
+
247
+ {retrieved_content}
248
+ """)
249
+
250
+ chain = prompt | self.llm | StrOutputParser()
251
+
252
+ summaries = [None] * len(state["papers"])
253
+ relevancy_scores = [0.0] * len(state["papers"])
254
+
255
+ def process_paper(i, paper):
256
+ arxiv_id = paper["arxiv_id"]
257
+ summary_filename = os.path.join(
258
+ self.summaries_path, f"{arxiv_id}_summary.txt"
259
+ )
260
+
261
+ try:
262
+ cleaned_text = remove_surrogates(paper["full_text"])
263
+ summary = chain.invoke(
264
+ {
265
+ "retrieved_content": cleaned_text,
266
+ "context": state["context"],
267
+ },
268
+ config=self.build_config(tags=["arxiv", "summarize_each"]),
269
+ )
270
+
271
+ except Exception as e:
272
+ summary = f"Error summarizing paper: {e}"
273
+ relevancy_scores[i] = 0.0
274
+
275
+ with open(summary_filename, "w") as f:
276
+ f.write(summary)
277
+
278
+ return i, summary
279
+
280
+ if "papers" not in state or len(state["papers"]) == 0:
281
+ print(
282
+ "No papers retrieved - bad query or network connection to ArXiv?"
283
+ )
284
+ return {**state, "summaries": None}
285
+
286
+ with ThreadPoolExecutor(
287
+ max_workers=min(32, len(state["papers"]))
288
+ ) as executor:
289
+ futures = [
290
+ executor.submit(process_paper, i, paper)
291
+ for i, paper in enumerate(state["papers"])
292
+ ]
293
+
294
+ for future in tqdm(
295
+ as_completed(futures),
296
+ total=len(futures),
297
+ desc="Summarizing Papers",
298
+ ):
299
+ i, result = future.result()
300
+ summaries[i] = result
301
+
302
+ return {**state, "summaries": summaries}
303
+
304
+ def _rag_node(self, state: PaperState) -> PaperState:
305
+ new_state = state.copy()
306
+ rag_agent = RAGAgent(
307
+ llm=self.llm,
308
+ embedding=self.rag_embedding,
309
+ database_path=self.database_path,
310
+ )
311
+ new_state["final_summary"] = rag_agent.invoke(context=state["context"])[
312
+ "summary"
313
+ ]
314
+ return new_state
315
+
316
+ def _aggregate_node(self, state: PaperState) -> PaperState:
317
+ summaries = state["summaries"]
318
+ papers = state["papers"]
319
+ formatted = []
320
+
321
+ if (
322
+ "summaries" not in state
323
+ or state["summaries"] is None
324
+ or "papers" not in state
325
+ or state["papers"] is None
326
+ ):
327
+ return {**state, "final_summary": None}
328
+
329
+ for i, (paper, summary) in enumerate(zip(papers, summaries)):
330
+ citation = f"[{i + 1}] Arxiv ID: {paper['arxiv_id']}"
331
+ formatted.append(f"{citation}\n\nSummary:\n{summary}")
332
+
333
+ combined = "\n\n" + ("\n\n" + "-" * 40 + "\n\n").join(formatted)
334
+
335
+ with open(self.summaries_path + "/summaries_combined.txt", "w") as f:
336
+ f.write(combined)
337
+
338
+ prompt = ChatPromptTemplate.from_template("""
339
+ You are a scientific assistant helping extract insights from summaries of research papers.
340
+
341
+ Here are the summaries of a large number of extracts from scientific papers:
342
+
343
+ {Summaries}
344
+
345
+ Your task is to read all the summaries and provide a response to this task: {context}
346
+ """)
347
+
348
+ chain = prompt | self.llm | StrOutputParser()
349
+
350
+ final_summary = chain.invoke(
351
+ {
352
+ "Summaries": combined,
353
+ "context": state["context"],
354
+ },
355
+ config=self.build_config(tags=["arxiv", "aggregate"]),
356
+ )
357
+
358
+ with open(self.summaries_path + "/final_summary.txt", "w") as f:
359
+ f.write(final_summary)
360
+
361
+ return {**state, "final_summary": final_summary}
362
+
363
+ def _build_graph(self):
364
+ graph = StateGraph(PaperState)
365
+
366
+ self.add_node(graph, self._fetch_node)
367
+ if self.summarize:
368
+ if self.rag_embedding:
369
+ self.add_node(graph, self._rag_node)
370
+ graph.set_entry_point("_fetch_node")
371
+ graph.add_edge("_fetch_node", "_rag_node")
372
+ graph.set_finish_point("_rag_node")
373
+ else:
374
+ self.add_node(graph, self._summarize_node)
375
+ self.add_node(graph, self._aggregate_node)
376
+
377
+ graph.set_entry_point("_fetch_node")
378
+ graph.add_edge("_fetch_node", "_summarize_node")
379
+ graph.add_edge("_summarize_node", "_aggregate_node")
380
+ graph.set_finish_point("_aggregate_node")
381
+ else:
382
+ graph.set_entry_point("_fetch_node")
383
+ graph.set_finish_point("_fetch_node")
384
+
385
+ return graph.compile(checkpointer=self.checkpointer)
386
+
387
+ def _invoke(
388
+ self,
389
+ inputs: Mapping[str, Any],
390
+ *,
391
+ summarize: bool | None = None,
392
+ recursion_limit: int = 1000,
393
+ **_,
394
+ ) -> str:
395
+ config = self.build_config(
396
+ recursion_limit=recursion_limit, tags=["graph"]
397
+ )
398
+
399
+ # this seems dumb, but it's b/c sometimes we had referred to the value as
400
+ # 'query' other times as 'arxiv_search_query' so trying to keep it compatible
401
+ # aliasing: accept arxiv_search_query -> query
402
+ if "query" not in inputs:
403
+ if "arxiv_search_query" in inputs:
404
+ # make a shallow copy and rename the key
405
+ inputs = dict(inputs)
406
+ inputs["query"] = inputs.pop("arxiv_search_query")
407
+ else:
408
+ raise KeyError(
409
+ "Missing 'query' in inputs (alias 'arxiv_search_query' also accepted)."
410
+ )
411
+
412
+ result = self._action.invoke(inputs, config)
413
+
414
+ use_summary = self.summarize if summarize is None else summarize
415
+
416
+ return (
417
+ result.get("final_summary", "No summary generated.")
418
+ if use_summary
419
+ else "\n\nFinished Fetching papers!"
420
+ )
421
+
422
+
423
+ # NOTE: Run test in `tests/agents/test_arxiv_agent/test_arxiv_agent.py` via:
424
+ #
425
+ # pytest -s tests/agents/test_arxiv_agent
426
+ #
427
+ # OR
428
+ #
429
+ # uv run pytest -s tests/agents/test_arxiv_agent