ai-parrot 0.8.3__cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-parrot might be problematic. Click here for more details.
- ai_parrot-0.8.3.dist-info/LICENSE +21 -0
- ai_parrot-0.8.3.dist-info/METADATA +306 -0
- ai_parrot-0.8.3.dist-info/RECORD +128 -0
- ai_parrot-0.8.3.dist-info/WHEEL +6 -0
- ai_parrot-0.8.3.dist-info/top_level.txt +2 -0
- parrot/__init__.py +30 -0
- parrot/bots/__init__.py +5 -0
- parrot/bots/abstract.py +1115 -0
- parrot/bots/agent.py +492 -0
- parrot/bots/basic.py +9 -0
- parrot/bots/bose.py +17 -0
- parrot/bots/chatbot.py +271 -0
- parrot/bots/cody.py +17 -0
- parrot/bots/copilot.py +117 -0
- parrot/bots/data.py +730 -0
- parrot/bots/dataframe.py +103 -0
- parrot/bots/hrbot.py +15 -0
- parrot/bots/interfaces/__init__.py +1 -0
- parrot/bots/interfaces/retrievers.py +12 -0
- parrot/bots/notebook.py +619 -0
- parrot/bots/odoo.py +17 -0
- parrot/bots/prompts/__init__.py +41 -0
- parrot/bots/prompts/agents.py +91 -0
- parrot/bots/prompts/data.py +214 -0
- parrot/bots/retrievals/__init__.py +1 -0
- parrot/bots/retrievals/constitutional.py +19 -0
- parrot/bots/retrievals/multi.py +122 -0
- parrot/bots/retrievals/retrieval.py +610 -0
- parrot/bots/tools/__init__.py +7 -0
- parrot/bots/tools/eda.py +325 -0
- parrot/bots/tools/pdf.py +50 -0
- parrot/bots/tools/plot.py +48 -0
- parrot/bots/troc.py +16 -0
- parrot/conf.py +170 -0
- parrot/crew/__init__.py +3 -0
- parrot/crew/tools/__init__.py +22 -0
- parrot/crew/tools/bing.py +13 -0
- parrot/crew/tools/config.py +43 -0
- parrot/crew/tools/duckgo.py +62 -0
- parrot/crew/tools/file.py +24 -0
- parrot/crew/tools/google.py +168 -0
- parrot/crew/tools/gtrends.py +16 -0
- parrot/crew/tools/md2pdf.py +25 -0
- parrot/crew/tools/rag.py +42 -0
- parrot/crew/tools/search.py +32 -0
- parrot/crew/tools/url.py +21 -0
- parrot/exceptions.cpython-39-x86_64-linux-gnu.so +0 -0
- parrot/handlers/__init__.py +4 -0
- parrot/handlers/agents.py +292 -0
- parrot/handlers/bots.py +196 -0
- parrot/handlers/chat.py +192 -0
- parrot/interfaces/__init__.py +6 -0
- parrot/interfaces/database.py +27 -0
- parrot/interfaces/http.py +805 -0
- parrot/interfaces/images/__init__.py +0 -0
- parrot/interfaces/images/plugins/__init__.py +18 -0
- parrot/interfaces/images/plugins/abstract.py +58 -0
- parrot/interfaces/images/plugins/exif.py +709 -0
- parrot/interfaces/images/plugins/hash.py +52 -0
- parrot/interfaces/images/plugins/vision.py +104 -0
- parrot/interfaces/images/plugins/yolo.py +66 -0
- parrot/interfaces/images/plugins/zerodetect.py +197 -0
- parrot/llms/__init__.py +1 -0
- parrot/llms/abstract.py +69 -0
- parrot/llms/anthropic.py +58 -0
- parrot/llms/gemma.py +15 -0
- parrot/llms/google.py +44 -0
- parrot/llms/groq.py +67 -0
- parrot/llms/hf.py +45 -0
- parrot/llms/openai.py +61 -0
- parrot/llms/pipes.py +114 -0
- parrot/llms/vertex.py +89 -0
- parrot/loaders/__init__.py +9 -0
- parrot/loaders/abstract.py +628 -0
- parrot/loaders/files/__init__.py +0 -0
- parrot/loaders/files/abstract.py +39 -0
- parrot/loaders/files/text.py +63 -0
- parrot/loaders/txt.py +26 -0
- parrot/manager.py +333 -0
- parrot/models.py +504 -0
- parrot/py.typed +0 -0
- parrot/stores/__init__.py +11 -0
- parrot/stores/abstract.py +248 -0
- parrot/stores/chroma.py +188 -0
- parrot/stores/duck.py +162 -0
- parrot/stores/embeddings/__init__.py +10 -0
- parrot/stores/embeddings/abstract.py +46 -0
- parrot/stores/embeddings/base.py +52 -0
- parrot/stores/embeddings/bge.py +20 -0
- parrot/stores/embeddings/fastembed.py +17 -0
- parrot/stores/embeddings/google.py +18 -0
- parrot/stores/embeddings/huggingface.py +20 -0
- parrot/stores/embeddings/ollama.py +14 -0
- parrot/stores/embeddings/openai.py +26 -0
- parrot/stores/embeddings/transformers.py +21 -0
- parrot/stores/embeddings/vertexai.py +17 -0
- parrot/stores/empty.py +10 -0
- parrot/stores/faiss.py +160 -0
- parrot/stores/milvus.py +397 -0
- parrot/stores/postgres.py +653 -0
- parrot/stores/qdrant.py +170 -0
- parrot/tools/__init__.py +23 -0
- parrot/tools/abstract.py +68 -0
- parrot/tools/asknews.py +33 -0
- parrot/tools/basic.py +51 -0
- parrot/tools/bby.py +359 -0
- parrot/tools/bing.py +13 -0
- parrot/tools/docx.py +343 -0
- parrot/tools/duck.py +62 -0
- parrot/tools/execute.py +56 -0
- parrot/tools/gamma.py +28 -0
- parrot/tools/google.py +170 -0
- parrot/tools/gvoice.py +301 -0
- parrot/tools/results.py +278 -0
- parrot/tools/stack.py +27 -0
- parrot/tools/weather.py +70 -0
- parrot/tools/wikipedia.py +58 -0
- parrot/tools/zipcode.py +198 -0
- parrot/utils/__init__.py +2 -0
- parrot/utils/parsers/__init__.py +5 -0
- parrot/utils/parsers/toml.cpython-39-x86_64-linux-gnu.so +0 -0
- parrot/utils/toml.py +11 -0
- parrot/utils/types.cpython-39-x86_64-linux-gnu.so +0 -0
- parrot/utils/uv.py +11 -0
- parrot/version.py +10 -0
- resources/users/__init__.py +5 -0
- resources/users/handlers.py +13 -0
- resources/users/models.py +205 -0
parrot/tools/gvoice.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Optional
|
|
4
|
+
import os
|
|
5
|
+
import re
|
|
6
|
+
from xml.sax.saxutils import escape
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
import aiofiles
|
|
9
|
+
# Use v1 for wider feature set including SSML
|
|
10
|
+
from google.cloud import texttospeech_v1 as texttospeech
|
|
11
|
+
from google.oauth2 import service_account
|
|
12
|
+
from pydantic import BaseModel, Field
|
|
13
|
+
from langchain.tools import BaseTool
|
|
14
|
+
from navconfig import BASE_DIR
|
|
15
|
+
from parrot.conf import GOOGLE_TTS_SERVICE
|
|
16
|
+
|
|
17
|
+
class PodcastInput(BaseModel):
|
|
18
|
+
"""Input for podcast generator tool."""
|
|
19
|
+
text: str = Field(description="The text content to convert to speech")
|
|
20
|
+
|
|
21
|
+
class GoogleVoiceTool(BaseTool):
|
|
22
|
+
"""Generate a podcast-style audio file from Text using Google Cloud Text-to-Speech."""
|
|
23
|
+
name: str = "generate_podcast_style_audio_file"
|
|
24
|
+
description: str = (
|
|
25
|
+
"Generates a podcast-style audio file from a given text script using Google Cloud Text-to-Speech."
|
|
26
|
+
" Use this tool if the user requests a podcast, an audio summary, or a narrative of the findings "
|
|
27
|
+
" First, ensure you have a clear and concise text summary of the information to be narrated. You might need to generate this summary based on your analysis or previous steps."
|
|
28
|
+
" Provide the text *as-is* without enclosing on backticks or backquotes."
|
|
29
|
+
)
|
|
30
|
+
voice_model: str = "en-US-Neural2-F" # "en-US-Studio-O"
|
|
31
|
+
voice_gender: str = "FEMALE"
|
|
32
|
+
language_code: str = "en-US"
|
|
33
|
+
output_format: str = "OGG_OPUS" # OGG format is more podcast-friendly
|
|
34
|
+
_key_service: Optional[str]
|
|
35
|
+
|
|
36
|
+
# Add a proper args_schema for tool-calling compatibility
|
|
37
|
+
args_schema: dict = {
|
|
38
|
+
"type": "object",
|
|
39
|
+
"properties": {
|
|
40
|
+
"query": {
|
|
41
|
+
"type": "string",
|
|
42
|
+
"description": "The text content to convert to speech"
|
|
43
|
+
}
|
|
44
|
+
},
|
|
45
|
+
"required": ["query"]
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
def __init__(self,
|
|
49
|
+
voice_model: str = "en-US-Neural2-F",
|
|
50
|
+
output_format: str = "OGG_OPUS",
|
|
51
|
+
output_dir: str = None,
|
|
52
|
+
name: str = "podcast_generator_tool",
|
|
53
|
+
**kwargs
|
|
54
|
+
):
|
|
55
|
+
"""Initialize the GoogleVoiceTool."""
|
|
56
|
+
|
|
57
|
+
super().__init__(**kwargs)
|
|
58
|
+
|
|
59
|
+
# Using the config from conf.py, but with additional verification
|
|
60
|
+
self._key_service = GOOGLE_TTS_SERVICE
|
|
61
|
+
|
|
62
|
+
# If not found in the config, try a default location
|
|
63
|
+
if self._key_service is None:
|
|
64
|
+
default_path = BASE_DIR / "env" / "google" / "tts-service.json"
|
|
65
|
+
if os.path.exists(default_path):
|
|
66
|
+
self._key_service = str(default_path)
|
|
67
|
+
print(f"Using default credentials path: {self._key_service}")
|
|
68
|
+
else:
|
|
69
|
+
print(f"Warning: No credentials found in config or at default path {default_path}")
|
|
70
|
+
else:
|
|
71
|
+
print(f"Using credentials from config: {self._key_service}")
|
|
72
|
+
|
|
73
|
+
if self.voice_gender == 'FEMALE':
|
|
74
|
+
self.voice_model = "en-US-Neural2-F"
|
|
75
|
+
elif self.voice_gender == 'MALE':
|
|
76
|
+
self.voice_model = "en-US-Neural2-M"
|
|
77
|
+
else:
|
|
78
|
+
self.voice_model = "en-US-Neural2-G"
|
|
79
|
+
|
|
80
|
+
def is_markdown(self, text: str) -> bool:
|
|
81
|
+
"""Determine if the text appears to be Markdown formatted."""
|
|
82
|
+
if not text or not isinstance(text, str):
|
|
83
|
+
return False
|
|
84
|
+
|
|
85
|
+
# Corrección: Separar los caracteres problemáticos y el rango
|
|
86
|
+
if re.search(r"^[#*_>`\[\d-]", text.strip()[0]): # Check if first char is a Markdown marker
|
|
87
|
+
return True
|
|
88
|
+
|
|
89
|
+
# Check for common Markdown patterns
|
|
90
|
+
if re.search(r"#{1,6}\s+", text): # Headers
|
|
91
|
+
return True
|
|
92
|
+
if re.search(r"\*\*.*?\*\*", text): # Bold
|
|
93
|
+
return True
|
|
94
|
+
if re.search(r"_.*?_", text): # Italic
|
|
95
|
+
return True
|
|
96
|
+
if re.search(r"`.*?`", text): # Code
|
|
97
|
+
return True
|
|
98
|
+
if re.search(r"\[.*?\]\(.*?\)", text): # Links
|
|
99
|
+
return True
|
|
100
|
+
if re.search(r"^\s*[\*\-\+]\s+", text, re.MULTILINE): # Unordered lists
|
|
101
|
+
return True
|
|
102
|
+
if re.search(r"^\s*\d+\.\s+", text, re.MULTILINE): # Ordered lists
|
|
103
|
+
return True
|
|
104
|
+
if re.search(r"```.*?```", text, re.DOTALL): # Code blocks
|
|
105
|
+
return True
|
|
106
|
+
|
|
107
|
+
return False
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def text_to_ssml(self, text: str) -> str:
|
|
111
|
+
"""Converts plain text to SSML."""
|
|
112
|
+
ssml = f"<speak><p>{escape(text)}</p></speak>"
|
|
113
|
+
return ssml
|
|
114
|
+
|
|
115
|
+
def markdown_to_ssml(self, markdown_text: str) -> str:
|
|
116
|
+
"""Converts Markdown text to SSML, handling code blocks and ellipses."""
|
|
117
|
+
|
|
118
|
+
if markdown_text.startswith("```text"):
|
|
119
|
+
markdown_text = markdown_text[len("```text"):].strip()
|
|
120
|
+
|
|
121
|
+
ssml = "<speak>"
|
|
122
|
+
lines = markdown_text.split('\n')
|
|
123
|
+
in_code_block = False
|
|
124
|
+
|
|
125
|
+
for line in lines:
|
|
126
|
+
line = line.strip()
|
|
127
|
+
|
|
128
|
+
if line.startswith("```"):
|
|
129
|
+
in_code_block = not in_code_block
|
|
130
|
+
if in_code_block:
|
|
131
|
+
ssml += '<prosody rate="x-slow"><p><code>'
|
|
132
|
+
else:
|
|
133
|
+
ssml += '</code></p></prosody>'
|
|
134
|
+
continue
|
|
135
|
+
|
|
136
|
+
if in_code_block:
|
|
137
|
+
ssml += escape(line) + '<break time="100ms"/>' # Add slight pauses within code
|
|
138
|
+
continue
|
|
139
|
+
|
|
140
|
+
if line == "...":
|
|
141
|
+
ssml += '<break time="500ms"/>' # Keep the pause for ellipses
|
|
142
|
+
continue
|
|
143
|
+
|
|
144
|
+
# Handle Markdown headings
|
|
145
|
+
heading_match = re.match(r"^(#+)\s+(.*)", line)
|
|
146
|
+
if heading_match:
|
|
147
|
+
heading_level = len(heading_match.group(1)) # Number of '#'
|
|
148
|
+
heading_text = heading_match.group(2).strip()
|
|
149
|
+
ssml += f'<p><emphasis level="strong">{escape(heading_text)}</emphasis></p>'
|
|
150
|
+
continue
|
|
151
|
+
|
|
152
|
+
if line:
|
|
153
|
+
ssml += f'<p>{escape(line)}</p>'
|
|
154
|
+
|
|
155
|
+
ssml += "</speak>"
|
|
156
|
+
return ssml
|
|
157
|
+
|
|
158
|
+
async def _generate_podcast(self, query: str) -> dict:
|
|
159
|
+
"""Main method to generate a podcast from query."""
|
|
160
|
+
try:
|
|
161
|
+
if self._key_service and Path(self._key_service).exists():
|
|
162
|
+
try:
|
|
163
|
+
credentials = service_account.Credentials.from_service_account_file(
|
|
164
|
+
self._key_service
|
|
165
|
+
)
|
|
166
|
+
except Exception as cred_error:
|
|
167
|
+
print(f"Error loading credentials: {cred_error}")
|
|
168
|
+
|
|
169
|
+
if isinstance(query, str):
|
|
170
|
+
try:
|
|
171
|
+
# Try to parse as JSON
|
|
172
|
+
import json
|
|
173
|
+
query_dict = json.loads(query)
|
|
174
|
+
if "output_file" in query_dict:
|
|
175
|
+
print(f"Output file specified in query: {query_dict['output_file']}")
|
|
176
|
+
print(f"Output directory exists: {os.path.isdir(os.path.dirname(query_dict['output_file']))}")
|
|
177
|
+
except json.JSONDecodeError:
|
|
178
|
+
print("Query is plain text, not JSON")
|
|
179
|
+
|
|
180
|
+
print("1. Converting Markdown to SSML...")
|
|
181
|
+
if self.is_markdown(query):
|
|
182
|
+
ssml_text = self.markdown_to_ssml(query)
|
|
183
|
+
else:
|
|
184
|
+
ssml_text = self.text_to_ssml(query)
|
|
185
|
+
print(f"Generated SSML:\n{ssml_text}\n") # Uncomment for debugging
|
|
186
|
+
print(
|
|
187
|
+
f"2. Initializing Text-to-Speech client (Voice: {self.voice_model})..."
|
|
188
|
+
)
|
|
189
|
+
if not os.path.exists(self._key_service):
|
|
190
|
+
raise FileNotFoundError(
|
|
191
|
+
f"Service account file not found: {self._key_service}"
|
|
192
|
+
)
|
|
193
|
+
credentials = service_account.Credentials.from_service_account_file(
|
|
194
|
+
self._key_service
|
|
195
|
+
)
|
|
196
|
+
# Initialize the Text-to-Speech client with the service account credentials
|
|
197
|
+
# Use the v1 API for wider feature set including SSML
|
|
198
|
+
client = texttospeech.TextToSpeechClient(credentials=credentials)
|
|
199
|
+
synthesis_input = texttospeech.SynthesisInput(ssml=ssml_text)
|
|
200
|
+
# Select the voice parameters
|
|
201
|
+
voice = texttospeech.VoiceSelectionParams(
|
|
202
|
+
language_code=self.language_code,
|
|
203
|
+
name=self.voice_model
|
|
204
|
+
)
|
|
205
|
+
# Select the audio format (OGG with OPUS codec)
|
|
206
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
207
|
+
# Generate a unique filename based on the current timestamp
|
|
208
|
+
output_filename = f"podcast_{timestamp}.ogg" # Default output filename
|
|
209
|
+
# default to OGG
|
|
210
|
+
encoding = texttospeech.AudioEncoding.OGG_OPUS
|
|
211
|
+
if self.output_format == "OGG_OPUS":
|
|
212
|
+
encoding = texttospeech.AudioEncoding.OGG_OPUS
|
|
213
|
+
output_filename = f"podcast_{timestamp}.ogg"
|
|
214
|
+
elif self.output_format == "MP3":
|
|
215
|
+
encoding = texttospeech.AudioEncoding.MP3
|
|
216
|
+
output_filename = f"podcast_{timestamp}.mp3"
|
|
217
|
+
elif self.output_format == "LINEAR16":
|
|
218
|
+
encoding = texttospeech.AudioEncoding.LINEAR16
|
|
219
|
+
output_filename = f"podcast_{timestamp}.wav"
|
|
220
|
+
elif self.output_format == "WEBM_OPUS":
|
|
221
|
+
encoding = texttospeech.AudioEncoding.WEBM_OPUS
|
|
222
|
+
output_filename = f"podcast_{timestamp}.webm"
|
|
223
|
+
elif self.output_format == "FLAC":
|
|
224
|
+
encoding = texttospeech.AudioEncoding.FLAC
|
|
225
|
+
output_filename = f"podcast_{timestamp}.flac"
|
|
226
|
+
elif self.output_format == "OGG_VORBIS":
|
|
227
|
+
encoding = texttospeech.AudioEncoding.OGG_VORBIS
|
|
228
|
+
output_filename = f"podcast_{timestamp}.ogg"
|
|
229
|
+
|
|
230
|
+
audio_config = texttospeech.AudioConfig(
|
|
231
|
+
audio_encoding=encoding,
|
|
232
|
+
speaking_rate=1.0,
|
|
233
|
+
pitch=0.0
|
|
234
|
+
)
|
|
235
|
+
print("3. Synthesizing speech...")
|
|
236
|
+
response = client.synthesize_speech(
|
|
237
|
+
input=synthesis_input,
|
|
238
|
+
voice=voice,
|
|
239
|
+
audio_config=audio_config
|
|
240
|
+
)
|
|
241
|
+
print("4. Speech synthesized successfully.")
|
|
242
|
+
# Get the absolute path for the output file
|
|
243
|
+
output_dir: Path = BASE_DIR.joinpath('static', 'audio', 'podcasts')
|
|
244
|
+
if output_dir.exists() is False:
|
|
245
|
+
# Create the directory if it doesn't exist
|
|
246
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
247
|
+
output_filepath = output_dir.joinpath(output_filename)
|
|
248
|
+
print(f"5. Saving audio content to: {output_filepath}")
|
|
249
|
+
async with aiofiles.open(output_filepath, 'wb') as audio_file:
|
|
250
|
+
await audio_file.write(response.audio_content)
|
|
251
|
+
print("6. Audio content saved successfully.")
|
|
252
|
+
return {
|
|
253
|
+
"file_path": output_filepath,
|
|
254
|
+
"output_format": self.output_format,
|
|
255
|
+
"language_code": self.language_code,
|
|
256
|
+
"voice_model": self.voice_model,
|
|
257
|
+
"timestamp": timestamp,
|
|
258
|
+
"filename": output_filename
|
|
259
|
+
}
|
|
260
|
+
except Exception as e:
|
|
261
|
+
import traceback
|
|
262
|
+
print(f"Error in _generate_podcast: {e}")
|
|
263
|
+
print(traceback.format_exc())
|
|
264
|
+
return {"error": str(e)}
|
|
265
|
+
|
|
266
|
+
async def _arun(self, query: str) -> dict:
|
|
267
|
+
"""
|
|
268
|
+
Generates a podcast-style audio file from Markdown text using
|
|
269
|
+
Google Cloud Text-to-Speech.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
markdown_summary: The input text in Markdown format.
|
|
273
|
+
output_filename: The desired name for the output audio file (e.g., "my_podcast.ogg").
|
|
274
|
+
language_code: The language code (e.g., "en-US", "es-ES").
|
|
275
|
+
voice_name: The specific voice model name. Find names here:
|
|
276
|
+
https://cloud.google.com/text-to-speech/docs/voices
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
A dictionary containing the absolute path to the saved audio file
|
|
280
|
+
under the key "file_path", or an error message under "error".
|
|
281
|
+
"""
|
|
282
|
+
try:
|
|
283
|
+
return await self._generate_podcast(query)
|
|
284
|
+
except Exception as e:
|
|
285
|
+
import traceback
|
|
286
|
+
print(f"Error in GoogleVoiceTool._arun: {e}")
|
|
287
|
+
print(traceback.format_exc())
|
|
288
|
+
return {"error": str(e)}
|
|
289
|
+
|
|
290
|
+
def _run(self, query: str) -> dict:
|
|
291
|
+
"""
|
|
292
|
+
Synchronous method to generate a podcast-style audio file from Markdown text.
|
|
293
|
+
This method is not recommended for production use due to blocking I/O.
|
|
294
|
+
"""
|
|
295
|
+
loop = asyncio.get_event_loop()
|
|
296
|
+
if loop.is_running():
|
|
297
|
+
# If the event loop is already running, use run_until_complete
|
|
298
|
+
return loop.run_until_complete(self._generate_podcast(query))
|
|
299
|
+
else:
|
|
300
|
+
# If not, use run
|
|
301
|
+
return loop.run(self._generate_podcast(query))
|
parrot/tools/results.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
from typing import Dict, Any, Optional, Union
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import json
|
|
5
|
+
from langchain_core.tools import BaseTool
|
|
6
|
+
from langchain_core.callbacks.manager import CallbackManagerForToolRun
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# Helper function to import datetime safely (for use in the tool)
|
|
10
|
+
def import_time():
|
|
11
|
+
return datetime
|
|
12
|
+
|
|
13
|
+
class ResultStoreTool(BaseTool):
|
|
14
|
+
"""Tool for storing and retrieving intermediate results during agent execution."""
|
|
15
|
+
name: str = "store_result"
|
|
16
|
+
description: str = """
|
|
17
|
+
Store an intermediate result for later use. Use this to save important analysis outputs,
|
|
18
|
+
DataFrame snippets, calculations, or any other values you want to refer to in later steps.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
key (str): A unique identifier for the stored result
|
|
22
|
+
value (Any): The value to store (can be a string, number, dict, list, or DataFrame info)
|
|
23
|
+
description (str, optional): A brief description of what this value represents
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
str: Confirmation message indicating the value was stored
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
# Storage for results, shared across all instances
|
|
30
|
+
_storage: Dict[str, Dict[str, Any]] = {}
|
|
31
|
+
|
|
32
|
+
def _run(
|
|
33
|
+
self,
|
|
34
|
+
key: str,
|
|
35
|
+
value: Any,
|
|
36
|
+
description: Optional[str] = None,
|
|
37
|
+
run_manager: Optional[CallbackManagerForToolRun] = None,
|
|
38
|
+
) -> str:
|
|
39
|
+
"""Store a result with the given key."""
|
|
40
|
+
try:
|
|
41
|
+
# Handle DataFrame serialization
|
|
42
|
+
if str(type(value)).endswith("'pandas.core.frame.DataFrame'>"):
|
|
43
|
+
# Store a serializable representation of the DataFrame
|
|
44
|
+
stored_value = {
|
|
45
|
+
"type": "pandas_dataframe",
|
|
46
|
+
"shape": value.shape,
|
|
47
|
+
"columns": value.columns.tolist(),
|
|
48
|
+
"data": value.head(10).to_dict(orient="records") # Store first 10 rows
|
|
49
|
+
}
|
|
50
|
+
else:
|
|
51
|
+
# Try JSON serialization to check if value is serializable
|
|
52
|
+
try:
|
|
53
|
+
json.dumps(value)
|
|
54
|
+
stored_value = value
|
|
55
|
+
except (TypeError, OverflowError):
|
|
56
|
+
# If not JSON serializable, convert to string representation
|
|
57
|
+
stored_value = {
|
|
58
|
+
"type": "non_serializable",
|
|
59
|
+
"string_repr": str(value),
|
|
60
|
+
"python_type": str(type(value))
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
# Store the value with metadata
|
|
64
|
+
self._storage[key] = {
|
|
65
|
+
"value": stored_value,
|
|
66
|
+
"description": description,
|
|
67
|
+
"timestamp": import_time().strftime("%Y-%m-%d %H:%M:%S")
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
return f"Successfully stored result '{key}'"
|
|
71
|
+
|
|
72
|
+
except Exception as e:
|
|
73
|
+
return f"Error storing result: {str(e)}"
|
|
74
|
+
|
|
75
|
+
@classmethod
|
|
76
|
+
def get_result(cls, key: str) -> Union[Any, None]:
|
|
77
|
+
"""Retrieve a stored result."""
|
|
78
|
+
if key in cls._storage:
|
|
79
|
+
return cls._storage[key]["value"]
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def list_results(cls) -> Dict[str, Dict[str, Any]]:
|
|
84
|
+
"""List all stored results with their metadata."""
|
|
85
|
+
return {
|
|
86
|
+
k: {
|
|
87
|
+
"description": v.get("description", "No description provided"),
|
|
88
|
+
"timestamp": v.get("timestamp", "Unknown"),
|
|
89
|
+
"type": type(v["value"]).__name__
|
|
90
|
+
}
|
|
91
|
+
for k, v in cls._storage.items()
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
@classmethod
|
|
95
|
+
def clear_results(cls) -> None:
|
|
96
|
+
"""Clear all stored results."""
|
|
97
|
+
cls._storage.clear()
|
|
98
|
+
|
|
99
|
+
@classmethod
|
|
100
|
+
def delete_result(cls, key: str) -> bool:
|
|
101
|
+
"""Delete a specific stored result."""
|
|
102
|
+
if key in cls._storage:
|
|
103
|
+
del cls._storage[key]
|
|
104
|
+
return True
|
|
105
|
+
return False
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class GetResultTool(BaseTool):
|
|
109
|
+
"""Tool for retrieving previously stored results."""
|
|
110
|
+
name: str = "get_result"
|
|
111
|
+
description: str = """
|
|
112
|
+
Retrieve a previously stored result by its key.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
key (str): The unique identifier of the stored result
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Any: The stored value, or an error message if the key doesn't exist
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def _run(
|
|
122
|
+
self,
|
|
123
|
+
key: str,
|
|
124
|
+
run_manager: Optional[CallbackManagerForToolRun] = None,
|
|
125
|
+
) -> Any:
|
|
126
|
+
"""Retrieve a result with the given key."""
|
|
127
|
+
result = ResultStoreTool.get_result(key)
|
|
128
|
+
|
|
129
|
+
if result is None:
|
|
130
|
+
return f"Error: No result found with key '{key}'. Available keys: {list(ResultStoreTool._storage.keys())}"
|
|
131
|
+
|
|
132
|
+
return result
|
|
133
|
+
|
|
134
|
+
class ListResultsTool(BaseTool):
|
|
135
|
+
"""Tool for listing all stored results."""
|
|
136
|
+
name: str = "list_results"
|
|
137
|
+
description: str = """
|
|
138
|
+
List all currently stored results with their metadata.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Dict: A dictionary mapping result keys to their metadata
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
def _run(
|
|
145
|
+
self,
|
|
146
|
+
run_manager: Optional[CallbackManagerForToolRun] = None,
|
|
147
|
+
) -> Dict[str, Dict[str, Any]]:
|
|
148
|
+
"""List all stored results."""
|
|
149
|
+
results = ResultStoreTool.list_results()
|
|
150
|
+
|
|
151
|
+
if not results:
|
|
152
|
+
return "No results have been stored yet."
|
|
153
|
+
|
|
154
|
+
return results
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class DataFrameStoreTool(BaseTool):
|
|
158
|
+
"""Tool specifically for storing and retrieving pandas DataFrames."""
|
|
159
|
+
name: str = "store_dataframe"
|
|
160
|
+
description: str = """
|
|
161
|
+
Store a pandas DataFrame for later use or reference.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
key (str): A unique identifier for the stored DataFrame
|
|
165
|
+
df_variable (str): The variable name of the DataFrame to store
|
|
166
|
+
description (str, optional): A brief description of what this DataFrame contains
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
str: Confirmation message indicating the DataFrame was stored
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
# Storage for DataFrames, shared across all instances
|
|
173
|
+
_df_storage: Dict[str, Dict[str, Any]] = {}
|
|
174
|
+
|
|
175
|
+
def __init__(self, df_locals: Dict[str, Any]):
|
|
176
|
+
"""Initialize with access to the locals dictionary."""
|
|
177
|
+
super().__init__()
|
|
178
|
+
self.df_locals = df_locals
|
|
179
|
+
|
|
180
|
+
def _run(
|
|
181
|
+
self,
|
|
182
|
+
key: str,
|
|
183
|
+
df_variable: str,
|
|
184
|
+
description: Optional[str] = None,
|
|
185
|
+
run_manager: Optional[CallbackManagerForToolRun] = None,
|
|
186
|
+
) -> str:
|
|
187
|
+
"""Store a DataFrame with the given key."""
|
|
188
|
+
try:
|
|
189
|
+
# Check if the variable exists in df_locals
|
|
190
|
+
if df_variable not in self.df_locals:
|
|
191
|
+
return f"Error: DataFrame '{df_variable}' not found in available variables."
|
|
192
|
+
|
|
193
|
+
df = self.df_locals[df_variable]
|
|
194
|
+
|
|
195
|
+
# Verify it's actually a DataFrame
|
|
196
|
+
if not str(type(df)).endswith("'pandas.core.frame.DataFrame'>"):
|
|
197
|
+
return f"Error: '{df_variable}' is not a pandas DataFrame, it's a {type(df).__name__}."
|
|
198
|
+
|
|
199
|
+
# Store the actual DataFrame (not just a representation)
|
|
200
|
+
self._df_storage[key] = {
|
|
201
|
+
"dataframe": df.copy(), # Store a copy to avoid mutation
|
|
202
|
+
"description": description,
|
|
203
|
+
"timestamp": import_time().strftime("%Y-%m-%d %H:%M:%S"),
|
|
204
|
+
"shape": df.shape,
|
|
205
|
+
"columns": df.columns.tolist()
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
# Add a reference to the df_locals so it can be accessed in Python code
|
|
209
|
+
self.df_locals[f"stored_df_{key}"] = df.copy()
|
|
210
|
+
|
|
211
|
+
return f"Successfully stored DataFrame '{key}' with shape {df.shape}"
|
|
212
|
+
|
|
213
|
+
except Exception as e:
|
|
214
|
+
return f"Error storing DataFrame: {str(e)}"
|
|
215
|
+
|
|
216
|
+
@classmethod
|
|
217
|
+
def get_dataframe(cls, key: str) -> Union[pd.DataFrame, None]:
|
|
218
|
+
"""Retrieve a stored DataFrame."""
|
|
219
|
+
if key in cls._df_storage:
|
|
220
|
+
return cls._df_storage[key]["dataframe"]
|
|
221
|
+
return None
|
|
222
|
+
|
|
223
|
+
@classmethod
|
|
224
|
+
def list_dataframes(cls) -> Dict[str, Dict[str, Any]]:
|
|
225
|
+
"""List all stored DataFrames with their metadata."""
|
|
226
|
+
return {
|
|
227
|
+
k: {
|
|
228
|
+
"description": v.get("description", "No description provided"),
|
|
229
|
+
"timestamp": v.get("timestamp", "Unknown"),
|
|
230
|
+
"shape": v.get("shape", "Unknown"),
|
|
231
|
+
"columns": v.get("columns", [])
|
|
232
|
+
}
|
|
233
|
+
for k, v in cls._df_storage.items()
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
class GetDataFrameTool(BaseTool):
|
|
237
|
+
"""Tool for retrieving stored DataFrames."""
|
|
238
|
+
name: str = "get_dataframe"
|
|
239
|
+
description: str = """
|
|
240
|
+
Retrieve a previously stored DataFrame by its key.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
key (str): The unique identifier of the stored DataFrame
|
|
244
|
+
target_variable (str, optional): If provided, store the retrieved DataFrame in this variable
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
str: Information about the retrieved DataFrame
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
def __init__(self, df_locals: Dict[str, Any]):
|
|
251
|
+
"""Initialize with access to the locals dictionary."""
|
|
252
|
+
super().__init__()
|
|
253
|
+
self.df_locals = df_locals
|
|
254
|
+
|
|
255
|
+
def _run(
|
|
256
|
+
self,
|
|
257
|
+
key: str,
|
|
258
|
+
target_variable: Optional[str] = None,
|
|
259
|
+
run_manager: Optional[CallbackManagerForToolRun] = None,
|
|
260
|
+
) -> str:
|
|
261
|
+
"""Retrieve a DataFrame with the given key."""
|
|
262
|
+
df = DataFrameStoreTool.get_dataframe(key)
|
|
263
|
+
|
|
264
|
+
if df is None:
|
|
265
|
+
available_keys = list(DataFrameStoreTool._df_storage.keys())
|
|
266
|
+
return f"Error: No DataFrame found with key '{key}'. Available DataFrame keys: {available_keys}"
|
|
267
|
+
|
|
268
|
+
# If a target variable is specified, store the DataFrame there
|
|
269
|
+
if target_variable:
|
|
270
|
+
self.df_locals[target_variable] = df
|
|
271
|
+
return f"Retrieved DataFrame '{key}' and stored in variable '{target_variable}' with shape {df.shape}"
|
|
272
|
+
else:
|
|
273
|
+
# Otherwise, use a default variable name
|
|
274
|
+
default_var = f"retrieved_df_{key}"
|
|
275
|
+
self.df_locals[default_var] = df
|
|
276
|
+
return (
|
|
277
|
+
f"Retrieved DataFrame '{key}' with shape {df.shape}. "
|
|
278
|
+
)
|
parrot/tools/stack.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from langchain.tools import Tool
|
|
2
|
+
from langchain.tools import BaseTool
|
|
3
|
+
from langchain_community.utilities import StackExchangeAPIWrapper
|
|
4
|
+
|
|
5
|
+
class StackExchangeTool(BaseTool):
|
|
6
|
+
"""Tool that searches the StackExchangeTool API."""
|
|
7
|
+
name: str = "StackExchangeSearch"
|
|
8
|
+
description: str = (
|
|
9
|
+
"Search for questions and answers on Stack Exchange. "
|
|
10
|
+
"Stack Exchange is a network of question-and-answer (Q&A) websites on topics in diverse fields, each site covering a specific topic."
|
|
11
|
+
"Useful for when you need to answer general questions about different topics when user requested."
|
|
12
|
+
)
|
|
13
|
+
search: Tool = None
|
|
14
|
+
|
|
15
|
+
def __init__(self, **kwargs):
|
|
16
|
+
super().__init__(**kwargs)
|
|
17
|
+
self.search = StackExchangeAPIWrapper(
|
|
18
|
+
query_type='title',
|
|
19
|
+
max_results=5
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
def _run(
|
|
23
|
+
self,
|
|
24
|
+
query: dict,
|
|
25
|
+
) -> dict:
|
|
26
|
+
"""Use the StackExchangeSearch tool."""
|
|
27
|
+
return self.search.run(query)
|
parrot/tools/weather.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import requests
|
|
2
|
+
from langchain.tools import BaseTool
|
|
3
|
+
from langchain.tools import Tool
|
|
4
|
+
from langchain_community.utilities import OpenWeatherMapAPIWrapper
|
|
5
|
+
from navconfig import config
|
|
6
|
+
import orjson
|
|
7
|
+
|
|
8
|
+
class OpenWeatherMapTool(BaseTool):
|
|
9
|
+
"""Tool that searches the OpenWeatherMap API."""
|
|
10
|
+
name: str = "OpenWeatherMap"
|
|
11
|
+
description: str = (
|
|
12
|
+
"A wrapper around OpenWeatherMap. "
|
|
13
|
+
"Useful for when you need to answer general questions about "
|
|
14
|
+
"weather, temperature, humidity, wind speed, or other weather-related information. "
|
|
15
|
+
)
|
|
16
|
+
search: Tool = None
|
|
17
|
+
|
|
18
|
+
def __init__(self, **kwargs):
|
|
19
|
+
super().__init__(**kwargs)
|
|
20
|
+
self.search = OpenWeatherMapAPIWrapper(
|
|
21
|
+
openweathermap_api_key=config.get('OPENWEATHER_APPID')
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
def _run(
|
|
25
|
+
self,
|
|
26
|
+
query: dict,
|
|
27
|
+
) -> dict:
|
|
28
|
+
"""Use the OpenWeatherMap tool."""
|
|
29
|
+
return self.search.run(query)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class OpenWeather(BaseTool):
|
|
33
|
+
"""
|
|
34
|
+
Tool to get weather information about a location.
|
|
35
|
+
"""
|
|
36
|
+
name: str = 'openweather_tool'
|
|
37
|
+
description: str = (
|
|
38
|
+
"Get weather information about a location, use this tool to answer questions about weather or weather forecast."
|
|
39
|
+
" Input should be the latitude and longitude of the location you want weather information about."
|
|
40
|
+
)
|
|
41
|
+
base_url: str = 'http://api.openweathermap.org/'
|
|
42
|
+
units: str = 'metric'
|
|
43
|
+
days: int = 3
|
|
44
|
+
appid: str = None
|
|
45
|
+
request: str = 'weather'
|
|
46
|
+
country: str = 'us'
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def __init__(self, request: str = 'weather', country: str = 'us', **kwargs):
|
|
50
|
+
super().__init__(**kwargs)
|
|
51
|
+
self.request = request
|
|
52
|
+
self.country = country
|
|
53
|
+
self.appid = config.get('OPENWEATHER_APPID')
|
|
54
|
+
|
|
55
|
+
def _run(self, query: dict) -> dict:
|
|
56
|
+
q = orjson.loads(query) # pylint: disable=no-member
|
|
57
|
+
if 'latitude' in q and 'longitude' in q:
|
|
58
|
+
lat = q['latitude']
|
|
59
|
+
lon = q['longitude']
|
|
60
|
+
if self.request == 'weather':
|
|
61
|
+
url = f"{self.base_url}data/2.5/weather?lat={lat}&lon={lon}&units={self.units}&appid={self.appid}"
|
|
62
|
+
elif self.request == 'forecast':
|
|
63
|
+
url = f"{self.base_url}data/2.5/forecast?lat={lat}&lon={lon}&units={self.units}&cnt={self.days}&appid={self.appid}"
|
|
64
|
+
else:
|
|
65
|
+
return {'error': 'Latitude and longitude are required'}
|
|
66
|
+
response = requests.get(url)
|
|
67
|
+
return response.json()
|
|
68
|
+
|
|
69
|
+
async def _arun(self, query: dict) -> dict:
|
|
70
|
+
raise NotImplementedError("Async method not implemented yet")
|