praisonaiagents 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -394,7 +394,7 @@ Your Goal: {self.goal}
394
394
  display_error(f"Error in chat completion: {e}")
395
395
  return None
396
396
 
397
- def chat(self, prompt, temperature=0.2, tools=None, output_json=None):
397
+ def chat(self, prompt, temperature=0.2, tools=None, output_json=None, output_pydantic=None):
398
398
  if self.use_system_prompt:
399
399
  system_prompt = f"""{self.backstory}\n
400
400
  Your Role: {self.role}\n
@@ -402,6 +402,8 @@ Your Goal: {self.goal}
402
402
  """
403
403
  if output_json:
404
404
  system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {output_json.schema_json()}"
405
+ elif output_pydantic:
406
+ system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {output_pydantic.schema_json()}"
405
407
  else:
406
408
  system_prompt = None
407
409
 
@@ -410,9 +412,9 @@ Your Goal: {self.goal}
410
412
  messages.append({"role": "system", "content": system_prompt})
411
413
  messages.extend(self.chat_history)
412
414
 
413
- # Modify prompt if output_json is specified
415
+ # Modify prompt if output_json or output_pydantic is specified
414
416
  original_prompt = prompt
415
- if output_json:
417
+ if output_json or output_pydantic:
416
418
  if isinstance(prompt, str):
417
419
  prompt += "\nReturn ONLY a valid JSON object. No other text or explanation."
418
420
  elif isinstance(prompt, list):
@@ -487,23 +489,15 @@ Your Goal: {self.goal}
487
489
  return None
488
490
  response_text = response.choices[0].message.content.strip()
489
491
 
490
- # Handle output_json if specified
491
- if output_json:
492
- try:
493
- # Clean the response text to get only JSON
494
- cleaned_json = self.clean_json_output(response_text)
495
- # Parse into Pydantic model
496
- parsed_model = output_json.model_validate_json(cleaned_json)
497
- # Add to chat history and return
498
- self.chat_history.append({"role": "user", "content": original_prompt})
499
- self.chat_history.append({"role": "assistant", "content": response_text})
500
- if self.verbose:
501
- display_interaction(original_prompt, response_text, markdown=self.markdown,
502
- generation_time=time.time() - start_time, console=self.console)
503
- return parsed_model
504
- except Exception as e:
505
- display_error(f"Failed to parse response as {output_json.__name__}: {e}")
506
- return None
492
+ # Handle output_json or output_pydantic if specified
493
+ if output_json or output_pydantic:
494
+ # Add to chat history and return raw response
495
+ self.chat_history.append({"role": "user", "content": original_prompt})
496
+ self.chat_history.append({"role": "assistant", "content": response_text})
497
+ if self.verbose:
498
+ display_interaction(original_prompt, response_text, markdown=self.markdown,
499
+ generation_time=time.time() - start_time, console=self.console)
500
+ return response_text
507
501
 
508
502
  if not self.self_reflect:
509
503
  self.chat_history.append({"role": "user", "content": original_prompt})
@@ -585,19 +579,21 @@ Output MUST be JSON with 'reflection' and 'satisfactory'.
585
579
  cleaned = cleaned[:-3].strip()
586
580
  return cleaned
587
581
 
588
- async def achat(self, prompt, temperature=0.2, tools=None, output_json=None):
582
+ async def achat(self, prompt, temperature=0.2, tools=None, output_json=None, output_pydantic=None):
589
583
  """Async version of chat method"""
590
584
  try:
591
585
  # Build system prompt
592
586
  system_prompt = self.system_prompt
593
587
  if output_json:
594
588
  system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {output_json.schema_json()}"
589
+ elif output_pydantic:
590
+ system_prompt += f"\nReturn ONLY a JSON object that matches this Pydantic model: {output_pydantic.schema_json()}"
595
591
 
596
592
  # Build messages
597
593
  if isinstance(prompt, str):
598
594
  messages = [
599
595
  {"role": "system", "content": system_prompt},
600
- {"role": "user", "content": prompt + ("\nReturn ONLY a valid JSON object. No other text or explanation." if output_json else "")}
596
+ {"role": "user", "content": prompt + ("\nReturn ONLY a valid JSON object. No other text or explanation." if (output_json or output_pydantic) else "")}
601
597
  ]
602
598
  else:
603
599
  # For multimodal prompts
@@ -605,7 +601,7 @@ Output MUST be JSON with 'reflection' and 'satisfactory'.
605
601
  {"role": "system", "content": system_prompt},
606
602
  {"role": "user", "content": prompt}
607
603
  ]
608
- if output_json:
604
+ if output_json or output_pydantic:
609
605
  # Add JSON instruction to text content
610
606
  for item in messages[-1]["content"]:
611
607
  if item["type"] == "text":
@@ -639,22 +635,15 @@ Output MUST be JSON with 'reflection' and 'satisfactory'.
639
635
  tools=formatted_tools
640
636
  )
641
637
  return await self._achat_completion(response, tools)
642
- elif output_json:
638
+ elif output_json or output_pydantic:
643
639
  response = await async_client.chat.completions.create(
644
640
  model=self.llm,
645
641
  messages=messages,
646
642
  temperature=temperature,
647
643
  response_format={"type": "json_object"}
648
644
  )
649
- result = response.choices[0].message.content
650
- # Clean and parse the JSON response
651
- cleaned_json = self.clean_json_output(result)
652
- try:
653
- parsed = json.loads(cleaned_json)
654
- return output_json(**parsed)
655
- except Exception as e:
656
- display_error(f"Error parsing JSON response: {e}")
657
- return None
645
+ # Return the raw response
646
+ return response.choices[0].message.content
658
647
  else:
659
648
  response = await async_client.chat.completions.create(
660
649
  model=self.llm,
@@ -195,10 +195,17 @@ Here are the results of previous tasks that might be useful:\n
195
195
 
196
196
  agent_output = await executor_agent.achat(
197
197
  _get_multimodal_message(task_prompt, task.images),
198
- tools=task.tools
198
+ tools=task.tools,
199
+ output_json=task.output_json,
200
+ output_pydantic=task.output_pydantic
199
201
  )
200
202
  else:
201
- agent_output = await executor_agent.achat(task_prompt, tools=task.tools)
203
+ agent_output = await executor_agent.achat(
204
+ task_prompt,
205
+ tools=task.tools,
206
+ output_json=task.output_json,
207
+ output_pydantic=task.output_pydantic
208
+ )
202
209
 
203
210
  if agent_output:
204
211
  task_output = TaskOutput(
@@ -405,10 +412,17 @@ Here are the results of previous tasks that might be useful:\n
405
412
 
406
413
  agent_output = executor_agent.chat(
407
414
  _get_multimodal_message(task_prompt, task.images),
408
- tools=task.tools
415
+ tools=task.tools,
416
+ output_json=task.output_json,
417
+ output_pydantic=task.output_pydantic
409
418
  )
410
419
  else:
411
- agent_output = executor_agent.chat(task_prompt, tools=task.tools)
420
+ agent_output = executor_agent.chat(
421
+ task_prompt,
422
+ tools=task.tools,
423
+ output_json=task.output_json,
424
+ output_pydantic=task.output_pydantic
425
+ )
412
426
 
413
427
  if agent_output:
414
428
  task_output = TaskOutput(
@@ -1,4 +1,167 @@
1
1
  """Tools package for PraisonAI Agents"""
2
- from .tools import Tools
2
+ from importlib import import_module
3
+ from typing import Any
3
4
 
4
- __all__ = ['Tools']
5
+ # Map of function names to their module and class (if any)
6
+ TOOL_MAPPINGS = {
7
+ # Direct functions
8
+ 'internet_search': ('.duckduckgo_tools', None),
9
+ 'duckduckgo': ('.duckduckgo_tools', None),
10
+
11
+ # Class methods from YFinance
12
+ 'get_stock_price': ('.yfinance_tools', 'YFinanceTools'),
13
+ 'get_stock_info': ('.yfinance_tools', 'YFinanceTools'),
14
+ 'get_historical_data': ('.yfinance_tools', 'YFinanceTools'),
15
+ 'yfinance': ('.yfinance_tools', 'YFinanceTools'),
16
+
17
+ # File Tools
18
+ 'read_file': ('.file_tools', 'FileTools'),
19
+ 'write_file': ('.file_tools', 'FileTools'),
20
+ 'list_files': ('.file_tools', 'FileTools'),
21
+ 'get_file_info': ('.file_tools', 'FileTools'),
22
+ 'copy_file': ('.file_tools', 'FileTools'),
23
+ 'move_file': ('.file_tools', 'FileTools'),
24
+ 'delete_file': ('.file_tools', 'FileTools'),
25
+ 'file_tools': ('.file_tools', 'FileTools'),
26
+
27
+ # CSV Tools
28
+ 'read_csv': ('.csv_tools', 'CSVTools'),
29
+ 'write_csv': ('.csv_tools', 'CSVTools'),
30
+ 'merge_csv': ('.csv_tools', 'CSVTools'),
31
+ 'analyze_csv': ('.csv_tools', 'CSVTools'),
32
+ 'split_csv': ('.csv_tools', 'CSVTools'),
33
+ 'csv_tools': ('.csv_tools', 'CSVTools'),
34
+
35
+ # JSON Tools
36
+ 'read_json': ('.json_tools', 'JSONTools'),
37
+ 'write_json': ('.json_tools', 'JSONTools'),
38
+ 'merge_json': ('.json_tools', 'JSONTools'),
39
+ 'validate_json': ('.json_tools', 'JSONTools'),
40
+ 'analyze_json': ('.json_tools', 'JSONTools'),
41
+ 'transform_json': ('.json_tools', 'JSONTools'),
42
+ 'json_tools': ('.json_tools', 'JSONTools'),
43
+
44
+ # Excel Tools
45
+ 'read_excel': ('.excel_tools', 'ExcelTools'),
46
+ 'write_excel': ('.excel_tools', 'ExcelTools'),
47
+ 'merge_excel': ('.excel_tools', 'ExcelTools'),
48
+ 'create_chart': ('.excel_tools', 'ExcelTools'),
49
+ 'add_chart_to_sheet': ('.excel_tools', 'ExcelTools'),
50
+ 'excel_tools': ('.excel_tools', 'ExcelTools'),
51
+
52
+ # XML Tools
53
+ 'read_xml': ('.xml_tools', 'XMLTools'),
54
+ 'write_xml': ('.xml_tools', 'XMLTools'),
55
+ 'transform_xml': ('.xml_tools', 'XMLTools'),
56
+ 'validate_xml': ('.xml_tools', 'XMLTools'),
57
+ 'xml_to_dict': ('.xml_tools', 'XMLTools'),
58
+ 'dict_to_xml': ('.xml_tools', 'XMLTools'),
59
+ 'xpath_query': ('.xml_tools', 'XMLTools'),
60
+ 'xml_tools': ('.xml_tools', 'XMLTools'),
61
+
62
+ # YAML Tools
63
+ 'read_yaml': ('.yaml_tools', 'YAMLTools'),
64
+ 'write_yaml': ('.yaml_tools', 'YAMLTools'),
65
+ 'merge_yaml': ('.yaml_tools', 'YAMLTools'),
66
+ 'validate_yaml': ('.yaml_tools', 'YAMLTools'),
67
+ 'analyze_yaml': ('.yaml_tools', 'YAMLTools'),
68
+ 'transform_yaml': ('.yaml_tools', 'YAMLTools'),
69
+ 'yaml_tools': ('.yaml_tools', 'YAMLTools'),
70
+
71
+ # Calculator Tools
72
+ 'evaluate': ('.calculator_tools', 'CalculatorTools'),
73
+ 'solve_equation': ('.calculator_tools', 'CalculatorTools'),
74
+ 'convert_units': ('.calculator_tools', 'CalculatorTools'),
75
+ 'calculate_statistics': ('.calculator_tools', 'CalculatorTools'),
76
+ 'calculate_financial': ('.calculator_tools', 'CalculatorTools'),
77
+ 'calculator_tools': ('.calculator_tools', 'CalculatorTools'),
78
+
79
+ # Python Tools
80
+ 'execute_code': ('.python_tools', 'PythonTools'),
81
+ 'analyze_code': ('.python_tools', 'PythonTools'),
82
+ 'format_code': ('.python_tools', 'PythonTools'),
83
+ 'lint_code': ('.python_tools', 'PythonTools'),
84
+ 'disassemble_code': ('.python_tools', 'PythonTools'),
85
+ 'python_tools': ('.python_tools', 'PythonTools'),
86
+
87
+ # Pandas Tools
88
+ 'filter_data': ('.pandas_tools', 'PandasTools'),
89
+ 'get_summary': ('.pandas_tools', 'PandasTools'),
90
+ 'group_by': ('.pandas_tools', 'PandasTools'),
91
+ 'pivot_table': ('.pandas_tools', 'PandasTools'),
92
+ 'pandas_tools': ('.pandas_tools', 'PandasTools'),
93
+
94
+ # Wikipedia Tools
95
+ 'search': ('.wikipedia_tools', 'WikipediaTools'),
96
+ 'get_wikipedia_summary': ('.wikipedia_tools', 'WikipediaTools'),
97
+ 'get_wikipedia_page': ('.wikipedia_tools', 'WikipediaTools'),
98
+ 'get_random_wikipedia': ('.wikipedia_tools', 'WikipediaTools'),
99
+ 'set_wikipedia_language': ('.wikipedia_tools', 'WikipediaTools'),
100
+ 'wikipedia_tools': ('.wikipedia_tools', 'WikipediaTools'),
101
+
102
+ # Newspaper Tools
103
+ 'get_article': ('.newspaper_tools', 'NewspaperTools'),
104
+ 'get_news_sources': ('.newspaper_tools', 'NewspaperTools'),
105
+ 'get_articles_from_source': ('.newspaper_tools', 'NewspaperTools'),
106
+ 'get_trending_topics': ('.newspaper_tools', 'NewspaperTools'),
107
+ 'newspaper_tools': ('.newspaper_tools', 'NewspaperTools'),
108
+
109
+ # arXiv Tools
110
+ 'search_arxiv': ('.arxiv_tools', 'ArxivTools'),
111
+ 'get_arxiv_paper': ('.arxiv_tools', 'ArxivTools'),
112
+ 'get_papers_by_author': ('.arxiv_tools', 'ArxivTools'),
113
+ 'get_papers_by_category': ('.arxiv_tools', 'ArxivTools'),
114
+ 'arxiv_tools': ('.arxiv_tools', 'ArxivTools'),
115
+
116
+ # Spider Tools
117
+ 'scrape_page': ('.spider_tools', 'SpiderTools'),
118
+ 'extract_links': ('.spider_tools', 'SpiderTools'),
119
+ 'crawl': ('.spider_tools', 'SpiderTools'),
120
+ 'extract_text': ('.spider_tools', 'SpiderTools'),
121
+ 'spider_tools': ('.spider_tools', 'SpiderTools'),
122
+
123
+ # DuckDB Tools
124
+ 'query': ('.duckdb_tools', 'DuckDBTools'),
125
+ 'create_table': ('.duckdb_tools', 'DuckDBTools'),
126
+ 'load_data': ('.duckdb_tools', 'DuckDBTools'),
127
+ 'export_data': ('.duckdb_tools', 'DuckDBTools'),
128
+ 'get_table_info': ('.duckdb_tools', 'DuckDBTools'),
129
+ 'analyze_data': ('.duckdb_tools', 'DuckDBTools'),
130
+ 'duckdb_tools': ('.duckdb_tools', 'DuckDBTools'),
131
+
132
+ # Shell Tools
133
+ 'execute_command': ('.shell_tools', 'ShellTools'),
134
+ 'list_processes': ('.shell_tools', 'ShellTools'),
135
+ 'kill_process': ('.shell_tools', 'ShellTools'),
136
+ 'get_system_info': ('.shell_tools', 'ShellTools'),
137
+ 'shell_tools': ('.shell_tools', 'ShellTools'),
138
+ }
139
+
140
+ _instances = {} # Cache for class instances
141
+
142
+ def __getattr__(name: str) -> Any:
143
+ """Smart lazy loading of tools with class method support."""
144
+ if name not in TOOL_MAPPINGS:
145
+ raise AttributeError(f"module '{__package__}' has no attribute '{name}'")
146
+
147
+ module_path, class_name = TOOL_MAPPINGS[name]
148
+
149
+ if class_name is None:
150
+ # Direct function import
151
+ module = import_module(module_path, __package__)
152
+ if name in ['duckduckgo', 'file_tools', 'pandas_tools', 'wikipedia_tools',
153
+ 'newspaper_tools', 'arxiv_tools', 'spider_tools', 'duckdb_tools', 'csv_tools', 'json_tools', 'excel_tools', 'xml_tools', 'yaml_tools', 'calculator_tools', 'python_tools', 'shell_tools']:
154
+ return module # Returns the callable module
155
+ return getattr(module, name)
156
+ else:
157
+ # Class method import
158
+ if class_name not in _instances:
159
+ module = import_module(module_path, __package__)
160
+ class_ = getattr(module, class_name)
161
+ _instances[class_name] = class_()
162
+
163
+ # Get the method and bind it to the instance
164
+ method = getattr(_instances[class_name], name)
165
+ return method
166
+
167
+ __all__ = list(TOOL_MAPPINGS.keys())
@@ -0,0 +1,292 @@
1
+ """Tools for searching and retrieving papers from arXiv.
2
+
3
+ Usage:
4
+ from praisonaiagents.tools import arxiv_tools
5
+ papers = arxiv_tools.search("quantum computing")
6
+ paper = arxiv_tools.get_paper("2401.00123")
7
+
8
+ or
9
+ from praisonaiagents.tools import search_arxiv, get_arxiv_paper
10
+ papers = search_arxiv("quantum computing")
11
+ """
12
+
13
+ import logging
14
+ from typing import List, Dict, Union, Optional, Any
15
+ from importlib import util
16
+ import json
17
+
18
+ # Map sort criteria to arxiv.SortCriterion
19
+ SORT_CRITERIA = {
20
+ "relevance": "Relevance",
21
+ "lastupdateddate": "LastUpdatedDate",
22
+ "submitteddate": "SubmittedDate"
23
+ }
24
+
25
+ # Map sort order to arxiv.SortOrder
26
+ SORT_ORDER = {
27
+ "ascending": "Ascending",
28
+ "descending": "Descending"
29
+ }
30
+
31
+ class ArxivTools:
32
+ """Tools for searching and retrieving papers from arXiv."""
33
+
34
+ def __init__(self):
35
+ """Initialize ArxivTools and check for arxiv package."""
36
+ self._check_arxiv()
37
+
38
+ def _check_arxiv(self):
39
+ """Check if arxiv package is installed."""
40
+ if util.find_spec("arxiv") is None:
41
+ raise ImportError("arxiv package is not available. Please install it using: pip install arxiv")
42
+ global arxiv
43
+ import arxiv
44
+
45
+ def search(
46
+ self,
47
+ query: str,
48
+ max_results: int = 10,
49
+ sort_by: str = "relevance",
50
+ sort_order: str = "descending",
51
+ include_fields: Optional[List[str]] = None
52
+ ) -> Union[List[Dict[str, Any]], Dict[str, str]]:
53
+ """
54
+ Search arXiv for papers matching the query.
55
+
56
+ Args:
57
+ query: Search query (e.g., "quantum computing", "author:Einstein")
58
+ max_results: Maximum number of results to return
59
+ sort_by: Sort results by ("relevance", "lastUpdatedDate", "submittedDate")
60
+ sort_order: Sort order ("ascending" or "descending")
61
+ include_fields: List of fields to include in results. If None, includes all.
62
+ Available fields: ["title", "authors", "summary", "comment", "journal_ref",
63
+ "doi", "primary_category", "categories", "links"]
64
+
65
+ Returns:
66
+ List[Dict] or Dict: List of papers or error dict
67
+ """
68
+ try:
69
+ import arxiv
70
+
71
+ # Configure search client
72
+ client = arxiv.Client()
73
+
74
+ # Map sort criteria
75
+ sort_by_enum = getattr(arxiv.SortCriterion, SORT_CRITERIA[sort_by.lower()])
76
+ sort_order_enum = getattr(arxiv.SortOrder, SORT_ORDER[sort_order.lower()])
77
+
78
+ # Build search query
79
+ search = arxiv.Search(
80
+ query=query,
81
+ max_results=max_results,
82
+ sort_by=sort_by_enum,
83
+ sort_order=sort_order_enum
84
+ )
85
+
86
+ # Execute search
87
+ results = []
88
+ for result in client.results(search):
89
+ # Convert to dict with selected fields
90
+ paper = self._result_to_dict(result, include_fields)
91
+ results.append(paper)
92
+
93
+ return results
94
+ except Exception as e:
95
+ error_msg = f"Error searching arXiv: {str(e)}"
96
+ logging.error(error_msg)
97
+ return {"error": error_msg}
98
+
99
+ def get_paper(
100
+ self,
101
+ paper_id: str,
102
+ include_fields: Optional[List[str]] = None
103
+ ) -> Union[Dict[str, Any], Dict[str, str]]:
104
+ """
105
+ Get details of a specific paper by its arXiv ID.
106
+
107
+ Args:
108
+ paper_id: arXiv paper ID (e.g., "2401.00123")
109
+ include_fields: List of fields to include in results. If None, includes all.
110
+ Available fields: ["title", "authors", "summary", "comment", "journal_ref",
111
+ "doi", "primary_category", "categories", "links"]
112
+
113
+ Returns:
114
+ Dict: Paper details or error dict
115
+ """
116
+ try:
117
+ import arxiv
118
+
119
+ # Configure client
120
+ client = arxiv.Client()
121
+
122
+ # Get paper by ID
123
+ search = arxiv.Search(id_list=[paper_id])
124
+ results = list(client.results(search))
125
+
126
+ if not results:
127
+ return {"error": f"Paper with ID {paper_id} not found"}
128
+
129
+ # Convert to dict with selected fields
130
+ paper = self._result_to_dict(results[0], include_fields)
131
+ return paper
132
+ except Exception as e:
133
+ error_msg = f"Error getting paper {paper_id}: {str(e)}"
134
+ logging.error(error_msg)
135
+ return {"error": error_msg}
136
+
137
+ def get_papers_by_author(
138
+ self,
139
+ author: str,
140
+ max_results: int = 10,
141
+ sort_by: str = "submittedDate",
142
+ sort_order: str = "descending",
143
+ include_fields: Optional[List[str]] = None
144
+ ) -> Union[List[Dict[str, Any]], Dict[str, str]]:
145
+ """
146
+ Get papers by a specific author.
147
+
148
+ Args:
149
+ author: Author name (e.g., "Einstein")
150
+ max_results: Maximum number of results to return
151
+ sort_by: Sort results by ("relevance", "lastUpdatedDate", "submittedDate")
152
+ sort_order: Sort order ("ascending" or "descending")
153
+ include_fields: List of fields to include in results
154
+
155
+ Returns:
156
+ List[Dict] or Dict: List of papers or error dict
157
+ """
158
+ query = f'au:"{author}"'
159
+ return self.search(query, max_results, sort_by, sort_order, include_fields)
160
+
161
+ def get_papers_by_category(
162
+ self,
163
+ category: str,
164
+ max_results: int = 10,
165
+ sort_by: str = "submittedDate",
166
+ sort_order: str = "descending",
167
+ include_fields: Optional[List[str]] = None
168
+ ) -> Union[List[Dict[str, Any]], Dict[str, str]]:
169
+ """
170
+ Get papers from a specific category.
171
+
172
+ Args:
173
+ category: arXiv category (e.g., "cs.AI", "physics.gen-ph")
174
+ max_results: Maximum number of results to return
175
+ sort_by: Sort results by ("relevance", "lastUpdatedDate", "submittedDate")
176
+ sort_order: Sort order ("ascending" or "descending")
177
+ include_fields: List of fields to include in results
178
+
179
+ Returns:
180
+ List[Dict] or Dict: List of papers or error dict
181
+ """
182
+ query = f'cat:{category}'
183
+ return self.search(query, max_results, sort_by, sort_order, include_fields)
184
+
185
+ def _result_to_dict(
186
+ self,
187
+ result: Any,
188
+ include_fields: Optional[List[str]] = None
189
+ ) -> Dict[str, Any]:
190
+ """Convert arxiv.Result to dictionary with selected fields."""
191
+ # Default fields to include
192
+ if include_fields is None:
193
+ include_fields = [
194
+ "title", "authors", "summary", "comment", "journal_ref",
195
+ "doi", "primary_category", "categories", "links"
196
+ ]
197
+
198
+ # Build paper dict with selected fields
199
+ paper = {}
200
+
201
+ # Always include these basic fields
202
+ paper["arxiv_id"] = result.entry_id.split("/")[-1]
203
+ paper["updated"] = result.updated.isoformat() if result.updated else None
204
+ paper["published"] = result.published.isoformat() if result.published else None
205
+
206
+ # Add requested fields
207
+ if "title" in include_fields:
208
+ paper["title"] = result.title
209
+ if "authors" in include_fields:
210
+ paper["authors"] = [str(author) for author in result.authors]
211
+ if "summary" in include_fields:
212
+ paper["summary"] = result.summary
213
+ if "comment" in include_fields:
214
+ paper["comment"] = result.comment
215
+ if "journal_ref" in include_fields:
216
+ paper["journal_ref"] = result.journal_ref
217
+ if "doi" in include_fields:
218
+ paper["doi"] = result.doi
219
+ if "primary_category" in include_fields:
220
+ paper["primary_category"] = result.primary_category
221
+ if "categories" in include_fields:
222
+ paper["categories"] = result.categories
223
+ if "links" in include_fields:
224
+ paper["pdf_url"] = result.pdf_url
225
+ paper["abstract_url"] = f"https://arxiv.org/abs/{paper['arxiv_id']}"
226
+
227
+ return paper
228
+
229
+ # Create instance for direct function access
230
+ _arxiv_tools = ArxivTools()
231
+ search_arxiv = _arxiv_tools.search
232
+ get_arxiv_paper = _arxiv_tools.get_paper
233
+ get_papers_by_author = _arxiv_tools.get_papers_by_author
234
+ get_papers_by_category = _arxiv_tools.get_papers_by_category
235
+
236
+ if __name__ == "__main__":
237
+ # Example usage
238
+ print("\n==================================================")
239
+ print("ArxivTools Demonstration")
240
+ print("==================================================\n")
241
+
242
+ # 1. Search for papers
243
+ print("1. Searching for Papers")
244
+ print("------------------------------")
245
+ query = "quantum computing"
246
+ papers = search_arxiv(query, max_results=3)
247
+ print(f"Papers about {query}:")
248
+ if isinstance(papers, list):
249
+ print(json.dumps(papers, indent=2))
250
+ else:
251
+ print(papers) # Show error
252
+ print()
253
+
254
+ # 2. Get specific paper
255
+ print("2. Getting Specific Paper")
256
+ print("------------------------------")
257
+ if isinstance(papers, list) and papers:
258
+ paper_id = papers[0]["arxiv_id"]
259
+ paper = get_arxiv_paper(paper_id)
260
+ print(f"Paper {paper_id}:")
261
+ if "error" not in paper:
262
+ print(json.dumps(paper, indent=2))
263
+ else:
264
+ print(paper) # Show error
265
+ print()
266
+
267
+ # 3. Get papers by author
268
+ print("3. Getting Papers by Author")
269
+ print("------------------------------")
270
+ author = "Yoshua Bengio"
271
+ author_papers = get_papers_by_author(author, max_results=3)
272
+ print(f"Papers by {author}:")
273
+ if isinstance(author_papers, list):
274
+ print(json.dumps(author_papers, indent=2))
275
+ else:
276
+ print(author_papers) # Show error
277
+ print()
278
+
279
+ # 4. Get papers by category
280
+ print("4. Getting Papers by Category")
281
+ print("------------------------------")
282
+ category = "cs.AI"
283
+ category_papers = get_papers_by_category(category, max_results=3)
284
+ print(f"Papers in category {category}:")
285
+ if isinstance(category_papers, list):
286
+ print(json.dumps(category_papers, indent=2))
287
+ else:
288
+ print(category_papers) # Show error
289
+
290
+ print("\n==================================================")
291
+ print("Demonstration Complete")
292
+ print("==================================================")