airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +148 -2
- airtrain/__main__.py +4 -0
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/builder.py +23 -0
- airtrain/cli/main.py +120 -0
- airtrain/contrib/__init__.py +29 -0
- airtrain/contrib/travel/__init__.py +35 -0
- airtrain/contrib/travel/agents.py +243 -0
- airtrain/contrib/travel/models.py +59 -0
- airtrain/core/__init__.py +7 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +171 -0
- airtrain/core/schemas.py +237 -0
- airtrain/core/skills.py +269 -0
- airtrain/integrations/__init__.py +74 -0
- airtrain/integrations/anthropic/__init__.py +33 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/list_models.py +110 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/anthropic/skills.py +155 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +19 -0
- airtrain/integrations/cerebras/skills.py +127 -0
- airtrain/integrations/combined/__init__.py +21 -0
- airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
- airtrain/integrations/combined/list_models_factory.py +210 -0
- airtrain/integrations/fireworks/__init__.py +21 -0
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/credentials.py +26 -0
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +139 -0
- airtrain/integrations/fireworks/requests_skills.py +207 -0
- airtrain/integrations/fireworks/skills.py +181 -0
- airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/google/__init__.py +7 -0
- airtrain/integrations/google/credentials.py +58 -0
- airtrain/integrations/google/skills.py +122 -0
- airtrain/integrations/groq/__init__.py +23 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +201 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +37 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/list_models.py +112 -0
- airtrain/integrations/openai/models_config.py +224 -0
- airtrain/integrations/openai/skills.py +342 -0
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +129 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +115 -0
- airtrain/integrations/together/__init__.py +33 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +143 -0
- airtrain/integrations/together/list_models.py +76 -0
- airtrain/integrations/together/models.py +95 -0
- airtrain/integrations/together/models_config.py +399 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +305 -0
- airtrain/integrations/together/vision_models_config.py +49 -0
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +45 -0
- airtrain/tools/command.py +398 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- airtrain-0.1.4.dist-info/METADATA +222 -0
- airtrain-0.1.4.dist-info/RECORD +108 -0
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
- airtrain-0.1.4.dist-info/entry_points.txt +2 -0
- airtrain-0.1.2.dist-info/METADATA +0 -106
- airtrain-0.1.2.dist-info/RECORD +0 -5
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,289 @@
|
|
1
|
+
"""
|
2
|
+
Groq powered agent with persistent memory.
|
3
|
+
|
4
|
+
This module provides a Groq-powered agent with persistent memory
|
5
|
+
that saves to the standard ~/.trmx path.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from typing import List, Dict, Any, Optional
|
9
|
+
import os
|
10
|
+
from pathlib import Path
|
11
|
+
|
12
|
+
from airtrain.agents.registry import BaseAgent, register_agent
|
13
|
+
from airtrain.agents.memory import SharedMemory
|
14
|
+
from airtrain.tools import ToolFactory, execute_tool_call, BaseTool
|
15
|
+
|
16
|
+
# Groq integration
|
17
|
+
from airtrain.integrations.groq.skills import GroqChatSkill, GroqInput
|
18
|
+
from airtrain.integrations.groq.credentials import GroqCredentials
|
19
|
+
|
20
|
+
|
21
|
+
@register_agent("groq_agent")
|
22
|
+
class GroqAgent(BaseAgent):
|
23
|
+
"""Agent powered by Groq LLM with persistent memory."""
|
24
|
+
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
name: str,
|
28
|
+
models: Optional[List[str]] = None,
|
29
|
+
tools: Optional[List[BaseTool]] = None,
|
30
|
+
memory_size: int = 10,
|
31
|
+
temperature: float = 0.7,
|
32
|
+
max_tokens: int = 1024,
|
33
|
+
persist_path: Optional[str] = None,
|
34
|
+
system_prompt: Optional[str] = None
|
35
|
+
):
|
36
|
+
"""
|
37
|
+
Initialize the Groq agent.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
name: Name of the agent
|
41
|
+
models: Groq model to use (will use default if None)
|
42
|
+
tools: List of tools for the agent
|
43
|
+
memory_size: Size of short-term memory
|
44
|
+
temperature: Temperature for generation (0-1)
|
45
|
+
max_tokens: Maximum tokens in response
|
46
|
+
persist_path: Path to persist memory (if None, use standard path)
|
47
|
+
system_prompt: Custom system prompt
|
48
|
+
"""
|
49
|
+
# Default to Groq's best model if none provided
|
50
|
+
if not models:
|
51
|
+
models = ["llama-3.1-8b-instant"]
|
52
|
+
|
53
|
+
super().__init__(name, models, tools)
|
54
|
+
|
55
|
+
# Create specialized memories
|
56
|
+
self.create_memory("dialog", memory_size)
|
57
|
+
self.create_memory("reasoning", 5)
|
58
|
+
|
59
|
+
# Configure generation parameters
|
60
|
+
self.temperature = temperature
|
61
|
+
self.max_tokens = max_tokens
|
62
|
+
self.persist_path = persist_path
|
63
|
+
|
64
|
+
# Initialize Groq backend
|
65
|
+
self.credentials = GroqCredentials.from_env()
|
66
|
+
self.groq_skill = GroqChatSkill(self.credentials)
|
67
|
+
|
68
|
+
# Set system prompt
|
69
|
+
self.system_prompt = system_prompt or (
|
70
|
+
f"You are {self.name}, a helpful AI assistant powered by Groq. "
|
71
|
+
"Your responses are accurate, helpful, and concise. You have access "
|
72
|
+
"to tools that help you accomplish tasks. When reasoning through a problem, "
|
73
|
+
"take a step-by-step approach and use tools when appropriate."
|
74
|
+
)
|
75
|
+
|
76
|
+
# Load memory from persistent storage if available
|
77
|
+
self._load_memory()
|
78
|
+
|
79
|
+
def _get_tool_definitions(self):
|
80
|
+
"""Get tool definitions in a format suitable for LLM."""
|
81
|
+
return [tool.to_dict() for tool in self.tools] if self.tools else None
|
82
|
+
|
83
|
+
def process(self, user_input: str, memory_name: str = "dialog") -> str:
|
84
|
+
"""
|
85
|
+
Process user input and generate a response.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
user_input: User input to process
|
89
|
+
memory_name: Name of the memory to use
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
Agent's response
|
93
|
+
"""
|
94
|
+
# Add user input to memories
|
95
|
+
user_message = {"role": "user", "content": user_input}
|
96
|
+
self.memory.add_to_all(user_message)
|
97
|
+
|
98
|
+
# Get context from memory
|
99
|
+
context = self.memory.get_context(memory_name)
|
100
|
+
|
101
|
+
# Build conversation history for the LLM
|
102
|
+
conversation_history = [{"role": "system", "content": self.system_prompt}]
|
103
|
+
|
104
|
+
for message in context:
|
105
|
+
# Skip messages that aren't relevant for conversation
|
106
|
+
if "role" not in message:
|
107
|
+
continue
|
108
|
+
|
109
|
+
role = message["role"]
|
110
|
+
|
111
|
+
# Only include roles that the LLM understands
|
112
|
+
if role in ["user", "assistant", "system", "tool"]:
|
113
|
+
content = message.get("content", "")
|
114
|
+
conversation_history.append({"role": role, "content": content})
|
115
|
+
|
116
|
+
# Prepare tool definitions
|
117
|
+
tool_defs = self._get_tool_definitions()
|
118
|
+
|
119
|
+
# Process with Groq
|
120
|
+
input_data = GroqInput(
|
121
|
+
model=self.models[0],
|
122
|
+
user_input="", # Empty because we use conversation_history
|
123
|
+
conversation_history=conversation_history,
|
124
|
+
tools=tool_defs,
|
125
|
+
temperature=self.temperature,
|
126
|
+
max_tokens=self.max_tokens
|
127
|
+
)
|
128
|
+
|
129
|
+
# Get initial response
|
130
|
+
result = self.groq_skill.process(input_data)
|
131
|
+
|
132
|
+
# Handle tool calls if any
|
133
|
+
if hasattr(result, "tool_calls") and result.tool_calls:
|
134
|
+
# Execute tools and get results
|
135
|
+
tool_results = []
|
136
|
+
|
137
|
+
for tool_call in result.tool_calls:
|
138
|
+
try:
|
139
|
+
# Extract tool information for better error handling
|
140
|
+
func_name = tool_call.get("function", {}).get("name", "unknown_tool")
|
141
|
+
|
142
|
+
# Execute the tool call
|
143
|
+
tool_result = execute_tool_call(tool_call)
|
144
|
+
tool_results.append((tool_call, tool_result))
|
145
|
+
|
146
|
+
# Add tool usage to reasoning memory
|
147
|
+
self.memory.add_to_memory("reasoning", {
|
148
|
+
"role": "function",
|
149
|
+
"name": func_name,
|
150
|
+
"content": str(tool_result)
|
151
|
+
})
|
152
|
+
except ValueError as e:
|
153
|
+
# Handle case where tool doesn't exist
|
154
|
+
error_message = str(e)
|
155
|
+
error_result = {
|
156
|
+
"error": error_message,
|
157
|
+
"status": "error",
|
158
|
+
"message": f"Tool '{func_name}' not found or not available."
|
159
|
+
}
|
160
|
+
tool_results.append((tool_call, error_result))
|
161
|
+
|
162
|
+
# Add error to reasoning memory
|
163
|
+
self.memory.add_to_memory("reasoning", {
|
164
|
+
"role": "function",
|
165
|
+
"name": func_name,
|
166
|
+
"content": str(error_result)
|
167
|
+
})
|
168
|
+
except Exception as e:
|
169
|
+
# Handle other execution errors
|
170
|
+
error_result = {
|
171
|
+
"error": str(e),
|
172
|
+
"status": "error",
|
173
|
+
"message": f"Error executing tool '{func_name}': {str(e)}"
|
174
|
+
}
|
175
|
+
tool_results.append((tool_call, error_result))
|
176
|
+
|
177
|
+
# Add error to reasoning memory
|
178
|
+
self.memory.add_to_memory("reasoning", {
|
179
|
+
"role": "function",
|
180
|
+
"name": func_name,
|
181
|
+
"content": str(error_result)
|
182
|
+
})
|
183
|
+
|
184
|
+
# Build followup with tool results
|
185
|
+
followup_history = conversation_history.copy()
|
186
|
+
|
187
|
+
# Add the assistant's response that led to the tool call
|
188
|
+
if result.response:
|
189
|
+
followup_history.append({
|
190
|
+
"role": "assistant",
|
191
|
+
"content": result.response
|
192
|
+
})
|
193
|
+
|
194
|
+
# Add tool results
|
195
|
+
for tool_call, tool_result in tool_results:
|
196
|
+
followup_history.append({
|
197
|
+
"role": "tool",
|
198
|
+
"tool_call_id": tool_call.get("id", "unknown"),
|
199
|
+
"content": str(tool_result)
|
200
|
+
})
|
201
|
+
|
202
|
+
# Get completion with tool results
|
203
|
+
followup_input = GroqInput(
|
204
|
+
model=self.models[0],
|
205
|
+
user_input="",
|
206
|
+
conversation_history=followup_history,
|
207
|
+
temperature=self.temperature,
|
208
|
+
max_tokens=self.max_tokens
|
209
|
+
)
|
210
|
+
|
211
|
+
followup_result = self.groq_skill.process(followup_input)
|
212
|
+
response = followup_result.response
|
213
|
+
else:
|
214
|
+
# No tool calls, use direct response
|
215
|
+
response = result.response
|
216
|
+
|
217
|
+
# Add response to memory
|
218
|
+
self.memory.add_to_all({"role": "assistant", "content": response})
|
219
|
+
|
220
|
+
# Persist memory after processing
|
221
|
+
self._persist_memory()
|
222
|
+
|
223
|
+
return response
|
224
|
+
|
225
|
+
def _persist_memory(self):
|
226
|
+
"""Persist memory to disk using standard path."""
|
227
|
+
if self.persist_path:
|
228
|
+
self.memory.persist(self.persist_path)
|
229
|
+
else:
|
230
|
+
self.memory.persist(agent_name=self.name)
|
231
|
+
|
232
|
+
def _load_memory(self):
|
233
|
+
"""Load memory from disk if available."""
|
234
|
+
try:
|
235
|
+
if self.persist_path:
|
236
|
+
if os.path.exists(self.persist_path):
|
237
|
+
self.memory.load(self.persist_path)
|
238
|
+
else:
|
239
|
+
# Check if standard path exists
|
240
|
+
home_dir = str(Path.home())
|
241
|
+
agent_dir = os.path.join(home_dir, ".trmx", "agents", self.name)
|
242
|
+
|
243
|
+
if os.path.exists(agent_dir):
|
244
|
+
self.memory.load(agent_name=self.name)
|
245
|
+
except Exception as e:
|
246
|
+
print(f"Warning: Failed to load memory for {self.name}: {str(e)}")
|
247
|
+
|
248
|
+
|
249
|
+
# Example usage
|
250
|
+
if __name__ == "__main__":
|
251
|
+
import dotenv
|
252
|
+
|
253
|
+
# Load environment variables
|
254
|
+
dotenv.load_dotenv()
|
255
|
+
|
256
|
+
# Check for API key
|
257
|
+
if not os.getenv("GROQ_API_KEY"):
|
258
|
+
print("Error: GROQ_API_KEY environment variable not set")
|
259
|
+
exit(1)
|
260
|
+
|
261
|
+
# Create an agent
|
262
|
+
agent = GroqAgent(
|
263
|
+
name="GroqAssistant",
|
264
|
+
memory_size=5
|
265
|
+
)
|
266
|
+
|
267
|
+
# Add calculator tool if available
|
268
|
+
try:
|
269
|
+
calculator = ToolFactory.get_tool("calculator")
|
270
|
+
agent.add_tool(calculator)
|
271
|
+
print(f"Added calculator tool to {agent.name}")
|
272
|
+
except ValueError:
|
273
|
+
print("Calculator tool not available")
|
274
|
+
|
275
|
+
# Test the agent
|
276
|
+
print(f"\n=== Testing {agent.name} ===")
|
277
|
+
|
278
|
+
# Process a few inputs
|
279
|
+
sample_inputs = [
|
280
|
+
"Hello, what can you do?",
|
281
|
+
"Can you help me calculate 23.5 * 17?",
|
282
|
+
"Thank you! Can you remember that result for me?",
|
283
|
+
"What was the calculation result we discussed earlier?"
|
284
|
+
]
|
285
|
+
|
286
|
+
for i, user_input in enumerate(sample_inputs):
|
287
|
+
print(f"\nUser: {user_input}")
|
288
|
+
response = agent.process(user_input)
|
289
|
+
print(f"{agent.name}: {response}")
|