airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. airtrain/__init__.py +148 -2
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__init__.py +7 -0
  19. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  21. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  22. airtrain/core/credentials.py +171 -0
  23. airtrain/core/schemas.py +237 -0
  24. airtrain/core/skills.py +269 -0
  25. airtrain/integrations/__init__.py +74 -0
  26. airtrain/integrations/anthropic/__init__.py +33 -0
  27. airtrain/integrations/anthropic/credentials.py +32 -0
  28. airtrain/integrations/anthropic/list_models.py +110 -0
  29. airtrain/integrations/anthropic/models_config.py +100 -0
  30. airtrain/integrations/anthropic/skills.py +155 -0
  31. airtrain/integrations/aws/__init__.py +6 -0
  32. airtrain/integrations/aws/credentials.py +36 -0
  33. airtrain/integrations/aws/skills.py +98 -0
  34. airtrain/integrations/cerebras/__init__.py +6 -0
  35. airtrain/integrations/cerebras/credentials.py +19 -0
  36. airtrain/integrations/cerebras/skills.py +127 -0
  37. airtrain/integrations/combined/__init__.py +21 -0
  38. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  39. airtrain/integrations/combined/list_models_factory.py +210 -0
  40. airtrain/integrations/fireworks/__init__.py +21 -0
  41. airtrain/integrations/fireworks/completion_skills.py +147 -0
  42. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  43. airtrain/integrations/fireworks/credentials.py +26 -0
  44. airtrain/integrations/fireworks/list_models.py +128 -0
  45. airtrain/integrations/fireworks/models.py +139 -0
  46. airtrain/integrations/fireworks/requests_skills.py +207 -0
  47. airtrain/integrations/fireworks/skills.py +181 -0
  48. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  49. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  50. airtrain/integrations/fireworks/structured_skills.py +102 -0
  51. airtrain/integrations/google/__init__.py +7 -0
  52. airtrain/integrations/google/credentials.py +58 -0
  53. airtrain/integrations/google/skills.py +122 -0
  54. airtrain/integrations/groq/__init__.py +23 -0
  55. airtrain/integrations/groq/credentials.py +24 -0
  56. airtrain/integrations/groq/models_config.py +162 -0
  57. airtrain/integrations/groq/skills.py +201 -0
  58. airtrain/integrations/ollama/__init__.py +6 -0
  59. airtrain/integrations/ollama/credentials.py +26 -0
  60. airtrain/integrations/ollama/skills.py +41 -0
  61. airtrain/integrations/openai/__init__.py +37 -0
  62. airtrain/integrations/openai/chinese_assistant.py +42 -0
  63. airtrain/integrations/openai/credentials.py +39 -0
  64. airtrain/integrations/openai/list_models.py +112 -0
  65. airtrain/integrations/openai/models_config.py +224 -0
  66. airtrain/integrations/openai/skills.py +342 -0
  67. airtrain/integrations/perplexity/__init__.py +49 -0
  68. airtrain/integrations/perplexity/credentials.py +43 -0
  69. airtrain/integrations/perplexity/list_models.py +112 -0
  70. airtrain/integrations/perplexity/models_config.py +128 -0
  71. airtrain/integrations/perplexity/skills.py +279 -0
  72. airtrain/integrations/sambanova/__init__.py +6 -0
  73. airtrain/integrations/sambanova/credentials.py +20 -0
  74. airtrain/integrations/sambanova/skills.py +129 -0
  75. airtrain/integrations/search/__init__.py +21 -0
  76. airtrain/integrations/search/exa/__init__.py +23 -0
  77. airtrain/integrations/search/exa/credentials.py +30 -0
  78. airtrain/integrations/search/exa/schemas.py +114 -0
  79. airtrain/integrations/search/exa/skills.py +115 -0
  80. airtrain/integrations/together/__init__.py +33 -0
  81. airtrain/integrations/together/audio_models_config.py +34 -0
  82. airtrain/integrations/together/credentials.py +22 -0
  83. airtrain/integrations/together/embedding_models_config.py +92 -0
  84. airtrain/integrations/together/image_models_config.py +69 -0
  85. airtrain/integrations/together/image_skill.py +143 -0
  86. airtrain/integrations/together/list_models.py +76 -0
  87. airtrain/integrations/together/models.py +95 -0
  88. airtrain/integrations/together/models_config.py +399 -0
  89. airtrain/integrations/together/rerank_models_config.py +43 -0
  90. airtrain/integrations/together/rerank_skill.py +49 -0
  91. airtrain/integrations/together/schemas.py +33 -0
  92. airtrain/integrations/together/skills.py +305 -0
  93. airtrain/integrations/together/vision_models_config.py +49 -0
  94. airtrain/telemetry/__init__.py +38 -0
  95. airtrain/telemetry/service.py +167 -0
  96. airtrain/telemetry/views.py +237 -0
  97. airtrain/tools/__init__.py +45 -0
  98. airtrain/tools/command.py +398 -0
  99. airtrain/tools/filesystem.py +166 -0
  100. airtrain/tools/network.py +111 -0
  101. airtrain/tools/registry.py +320 -0
  102. airtrain/tools/search.py +450 -0
  103. airtrain/tools/testing.py +135 -0
  104. airtrain-0.1.4.dist-info/METADATA +222 -0
  105. airtrain-0.1.4.dist-info/RECORD +108 -0
  106. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  107. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  108. airtrain-0.1.2.dist-info/METADATA +0 -106
  109. airtrain-0.1.2.dist-info/RECORD +0 -5
  110. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,289 @@
1
+ """
2
+ Groq powered agent with persistent memory.
3
+
4
+ This module provides a Groq-powered agent with persistent memory
5
+ that saves to the standard ~/.trmx path.
6
+ """
7
+
8
+ from typing import List, Dict, Any, Optional
9
+ import os
10
+ from pathlib import Path
11
+
12
+ from airtrain.agents.registry import BaseAgent, register_agent
13
+ from airtrain.agents.memory import SharedMemory
14
+ from airtrain.tools import ToolFactory, execute_tool_call, BaseTool
15
+
16
+ # Groq integration
17
+ from airtrain.integrations.groq.skills import GroqChatSkill, GroqInput
18
+ from airtrain.integrations.groq.credentials import GroqCredentials
19
+
20
+
21
+ @register_agent("groq_agent")
22
+ class GroqAgent(BaseAgent):
23
+ """Agent powered by Groq LLM with persistent memory."""
24
+
25
+ def __init__(
26
+ self,
27
+ name: str,
28
+ models: Optional[List[str]] = None,
29
+ tools: Optional[List[BaseTool]] = None,
30
+ memory_size: int = 10,
31
+ temperature: float = 0.7,
32
+ max_tokens: int = 1024,
33
+ persist_path: Optional[str] = None,
34
+ system_prompt: Optional[str] = None
35
+ ):
36
+ """
37
+ Initialize the Groq agent.
38
+
39
+ Args:
40
+ name: Name of the agent
41
+ models: Groq model to use (will use default if None)
42
+ tools: List of tools for the agent
43
+ memory_size: Size of short-term memory
44
+ temperature: Temperature for generation (0-1)
45
+ max_tokens: Maximum tokens in response
46
+ persist_path: Path to persist memory (if None, use standard path)
47
+ system_prompt: Custom system prompt
48
+ """
49
+ # Default to Groq's best model if none provided
50
+ if not models:
51
+ models = ["llama-3.1-8b-instant"]
52
+
53
+ super().__init__(name, models, tools)
54
+
55
+ # Create specialized memories
56
+ self.create_memory("dialog", memory_size)
57
+ self.create_memory("reasoning", 5)
58
+
59
+ # Configure generation parameters
60
+ self.temperature = temperature
61
+ self.max_tokens = max_tokens
62
+ self.persist_path = persist_path
63
+
64
+ # Initialize Groq backend
65
+ self.credentials = GroqCredentials.from_env()
66
+ self.groq_skill = GroqChatSkill(self.credentials)
67
+
68
+ # Set system prompt
69
+ self.system_prompt = system_prompt or (
70
+ f"You are {self.name}, a helpful AI assistant powered by Groq. "
71
+ "Your responses are accurate, helpful, and concise. You have access "
72
+ "to tools that help you accomplish tasks. When reasoning through a problem, "
73
+ "take a step-by-step approach and use tools when appropriate."
74
+ )
75
+
76
+ # Load memory from persistent storage if available
77
+ self._load_memory()
78
+
79
+ def _get_tool_definitions(self):
80
+ """Get tool definitions in a format suitable for LLM."""
81
+ return [tool.to_dict() for tool in self.tools] if self.tools else None
82
+
83
+ def process(self, user_input: str, memory_name: str = "dialog") -> str:
84
+ """
85
+ Process user input and generate a response.
86
+
87
+ Args:
88
+ user_input: User input to process
89
+ memory_name: Name of the memory to use
90
+
91
+ Returns:
92
+ Agent's response
93
+ """
94
+ # Add user input to memories
95
+ user_message = {"role": "user", "content": user_input}
96
+ self.memory.add_to_all(user_message)
97
+
98
+ # Get context from memory
99
+ context = self.memory.get_context(memory_name)
100
+
101
+ # Build conversation history for the LLM
102
+ conversation_history = [{"role": "system", "content": self.system_prompt}]
103
+
104
+ for message in context:
105
+ # Skip messages that aren't relevant for conversation
106
+ if "role" not in message:
107
+ continue
108
+
109
+ role = message["role"]
110
+
111
+ # Only include roles that the LLM understands
112
+ if role in ["user", "assistant", "system", "tool"]:
113
+ content = message.get("content", "")
114
+ conversation_history.append({"role": role, "content": content})
115
+
116
+ # Prepare tool definitions
117
+ tool_defs = self._get_tool_definitions()
118
+
119
+ # Process with Groq
120
+ input_data = GroqInput(
121
+ model=self.models[0],
122
+ user_input="", # Empty because we use conversation_history
123
+ conversation_history=conversation_history,
124
+ tools=tool_defs,
125
+ temperature=self.temperature,
126
+ max_tokens=self.max_tokens
127
+ )
128
+
129
+ # Get initial response
130
+ result = self.groq_skill.process(input_data)
131
+
132
+ # Handle tool calls if any
133
+ if hasattr(result, "tool_calls") and result.tool_calls:
134
+ # Execute tools and get results
135
+ tool_results = []
136
+
137
+ for tool_call in result.tool_calls:
138
+ try:
139
+ # Extract tool information for better error handling
140
+ func_name = tool_call.get("function", {}).get("name", "unknown_tool")
141
+
142
+ # Execute the tool call
143
+ tool_result = execute_tool_call(tool_call)
144
+ tool_results.append((tool_call, tool_result))
145
+
146
+ # Add tool usage to reasoning memory
147
+ self.memory.add_to_memory("reasoning", {
148
+ "role": "function",
149
+ "name": func_name,
150
+ "content": str(tool_result)
151
+ })
152
+ except ValueError as e:
153
+ # Handle case where tool doesn't exist
154
+ error_message = str(e)
155
+ error_result = {
156
+ "error": error_message,
157
+ "status": "error",
158
+ "message": f"Tool '{func_name}' not found or not available."
159
+ }
160
+ tool_results.append((tool_call, error_result))
161
+
162
+ # Add error to reasoning memory
163
+ self.memory.add_to_memory("reasoning", {
164
+ "role": "function",
165
+ "name": func_name,
166
+ "content": str(error_result)
167
+ })
168
+ except Exception as e:
169
+ # Handle other execution errors
170
+ error_result = {
171
+ "error": str(e),
172
+ "status": "error",
173
+ "message": f"Error executing tool '{func_name}': {str(e)}"
174
+ }
175
+ tool_results.append((tool_call, error_result))
176
+
177
+ # Add error to reasoning memory
178
+ self.memory.add_to_memory("reasoning", {
179
+ "role": "function",
180
+ "name": func_name,
181
+ "content": str(error_result)
182
+ })
183
+
184
+ # Build followup with tool results
185
+ followup_history = conversation_history.copy()
186
+
187
+ # Add the assistant's response that led to the tool call
188
+ if result.response:
189
+ followup_history.append({
190
+ "role": "assistant",
191
+ "content": result.response
192
+ })
193
+
194
+ # Add tool results
195
+ for tool_call, tool_result in tool_results:
196
+ followup_history.append({
197
+ "role": "tool",
198
+ "tool_call_id": tool_call.get("id", "unknown"),
199
+ "content": str(tool_result)
200
+ })
201
+
202
+ # Get completion with tool results
203
+ followup_input = GroqInput(
204
+ model=self.models[0],
205
+ user_input="",
206
+ conversation_history=followup_history,
207
+ temperature=self.temperature,
208
+ max_tokens=self.max_tokens
209
+ )
210
+
211
+ followup_result = self.groq_skill.process(followup_input)
212
+ response = followup_result.response
213
+ else:
214
+ # No tool calls, use direct response
215
+ response = result.response
216
+
217
+ # Add response to memory
218
+ self.memory.add_to_all({"role": "assistant", "content": response})
219
+
220
+ # Persist memory after processing
221
+ self._persist_memory()
222
+
223
+ return response
224
+
225
+ def _persist_memory(self):
226
+ """Persist memory to disk using standard path."""
227
+ if self.persist_path:
228
+ self.memory.persist(self.persist_path)
229
+ else:
230
+ self.memory.persist(agent_name=self.name)
231
+
232
+ def _load_memory(self):
233
+ """Load memory from disk if available."""
234
+ try:
235
+ if self.persist_path:
236
+ if os.path.exists(self.persist_path):
237
+ self.memory.load(self.persist_path)
238
+ else:
239
+ # Check if standard path exists
240
+ home_dir = str(Path.home())
241
+ agent_dir = os.path.join(home_dir, ".trmx", "agents", self.name)
242
+
243
+ if os.path.exists(agent_dir):
244
+ self.memory.load(agent_name=self.name)
245
+ except Exception as e:
246
+ print(f"Warning: Failed to load memory for {self.name}: {str(e)}")
247
+
248
+
249
+ # Example usage
250
+ if __name__ == "__main__":
251
+ import dotenv
252
+
253
+ # Load environment variables
254
+ dotenv.load_dotenv()
255
+
256
+ # Check for API key
257
+ if not os.getenv("GROQ_API_KEY"):
258
+ print("Error: GROQ_API_KEY environment variable not set")
259
+ exit(1)
260
+
261
+ # Create an agent
262
+ agent = GroqAgent(
263
+ name="GroqAssistant",
264
+ memory_size=5
265
+ )
266
+
267
+ # Add calculator tool if available
268
+ try:
269
+ calculator = ToolFactory.get_tool("calculator")
270
+ agent.add_tool(calculator)
271
+ print(f"Added calculator tool to {agent.name}")
272
+ except ValueError:
273
+ print("Calculator tool not available")
274
+
275
+ # Test the agent
276
+ print(f"\n=== Testing {agent.name} ===")
277
+
278
+ # Process a few inputs
279
+ sample_inputs = [
280
+ "Hello, what can you do?",
281
+ "Can you help me calculate 23.5 * 17?",
282
+ "Thank you! Can you remember that result for me?",
283
+ "What was the calculation result we discussed earlier?"
284
+ ]
285
+
286
+ for i, user_input in enumerate(sample_inputs):
287
+ print(f"\nUser: {user_input}")
288
+ response = agent.process(user_input)
289
+ print(f"{agent.name}: {response}")