praisonaiagents 0.0.22__tar.gz → 0.0.24__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/PKG-INFO +1 -1
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/agent/agent.py +22 -33
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/agents/agents.py +18 -4
- praisonaiagents-0.0.24/praisonaiagents/tools/__init__.py +167 -0
- praisonaiagents-0.0.24/praisonaiagents/tools/arxiv_tools.py +292 -0
- praisonaiagents-0.0.24/praisonaiagents/tools/calculator_tools.py +278 -0
- praisonaiagents-0.0.24/praisonaiagents/tools/csv_tools.py +266 -0
- praisonaiagents-0.0.24/praisonaiagents/tools/duckdb_tools.py +268 -0
- praisonaiagents-0.0.24/praisonaiagents/tools/duckduckgo_tools.py +52 -0
- praisonaiagents-0.0.24/praisonaiagents/tools/excel_tools.py +310 -0
- praisonaiagents-0.0.24/praisonaiagents/tools/file_tools.py +274 -0
- praisonaiagents-0.0.24/praisonaiagents/tools/json_tools.py +515 -0
- praisonaiagents-0.0.24/praisonaiagents/tools/newspaper_tools.py +354 -0
- praisonaiagents-0.0.24/praisonaiagents/tools/pandas_tools.py +326 -0
- praisonaiagents-0.0.24/praisonaiagents/tools/python_tools.py +423 -0
- praisonaiagents-0.0.24/praisonaiagents/tools/shell_tools.py +278 -0
- praisonaiagents-0.0.24/praisonaiagents/tools/spider_tools.py +431 -0
- praisonaiagents-0.0.24/praisonaiagents/tools/test.py +56 -0
- praisonaiagents-0.0.24/praisonaiagents/tools/tools.py +9 -0
- praisonaiagents-0.0.24/praisonaiagents/tools/wikipedia_tools.py +272 -0
- praisonaiagents-0.0.24/praisonaiagents/tools/xml_tools.py +498 -0
- praisonaiagents-0.0.24/praisonaiagents/tools/yaml_tools.py +417 -0
- praisonaiagents-0.0.24/praisonaiagents/tools/yfinance_tools.py +213 -0
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents.egg-info/PKG-INFO +1 -1
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents.egg-info/SOURCES.txt +19 -1
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/pyproject.toml +1 -1
- praisonaiagents-0.0.22/praisonaiagents/tools/__init__.py +0 -4
- praisonaiagents-0.0.22/praisonaiagents/tools/tools.py +0 -40
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/__init__.py +0 -0
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/agent/__init__.py +0 -0
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/agents/__init__.py +0 -0
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/build/lib/praisonaiagents/__init__.py +0 -0
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/build/lib/praisonaiagents/agent/__init__.py +0 -0
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/build/lib/praisonaiagents/agent/agent.py +0 -0
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/build/lib/praisonaiagents/agents/__init__.py +0 -0
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/build/lib/praisonaiagents/agents/agents.py +0 -0
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/build/lib/praisonaiagents/main.py +0 -0
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/build/lib/praisonaiagents/task/__init__.py +0 -0
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/build/lib/praisonaiagents/task/task.py +0 -0
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/main.py +0 -0
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/process/__init__.py +0 -0
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/process/process.py +0 -0
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/task/__init__.py +0 -0
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/task/task.py +0 -0
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents.egg-info/dependency_links.txt +0 -0
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents.egg-info/requires.txt +0 -0
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents.egg-info/top_level.txt +0 -0
- {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/setup.cfg +0 -0
@@ -394,7 +394,7 @@ Your Goal: {self.goal}
|
|
394
394
|
display_error(f"Error in chat completion: {e}")
|
395
395
|
return None
|
396
396
|
|
397
|
-
def chat(self, prompt, temperature=0.2, tools=None, output_json=None):
|
397
|
+
def chat(self, prompt, temperature=0.2, tools=None, output_json=None, output_pydantic=None):
|
398
398
|
if self.use_system_prompt:
|
399
399
|
system_prompt = f"""{self.backstory}\n
|
400
400
|
Your Role: {self.role}\n
|
@@ -402,6 +402,8 @@ Your Goal: {self.goal}
|
|
402
402
|
"""
|
403
403
|
if output_json:
|
404
404
|
system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {output_json.schema_json()}"
|
405
|
+
elif output_pydantic:
|
406
|
+
system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {output_pydantic.schema_json()}"
|
405
407
|
else:
|
406
408
|
system_prompt = None
|
407
409
|
|
@@ -410,9 +412,9 @@ Your Goal: {self.goal}
|
|
410
412
|
messages.append({"role": "system", "content": system_prompt})
|
411
413
|
messages.extend(self.chat_history)
|
412
414
|
|
413
|
-
# Modify prompt if output_json is specified
|
415
|
+
# Modify prompt if output_json or output_pydantic is specified
|
414
416
|
original_prompt = prompt
|
415
|
-
if output_json:
|
417
|
+
if output_json or output_pydantic:
|
416
418
|
if isinstance(prompt, str):
|
417
419
|
prompt += "\nReturn ONLY a valid JSON object. No other text or explanation."
|
418
420
|
elif isinstance(prompt, list):
|
@@ -487,23 +489,15 @@ Your Goal: {self.goal}
|
|
487
489
|
return None
|
488
490
|
response_text = response.choices[0].message.content.strip()
|
489
491
|
|
490
|
-
# Handle output_json if specified
|
491
|
-
if output_json:
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
self.chat_history.append({"role": "assistant", "content": response_text})
|
500
|
-
if self.verbose:
|
501
|
-
display_interaction(original_prompt, response_text, markdown=self.markdown,
|
502
|
-
generation_time=time.time() - start_time, console=self.console)
|
503
|
-
return parsed_model
|
504
|
-
except Exception as e:
|
505
|
-
display_error(f"Failed to parse response as {output_json.__name__}: {e}")
|
506
|
-
return None
|
492
|
+
# Handle output_json or output_pydantic if specified
|
493
|
+
if output_json or output_pydantic:
|
494
|
+
# Add to chat history and return raw response
|
495
|
+
self.chat_history.append({"role": "user", "content": original_prompt})
|
496
|
+
self.chat_history.append({"role": "assistant", "content": response_text})
|
497
|
+
if self.verbose:
|
498
|
+
display_interaction(original_prompt, response_text, markdown=self.markdown,
|
499
|
+
generation_time=time.time() - start_time, console=self.console)
|
500
|
+
return response_text
|
507
501
|
|
508
502
|
if not self.self_reflect:
|
509
503
|
self.chat_history.append({"role": "user", "content": original_prompt})
|
@@ -585,19 +579,21 @@ Output MUST be JSON with 'reflection' and 'satisfactory'.
|
|
585
579
|
cleaned = cleaned[:-3].strip()
|
586
580
|
return cleaned
|
587
581
|
|
588
|
-
async def achat(self, prompt, temperature=0.2, tools=None, output_json=None):
|
582
|
+
async def achat(self, prompt, temperature=0.2, tools=None, output_json=None, output_pydantic=None):
|
589
583
|
"""Async version of chat method"""
|
590
584
|
try:
|
591
585
|
# Build system prompt
|
592
586
|
system_prompt = self.system_prompt
|
593
587
|
if output_json:
|
594
588
|
system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {output_json.schema_json()}"
|
589
|
+
elif output_pydantic:
|
590
|
+
system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {output_pydantic.schema_json()}"
|
595
591
|
|
596
592
|
# Build messages
|
597
593
|
if isinstance(prompt, str):
|
598
594
|
messages = [
|
599
595
|
{"role": "system", "content": system_prompt},
|
600
|
-
{"role": "user", "content": prompt + ("\nReturn ONLY a valid JSON object. No other text or explanation." if output_json else "")}
|
596
|
+
{"role": "user", "content": prompt + ("\nReturn ONLY a valid JSON object. No other text or explanation." if (output_json or output_pydantic) else "")}
|
601
597
|
]
|
602
598
|
else:
|
603
599
|
# For multimodal prompts
|
@@ -605,7 +601,7 @@ Output MUST be JSON with 'reflection' and 'satisfactory'.
|
|
605
601
|
{"role": "system", "content": system_prompt},
|
606
602
|
{"role": "user", "content": prompt}
|
607
603
|
]
|
608
|
-
if output_json:
|
604
|
+
if output_json or output_pydantic:
|
609
605
|
# Add JSON instruction to text content
|
610
606
|
for item in messages[-1]["content"]:
|
611
607
|
if item["type"] == "text":
|
@@ -639,22 +635,15 @@ Output MUST be JSON with 'reflection' and 'satisfactory'.
|
|
639
635
|
tools=formatted_tools
|
640
636
|
)
|
641
637
|
return await self._achat_completion(response, tools)
|
642
|
-
elif output_json:
|
638
|
+
elif output_json or output_pydantic:
|
643
639
|
response = await async_client.chat.completions.create(
|
644
640
|
model=self.llm,
|
645
641
|
messages=messages,
|
646
642
|
temperature=temperature,
|
647
643
|
response_format={"type": "json_object"}
|
648
644
|
)
|
649
|
-
|
650
|
-
|
651
|
-
cleaned_json = self.clean_json_output(result)
|
652
|
-
try:
|
653
|
-
parsed = json.loads(cleaned_json)
|
654
|
-
return output_json(**parsed)
|
655
|
-
except Exception as e:
|
656
|
-
display_error(f"Error parsing JSON response: {e}")
|
657
|
-
return None
|
645
|
+
# Return the raw response
|
646
|
+
return response.choices[0].message.content
|
658
647
|
else:
|
659
648
|
response = await async_client.chat.completions.create(
|
660
649
|
model=self.llm,
|
@@ -195,10 +195,17 @@ Here are the results of previous tasks that might be useful:\n
|
|
195
195
|
|
196
196
|
agent_output = await executor_agent.achat(
|
197
197
|
_get_multimodal_message(task_prompt, task.images),
|
198
|
-
tools=task.tools
|
198
|
+
tools=task.tools,
|
199
|
+
output_json=task.output_json,
|
200
|
+
output_pydantic=task.output_pydantic
|
199
201
|
)
|
200
202
|
else:
|
201
|
-
agent_output = await executor_agent.achat(
|
203
|
+
agent_output = await executor_agent.achat(
|
204
|
+
task_prompt,
|
205
|
+
tools=task.tools,
|
206
|
+
output_json=task.output_json,
|
207
|
+
output_pydantic=task.output_pydantic
|
208
|
+
)
|
202
209
|
|
203
210
|
if agent_output:
|
204
211
|
task_output = TaskOutput(
|
@@ -405,10 +412,17 @@ Here are the results of previous tasks that might be useful:\n
|
|
405
412
|
|
406
413
|
agent_output = executor_agent.chat(
|
407
414
|
_get_multimodal_message(task_prompt, task.images),
|
408
|
-
tools=task.tools
|
415
|
+
tools=task.tools,
|
416
|
+
output_json=task.output_json,
|
417
|
+
output_pydantic=task.output_pydantic
|
409
418
|
)
|
410
419
|
else:
|
411
|
-
agent_output = executor_agent.chat(
|
420
|
+
agent_output = executor_agent.chat(
|
421
|
+
task_prompt,
|
422
|
+
tools=task.tools,
|
423
|
+
output_json=task.output_json,
|
424
|
+
output_pydantic=task.output_pydantic
|
425
|
+
)
|
412
426
|
|
413
427
|
if agent_output:
|
414
428
|
task_output = TaskOutput(
|
@@ -0,0 +1,167 @@
|
|
1
|
+
"""Tools package for PraisonAI Agents"""
|
2
|
+
from importlib import import_module
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
# Map of function names to their module and class (if any)
|
6
|
+
TOOL_MAPPINGS = {
|
7
|
+
# Direct functions
|
8
|
+
'internet_search': ('.duckduckgo_tools', None),
|
9
|
+
'duckduckgo': ('.duckduckgo_tools', None),
|
10
|
+
|
11
|
+
# Class methods from YFinance
|
12
|
+
'get_stock_price': ('.yfinance_tools', 'YFinanceTools'),
|
13
|
+
'get_stock_info': ('.yfinance_tools', 'YFinanceTools'),
|
14
|
+
'get_historical_data': ('.yfinance_tools', 'YFinanceTools'),
|
15
|
+
'yfinance': ('.yfinance_tools', 'YFinanceTools'),
|
16
|
+
|
17
|
+
# File Tools
|
18
|
+
'read_file': ('.file_tools', 'FileTools'),
|
19
|
+
'write_file': ('.file_tools', 'FileTools'),
|
20
|
+
'list_files': ('.file_tools', 'FileTools'),
|
21
|
+
'get_file_info': ('.file_tools', 'FileTools'),
|
22
|
+
'copy_file': ('.file_tools', 'FileTools'),
|
23
|
+
'move_file': ('.file_tools', 'FileTools'),
|
24
|
+
'delete_file': ('.file_tools', 'FileTools'),
|
25
|
+
'file_tools': ('.file_tools', 'FileTools'),
|
26
|
+
|
27
|
+
# CSV Tools
|
28
|
+
'read_csv': ('.csv_tools', 'CSVTools'),
|
29
|
+
'write_csv': ('.csv_tools', 'CSVTools'),
|
30
|
+
'merge_csv': ('.csv_tools', 'CSVTools'),
|
31
|
+
'analyze_csv': ('.csv_tools', 'CSVTools'),
|
32
|
+
'split_csv': ('.csv_tools', 'CSVTools'),
|
33
|
+
'csv_tools': ('.csv_tools', 'CSVTools'),
|
34
|
+
|
35
|
+
# JSON Tools
|
36
|
+
'read_json': ('.json_tools', 'JSONTools'),
|
37
|
+
'write_json': ('.json_tools', 'JSONTools'),
|
38
|
+
'merge_json': ('.json_tools', 'JSONTools'),
|
39
|
+
'validate_json': ('.json_tools', 'JSONTools'),
|
40
|
+
'analyze_json': ('.json_tools', 'JSONTools'),
|
41
|
+
'transform_json': ('.json_tools', 'JSONTools'),
|
42
|
+
'json_tools': ('.json_tools', 'JSONTools'),
|
43
|
+
|
44
|
+
# Excel Tools
|
45
|
+
'read_excel': ('.excel_tools', 'ExcelTools'),
|
46
|
+
'write_excel': ('.excel_tools', 'ExcelTools'),
|
47
|
+
'merge_excel': ('.excel_tools', 'ExcelTools'),
|
48
|
+
'create_chart': ('.excel_tools', 'ExcelTools'),
|
49
|
+
'add_chart_to_sheet': ('.excel_tools', 'ExcelTools'),
|
50
|
+
'excel_tools': ('.excel_tools', 'ExcelTools'),
|
51
|
+
|
52
|
+
# XML Tools
|
53
|
+
'read_xml': ('.xml_tools', 'XMLTools'),
|
54
|
+
'write_xml': ('.xml_tools', 'XMLTools'),
|
55
|
+
'transform_xml': ('.xml_tools', 'XMLTools'),
|
56
|
+
'validate_xml': ('.xml_tools', 'XMLTools'),
|
57
|
+
'xml_to_dict': ('.xml_tools', 'XMLTools'),
|
58
|
+
'dict_to_xml': ('.xml_tools', 'XMLTools'),
|
59
|
+
'xpath_query': ('.xml_tools', 'XMLTools'),
|
60
|
+
'xml_tools': ('.xml_tools', 'XMLTools'),
|
61
|
+
|
62
|
+
# YAML Tools
|
63
|
+
'read_yaml': ('.yaml_tools', 'YAMLTools'),
|
64
|
+
'write_yaml': ('.yaml_tools', 'YAMLTools'),
|
65
|
+
'merge_yaml': ('.yaml_tools', 'YAMLTools'),
|
66
|
+
'validate_yaml': ('.yaml_tools', 'YAMLTools'),
|
67
|
+
'analyze_yaml': ('.yaml_tools', 'YAMLTools'),
|
68
|
+
'transform_yaml': ('.yaml_tools', 'YAMLTools'),
|
69
|
+
'yaml_tools': ('.yaml_tools', 'YAMLTools'),
|
70
|
+
|
71
|
+
# Calculator Tools
|
72
|
+
'evaluate': ('.calculator_tools', 'CalculatorTools'),
|
73
|
+
'solve_equation': ('.calculator_tools', 'CalculatorTools'),
|
74
|
+
'convert_units': ('.calculator_tools', 'CalculatorTools'),
|
75
|
+
'calculate_statistics': ('.calculator_tools', 'CalculatorTools'),
|
76
|
+
'calculate_financial': ('.calculator_tools', 'CalculatorTools'),
|
77
|
+
'calculator_tools': ('.calculator_tools', 'CalculatorTools'),
|
78
|
+
|
79
|
+
# Python Tools
|
80
|
+
'execute_code': ('.python_tools', 'PythonTools'),
|
81
|
+
'analyze_code': ('.python_tools', 'PythonTools'),
|
82
|
+
'format_code': ('.python_tools', 'PythonTools'),
|
83
|
+
'lint_code': ('.python_tools', 'PythonTools'),
|
84
|
+
'disassemble_code': ('.python_tools', 'PythonTools'),
|
85
|
+
'python_tools': ('.python_tools', 'PythonTools'),
|
86
|
+
|
87
|
+
# Pandas Tools
|
88
|
+
'filter_data': ('.pandas_tools', 'PandasTools'),
|
89
|
+
'get_summary': ('.pandas_tools', 'PandasTools'),
|
90
|
+
'group_by': ('.pandas_tools', 'PandasTools'),
|
91
|
+
'pivot_table': ('.pandas_tools', 'PandasTools'),
|
92
|
+
'pandas_tools': ('.pandas_tools', 'PandasTools'),
|
93
|
+
|
94
|
+
# Wikipedia Tools
|
95
|
+
'search': ('.wikipedia_tools', 'WikipediaTools'),
|
96
|
+
'get_wikipedia_summary': ('.wikipedia_tools', 'WikipediaTools'),
|
97
|
+
'get_wikipedia_page': ('.wikipedia_tools', 'WikipediaTools'),
|
98
|
+
'get_random_wikipedia': ('.wikipedia_tools', 'WikipediaTools'),
|
99
|
+
'set_wikipedia_language': ('.wikipedia_tools', 'WikipediaTools'),
|
100
|
+
'wikipedia_tools': ('.wikipedia_tools', 'WikipediaTools'),
|
101
|
+
|
102
|
+
# Newspaper Tools
|
103
|
+
'get_article': ('.newspaper_tools', 'NewspaperTools'),
|
104
|
+
'get_news_sources': ('.newspaper_tools', 'NewspaperTools'),
|
105
|
+
'get_articles_from_source': ('.newspaper_tools', 'NewspaperTools'),
|
106
|
+
'get_trending_topics': ('.newspaper_tools', 'NewspaperTools'),
|
107
|
+
'newspaper_tools': ('.newspaper_tools', 'NewspaperTools'),
|
108
|
+
|
109
|
+
# arXiv Tools
|
110
|
+
'search_arxiv': ('.arxiv_tools', 'ArxivTools'),
|
111
|
+
'get_arxiv_paper': ('.arxiv_tools', 'ArxivTools'),
|
112
|
+
'get_papers_by_author': ('.arxiv_tools', 'ArxivTools'),
|
113
|
+
'get_papers_by_category': ('.arxiv_tools', 'ArxivTools'),
|
114
|
+
'arxiv_tools': ('.arxiv_tools', 'ArxivTools'),
|
115
|
+
|
116
|
+
# Spider Tools
|
117
|
+
'scrape_page': ('.spider_tools', 'SpiderTools'),
|
118
|
+
'extract_links': ('.spider_tools', 'SpiderTools'),
|
119
|
+
'crawl': ('.spider_tools', 'SpiderTools'),
|
120
|
+
'extract_text': ('.spider_tools', 'SpiderTools'),
|
121
|
+
'spider_tools': ('.spider_tools', 'SpiderTools'),
|
122
|
+
|
123
|
+
# DuckDB Tools
|
124
|
+
'query': ('.duckdb_tools', 'DuckDBTools'),
|
125
|
+
'create_table': ('.duckdb_tools', 'DuckDBTools'),
|
126
|
+
'load_data': ('.duckdb_tools', 'DuckDBTools'),
|
127
|
+
'export_data': ('.duckdb_tools', 'DuckDBTools'),
|
128
|
+
'get_table_info': ('.duckdb_tools', 'DuckDBTools'),
|
129
|
+
'analyze_data': ('.duckdb_tools', 'DuckDBTools'),
|
130
|
+
'duckdb_tools': ('.duckdb_tools', 'DuckDBTools'),
|
131
|
+
|
132
|
+
# Shell Tools
|
133
|
+
'execute_command': ('.shell_tools', 'ShellTools'),
|
134
|
+
'list_processes': ('.shell_tools', 'ShellTools'),
|
135
|
+
'kill_process': ('.shell_tools', 'ShellTools'),
|
136
|
+
'get_system_info': ('.shell_tools', 'ShellTools'),
|
137
|
+
'shell_tools': ('.shell_tools', 'ShellTools'),
|
138
|
+
}
|
139
|
+
|
140
|
+
_instances = {} # Cache for class instances
|
141
|
+
|
142
|
+
def __getattr__(name: str) -> Any:
|
143
|
+
"""Smart lazy loading of tools with class method support."""
|
144
|
+
if name not in TOOL_MAPPINGS:
|
145
|
+
raise AttributeError(f"module '{__package__}' has no attribute '{name}'")
|
146
|
+
|
147
|
+
module_path, class_name = TOOL_MAPPINGS[name]
|
148
|
+
|
149
|
+
if class_name is None:
|
150
|
+
# Direct function import
|
151
|
+
module = import_module(module_path, __package__)
|
152
|
+
if name in ['duckduckgo', 'file_tools', 'pandas_tools', 'wikipedia_tools',
|
153
|
+
'newspaper_tools', 'arxiv_tools', 'spider_tools', 'duckdb_tools', 'csv_tools', 'json_tools', 'excel_tools', 'xml_tools', 'yaml_tools', 'calculator_tools', 'python_tools', 'shell_tools']:
|
154
|
+
return module # Returns the callable module
|
155
|
+
return getattr(module, name)
|
156
|
+
else:
|
157
|
+
# Class method import
|
158
|
+
if class_name not in _instances:
|
159
|
+
module = import_module(module_path, __package__)
|
160
|
+
class_ = getattr(module, class_name)
|
161
|
+
_instances[class_name] = class_()
|
162
|
+
|
163
|
+
# Get the method and bind it to the instance
|
164
|
+
method = getattr(_instances[class_name], name)
|
165
|
+
return method
|
166
|
+
|
167
|
+
__all__ = list(TOOL_MAPPINGS.keys())
|
@@ -0,0 +1,292 @@
|
|
1
|
+
"""Tools for searching and retrieving papers from arXiv.
|
2
|
+
|
3
|
+
Usage:
|
4
|
+
from praisonaiagents.tools import arxiv_tools
|
5
|
+
papers = arxiv_tools.search("quantum computing")
|
6
|
+
paper = arxiv_tools.get_paper("2401.00123")
|
7
|
+
|
8
|
+
or
|
9
|
+
from praisonaiagents.tools import search_arxiv, get_arxiv_paper
|
10
|
+
papers = search_arxiv("quantum computing")
|
11
|
+
"""
|
12
|
+
|
13
|
+
import logging
|
14
|
+
from typing import List, Dict, Union, Optional, Any
|
15
|
+
from importlib import util
|
16
|
+
import json
|
17
|
+
|
18
|
+
# Map sort criteria to arxiv.SortCriterion
|
19
|
+
SORT_CRITERIA = {
|
20
|
+
"relevance": "Relevance",
|
21
|
+
"lastupdateddate": "LastUpdatedDate",
|
22
|
+
"submitteddate": "SubmittedDate"
|
23
|
+
}
|
24
|
+
|
25
|
+
# Map sort order to arxiv.SortOrder
|
26
|
+
SORT_ORDER = {
|
27
|
+
"ascending": "Ascending",
|
28
|
+
"descending": "Descending"
|
29
|
+
}
|
30
|
+
|
31
|
+
class ArxivTools:
|
32
|
+
"""Tools for searching and retrieving papers from arXiv."""
|
33
|
+
|
34
|
+
def __init__(self):
|
35
|
+
"""Initialize ArxivTools and check for arxiv package."""
|
36
|
+
self._check_arxiv()
|
37
|
+
|
38
|
+
def _check_arxiv(self):
|
39
|
+
"""Check if arxiv package is installed."""
|
40
|
+
if util.find_spec("arxiv") is None:
|
41
|
+
raise ImportError("arxiv package is not available. Please install it using: pip install arxiv")
|
42
|
+
global arxiv
|
43
|
+
import arxiv
|
44
|
+
|
45
|
+
def search(
|
46
|
+
self,
|
47
|
+
query: str,
|
48
|
+
max_results: int = 10,
|
49
|
+
sort_by: str = "relevance",
|
50
|
+
sort_order: str = "descending",
|
51
|
+
include_fields: Optional[List[str]] = None
|
52
|
+
) -> Union[List[Dict[str, Any]], Dict[str, str]]:
|
53
|
+
"""
|
54
|
+
Search arXiv for papers matching the query.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
query: Search query (e.g., "quantum computing", "author:Einstein")
|
58
|
+
max_results: Maximum number of results to return
|
59
|
+
sort_by: Sort results by ("relevance", "lastUpdatedDate", "submittedDate")
|
60
|
+
sort_order: Sort order ("ascending" or "descending")
|
61
|
+
include_fields: List of fields to include in results. If None, includes all.
|
62
|
+
Available fields: ["title", "authors", "summary", "comment", "journal_ref",
|
63
|
+
"doi", "primary_category", "categories", "links"]
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
List[Dict] or Dict: List of papers or error dict
|
67
|
+
"""
|
68
|
+
try:
|
69
|
+
import arxiv
|
70
|
+
|
71
|
+
# Configure search client
|
72
|
+
client = arxiv.Client()
|
73
|
+
|
74
|
+
# Map sort criteria
|
75
|
+
sort_by_enum = getattr(arxiv.SortCriterion, SORT_CRITERIA[sort_by.lower()])
|
76
|
+
sort_order_enum = getattr(arxiv.SortOrder, SORT_ORDER[sort_order.lower()])
|
77
|
+
|
78
|
+
# Build search query
|
79
|
+
search = arxiv.Search(
|
80
|
+
query=query,
|
81
|
+
max_results=max_results,
|
82
|
+
sort_by=sort_by_enum,
|
83
|
+
sort_order=sort_order_enum
|
84
|
+
)
|
85
|
+
|
86
|
+
# Execute search
|
87
|
+
results = []
|
88
|
+
for result in client.results(search):
|
89
|
+
# Convert to dict with selected fields
|
90
|
+
paper = self._result_to_dict(result, include_fields)
|
91
|
+
results.append(paper)
|
92
|
+
|
93
|
+
return results
|
94
|
+
except Exception as e:
|
95
|
+
error_msg = f"Error searching arXiv: {str(e)}"
|
96
|
+
logging.error(error_msg)
|
97
|
+
return {"error": error_msg}
|
98
|
+
|
99
|
+
def get_paper(
|
100
|
+
self,
|
101
|
+
paper_id: str,
|
102
|
+
include_fields: Optional[List[str]] = None
|
103
|
+
) -> Union[Dict[str, Any], Dict[str, str]]:
|
104
|
+
"""
|
105
|
+
Get details of a specific paper by its arXiv ID.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
paper_id: arXiv paper ID (e.g., "2401.00123")
|
109
|
+
include_fields: List of fields to include in results. If None, includes all.
|
110
|
+
Available fields: ["title", "authors", "summary", "comment", "journal_ref",
|
111
|
+
"doi", "primary_category", "categories", "links"]
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
Dict: Paper details or error dict
|
115
|
+
"""
|
116
|
+
try:
|
117
|
+
import arxiv
|
118
|
+
|
119
|
+
# Configure client
|
120
|
+
client = arxiv.Client()
|
121
|
+
|
122
|
+
# Get paper by ID
|
123
|
+
search = arxiv.Search(id_list=[paper_id])
|
124
|
+
results = list(client.results(search))
|
125
|
+
|
126
|
+
if not results:
|
127
|
+
return {"error": f"Paper with ID {paper_id} not found"}
|
128
|
+
|
129
|
+
# Convert to dict with selected fields
|
130
|
+
paper = self._result_to_dict(results[0], include_fields)
|
131
|
+
return paper
|
132
|
+
except Exception as e:
|
133
|
+
error_msg = f"Error getting paper {paper_id}: {str(e)}"
|
134
|
+
logging.error(error_msg)
|
135
|
+
return {"error": error_msg}
|
136
|
+
|
137
|
+
def get_papers_by_author(
|
138
|
+
self,
|
139
|
+
author: str,
|
140
|
+
max_results: int = 10,
|
141
|
+
sort_by: str = "submittedDate",
|
142
|
+
sort_order: str = "descending",
|
143
|
+
include_fields: Optional[List[str]] = None
|
144
|
+
) -> Union[List[Dict[str, Any]], Dict[str, str]]:
|
145
|
+
"""
|
146
|
+
Get papers by a specific author.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
author: Author name (e.g., "Einstein")
|
150
|
+
max_results: Maximum number of results to return
|
151
|
+
sort_by: Sort results by ("relevance", "lastUpdatedDate", "submittedDate")
|
152
|
+
sort_order: Sort order ("ascending" or "descending")
|
153
|
+
include_fields: List of fields to include in results
|
154
|
+
|
155
|
+
Returns:
|
156
|
+
List[Dict] or Dict: List of papers or error dict
|
157
|
+
"""
|
158
|
+
query = f'au:"{author}"'
|
159
|
+
return self.search(query, max_results, sort_by, sort_order, include_fields)
|
160
|
+
|
161
|
+
def get_papers_by_category(
|
162
|
+
self,
|
163
|
+
category: str,
|
164
|
+
max_results: int = 10,
|
165
|
+
sort_by: str = "submittedDate",
|
166
|
+
sort_order: str = "descending",
|
167
|
+
include_fields: Optional[List[str]] = None
|
168
|
+
) -> Union[List[Dict[str, Any]], Dict[str, str]]:
|
169
|
+
"""
|
170
|
+
Get papers from a specific category.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
category: arXiv category (e.g., "cs.AI", "physics.gen-ph")
|
174
|
+
max_results: Maximum number of results to return
|
175
|
+
sort_by: Sort results by ("relevance", "lastUpdatedDate", "submittedDate")
|
176
|
+
sort_order: Sort order ("ascending" or "descending")
|
177
|
+
include_fields: List of fields to include in results
|
178
|
+
|
179
|
+
Returns:
|
180
|
+
List[Dict] or Dict: List of papers or error dict
|
181
|
+
"""
|
182
|
+
query = f'cat:{category}'
|
183
|
+
return self.search(query, max_results, sort_by, sort_order, include_fields)
|
184
|
+
|
185
|
+
def _result_to_dict(
|
186
|
+
self,
|
187
|
+
result: Any,
|
188
|
+
include_fields: Optional[List[str]] = None
|
189
|
+
) -> Dict[str, Any]:
|
190
|
+
"""Convert arxiv.Result to dictionary with selected fields."""
|
191
|
+
# Default fields to include
|
192
|
+
if include_fields is None:
|
193
|
+
include_fields = [
|
194
|
+
"title", "authors", "summary", "comment", "journal_ref",
|
195
|
+
"doi", "primary_category", "categories", "links"
|
196
|
+
]
|
197
|
+
|
198
|
+
# Build paper dict with selected fields
|
199
|
+
paper = {}
|
200
|
+
|
201
|
+
# Always include these basic fields
|
202
|
+
paper["arxiv_id"] = result.entry_id.split("/")[-1]
|
203
|
+
paper["updated"] = result.updated.isoformat() if result.updated else None
|
204
|
+
paper["published"] = result.published.isoformat() if result.published else None
|
205
|
+
|
206
|
+
# Add requested fields
|
207
|
+
if "title" in include_fields:
|
208
|
+
paper["title"] = result.title
|
209
|
+
if "authors" in include_fields:
|
210
|
+
paper["authors"] = [str(author) for author in result.authors]
|
211
|
+
if "summary" in include_fields:
|
212
|
+
paper["summary"] = result.summary
|
213
|
+
if "comment" in include_fields:
|
214
|
+
paper["comment"] = result.comment
|
215
|
+
if "journal_ref" in include_fields:
|
216
|
+
paper["journal_ref"] = result.journal_ref
|
217
|
+
if "doi" in include_fields:
|
218
|
+
paper["doi"] = result.doi
|
219
|
+
if "primary_category" in include_fields:
|
220
|
+
paper["primary_category"] = result.primary_category
|
221
|
+
if "categories" in include_fields:
|
222
|
+
paper["categories"] = result.categories
|
223
|
+
if "links" in include_fields:
|
224
|
+
paper["pdf_url"] = result.pdf_url
|
225
|
+
paper["abstract_url"] = f"https://arxiv.org/abs/{paper['arxiv_id']}"
|
226
|
+
|
227
|
+
return paper
|
228
|
+
|
229
|
+
# Create instance for direct function access
|
230
|
+
_arxiv_tools = ArxivTools()
|
231
|
+
search_arxiv = _arxiv_tools.search
|
232
|
+
get_arxiv_paper = _arxiv_tools.get_paper
|
233
|
+
get_papers_by_author = _arxiv_tools.get_papers_by_author
|
234
|
+
get_papers_by_category = _arxiv_tools.get_papers_by_category
|
235
|
+
|
236
|
+
if __name__ == "__main__":
|
237
|
+
# Example usage
|
238
|
+
print("\n==================================================")
|
239
|
+
print("ArxivTools Demonstration")
|
240
|
+
print("==================================================\n")
|
241
|
+
|
242
|
+
# 1. Search for papers
|
243
|
+
print("1. Searching for Papers")
|
244
|
+
print("------------------------------")
|
245
|
+
query = "quantum computing"
|
246
|
+
papers = search_arxiv(query, max_results=3)
|
247
|
+
print(f"Papers about {query}:")
|
248
|
+
if isinstance(papers, list):
|
249
|
+
print(json.dumps(papers, indent=2))
|
250
|
+
else:
|
251
|
+
print(papers) # Show error
|
252
|
+
print()
|
253
|
+
|
254
|
+
# 2. Get specific paper
|
255
|
+
print("2. Getting Specific Paper")
|
256
|
+
print("------------------------------")
|
257
|
+
if isinstance(papers, list) and papers:
|
258
|
+
paper_id = papers[0]["arxiv_id"]
|
259
|
+
paper = get_arxiv_paper(paper_id)
|
260
|
+
print(f"Paper {paper_id}:")
|
261
|
+
if "error" not in paper:
|
262
|
+
print(json.dumps(paper, indent=2))
|
263
|
+
else:
|
264
|
+
print(paper) # Show error
|
265
|
+
print()
|
266
|
+
|
267
|
+
# 3. Get papers by author
|
268
|
+
print("3. Getting Papers by Author")
|
269
|
+
print("------------------------------")
|
270
|
+
author = "Yoshua Bengio"
|
271
|
+
author_papers = get_papers_by_author(author, max_results=3)
|
272
|
+
print(f"Papers by {author}:")
|
273
|
+
if isinstance(author_papers, list):
|
274
|
+
print(json.dumps(author_papers, indent=2))
|
275
|
+
else:
|
276
|
+
print(author_papers) # Show error
|
277
|
+
print()
|
278
|
+
|
279
|
+
# 4. Get papers by category
|
280
|
+
print("4. Getting Papers by Category")
|
281
|
+
print("------------------------------")
|
282
|
+
category = "cs.AI"
|
283
|
+
category_papers = get_papers_by_category(category, max_results=3)
|
284
|
+
print(f"Papers in category {category}:")
|
285
|
+
if isinstance(category_papers, list):
|
286
|
+
print(json.dumps(category_papers, indent=2))
|
287
|
+
else:
|
288
|
+
print(category_papers) # Show error
|
289
|
+
|
290
|
+
print("\n==================================================")
|
291
|
+
print("Demonstration Complete")
|
292
|
+
print("==================================================")
|