airtrain 0.1.51__py3-none-any.whl → 0.1.57__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +42 -2
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/core/skills.py +102 -0
- airtrain/integrations/combined/list_models_factory.py +9 -3
- airtrain/integrations/groq/__init__.py +18 -1
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +93 -17
- airtrain/integrations/together/__init__.py +15 -1
- airtrain/integrations/together/models_config.py +123 -1
- airtrain/integrations/together/skills.py +117 -20
- airtrain/telemetry/__init__.py +34 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +173 -0
- airtrain/tools/__init__.py +41 -0
- airtrain/tools/command.py +211 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- {airtrain-0.1.51.dist-info → airtrain-0.1.57.dist-info}/METADATA +37 -1
- {airtrain-0.1.51.dist-info → airtrain-0.1.57.dist-info}/RECORD +27 -13
- {airtrain-0.1.51.dist-info → airtrain-0.1.57.dist-info}/WHEEL +1 -1
- {airtrain-0.1.51.dist-info → airtrain-0.1.57.dist-info}/entry_points.txt +0 -0
- {airtrain-0.1.51.dist-info → airtrain-0.1.57.dist-info}/top_level.txt +0 -0
airtrain/__init__.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
"""Airtrain - A platform for building and deploying AI agents with structured skills"""
|
2
2
|
|
3
|
-
__version__ = "0.1.
|
3
|
+
__version__ = "0.1.57"
|
4
4
|
|
5
5
|
# Core imports
|
6
6
|
from .core.skills import Skill, ProcessingError
|
@@ -29,6 +29,30 @@ from .integrations.ollama.skills import OllamaChatSkill
|
|
29
29
|
from .integrations.sambanova.skills import SambanovaChatSkill
|
30
30
|
from .integrations.cerebras.skills import CerebrasChatSkill
|
31
31
|
|
32
|
+
# Tool imports
|
33
|
+
from .tools import (
|
34
|
+
StatefulTool,
|
35
|
+
StatelessTool,
|
36
|
+
register_tool,
|
37
|
+
ToolFactory,
|
38
|
+
execute_tool_call
|
39
|
+
)
|
40
|
+
|
41
|
+
# Agent imports
|
42
|
+
from .agents import (
|
43
|
+
BaseAgent,
|
44
|
+
AgentFactory,
|
45
|
+
register_agent,
|
46
|
+
BaseMemory,
|
47
|
+
ShortTermMemory,
|
48
|
+
LongTermMemory,
|
49
|
+
SharedMemory
|
50
|
+
)
|
51
|
+
|
52
|
+
# Telemetry import
|
53
|
+
from .telemetry import telemetry
|
54
|
+
|
55
|
+
|
32
56
|
__all__ = [
|
33
57
|
# Core
|
34
58
|
"Skill",
|
@@ -36,7 +60,7 @@ __all__ = [
|
|
36
60
|
"InputSchema",
|
37
61
|
"OutputSchema",
|
38
62
|
"BaseCredentials",
|
39
|
-
# Credentials
|
63
|
+
# Credentials
|
40
64
|
"OpenAICredentials",
|
41
65
|
"AWSCredentials",
|
42
66
|
"GoogleCloudCredentials",
|
@@ -57,4 +81,20 @@ __all__ = [
|
|
57
81
|
"OllamaChatSkill",
|
58
82
|
"SambanovaChatSkill",
|
59
83
|
"CerebrasChatSkill",
|
84
|
+
# Tools
|
85
|
+
"StatefulTool",
|
86
|
+
"StatelessTool",
|
87
|
+
"register_tool",
|
88
|
+
"ToolFactory",
|
89
|
+
"execute_tool_call",
|
90
|
+
# Agents
|
91
|
+
"BaseAgent",
|
92
|
+
"AgentFactory",
|
93
|
+
"register_agent",
|
94
|
+
"BaseMemory",
|
95
|
+
"ShortTermMemory",
|
96
|
+
"LongTermMemory",
|
97
|
+
"SharedMemory",
|
98
|
+
# Telemetry - not directly exposed to users
|
99
|
+
# but initialized at import time
|
60
100
|
]
|
@@ -0,0 +1,45 @@
|
|
1
|
+
"""
|
2
|
+
Agents package for AirTrain.
|
3
|
+
|
4
|
+
This package provides a registry of agents that can be used to build AI systems.
|
5
|
+
"""
|
6
|
+
|
7
|
+
# Import registry components
|
8
|
+
from .registry import (
|
9
|
+
BaseAgent,
|
10
|
+
AgentFactory,
|
11
|
+
register_agent,
|
12
|
+
AgentRegistry
|
13
|
+
)
|
14
|
+
|
15
|
+
# Import memory components
|
16
|
+
from .memory import (
|
17
|
+
BaseMemory,
|
18
|
+
ShortTermMemory,
|
19
|
+
LongTermMemory,
|
20
|
+
SharedMemory,
|
21
|
+
AgentMemoryManager
|
22
|
+
)
|
23
|
+
|
24
|
+
# Import agent implementations
|
25
|
+
from .groq_agent import GroqAgent
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
# Base classes
|
29
|
+
"BaseAgent",
|
30
|
+
|
31
|
+
# Registry components
|
32
|
+
"AgentFactory",
|
33
|
+
"register_agent",
|
34
|
+
"AgentRegistry",
|
35
|
+
|
36
|
+
# Memory components
|
37
|
+
"BaseMemory",
|
38
|
+
"ShortTermMemory",
|
39
|
+
"LongTermMemory",
|
40
|
+
"SharedMemory",
|
41
|
+
"AgentMemoryManager",
|
42
|
+
|
43
|
+
# Agent implementations
|
44
|
+
"GroqAgent",
|
45
|
+
]
|
@@ -0,0 +1,348 @@
|
|
1
|
+
"""
|
2
|
+
Example Agent implementation for AirTrain.
|
3
|
+
|
4
|
+
This module provides a simple example agent that demonstrates the use of
|
5
|
+
the AirTrain agent framework with memory and tool integration.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from typing import List, Any, Optional
|
9
|
+
|
10
|
+
from airtrain.agents.registry import BaseAgent, register_agent
|
11
|
+
from airtrain.agents.memory import SharedMemory
|
12
|
+
from airtrain.tools import ToolFactory, execute_tool_call
|
13
|
+
|
14
|
+
try:
|
15
|
+
from airtrain.integrations.groq.skills import GroqChatSkill, GroqInput
|
16
|
+
HAS_GROQ = True
|
17
|
+
except ImportError:
|
18
|
+
HAS_GROQ = False
|
19
|
+
|
20
|
+
try:
|
21
|
+
from airtrain.integrations.fireworks.skills import (
|
22
|
+
FireworksChatSkill,
|
23
|
+
FireworksInput
|
24
|
+
)
|
25
|
+
HAS_FIREWORKS = True
|
26
|
+
except ImportError:
|
27
|
+
HAS_FIREWORKS = False
|
28
|
+
|
29
|
+
|
30
|
+
@register_agent("conversation_agent")
|
31
|
+
class ConversationAgent(BaseAgent):
|
32
|
+
"""Agent specialized for conversation with memory management."""
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
name: str,
|
37
|
+
models: Optional[List[str]] = None,
|
38
|
+
tools: Optional[List[Any]] = None,
|
39
|
+
memory_size: int = 10,
|
40
|
+
temperature: float = 0.2,
|
41
|
+
max_tokens: int = 1024
|
42
|
+
):
|
43
|
+
"""
|
44
|
+
Initialize conversation agent.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
name: Name of the agent
|
48
|
+
models: List of model identifiers
|
49
|
+
tools: List of tools for the agent
|
50
|
+
memory_size: Size of the conversation memory
|
51
|
+
temperature: Temperature for generation
|
52
|
+
max_tokens: Maximum tokens for responses
|
53
|
+
"""
|
54
|
+
super().__init__(name, models, tools)
|
55
|
+
|
56
|
+
# Create specialized memories
|
57
|
+
self.create_memory("dialog", memory_size)
|
58
|
+
self.create_memory("reasoning", 5) # Shorter context for reasoning
|
59
|
+
|
60
|
+
self.temperature = temperature
|
61
|
+
self.max_tokens = max_tokens
|
62
|
+
|
63
|
+
# Initialize model backends
|
64
|
+
self._initialize_backends()
|
65
|
+
|
66
|
+
def _initialize_backends(self):
|
67
|
+
"""Initialize available LLM backends based on installed integrations."""
|
68
|
+
self.backends = {}
|
69
|
+
|
70
|
+
if HAS_GROQ:
|
71
|
+
self.backends["groq"] = GroqChatSkill()
|
72
|
+
|
73
|
+
if HAS_FIREWORKS:
|
74
|
+
self.backends["fireworks"] = FireworksChatSkill()
|
75
|
+
|
76
|
+
if not self.backends:
|
77
|
+
raise ImportError(
|
78
|
+
"No LLM backend available. Please install at least one of: "
|
79
|
+
"airtrain-groq, airtrain-fireworks"
|
80
|
+
)
|
81
|
+
|
82
|
+
def _get_backend_for_model(self, model: str):
|
83
|
+
"""Get the appropriate backend for a model."""
|
84
|
+
if model.startswith("llama-") or model.endswith("-groq"):
|
85
|
+
return self.backends.get("groq")
|
86
|
+
elif "fireworks" in model:
|
87
|
+
return self.backends.get("fireworks")
|
88
|
+
|
89
|
+
# Default to first available backend
|
90
|
+
return next(iter(self.backends.values()))
|
91
|
+
|
92
|
+
def _get_tool_definitions(self):
|
93
|
+
"""Get tool definitions for LLM function calling."""
|
94
|
+
return [tool.to_dict() for tool in self.tools]
|
95
|
+
|
96
|
+
def process(self, user_input: str, memory_name: str = "dialog") -> str:
|
97
|
+
"""
|
98
|
+
Process user input and generate a response.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
user_input: User input to process
|
102
|
+
memory_name: Name of the memory to use
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
Agent's response
|
106
|
+
"""
|
107
|
+
if not self.models:
|
108
|
+
raise ValueError("No models configured for agent")
|
109
|
+
|
110
|
+
# 1. Add user input to memories
|
111
|
+
user_message = {"role": "user", "content": user_input}
|
112
|
+
self.memory.add_to_all(user_message)
|
113
|
+
|
114
|
+
# 2. Get context from memory
|
115
|
+
context = self.memory.get_context(memory_name)
|
116
|
+
|
117
|
+
# 3. Prepare conversation history
|
118
|
+
conversation_history = []
|
119
|
+
for message in context:
|
120
|
+
# Skip messages that aren't relevant to the conversation
|
121
|
+
if "role" not in message:
|
122
|
+
continue
|
123
|
+
|
124
|
+
# Convert to format expected by LLM
|
125
|
+
if message["role"] in ["user", "assistant", "system"]:
|
126
|
+
conversation_history.append({
|
127
|
+
"role": message["role"],
|
128
|
+
"content": message.get("content", "")
|
129
|
+
})
|
130
|
+
|
131
|
+
# Add system message if none present
|
132
|
+
if not any(msg["role"] == "system" for msg in conversation_history):
|
133
|
+
conversation_history.insert(0, {
|
134
|
+
"role": "system",
|
135
|
+
"content": (
|
136
|
+
f"You are {self.name}, a helpful AI assistant. "
|
137
|
+
"Provide accurate and concise responses."
|
138
|
+
)
|
139
|
+
})
|
140
|
+
|
141
|
+
# 4. Prepare tool definitions
|
142
|
+
tool_defs = self._get_tool_definitions() if self.tools else None
|
143
|
+
|
144
|
+
# 5. Call primary model
|
145
|
+
primary_model = self.models[0]
|
146
|
+
backend = self._get_backend_for_model(primary_model)
|
147
|
+
|
148
|
+
if "groq" in str(backend.__class__.__name__).lower():
|
149
|
+
# Groq backend
|
150
|
+
input_data = GroqInput(
|
151
|
+
model=primary_model,
|
152
|
+
conversation_history=conversation_history,
|
153
|
+
tools=tool_defs,
|
154
|
+
temperature=self.temperature,
|
155
|
+
max_tokens=self.max_tokens
|
156
|
+
)
|
157
|
+
elif "fireworks" in str(backend.__class__.__name__).lower():
|
158
|
+
# Fireworks backend
|
159
|
+
input_data = FireworksInput(
|
160
|
+
model=primary_model,
|
161
|
+
conversation_history=conversation_history,
|
162
|
+
tools=tool_defs,
|
163
|
+
temperature=self.temperature,
|
164
|
+
max_tokens=self.max_tokens
|
165
|
+
)
|
166
|
+
else:
|
167
|
+
raise ValueError(f"Unsupported backend for model: {primary_model}")
|
168
|
+
|
169
|
+
# Process the request
|
170
|
+
result = backend.process(input_data)
|
171
|
+
|
172
|
+
# 6. Handle tool calls if any
|
173
|
+
if hasattr(result, "tool_calls") and result.tool_calls:
|
174
|
+
# We have tool calls - execute them and get results
|
175
|
+
tool_results = []
|
176
|
+
|
177
|
+
for tool_call in result.tool_calls:
|
178
|
+
# Execute the tool call
|
179
|
+
tool_result = execute_tool_call(tool_call)
|
180
|
+
tool_results.append((tool_call, tool_result))
|
181
|
+
|
182
|
+
# Add to reasoning memory
|
183
|
+
self.memory.add_to_memory("reasoning", {
|
184
|
+
"role": "function",
|
185
|
+
"name": tool_call.get("function", {}).get("name"),
|
186
|
+
"content": str(tool_result)
|
187
|
+
})
|
188
|
+
|
189
|
+
# Create followup with tool results
|
190
|
+
followup_messages = conversation_history.copy()
|
191
|
+
|
192
|
+
# Add the assistant's response that led to tool calls
|
193
|
+
if result.response:
|
194
|
+
followup_messages.append({
|
195
|
+
"role": "assistant",
|
196
|
+
"content": result.response
|
197
|
+
})
|
198
|
+
|
199
|
+
# Add tool results
|
200
|
+
for tool_call, tool_result in tool_results:
|
201
|
+
followup_messages.append({
|
202
|
+
"role": "tool",
|
203
|
+
"tool_call_id": tool_call.get("id", "unknown"),
|
204
|
+
"content": str(tool_result)
|
205
|
+
})
|
206
|
+
|
207
|
+
# Get completion with tool results
|
208
|
+
if "groq" in str(backend.__class__.__name__).lower():
|
209
|
+
followup_input = GroqInput(
|
210
|
+
model=primary_model,
|
211
|
+
conversation_history=followup_messages,
|
212
|
+
temperature=self.temperature,
|
213
|
+
max_tokens=self.max_tokens
|
214
|
+
)
|
215
|
+
else:
|
216
|
+
followup_input = FireworksInput(
|
217
|
+
model=primary_model,
|
218
|
+
conversation_history=followup_messages,
|
219
|
+
temperature=self.temperature,
|
220
|
+
max_tokens=self.max_tokens
|
221
|
+
)
|
222
|
+
|
223
|
+
followup_result = backend.process(followup_input)
|
224
|
+
response = followup_result.response
|
225
|
+
else:
|
226
|
+
# No tool calls, just use the direct response
|
227
|
+
response = result.response
|
228
|
+
|
229
|
+
# 7. Add response to memory
|
230
|
+
self.memory.add_to_all({"role": "assistant", "content": response})
|
231
|
+
|
232
|
+
# 8. Return final response
|
233
|
+
return response
|
234
|
+
|
235
|
+
|
236
|
+
def create_agent_team(shared_memory_name: str = "team_knowledge"):
|
237
|
+
"""
|
238
|
+
Create a team of agents that share memory.
|
239
|
+
|
240
|
+
Args:
|
241
|
+
shared_memory_name: Name for the shared memory
|
242
|
+
|
243
|
+
Returns:
|
244
|
+
Tuple of agent instances
|
245
|
+
"""
|
246
|
+
# Create shared memory
|
247
|
+
shared_memory = SharedMemory(shared_memory_name)
|
248
|
+
|
249
|
+
# Get tools from tool registry
|
250
|
+
calculator_tool = None
|
251
|
+
memory_tool = None
|
252
|
+
|
253
|
+
try:
|
254
|
+
calculator_tool = ToolFactory.get_tool("calculator")
|
255
|
+
except ValueError:
|
256
|
+
pass
|
257
|
+
|
258
|
+
try:
|
259
|
+
memory_tool = ToolFactory.get_tool("conversation_memory", "stateful")
|
260
|
+
except ValueError:
|
261
|
+
pass
|
262
|
+
|
263
|
+
# Determine available models
|
264
|
+
groq_model = "llama-3.1-8b-instant" if HAS_GROQ else None
|
265
|
+
fireworks_model = None
|
266
|
+
if HAS_FIREWORKS:
|
267
|
+
fireworks_model = "accounts/fireworks/models/firefunction-v1"
|
268
|
+
|
269
|
+
# Create agents
|
270
|
+
agents = []
|
271
|
+
|
272
|
+
if groq_model:
|
273
|
+
# Create Groq-based agent
|
274
|
+
agent1 = ConversationAgent(
|
275
|
+
name="GroqAgent",
|
276
|
+
models=[groq_model],
|
277
|
+
tools=[calculator_tool] if calculator_tool else []
|
278
|
+
)
|
279
|
+
agent1.memory.add_shared_memory(shared_memory)
|
280
|
+
agents.append(agent1)
|
281
|
+
|
282
|
+
if fireworks_model:
|
283
|
+
# Create Fireworks-based agent
|
284
|
+
agent2 = ConversationAgent(
|
285
|
+
name="FireworksAgent",
|
286
|
+
models=[fireworks_model],
|
287
|
+
tools=[memory_tool] if memory_tool else []
|
288
|
+
)
|
289
|
+
agent2.memory.add_shared_memory(shared_memory)
|
290
|
+
agents.append(agent2)
|
291
|
+
|
292
|
+
# Return the created agents
|
293
|
+
return tuple(agents)
|
294
|
+
|
295
|
+
|
296
|
+
# Example usage
|
297
|
+
if __name__ == "__main__":
|
298
|
+
import os
|
299
|
+
import dotenv
|
300
|
+
|
301
|
+
# Load environment variables for API keys
|
302
|
+
dotenv.load_dotenv()
|
303
|
+
|
304
|
+
if not (os.getenv("GROQ_API_KEY") or os.getenv("FIREWORKS_API_KEY")):
|
305
|
+
print("No API keys found. Set GROQ_API_KEY or FIREWORKS_API_KEY in .env file.")
|
306
|
+
exit(1)
|
307
|
+
|
308
|
+
# Create an agent
|
309
|
+
try:
|
310
|
+
# Set model based on available backend
|
311
|
+
groq_model = "llama-3.1-8b-instant"
|
312
|
+
fw_model = "accounts/fireworks/models/firefunction-v1"
|
313
|
+
model = groq_model if HAS_GROQ else fw_model
|
314
|
+
|
315
|
+
agent = ConversationAgent(
|
316
|
+
name="TestAgent",
|
317
|
+
models=[model],
|
318
|
+
memory_size=5
|
319
|
+
)
|
320
|
+
|
321
|
+
# Add calculator tool if available
|
322
|
+
try:
|
323
|
+
calculator = ToolFactory.get_tool("calculator")
|
324
|
+
agent.add_tool(calculator)
|
325
|
+
print(f"Added calculator tool to {agent.name}")
|
326
|
+
except ValueError:
|
327
|
+
pass
|
328
|
+
|
329
|
+
# Test the agent
|
330
|
+
print(f"\n=== Testing {agent.name} ===")
|
331
|
+
|
332
|
+
# Process a few inputs
|
333
|
+
sample_inputs = [
|
334
|
+
"Hello, what can you do?",
|
335
|
+
"Can you help me calculate 23.5 * 17?",
|
336
|
+
"Thank you! Can you remember that result for me?",
|
337
|
+
"What was the calculation result we discussed earlier?"
|
338
|
+
]
|
339
|
+
|
340
|
+
for i, user_input in enumerate(sample_inputs):
|
341
|
+
print(f"\nUser: {user_input}")
|
342
|
+
response = agent.process(user_input)
|
343
|
+
print(f"{agent.name}: {response}")
|
344
|
+
|
345
|
+
except ImportError as e:
|
346
|
+
print(f"Error creating agent: {str(e)}")
|
347
|
+
except Exception as e:
|
348
|
+
print(f"Unexpected error: {str(e)}")
|