praisonaiagents 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- praisonaiagents/agent/agent.py +22 -33
- praisonaiagents/agents/agents.py +18 -4
- praisonaiagents/tools/__init__.py +165 -2
- praisonaiagents/tools/arxiv_tools.py +292 -0
- praisonaiagents/tools/calculator_tools.py +278 -0
- praisonaiagents/tools/csv_tools.py +266 -0
- praisonaiagents/tools/duckdb_tools.py +268 -0
- praisonaiagents/tools/duckduckgo_tools.py +52 -0
- praisonaiagents/tools/excel_tools.py +310 -0
- praisonaiagents/tools/file_tools.py +274 -0
- praisonaiagents/tools/json_tools.py +515 -0
- praisonaiagents/tools/newspaper_tools.py +354 -0
- praisonaiagents/tools/pandas_tools.py +326 -0
- praisonaiagents/tools/python_tools.py +423 -0
- praisonaiagents/tools/shell_tools.py +278 -0
- praisonaiagents/tools/spider_tools.py +431 -0
- praisonaiagents/tools/test.py +56 -0
- praisonaiagents/tools/tools.py +5 -36
- praisonaiagents/tools/wikipedia_tools.py +272 -0
- praisonaiagents/tools/xml_tools.py +498 -0
- praisonaiagents/tools/yaml_tools.py +417 -0
- praisonaiagents/tools/yfinance_tools.py +213 -0
- {praisonaiagents-0.0.22.dist-info → praisonaiagents-0.0.24.dist-info}/METADATA +1 -1
- praisonaiagents-0.0.24.dist-info/RECORD +42 -0
- praisonaiagents-0.0.22.dist-info/RECORD +0 -24
- {praisonaiagents-0.0.22.dist-info → praisonaiagents-0.0.24.dist-info}/WHEEL +0 -0
- {praisonaiagents-0.0.22.dist-info → praisonaiagents-0.0.24.dist-info}/top_level.txt +0 -0
praisonaiagents/agent/agent.py
CHANGED
@@ -394,7 +394,7 @@ Your Goal: {self.goal}
|
|
394
394
|
display_error(f"Error in chat completion: {e}")
|
395
395
|
return None
|
396
396
|
|
397
|
-
def chat(self, prompt, temperature=0.2, tools=None, output_json=None):
|
397
|
+
def chat(self, prompt, temperature=0.2, tools=None, output_json=None, output_pydantic=None):
|
398
398
|
if self.use_system_prompt:
|
399
399
|
system_prompt = f"""{self.backstory}\n
|
400
400
|
Your Role: {self.role}\n
|
@@ -402,6 +402,8 @@ Your Goal: {self.goal}
|
|
402
402
|
"""
|
403
403
|
if output_json:
|
404
404
|
system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {output_json.schema_json()}"
|
405
|
+
elif output_pydantic:
|
406
|
+
system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {output_pydantic.schema_json()}"
|
405
407
|
else:
|
406
408
|
system_prompt = None
|
407
409
|
|
@@ -410,9 +412,9 @@ Your Goal: {self.goal}
|
|
410
412
|
messages.append({"role": "system", "content": system_prompt})
|
411
413
|
messages.extend(self.chat_history)
|
412
414
|
|
413
|
-
# Modify prompt if output_json is specified
|
415
|
+
# Modify prompt if output_json or output_pydantic is specified
|
414
416
|
original_prompt = prompt
|
415
|
-
if output_json:
|
417
|
+
if output_json or output_pydantic:
|
416
418
|
if isinstance(prompt, str):
|
417
419
|
prompt += "\nReturn ONLY a valid JSON object. No other text or explanation."
|
418
420
|
elif isinstance(prompt, list):
|
@@ -487,23 +489,15 @@ Your Goal: {self.goal}
|
|
487
489
|
return None
|
488
490
|
response_text = response.choices[0].message.content.strip()
|
489
491
|
|
490
|
-
# Handle output_json if specified
|
491
|
-
if output_json:
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
self.chat_history.append({"role": "assistant", "content": response_text})
|
500
|
-
if self.verbose:
|
501
|
-
display_interaction(original_prompt, response_text, markdown=self.markdown,
|
502
|
-
generation_time=time.time() - start_time, console=self.console)
|
503
|
-
return parsed_model
|
504
|
-
except Exception as e:
|
505
|
-
display_error(f"Failed to parse response as {output_json.__name__}: {e}")
|
506
|
-
return None
|
492
|
+
# Handle output_json or output_pydantic if specified
|
493
|
+
if output_json or output_pydantic:
|
494
|
+
# Add to chat history and return raw response
|
495
|
+
self.chat_history.append({"role": "user", "content": original_prompt})
|
496
|
+
self.chat_history.append({"role": "assistant", "content": response_text})
|
497
|
+
if self.verbose:
|
498
|
+
display_interaction(original_prompt, response_text, markdown=self.markdown,
|
499
|
+
generation_time=time.time() - start_time, console=self.console)
|
500
|
+
return response_text
|
507
501
|
|
508
502
|
if not self.self_reflect:
|
509
503
|
self.chat_history.append({"role": "user", "content": original_prompt})
|
@@ -585,19 +579,21 @@ Output MUST be JSON with 'reflection' and 'satisfactory'.
|
|
585
579
|
cleaned = cleaned[:-3].strip()
|
586
580
|
return cleaned
|
587
581
|
|
588
|
-
async def achat(self, prompt, temperature=0.2, tools=None, output_json=None):
|
582
|
+
async def achat(self, prompt, temperature=0.2, tools=None, output_json=None, output_pydantic=None):
|
589
583
|
"""Async version of chat method"""
|
590
584
|
try:
|
591
585
|
# Build system prompt
|
592
586
|
system_prompt = self.system_prompt
|
593
587
|
if output_json:
|
594
588
|
system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {output_json.schema_json()}"
|
589
|
+
elif output_pydantic:
|
590
|
+
system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {output_pydantic.schema_json()}"
|
595
591
|
|
596
592
|
# Build messages
|
597
593
|
if isinstance(prompt, str):
|
598
594
|
messages = [
|
599
595
|
{"role": "system", "content": system_prompt},
|
600
|
-
{"role": "user", "content": prompt + ("\nReturn ONLY a valid JSON object. No other text or explanation." if output_json else "")}
|
596
|
+
{"role": "user", "content": prompt + ("\nReturn ONLY a valid JSON object. No other text or explanation." if (output_json or output_pydantic) else "")}
|
601
597
|
]
|
602
598
|
else:
|
603
599
|
# For multimodal prompts
|
@@ -605,7 +601,7 @@ Output MUST be JSON with 'reflection' and 'satisfactory'.
|
|
605
601
|
{"role": "system", "content": system_prompt},
|
606
602
|
{"role": "user", "content": prompt}
|
607
603
|
]
|
608
|
-
if output_json:
|
604
|
+
if output_json or output_pydantic:
|
609
605
|
# Add JSON instruction to text content
|
610
606
|
for item in messages[-1]["content"]:
|
611
607
|
if item["type"] == "text":
|
@@ -639,22 +635,15 @@ Output MUST be JSON with 'reflection' and 'satisfactory'.
|
|
639
635
|
tools=formatted_tools
|
640
636
|
)
|
641
637
|
return await self._achat_completion(response, tools)
|
642
|
-
elif output_json:
|
638
|
+
elif output_json or output_pydantic:
|
643
639
|
response = await async_client.chat.completions.create(
|
644
640
|
model=self.llm,
|
645
641
|
messages=messages,
|
646
642
|
temperature=temperature,
|
647
643
|
response_format={"type": "json_object"}
|
648
644
|
)
|
649
|
-
|
650
|
-
|
651
|
-
cleaned_json = self.clean_json_output(result)
|
652
|
-
try:
|
653
|
-
parsed = json.loads(cleaned_json)
|
654
|
-
return output_json(**parsed)
|
655
|
-
except Exception as e:
|
656
|
-
display_error(f"Error parsing JSON response: {e}")
|
657
|
-
return None
|
645
|
+
# Return the raw response
|
646
|
+
return response.choices[0].message.content
|
658
647
|
else:
|
659
648
|
response = await async_client.chat.completions.create(
|
660
649
|
model=self.llm,
|
praisonaiagents/agents/agents.py
CHANGED
@@ -195,10 +195,17 @@ Here are the results of previous tasks that might be useful:\n
|
|
195
195
|
|
196
196
|
agent_output = await executor_agent.achat(
|
197
197
|
_get_multimodal_message(task_prompt, task.images),
|
198
|
-
tools=task.tools
|
198
|
+
tools=task.tools,
|
199
|
+
output_json=task.output_json,
|
200
|
+
output_pydantic=task.output_pydantic
|
199
201
|
)
|
200
202
|
else:
|
201
|
-
agent_output = await executor_agent.achat(
|
203
|
+
agent_output = await executor_agent.achat(
|
204
|
+
task_prompt,
|
205
|
+
tools=task.tools,
|
206
|
+
output_json=task.output_json,
|
207
|
+
output_pydantic=task.output_pydantic
|
208
|
+
)
|
202
209
|
|
203
210
|
if agent_output:
|
204
211
|
task_output = TaskOutput(
|
@@ -405,10 +412,17 @@ Here are the results of previous tasks that might be useful:\n
|
|
405
412
|
|
406
413
|
agent_output = executor_agent.chat(
|
407
414
|
_get_multimodal_message(task_prompt, task.images),
|
408
|
-
tools=task.tools
|
415
|
+
tools=task.tools,
|
416
|
+
output_json=task.output_json,
|
417
|
+
output_pydantic=task.output_pydantic
|
409
418
|
)
|
410
419
|
else:
|
411
|
-
agent_output = executor_agent.chat(
|
420
|
+
agent_output = executor_agent.chat(
|
421
|
+
task_prompt,
|
422
|
+
tools=task.tools,
|
423
|
+
output_json=task.output_json,
|
424
|
+
output_pydantic=task.output_pydantic
|
425
|
+
)
|
412
426
|
|
413
427
|
if agent_output:
|
414
428
|
task_output = TaskOutput(
|
@@ -1,4 +1,167 @@
|
|
1
1
|
"""Tools package for PraisonAI Agents"""
|
2
|
-
from
|
2
|
+
from importlib import import_module
|
3
|
+
from typing import Any
|
3
4
|
|
4
|
-
|
5
|
+
# Map of function names to their module and class (if any)
|
6
|
+
TOOL_MAPPINGS = {
|
7
|
+
# Direct functions
|
8
|
+
'internet_search': ('.duckduckgo_tools', None),
|
9
|
+
'duckduckgo': ('.duckduckgo_tools', None),
|
10
|
+
|
11
|
+
# Class methods from YFinance
|
12
|
+
'get_stock_price': ('.yfinance_tools', 'YFinanceTools'),
|
13
|
+
'get_stock_info': ('.yfinance_tools', 'YFinanceTools'),
|
14
|
+
'get_historical_data': ('.yfinance_tools', 'YFinanceTools'),
|
15
|
+
'yfinance': ('.yfinance_tools', 'YFinanceTools'),
|
16
|
+
|
17
|
+
# File Tools
|
18
|
+
'read_file': ('.file_tools', 'FileTools'),
|
19
|
+
'write_file': ('.file_tools', 'FileTools'),
|
20
|
+
'list_files': ('.file_tools', 'FileTools'),
|
21
|
+
'get_file_info': ('.file_tools', 'FileTools'),
|
22
|
+
'copy_file': ('.file_tools', 'FileTools'),
|
23
|
+
'move_file': ('.file_tools', 'FileTools'),
|
24
|
+
'delete_file': ('.file_tools', 'FileTools'),
|
25
|
+
'file_tools': ('.file_tools', 'FileTools'),
|
26
|
+
|
27
|
+
# CSV Tools
|
28
|
+
'read_csv': ('.csv_tools', 'CSVTools'),
|
29
|
+
'write_csv': ('.csv_tools', 'CSVTools'),
|
30
|
+
'merge_csv': ('.csv_tools', 'CSVTools'),
|
31
|
+
'analyze_csv': ('.csv_tools', 'CSVTools'),
|
32
|
+
'split_csv': ('.csv_tools', 'CSVTools'),
|
33
|
+
'csv_tools': ('.csv_tools', 'CSVTools'),
|
34
|
+
|
35
|
+
# JSON Tools
|
36
|
+
'read_json': ('.json_tools', 'JSONTools'),
|
37
|
+
'write_json': ('.json_tools', 'JSONTools'),
|
38
|
+
'merge_json': ('.json_tools', 'JSONTools'),
|
39
|
+
'validate_json': ('.json_tools', 'JSONTools'),
|
40
|
+
'analyze_json': ('.json_tools', 'JSONTools'),
|
41
|
+
'transform_json': ('.json_tools', 'JSONTools'),
|
42
|
+
'json_tools': ('.json_tools', 'JSONTools'),
|
43
|
+
|
44
|
+
# Excel Tools
|
45
|
+
'read_excel': ('.excel_tools', 'ExcelTools'),
|
46
|
+
'write_excel': ('.excel_tools', 'ExcelTools'),
|
47
|
+
'merge_excel': ('.excel_tools', 'ExcelTools'),
|
48
|
+
'create_chart': ('.excel_tools', 'ExcelTools'),
|
49
|
+
'add_chart_to_sheet': ('.excel_tools', 'ExcelTools'),
|
50
|
+
'excel_tools': ('.excel_tools', 'ExcelTools'),
|
51
|
+
|
52
|
+
# XML Tools
|
53
|
+
'read_xml': ('.xml_tools', 'XMLTools'),
|
54
|
+
'write_xml': ('.xml_tools', 'XMLTools'),
|
55
|
+
'transform_xml': ('.xml_tools', 'XMLTools'),
|
56
|
+
'validate_xml': ('.xml_tools', 'XMLTools'),
|
57
|
+
'xml_to_dict': ('.xml_tools', 'XMLTools'),
|
58
|
+
'dict_to_xml': ('.xml_tools', 'XMLTools'),
|
59
|
+
'xpath_query': ('.xml_tools', 'XMLTools'),
|
60
|
+
'xml_tools': ('.xml_tools', 'XMLTools'),
|
61
|
+
|
62
|
+
# YAML Tools
|
63
|
+
'read_yaml': ('.yaml_tools', 'YAMLTools'),
|
64
|
+
'write_yaml': ('.yaml_tools', 'YAMLTools'),
|
65
|
+
'merge_yaml': ('.yaml_tools', 'YAMLTools'),
|
66
|
+
'validate_yaml': ('.yaml_tools', 'YAMLTools'),
|
67
|
+
'analyze_yaml': ('.yaml_tools', 'YAMLTools'),
|
68
|
+
'transform_yaml': ('.yaml_tools', 'YAMLTools'),
|
69
|
+
'yaml_tools': ('.yaml_tools', 'YAMLTools'),
|
70
|
+
|
71
|
+
# Calculator Tools
|
72
|
+
'evaluate': ('.calculator_tools', 'CalculatorTools'),
|
73
|
+
'solve_equation': ('.calculator_tools', 'CalculatorTools'),
|
74
|
+
'convert_units': ('.calculator_tools', 'CalculatorTools'),
|
75
|
+
'calculate_statistics': ('.calculator_tools', 'CalculatorTools'),
|
76
|
+
'calculate_financial': ('.calculator_tools', 'CalculatorTools'),
|
77
|
+
'calculator_tools': ('.calculator_tools', 'CalculatorTools'),
|
78
|
+
|
79
|
+
# Python Tools
|
80
|
+
'execute_code': ('.python_tools', 'PythonTools'),
|
81
|
+
'analyze_code': ('.python_tools', 'PythonTools'),
|
82
|
+
'format_code': ('.python_tools', 'PythonTools'),
|
83
|
+
'lint_code': ('.python_tools', 'PythonTools'),
|
84
|
+
'disassemble_code': ('.python_tools', 'PythonTools'),
|
85
|
+
'python_tools': ('.python_tools', 'PythonTools'),
|
86
|
+
|
87
|
+
# Pandas Tools
|
88
|
+
'filter_data': ('.pandas_tools', 'PandasTools'),
|
89
|
+
'get_summary': ('.pandas_tools', 'PandasTools'),
|
90
|
+
'group_by': ('.pandas_tools', 'PandasTools'),
|
91
|
+
'pivot_table': ('.pandas_tools', 'PandasTools'),
|
92
|
+
'pandas_tools': ('.pandas_tools', 'PandasTools'),
|
93
|
+
|
94
|
+
# Wikipedia Tools
|
95
|
+
'search': ('.wikipedia_tools', 'WikipediaTools'),
|
96
|
+
'get_wikipedia_summary': ('.wikipedia_tools', 'WikipediaTools'),
|
97
|
+
'get_wikipedia_page': ('.wikipedia_tools', 'WikipediaTools'),
|
98
|
+
'get_random_wikipedia': ('.wikipedia_tools', 'WikipediaTools'),
|
99
|
+
'set_wikipedia_language': ('.wikipedia_tools', 'WikipediaTools'),
|
100
|
+
'wikipedia_tools': ('.wikipedia_tools', 'WikipediaTools'),
|
101
|
+
|
102
|
+
# Newspaper Tools
|
103
|
+
'get_article': ('.newspaper_tools', 'NewspaperTools'),
|
104
|
+
'get_news_sources': ('.newspaper_tools', 'NewspaperTools'),
|
105
|
+
'get_articles_from_source': ('.newspaper_tools', 'NewspaperTools'),
|
106
|
+
'get_trending_topics': ('.newspaper_tools', 'NewspaperTools'),
|
107
|
+
'newspaper_tools': ('.newspaper_tools', 'NewspaperTools'),
|
108
|
+
|
109
|
+
# arXiv Tools
|
110
|
+
'search_arxiv': ('.arxiv_tools', 'ArxivTools'),
|
111
|
+
'get_arxiv_paper': ('.arxiv_tools', 'ArxivTools'),
|
112
|
+
'get_papers_by_author': ('.arxiv_tools', 'ArxivTools'),
|
113
|
+
'get_papers_by_category': ('.arxiv_tools', 'ArxivTools'),
|
114
|
+
'arxiv_tools': ('.arxiv_tools', 'ArxivTools'),
|
115
|
+
|
116
|
+
# Spider Tools
|
117
|
+
'scrape_page': ('.spider_tools', 'SpiderTools'),
|
118
|
+
'extract_links': ('.spider_tools', 'SpiderTools'),
|
119
|
+
'crawl': ('.spider_tools', 'SpiderTools'),
|
120
|
+
'extract_text': ('.spider_tools', 'SpiderTools'),
|
121
|
+
'spider_tools': ('.spider_tools', 'SpiderTools'),
|
122
|
+
|
123
|
+
# DuckDB Tools
|
124
|
+
'query': ('.duckdb_tools', 'DuckDBTools'),
|
125
|
+
'create_table': ('.duckdb_tools', 'DuckDBTools'),
|
126
|
+
'load_data': ('.duckdb_tools', 'DuckDBTools'),
|
127
|
+
'export_data': ('.duckdb_tools', 'DuckDBTools'),
|
128
|
+
'get_table_info': ('.duckdb_tools', 'DuckDBTools'),
|
129
|
+
'analyze_data': ('.duckdb_tools', 'DuckDBTools'),
|
130
|
+
'duckdb_tools': ('.duckdb_tools', 'DuckDBTools'),
|
131
|
+
|
132
|
+
# Shell Tools
|
133
|
+
'execute_command': ('.shell_tools', 'ShellTools'),
|
134
|
+
'list_processes': ('.shell_tools', 'ShellTools'),
|
135
|
+
'kill_process': ('.shell_tools', 'ShellTools'),
|
136
|
+
'get_system_info': ('.shell_tools', 'ShellTools'),
|
137
|
+
'shell_tools': ('.shell_tools', 'ShellTools'),
|
138
|
+
}
|
139
|
+
|
140
|
+
_instances = {} # Cache for class instances
|
141
|
+
|
142
|
+
def __getattr__(name: str) -> Any:
|
143
|
+
"""Smart lazy loading of tools with class method support."""
|
144
|
+
if name not in TOOL_MAPPINGS:
|
145
|
+
raise AttributeError(f"module '{__package__}' has no attribute '{name}'")
|
146
|
+
|
147
|
+
module_path, class_name = TOOL_MAPPINGS[name]
|
148
|
+
|
149
|
+
if class_name is None:
|
150
|
+
# Direct function import
|
151
|
+
module = import_module(module_path, __package__)
|
152
|
+
if name in ['duckduckgo', 'file_tools', 'pandas_tools', 'wikipedia_tools',
|
153
|
+
'newspaper_tools', 'arxiv_tools', 'spider_tools', 'duckdb_tools', 'csv_tools', 'json_tools', 'excel_tools', 'xml_tools', 'yaml_tools', 'calculator_tools', 'python_tools', 'shell_tools']:
|
154
|
+
return module # Returns the callable module
|
155
|
+
return getattr(module, name)
|
156
|
+
else:
|
157
|
+
# Class method import
|
158
|
+
if class_name not in _instances:
|
159
|
+
module = import_module(module_path, __package__)
|
160
|
+
class_ = getattr(module, class_name)
|
161
|
+
_instances[class_name] = class_()
|
162
|
+
|
163
|
+
# Get the method and bind it to the instance
|
164
|
+
method = getattr(_instances[class_name], name)
|
165
|
+
return method
|
166
|
+
|
167
|
+
__all__ = list(TOOL_MAPPINGS.keys())
|
@@ -0,0 +1,292 @@
|
|
1
|
+
"""Tools for searching and retrieving papers from arXiv.
|
2
|
+
|
3
|
+
Usage:
|
4
|
+
from praisonaiagents.tools import arxiv_tools
|
5
|
+
papers = arxiv_tools.search("quantum computing")
|
6
|
+
paper = arxiv_tools.get_paper("2401.00123")
|
7
|
+
|
8
|
+
or
|
9
|
+
from praisonaiagents.tools import search_arxiv, get_arxiv_paper
|
10
|
+
papers = search_arxiv("quantum computing")
|
11
|
+
"""
|
12
|
+
|
13
|
+
import logging
|
14
|
+
from typing import List, Dict, Union, Optional, Any
|
15
|
+
from importlib import util
|
16
|
+
import json
|
17
|
+
|
18
|
+
# Map sort criteria to arxiv.SortCriterion
|
19
|
+
SORT_CRITERIA = {
|
20
|
+
"relevance": "Relevance",
|
21
|
+
"lastupdateddate": "LastUpdatedDate",
|
22
|
+
"submitteddate": "SubmittedDate"
|
23
|
+
}
|
24
|
+
|
25
|
+
# Map sort order to arxiv.SortOrder
|
26
|
+
SORT_ORDER = {
|
27
|
+
"ascending": "Ascending",
|
28
|
+
"descending": "Descending"
|
29
|
+
}
|
30
|
+
|
31
|
+
class ArxivTools:
|
32
|
+
"""Tools for searching and retrieving papers from arXiv."""
|
33
|
+
|
34
|
+
def __init__(self):
|
35
|
+
"""Initialize ArxivTools and check for arxiv package."""
|
36
|
+
self._check_arxiv()
|
37
|
+
|
38
|
+
def _check_arxiv(self):
|
39
|
+
"""Check if arxiv package is installed."""
|
40
|
+
if util.find_spec("arxiv") is None:
|
41
|
+
raise ImportError("arxiv package is not available. Please install it using: pip install arxiv")
|
42
|
+
global arxiv
|
43
|
+
import arxiv
|
44
|
+
|
45
|
+
def search(
|
46
|
+
self,
|
47
|
+
query: str,
|
48
|
+
max_results: int = 10,
|
49
|
+
sort_by: str = "relevance",
|
50
|
+
sort_order: str = "descending",
|
51
|
+
include_fields: Optional[List[str]] = None
|
52
|
+
) -> Union[List[Dict[str, Any]], Dict[str, str]]:
|
53
|
+
"""
|
54
|
+
Search arXiv for papers matching the query.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
query: Search query (e.g., "quantum computing", "author:Einstein")
|
58
|
+
max_results: Maximum number of results to return
|
59
|
+
sort_by: Sort results by ("relevance", "lastUpdatedDate", "submittedDate")
|
60
|
+
sort_order: Sort order ("ascending" or "descending")
|
61
|
+
include_fields: List of fields to include in results. If None, includes all.
|
62
|
+
Available fields: ["title", "authors", "summary", "comment", "journal_ref",
|
63
|
+
"doi", "primary_category", "categories", "links"]
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
List[Dict] or Dict: List of papers or error dict
|
67
|
+
"""
|
68
|
+
try:
|
69
|
+
import arxiv
|
70
|
+
|
71
|
+
# Configure search client
|
72
|
+
client = arxiv.Client()
|
73
|
+
|
74
|
+
# Map sort criteria
|
75
|
+
sort_by_enum = getattr(arxiv.SortCriterion, SORT_CRITERIA[sort_by.lower()])
|
76
|
+
sort_order_enum = getattr(arxiv.SortOrder, SORT_ORDER[sort_order.lower()])
|
77
|
+
|
78
|
+
# Build search query
|
79
|
+
search = arxiv.Search(
|
80
|
+
query=query,
|
81
|
+
max_results=max_results,
|
82
|
+
sort_by=sort_by_enum,
|
83
|
+
sort_order=sort_order_enum
|
84
|
+
)
|
85
|
+
|
86
|
+
# Execute search
|
87
|
+
results = []
|
88
|
+
for result in client.results(search):
|
89
|
+
# Convert to dict with selected fields
|
90
|
+
paper = self._result_to_dict(result, include_fields)
|
91
|
+
results.append(paper)
|
92
|
+
|
93
|
+
return results
|
94
|
+
except Exception as e:
|
95
|
+
error_msg = f"Error searching arXiv: {str(e)}"
|
96
|
+
logging.error(error_msg)
|
97
|
+
return {"error": error_msg}
|
98
|
+
|
99
|
+
def get_paper(
|
100
|
+
self,
|
101
|
+
paper_id: str,
|
102
|
+
include_fields: Optional[List[str]] = None
|
103
|
+
) -> Union[Dict[str, Any], Dict[str, str]]:
|
104
|
+
"""
|
105
|
+
Get details of a specific paper by its arXiv ID.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
paper_id: arXiv paper ID (e.g., "2401.00123")
|
109
|
+
include_fields: List of fields to include in results. If None, includes all.
|
110
|
+
Available fields: ["title", "authors", "summary", "comment", "journal_ref",
|
111
|
+
"doi", "primary_category", "categories", "links"]
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
Dict: Paper details or error dict
|
115
|
+
"""
|
116
|
+
try:
|
117
|
+
import arxiv
|
118
|
+
|
119
|
+
# Configure client
|
120
|
+
client = arxiv.Client()
|
121
|
+
|
122
|
+
# Get paper by ID
|
123
|
+
search = arxiv.Search(id_list=[paper_id])
|
124
|
+
results = list(client.results(search))
|
125
|
+
|
126
|
+
if not results:
|
127
|
+
return {"error": f"Paper with ID {paper_id} not found"}
|
128
|
+
|
129
|
+
# Convert to dict with selected fields
|
130
|
+
paper = self._result_to_dict(results[0], include_fields)
|
131
|
+
return paper
|
132
|
+
except Exception as e:
|
133
|
+
error_msg = f"Error getting paper {paper_id}: {str(e)}"
|
134
|
+
logging.error(error_msg)
|
135
|
+
return {"error": error_msg}
|
136
|
+
|
137
|
+
def get_papers_by_author(
|
138
|
+
self,
|
139
|
+
author: str,
|
140
|
+
max_results: int = 10,
|
141
|
+
sort_by: str = "submittedDate",
|
142
|
+
sort_order: str = "descending",
|
143
|
+
include_fields: Optional[List[str]] = None
|
144
|
+
) -> Union[List[Dict[str, Any]], Dict[str, str]]:
|
145
|
+
"""
|
146
|
+
Get papers by a specific author.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
author: Author name (e.g., "Einstein")
|
150
|
+
max_results: Maximum number of results to return
|
151
|
+
sort_by: Sort results by ("relevance", "lastUpdatedDate", "submittedDate")
|
152
|
+
sort_order: Sort order ("ascending" or "descending")
|
153
|
+
include_fields: List of fields to include in results
|
154
|
+
|
155
|
+
Returns:
|
156
|
+
List[Dict] or Dict: List of papers or error dict
|
157
|
+
"""
|
158
|
+
query = f'au:"{author}"'
|
159
|
+
return self.search(query, max_results, sort_by, sort_order, include_fields)
|
160
|
+
|
161
|
+
def get_papers_by_category(
|
162
|
+
self,
|
163
|
+
category: str,
|
164
|
+
max_results: int = 10,
|
165
|
+
sort_by: str = "submittedDate",
|
166
|
+
sort_order: str = "descending",
|
167
|
+
include_fields: Optional[List[str]] = None
|
168
|
+
) -> Union[List[Dict[str, Any]], Dict[str, str]]:
|
169
|
+
"""
|
170
|
+
Get papers from a specific category.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
category: arXiv category (e.g., "cs.AI", "physics.gen-ph")
|
174
|
+
max_results: Maximum number of results to return
|
175
|
+
sort_by: Sort results by ("relevance", "lastUpdatedDate", "submittedDate")
|
176
|
+
sort_order: Sort order ("ascending" or "descending")
|
177
|
+
include_fields: List of fields to include in results
|
178
|
+
|
179
|
+
Returns:
|
180
|
+
List[Dict] or Dict: List of papers or error dict
|
181
|
+
"""
|
182
|
+
query = f'cat:{category}'
|
183
|
+
return self.search(query, max_results, sort_by, sort_order, include_fields)
|
184
|
+
|
185
|
+
def _result_to_dict(
|
186
|
+
self,
|
187
|
+
result: Any,
|
188
|
+
include_fields: Optional[List[str]] = None
|
189
|
+
) -> Dict[str, Any]:
|
190
|
+
"""Convert arxiv.Result to dictionary with selected fields."""
|
191
|
+
# Default fields to include
|
192
|
+
if include_fields is None:
|
193
|
+
include_fields = [
|
194
|
+
"title", "authors", "summary", "comment", "journal_ref",
|
195
|
+
"doi", "primary_category", "categories", "links"
|
196
|
+
]
|
197
|
+
|
198
|
+
# Build paper dict with selected fields
|
199
|
+
paper = {}
|
200
|
+
|
201
|
+
# Always include these basic fields
|
202
|
+
paper["arxiv_id"] = result.entry_id.split("/")[-1]
|
203
|
+
paper["updated"] = result.updated.isoformat() if result.updated else None
|
204
|
+
paper["published"] = result.published.isoformat() if result.published else None
|
205
|
+
|
206
|
+
# Add requested fields
|
207
|
+
if "title" in include_fields:
|
208
|
+
paper["title"] = result.title
|
209
|
+
if "authors" in include_fields:
|
210
|
+
paper["authors"] = [str(author) for author in result.authors]
|
211
|
+
if "summary" in include_fields:
|
212
|
+
paper["summary"] = result.summary
|
213
|
+
if "comment" in include_fields:
|
214
|
+
paper["comment"] = result.comment
|
215
|
+
if "journal_ref" in include_fields:
|
216
|
+
paper["journal_ref"] = result.journal_ref
|
217
|
+
if "doi" in include_fields:
|
218
|
+
paper["doi"] = result.doi
|
219
|
+
if "primary_category" in include_fields:
|
220
|
+
paper["primary_category"] = result.primary_category
|
221
|
+
if "categories" in include_fields:
|
222
|
+
paper["categories"] = result.categories
|
223
|
+
if "links" in include_fields:
|
224
|
+
paper["pdf_url"] = result.pdf_url
|
225
|
+
paper["abstract_url"] = f"https://arxiv.org/abs/{paper['arxiv_id']}"
|
226
|
+
|
227
|
+
return paper
|
228
|
+
|
229
|
+
# Create instance for direct function access
|
230
|
+
_arxiv_tools = ArxivTools()
|
231
|
+
search_arxiv = _arxiv_tools.search
|
232
|
+
get_arxiv_paper = _arxiv_tools.get_paper
|
233
|
+
get_papers_by_author = _arxiv_tools.get_papers_by_author
|
234
|
+
get_papers_by_category = _arxiv_tools.get_papers_by_category
|
235
|
+
|
236
|
+
if __name__ == "__main__":
|
237
|
+
# Example usage
|
238
|
+
print("\n==================================================")
|
239
|
+
print("ArxivTools Demonstration")
|
240
|
+
print("==================================================\n")
|
241
|
+
|
242
|
+
# 1. Search for papers
|
243
|
+
print("1. Searching for Papers")
|
244
|
+
print("------------------------------")
|
245
|
+
query = "quantum computing"
|
246
|
+
papers = search_arxiv(query, max_results=3)
|
247
|
+
print(f"Papers about {query}:")
|
248
|
+
if isinstance(papers, list):
|
249
|
+
print(json.dumps(papers, indent=2))
|
250
|
+
else:
|
251
|
+
print(papers) # Show error
|
252
|
+
print()
|
253
|
+
|
254
|
+
# 2. Get specific paper
|
255
|
+
print("2. Getting Specific Paper")
|
256
|
+
print("------------------------------")
|
257
|
+
if isinstance(papers, list) and papers:
|
258
|
+
paper_id = papers[0]["arxiv_id"]
|
259
|
+
paper = get_arxiv_paper(paper_id)
|
260
|
+
print(f"Paper {paper_id}:")
|
261
|
+
if "error" not in paper:
|
262
|
+
print(json.dumps(paper, indent=2))
|
263
|
+
else:
|
264
|
+
print(paper) # Show error
|
265
|
+
print()
|
266
|
+
|
267
|
+
# 3. Get papers by author
|
268
|
+
print("3. Getting Papers by Author")
|
269
|
+
print("------------------------------")
|
270
|
+
author = "Yoshua Bengio"
|
271
|
+
author_papers = get_papers_by_author(author, max_results=3)
|
272
|
+
print(f"Papers by {author}:")
|
273
|
+
if isinstance(author_papers, list):
|
274
|
+
print(json.dumps(author_papers, indent=2))
|
275
|
+
else:
|
276
|
+
print(author_papers) # Show error
|
277
|
+
print()
|
278
|
+
|
279
|
+
# 4. Get papers by category
|
280
|
+
print("4. Getting Papers by Category")
|
281
|
+
print("------------------------------")
|
282
|
+
category = "cs.AI"
|
283
|
+
category_papers = get_papers_by_category(category, max_results=3)
|
284
|
+
print(f"Papers in category {category}:")
|
285
|
+
if isinstance(category_papers, list):
|
286
|
+
print(json.dumps(category_papers, indent=2))
|
287
|
+
else:
|
288
|
+
print(category_papers) # Show error
|
289
|
+
|
290
|
+
print("\n==================================================")
|
291
|
+
print("Demonstration Complete")
|
292
|
+
print("==================================================")
|