praisonaiagents 0.0.22__tar.gz → 0.0.24__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (48) hide show
  1. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/PKG-INFO +1 -1
  2. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/agent/agent.py +22 -33
  3. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/agents/agents.py +18 -4
  4. praisonaiagents-0.0.24/praisonaiagents/tools/__init__.py +167 -0
  5. praisonaiagents-0.0.24/praisonaiagents/tools/arxiv_tools.py +292 -0
  6. praisonaiagents-0.0.24/praisonaiagents/tools/calculator_tools.py +278 -0
  7. praisonaiagents-0.0.24/praisonaiagents/tools/csv_tools.py +266 -0
  8. praisonaiagents-0.0.24/praisonaiagents/tools/duckdb_tools.py +268 -0
  9. praisonaiagents-0.0.24/praisonaiagents/tools/duckduckgo_tools.py +52 -0
  10. praisonaiagents-0.0.24/praisonaiagents/tools/excel_tools.py +310 -0
  11. praisonaiagents-0.0.24/praisonaiagents/tools/file_tools.py +274 -0
  12. praisonaiagents-0.0.24/praisonaiagents/tools/json_tools.py +515 -0
  13. praisonaiagents-0.0.24/praisonaiagents/tools/newspaper_tools.py +354 -0
  14. praisonaiagents-0.0.24/praisonaiagents/tools/pandas_tools.py +326 -0
  15. praisonaiagents-0.0.24/praisonaiagents/tools/python_tools.py +423 -0
  16. praisonaiagents-0.0.24/praisonaiagents/tools/shell_tools.py +278 -0
  17. praisonaiagents-0.0.24/praisonaiagents/tools/spider_tools.py +431 -0
  18. praisonaiagents-0.0.24/praisonaiagents/tools/test.py +56 -0
  19. praisonaiagents-0.0.24/praisonaiagents/tools/tools.py +9 -0
  20. praisonaiagents-0.0.24/praisonaiagents/tools/wikipedia_tools.py +272 -0
  21. praisonaiagents-0.0.24/praisonaiagents/tools/xml_tools.py +498 -0
  22. praisonaiagents-0.0.24/praisonaiagents/tools/yaml_tools.py +417 -0
  23. praisonaiagents-0.0.24/praisonaiagents/tools/yfinance_tools.py +213 -0
  24. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents.egg-info/PKG-INFO +1 -1
  25. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents.egg-info/SOURCES.txt +19 -1
  26. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/pyproject.toml +1 -1
  27. praisonaiagents-0.0.22/praisonaiagents/tools/__init__.py +0 -4
  28. praisonaiagents-0.0.22/praisonaiagents/tools/tools.py +0 -40
  29. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/__init__.py +0 -0
  30. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/agent/__init__.py +0 -0
  31. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/agents/__init__.py +0 -0
  32. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/build/lib/praisonaiagents/__init__.py +0 -0
  33. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/build/lib/praisonaiagents/agent/__init__.py +0 -0
  34. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/build/lib/praisonaiagents/agent/agent.py +0 -0
  35. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/build/lib/praisonaiagents/agents/__init__.py +0 -0
  36. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/build/lib/praisonaiagents/agents/agents.py +0 -0
  37. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/build/lib/praisonaiagents/main.py +0 -0
  38. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/build/lib/praisonaiagents/task/__init__.py +0 -0
  39. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/build/lib/praisonaiagents/task/task.py +0 -0
  40. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/main.py +0 -0
  41. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/process/__init__.py +0 -0
  42. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/process/process.py +0 -0
  43. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/task/__init__.py +0 -0
  44. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents/task/task.py +0 -0
  45. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents.egg-info/dependency_links.txt +0 -0
  46. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents.egg-info/requires.txt +0 -0
  47. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/praisonaiagents.egg-info/top_level.txt +0 -0
  48. {praisonaiagents-0.0.22 → praisonaiagents-0.0.24}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: praisonaiagents
3
- Version: 0.0.22
3
+ Version: 0.0.24
4
4
  Summary: Praison AI agents for completing complex tasks with Self Reflection Agents
5
5
  Author: Mervin Praison
6
6
  Requires-Dist: pydantic
@@ -394,7 +394,7 @@ Your Goal: {self.goal}
394
394
  display_error(f"Error in chat completion: {e}")
395
395
  return None
396
396
 
397
- def chat(self, prompt, temperature=0.2, tools=None, output_json=None):
397
+ def chat(self, prompt, temperature=0.2, tools=None, output_json=None, output_pydantic=None):
398
398
  if self.use_system_prompt:
399
399
  system_prompt = f"""{self.backstory}\n
400
400
  Your Role: {self.role}\n
@@ -402,6 +402,8 @@ Your Goal: {self.goal}
402
402
  """
403
403
  if output_json:
404
404
  system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {output_json.schema_json()}"
405
+ elif output_pydantic:
406
+ system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {output_pydantic.schema_json()}"
405
407
  else:
406
408
  system_prompt = None
407
409
 
@@ -410,9 +412,9 @@ Your Goal: {self.goal}
410
412
  messages.append({"role": "system", "content": system_prompt})
411
413
  messages.extend(self.chat_history)
412
414
 
413
- # Modify prompt if output_json is specified
415
+ # Modify prompt if output_json or output_pydantic is specified
414
416
  original_prompt = prompt
415
- if output_json:
417
+ if output_json or output_pydantic:
416
418
  if isinstance(prompt, str):
417
419
  prompt += "\nReturn ONLY a valid JSON object. No other text or explanation."
418
420
  elif isinstance(prompt, list):
@@ -487,23 +489,15 @@ Your Goal: {self.goal}
487
489
  return None
488
490
  response_text = response.choices[0].message.content.strip()
489
491
 
490
- # Handle output_json if specified
491
- if output_json:
492
- try:
493
- # Clean the response text to get only JSON
494
- cleaned_json = self.clean_json_output(response_text)
495
- # Parse into Pydantic model
496
- parsed_model = output_json.model_validate_json(cleaned_json)
497
- # Add to chat history and return
498
- self.chat_history.append({"role": "user", "content": original_prompt})
499
- self.chat_history.append({"role": "assistant", "content": response_text})
500
- if self.verbose:
501
- display_interaction(original_prompt, response_text, markdown=self.markdown,
502
- generation_time=time.time() - start_time, console=self.console)
503
- return parsed_model
504
- except Exception as e:
505
- display_error(f"Failed to parse response as {output_json.__name__}: {e}")
506
- return None
492
+ # Handle output_json or output_pydantic if specified
493
+ if output_json or output_pydantic:
494
+ # Add to chat history and return raw response
495
+ self.chat_history.append({"role": "user", "content": original_prompt})
496
+ self.chat_history.append({"role": "assistant", "content": response_text})
497
+ if self.verbose:
498
+ display_interaction(original_prompt, response_text, markdown=self.markdown,
499
+ generation_time=time.time() - start_time, console=self.console)
500
+ return response_text
507
501
 
508
502
  if not self.self_reflect:
509
503
  self.chat_history.append({"role": "user", "content": original_prompt})
@@ -585,19 +579,21 @@ Output MUST be JSON with 'reflection' and 'satisfactory'.
585
579
  cleaned = cleaned[:-3].strip()
586
580
  return cleaned
587
581
 
588
- async def achat(self, prompt, temperature=0.2, tools=None, output_json=None):
582
+ async def achat(self, prompt, temperature=0.2, tools=None, output_json=None, output_pydantic=None):
589
583
  """Async version of chat method"""
590
584
  try:
591
585
  # Build system prompt
592
586
  system_prompt = self.system_prompt
593
587
  if output_json:
594
588
  system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {output_json.schema_json()}"
589
+ elif output_pydantic:
590
+ system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {output_pydantic.schema_json()}"
595
591
 
596
592
  # Build messages
597
593
  if isinstance(prompt, str):
598
594
  messages = [
599
595
  {"role": "system", "content": system_prompt},
600
- {"role": "user", "content": prompt + ("\nReturn ONLY a valid JSON object. No other text or explanation." if output_json else "")}
596
+ {"role": "user", "content": prompt + ("\nReturn ONLY a valid JSON object. No other text or explanation." if (output_json or output_pydantic) else "")}
601
597
  ]
602
598
  else:
603
599
  # For multimodal prompts
@@ -605,7 +601,7 @@ Output MUST be JSON with 'reflection' and 'satisfactory'.
605
601
  {"role": "system", "content": system_prompt},
606
602
  {"role": "user", "content": prompt}
607
603
  ]
608
- if output_json:
604
+ if output_json or output_pydantic:
609
605
  # Add JSON instruction to text content
610
606
  for item in messages[-1]["content"]:
611
607
  if item["type"] == "text":
@@ -639,22 +635,15 @@ Output MUST be JSON with 'reflection' and 'satisfactory'.
639
635
  tools=formatted_tools
640
636
  )
641
637
  return await self._achat_completion(response, tools)
642
- elif output_json:
638
+ elif output_json or output_pydantic:
643
639
  response = await async_client.chat.completions.create(
644
640
  model=self.llm,
645
641
  messages=messages,
646
642
  temperature=temperature,
647
643
  response_format={"type": "json_object"}
648
644
  )
649
- result = response.choices[0].message.content
650
- # Clean and parse the JSON response
651
- cleaned_json = self.clean_json_output(result)
652
- try:
653
- parsed = json.loads(cleaned_json)
654
- return output_json(**parsed)
655
- except Exception as e:
656
- display_error(f"Error parsing JSON response: {e}")
657
- return None
645
+ # Return the raw response
646
+ return response.choices[0].message.content
658
647
  else:
659
648
  response = await async_client.chat.completions.create(
660
649
  model=self.llm,
@@ -195,10 +195,17 @@ Here are the results of previous tasks that might be useful:\n
195
195
 
196
196
  agent_output = await executor_agent.achat(
197
197
  _get_multimodal_message(task_prompt, task.images),
198
- tools=task.tools
198
+ tools=task.tools,
199
+ output_json=task.output_json,
200
+ output_pydantic=task.output_pydantic
199
201
  )
200
202
  else:
201
- agent_output = await executor_agent.achat(task_prompt, tools=task.tools)
203
+ agent_output = await executor_agent.achat(
204
+ task_prompt,
205
+ tools=task.tools,
206
+ output_json=task.output_json,
207
+ output_pydantic=task.output_pydantic
208
+ )
202
209
 
203
210
  if agent_output:
204
211
  task_output = TaskOutput(
@@ -405,10 +412,17 @@ Here are the results of previous tasks that might be useful:\n
405
412
 
406
413
  agent_output = executor_agent.chat(
407
414
  _get_multimodal_message(task_prompt, task.images),
408
- tools=task.tools
415
+ tools=task.tools,
416
+ output_json=task.output_json,
417
+ output_pydantic=task.output_pydantic
409
418
  )
410
419
  else:
411
- agent_output = executor_agent.chat(task_prompt, tools=task.tools)
420
+ agent_output = executor_agent.chat(
421
+ task_prompt,
422
+ tools=task.tools,
423
+ output_json=task.output_json,
424
+ output_pydantic=task.output_pydantic
425
+ )
412
426
 
413
427
  if agent_output:
414
428
  task_output = TaskOutput(
@@ -0,0 +1,167 @@
1
+ """Tools package for PraisonAI Agents"""
2
+ from importlib import import_module
3
+ from typing import Any
4
+
5
+ # Map of function names to their module and class (if any)
6
+ TOOL_MAPPINGS = {
7
+ # Direct functions
8
+ 'internet_search': ('.duckduckgo_tools', None),
9
+ 'duckduckgo': ('.duckduckgo_tools', None),
10
+
11
+ # Class methods from YFinance
12
+ 'get_stock_price': ('.yfinance_tools', 'YFinanceTools'),
13
+ 'get_stock_info': ('.yfinance_tools', 'YFinanceTools'),
14
+ 'get_historical_data': ('.yfinance_tools', 'YFinanceTools'),
15
+ 'yfinance': ('.yfinance_tools', 'YFinanceTools'),
16
+
17
+ # File Tools
18
+ 'read_file': ('.file_tools', 'FileTools'),
19
+ 'write_file': ('.file_tools', 'FileTools'),
20
+ 'list_files': ('.file_tools', 'FileTools'),
21
+ 'get_file_info': ('.file_tools', 'FileTools'),
22
+ 'copy_file': ('.file_tools', 'FileTools'),
23
+ 'move_file': ('.file_tools', 'FileTools'),
24
+ 'delete_file': ('.file_tools', 'FileTools'),
25
+ 'file_tools': ('.file_tools', 'FileTools'),
26
+
27
+ # CSV Tools
28
+ 'read_csv': ('.csv_tools', 'CSVTools'),
29
+ 'write_csv': ('.csv_tools', 'CSVTools'),
30
+ 'merge_csv': ('.csv_tools', 'CSVTools'),
31
+ 'analyze_csv': ('.csv_tools', 'CSVTools'),
32
+ 'split_csv': ('.csv_tools', 'CSVTools'),
33
+ 'csv_tools': ('.csv_tools', 'CSVTools'),
34
+
35
+ # JSON Tools
36
+ 'read_json': ('.json_tools', 'JSONTools'),
37
+ 'write_json': ('.json_tools', 'JSONTools'),
38
+ 'merge_json': ('.json_tools', 'JSONTools'),
39
+ 'validate_json': ('.json_tools', 'JSONTools'),
40
+ 'analyze_json': ('.json_tools', 'JSONTools'),
41
+ 'transform_json': ('.json_tools', 'JSONTools'),
42
+ 'json_tools': ('.json_tools', 'JSONTools'),
43
+
44
+ # Excel Tools
45
+ 'read_excel': ('.excel_tools', 'ExcelTools'),
46
+ 'write_excel': ('.excel_tools', 'ExcelTools'),
47
+ 'merge_excel': ('.excel_tools', 'ExcelTools'),
48
+ 'create_chart': ('.excel_tools', 'ExcelTools'),
49
+ 'add_chart_to_sheet': ('.excel_tools', 'ExcelTools'),
50
+ 'excel_tools': ('.excel_tools', 'ExcelTools'),
51
+
52
+ # XML Tools
53
+ 'read_xml': ('.xml_tools', 'XMLTools'),
54
+ 'write_xml': ('.xml_tools', 'XMLTools'),
55
+ 'transform_xml': ('.xml_tools', 'XMLTools'),
56
+ 'validate_xml': ('.xml_tools', 'XMLTools'),
57
+ 'xml_to_dict': ('.xml_tools', 'XMLTools'),
58
+ 'dict_to_xml': ('.xml_tools', 'XMLTools'),
59
+ 'xpath_query': ('.xml_tools', 'XMLTools'),
60
+ 'xml_tools': ('.xml_tools', 'XMLTools'),
61
+
62
+ # YAML Tools
63
+ 'read_yaml': ('.yaml_tools', 'YAMLTools'),
64
+ 'write_yaml': ('.yaml_tools', 'YAMLTools'),
65
+ 'merge_yaml': ('.yaml_tools', 'YAMLTools'),
66
+ 'validate_yaml': ('.yaml_tools', 'YAMLTools'),
67
+ 'analyze_yaml': ('.yaml_tools', 'YAMLTools'),
68
+ 'transform_yaml': ('.yaml_tools', 'YAMLTools'),
69
+ 'yaml_tools': ('.yaml_tools', 'YAMLTools'),
70
+
71
+ # Calculator Tools
72
+ 'evaluate': ('.calculator_tools', 'CalculatorTools'),
73
+ 'solve_equation': ('.calculator_tools', 'CalculatorTools'),
74
+ 'convert_units': ('.calculator_tools', 'CalculatorTools'),
75
+ 'calculate_statistics': ('.calculator_tools', 'CalculatorTools'),
76
+ 'calculate_financial': ('.calculator_tools', 'CalculatorTools'),
77
+ 'calculator_tools': ('.calculator_tools', 'CalculatorTools'),
78
+
79
+ # Python Tools
80
+ 'execute_code': ('.python_tools', 'PythonTools'),
81
+ 'analyze_code': ('.python_tools', 'PythonTools'),
82
+ 'format_code': ('.python_tools', 'PythonTools'),
83
+ 'lint_code': ('.python_tools', 'PythonTools'),
84
+ 'disassemble_code': ('.python_tools', 'PythonTools'),
85
+ 'python_tools': ('.python_tools', 'PythonTools'),
86
+
87
+ # Pandas Tools
88
+ 'filter_data': ('.pandas_tools', 'PandasTools'),
89
+ 'get_summary': ('.pandas_tools', 'PandasTools'),
90
+ 'group_by': ('.pandas_tools', 'PandasTools'),
91
+ 'pivot_table': ('.pandas_tools', 'PandasTools'),
92
+ 'pandas_tools': ('.pandas_tools', 'PandasTools'),
93
+
94
+ # Wikipedia Tools
95
+ 'search': ('.wikipedia_tools', 'WikipediaTools'),
96
+ 'get_wikipedia_summary': ('.wikipedia_tools', 'WikipediaTools'),
97
+ 'get_wikipedia_page': ('.wikipedia_tools', 'WikipediaTools'),
98
+ 'get_random_wikipedia': ('.wikipedia_tools', 'WikipediaTools'),
99
+ 'set_wikipedia_language': ('.wikipedia_tools', 'WikipediaTools'),
100
+ 'wikipedia_tools': ('.wikipedia_tools', 'WikipediaTools'),
101
+
102
+ # Newspaper Tools
103
+ 'get_article': ('.newspaper_tools', 'NewspaperTools'),
104
+ 'get_news_sources': ('.newspaper_tools', 'NewspaperTools'),
105
+ 'get_articles_from_source': ('.newspaper_tools', 'NewspaperTools'),
106
+ 'get_trending_topics': ('.newspaper_tools', 'NewspaperTools'),
107
+ 'newspaper_tools': ('.newspaper_tools', 'NewspaperTools'),
108
+
109
+ # arXiv Tools
110
+ 'search_arxiv': ('.arxiv_tools', 'ArxivTools'),
111
+ 'get_arxiv_paper': ('.arxiv_tools', 'ArxivTools'),
112
+ 'get_papers_by_author': ('.arxiv_tools', 'ArxivTools'),
113
+ 'get_papers_by_category': ('.arxiv_tools', 'ArxivTools'),
114
+ 'arxiv_tools': ('.arxiv_tools', 'ArxivTools'),
115
+
116
+ # Spider Tools
117
+ 'scrape_page': ('.spider_tools', 'SpiderTools'),
118
+ 'extract_links': ('.spider_tools', 'SpiderTools'),
119
+ 'crawl': ('.spider_tools', 'SpiderTools'),
120
+ 'extract_text': ('.spider_tools', 'SpiderTools'),
121
+ 'spider_tools': ('.spider_tools', 'SpiderTools'),
122
+
123
+ # DuckDB Tools
124
+ 'query': ('.duckdb_tools', 'DuckDBTools'),
125
+ 'create_table': ('.duckdb_tools', 'DuckDBTools'),
126
+ 'load_data': ('.duckdb_tools', 'DuckDBTools'),
127
+ 'export_data': ('.duckdb_tools', 'DuckDBTools'),
128
+ 'get_table_info': ('.duckdb_tools', 'DuckDBTools'),
129
+ 'analyze_data': ('.duckdb_tools', 'DuckDBTools'),
130
+ 'duckdb_tools': ('.duckdb_tools', 'DuckDBTools'),
131
+
132
+ # Shell Tools
133
+ 'execute_command': ('.shell_tools', 'ShellTools'),
134
+ 'list_processes': ('.shell_tools', 'ShellTools'),
135
+ 'kill_process': ('.shell_tools', 'ShellTools'),
136
+ 'get_system_info': ('.shell_tools', 'ShellTools'),
137
+ 'shell_tools': ('.shell_tools', 'ShellTools'),
138
+ }
139
+
140
+ _instances = {} # Cache for class instances
141
+
142
+ def __getattr__(name: str) -> Any:
143
+ """Smart lazy loading of tools with class method support."""
144
+ if name not in TOOL_MAPPINGS:
145
+ raise AttributeError(f"module '{__package__}' has no attribute '{name}'")
146
+
147
+ module_path, class_name = TOOL_MAPPINGS[name]
148
+
149
+ if class_name is None:
150
+ # Direct function import
151
+ module = import_module(module_path, __package__)
152
+ if name in ['duckduckgo', 'file_tools', 'pandas_tools', 'wikipedia_tools',
153
+ 'newspaper_tools', 'arxiv_tools', 'spider_tools', 'duckdb_tools', 'csv_tools', 'json_tools', 'excel_tools', 'xml_tools', 'yaml_tools', 'calculator_tools', 'python_tools', 'shell_tools']:
154
+ return module # Returns the callable module
155
+ return getattr(module, name)
156
+ else:
157
+ # Class method import
158
+ if class_name not in _instances:
159
+ module = import_module(module_path, __package__)
160
+ class_ = getattr(module, class_name)
161
+ _instances[class_name] = class_()
162
+
163
+ # Get the method and bind it to the instance
164
+ method = getattr(_instances[class_name], name)
165
+ return method
166
+
167
+ __all__ = list(TOOL_MAPPINGS.keys())
@@ -0,0 +1,292 @@
1
+ """Tools for searching and retrieving papers from arXiv.
2
+
3
+ Usage:
4
+ from praisonaiagents.tools import arxiv_tools
5
+ papers = arxiv_tools.search("quantum computing")
6
+ paper = arxiv_tools.get_paper("2401.00123")
7
+
8
+ or
9
+ from praisonaiagents.tools import search_arxiv, get_arxiv_paper
10
+ papers = search_arxiv("quantum computing")
11
+ """
12
+
13
+ import logging
14
+ from typing import List, Dict, Union, Optional, Any
15
+ from importlib import util
16
+ import json
17
+
18
+ # Map sort criteria to arxiv.SortCriterion
19
+ SORT_CRITERIA = {
20
+ "relevance": "Relevance",
21
+ "lastupdateddate": "LastUpdatedDate",
22
+ "submitteddate": "SubmittedDate"
23
+ }
24
+
25
+ # Map sort order to arxiv.SortOrder
26
+ SORT_ORDER = {
27
+ "ascending": "Ascending",
28
+ "descending": "Descending"
29
+ }
30
+
31
+ class ArxivTools:
32
+ """Tools for searching and retrieving papers from arXiv."""
33
+
34
+ def __init__(self):
35
+ """Initialize ArxivTools and check for arxiv package."""
36
+ self._check_arxiv()
37
+
38
+ def _check_arxiv(self):
39
+ """Check if arxiv package is installed."""
40
+ if util.find_spec("arxiv") is None:
41
+ raise ImportError("arxiv package is not available. Please install it using: pip install arxiv")
42
+ global arxiv
43
+ import arxiv
44
+
45
+ def search(
46
+ self,
47
+ query: str,
48
+ max_results: int = 10,
49
+ sort_by: str = "relevance",
50
+ sort_order: str = "descending",
51
+ include_fields: Optional[List[str]] = None
52
+ ) -> Union[List[Dict[str, Any]], Dict[str, str]]:
53
+ """
54
+ Search arXiv for papers matching the query.
55
+
56
+ Args:
57
+ query: Search query (e.g., "quantum computing", "author:Einstein")
58
+ max_results: Maximum number of results to return
59
+ sort_by: Sort results by ("relevance", "lastUpdatedDate", "submittedDate")
60
+ sort_order: Sort order ("ascending" or "descending")
61
+ include_fields: List of fields to include in results. If None, includes all.
62
+ Available fields: ["title", "authors", "summary", "comment", "journal_ref",
63
+ "doi", "primary_category", "categories", "links"]
64
+
65
+ Returns:
66
+ List[Dict] or Dict: List of papers or error dict
67
+ """
68
+ try:
69
+ import arxiv
70
+
71
+ # Configure search client
72
+ client = arxiv.Client()
73
+
74
+ # Map sort criteria
75
+ sort_by_enum = getattr(arxiv.SortCriterion, SORT_CRITERIA[sort_by.lower()])
76
+ sort_order_enum = getattr(arxiv.SortOrder, SORT_ORDER[sort_order.lower()])
77
+
78
+ # Build search query
79
+ search = arxiv.Search(
80
+ query=query,
81
+ max_results=max_results,
82
+ sort_by=sort_by_enum,
83
+ sort_order=sort_order_enum
84
+ )
85
+
86
+ # Execute search
87
+ results = []
88
+ for result in client.results(search):
89
+ # Convert to dict with selected fields
90
+ paper = self._result_to_dict(result, include_fields)
91
+ results.append(paper)
92
+
93
+ return results
94
+ except Exception as e:
95
+ error_msg = f"Error searching arXiv: {str(e)}"
96
+ logging.error(error_msg)
97
+ return {"error": error_msg}
98
+
99
+ def get_paper(
100
+ self,
101
+ paper_id: str,
102
+ include_fields: Optional[List[str]] = None
103
+ ) -> Union[Dict[str, Any], Dict[str, str]]:
104
+ """
105
+ Get details of a specific paper by its arXiv ID.
106
+
107
+ Args:
108
+ paper_id: arXiv paper ID (e.g., "2401.00123")
109
+ include_fields: List of fields to include in results. If None, includes all.
110
+ Available fields: ["title", "authors", "summary", "comment", "journal_ref",
111
+ "doi", "primary_category", "categories", "links"]
112
+
113
+ Returns:
114
+ Dict: Paper details or error dict
115
+ """
116
+ try:
117
+ import arxiv
118
+
119
+ # Configure client
120
+ client = arxiv.Client()
121
+
122
+ # Get paper by ID
123
+ search = arxiv.Search(id_list=[paper_id])
124
+ results = list(client.results(search))
125
+
126
+ if not results:
127
+ return {"error": f"Paper with ID {paper_id} not found"}
128
+
129
+ # Convert to dict with selected fields
130
+ paper = self._result_to_dict(results[0], include_fields)
131
+ return paper
132
+ except Exception as e:
133
+ error_msg = f"Error getting paper {paper_id}: {str(e)}"
134
+ logging.error(error_msg)
135
+ return {"error": error_msg}
136
+
137
+ def get_papers_by_author(
138
+ self,
139
+ author: str,
140
+ max_results: int = 10,
141
+ sort_by: str = "submittedDate",
142
+ sort_order: str = "descending",
143
+ include_fields: Optional[List[str]] = None
144
+ ) -> Union[List[Dict[str, Any]], Dict[str, str]]:
145
+ """
146
+ Get papers by a specific author.
147
+
148
+ Args:
149
+ author: Author name (e.g., "Einstein")
150
+ max_results: Maximum number of results to return
151
+ sort_by: Sort results by ("relevance", "lastUpdatedDate", "submittedDate")
152
+ sort_order: Sort order ("ascending" or "descending")
153
+ include_fields: List of fields to include in results
154
+
155
+ Returns:
156
+ List[Dict] or Dict: List of papers or error dict
157
+ """
158
+ query = f'au:"{author}"'
159
+ return self.search(query, max_results, sort_by, sort_order, include_fields)
160
+
161
+ def get_papers_by_category(
162
+ self,
163
+ category: str,
164
+ max_results: int = 10,
165
+ sort_by: str = "submittedDate",
166
+ sort_order: str = "descending",
167
+ include_fields: Optional[List[str]] = None
168
+ ) -> Union[List[Dict[str, Any]], Dict[str, str]]:
169
+ """
170
+ Get papers from a specific category.
171
+
172
+ Args:
173
+ category: arXiv category (e.g., "cs.AI", "physics.gen-ph")
174
+ max_results: Maximum number of results to return
175
+ sort_by: Sort results by ("relevance", "lastUpdatedDate", "submittedDate")
176
+ sort_order: Sort order ("ascending" or "descending")
177
+ include_fields: List of fields to include in results
178
+
179
+ Returns:
180
+ List[Dict] or Dict: List of papers or error dict
181
+ """
182
+ query = f'cat:{category}'
183
+ return self.search(query, max_results, sort_by, sort_order, include_fields)
184
+
185
+ def _result_to_dict(
186
+ self,
187
+ result: Any,
188
+ include_fields: Optional[List[str]] = None
189
+ ) -> Dict[str, Any]:
190
+ """Convert arxiv.Result to dictionary with selected fields."""
191
+ # Default fields to include
192
+ if include_fields is None:
193
+ include_fields = [
194
+ "title", "authors", "summary", "comment", "journal_ref",
195
+ "doi", "primary_category", "categories", "links"
196
+ ]
197
+
198
+ # Build paper dict with selected fields
199
+ paper = {}
200
+
201
+ # Always include these basic fields
202
+ paper["arxiv_id"] = result.entry_id.split("/")[-1]
203
+ paper["updated"] = result.updated.isoformat() if result.updated else None
204
+ paper["published"] = result.published.isoformat() if result.published else None
205
+
206
+ # Add requested fields
207
+ if "title" in include_fields:
208
+ paper["title"] = result.title
209
+ if "authors" in include_fields:
210
+ paper["authors"] = [str(author) for author in result.authors]
211
+ if "summary" in include_fields:
212
+ paper["summary"] = result.summary
213
+ if "comment" in include_fields:
214
+ paper["comment"] = result.comment
215
+ if "journal_ref" in include_fields:
216
+ paper["journal_ref"] = result.journal_ref
217
+ if "doi" in include_fields:
218
+ paper["doi"] = result.doi
219
+ if "primary_category" in include_fields:
220
+ paper["primary_category"] = result.primary_category
221
+ if "categories" in include_fields:
222
+ paper["categories"] = result.categories
223
+ if "links" in include_fields:
224
+ paper["pdf_url"] = result.pdf_url
225
+ paper["abstract_url"] = f"https://arxiv.org/abs/{paper['arxiv_id']}"
226
+
227
+ return paper
228
+
229
+ # Create instance for direct function access
230
+ _arxiv_tools = ArxivTools()
231
+ search_arxiv = _arxiv_tools.search
232
+ get_arxiv_paper = _arxiv_tools.get_paper
233
+ get_papers_by_author = _arxiv_tools.get_papers_by_author
234
+ get_papers_by_category = _arxiv_tools.get_papers_by_category
235
+
236
+ if __name__ == "__main__":
237
+ # Example usage
238
+ print("\n==================================================")
239
+ print("ArxivTools Demonstration")
240
+ print("==================================================\n")
241
+
242
+ # 1. Search for papers
243
+ print("1. Searching for Papers")
244
+ print("------------------------------")
245
+ query = "quantum computing"
246
+ papers = search_arxiv(query, max_results=3)
247
+ print(f"Papers about {query}:")
248
+ if isinstance(papers, list):
249
+ print(json.dumps(papers, indent=2))
250
+ else:
251
+ print(papers) # Show error
252
+ print()
253
+
254
+ # 2. Get specific paper
255
+ print("2. Getting Specific Paper")
256
+ print("------------------------------")
257
+ if isinstance(papers, list) and papers:
258
+ paper_id = papers[0]["arxiv_id"]
259
+ paper = get_arxiv_paper(paper_id)
260
+ print(f"Paper {paper_id}:")
261
+ if "error" not in paper:
262
+ print(json.dumps(paper, indent=2))
263
+ else:
264
+ print(paper) # Show error
265
+ print()
266
+
267
+ # 3. Get papers by author
268
+ print("3. Getting Papers by Author")
269
+ print("------------------------------")
270
+ author = "Yoshua Bengio"
271
+ author_papers = get_papers_by_author(author, max_results=3)
272
+ print(f"Papers by {author}:")
273
+ if isinstance(author_papers, list):
274
+ print(json.dumps(author_papers, indent=2))
275
+ else:
276
+ print(author_papers) # Show error
277
+ print()
278
+
279
+ # 4. Get papers by category
280
+ print("4. Getting Papers by Category")
281
+ print("------------------------------")
282
+ category = "cs.AI"
283
+ category_papers = get_papers_by_category(category, max_results=3)
284
+ print(f"Papers in category {category}:")
285
+ if isinstance(category_papers, list):
286
+ print(json.dumps(category_papers, indent=2))
287
+ else:
288
+ print(category_papers) # Show error
289
+
290
+ print("\n==================================================")
291
+ print("Demonstration Complete")
292
+ print("==================================================")