airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +148 -2
- airtrain/__main__.py +4 -0
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/builder.py +23 -0
- airtrain/cli/main.py +120 -0
- airtrain/contrib/__init__.py +29 -0
- airtrain/contrib/travel/__init__.py +35 -0
- airtrain/contrib/travel/agents.py +243 -0
- airtrain/contrib/travel/models.py +59 -0
- airtrain/core/__init__.py +7 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +171 -0
- airtrain/core/schemas.py +237 -0
- airtrain/core/skills.py +269 -0
- airtrain/integrations/__init__.py +74 -0
- airtrain/integrations/anthropic/__init__.py +33 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/list_models.py +110 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/anthropic/skills.py +155 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +19 -0
- airtrain/integrations/cerebras/skills.py +127 -0
- airtrain/integrations/combined/__init__.py +21 -0
- airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
- airtrain/integrations/combined/list_models_factory.py +210 -0
- airtrain/integrations/fireworks/__init__.py +21 -0
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/credentials.py +26 -0
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +139 -0
- airtrain/integrations/fireworks/requests_skills.py +207 -0
- airtrain/integrations/fireworks/skills.py +181 -0
- airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/google/__init__.py +7 -0
- airtrain/integrations/google/credentials.py +58 -0
- airtrain/integrations/google/skills.py +122 -0
- airtrain/integrations/groq/__init__.py +23 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +201 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +37 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/list_models.py +112 -0
- airtrain/integrations/openai/models_config.py +224 -0
- airtrain/integrations/openai/skills.py +342 -0
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +129 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +115 -0
- airtrain/integrations/together/__init__.py +33 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +143 -0
- airtrain/integrations/together/list_models.py +76 -0
- airtrain/integrations/together/models.py +95 -0
- airtrain/integrations/together/models_config.py +399 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +305 -0
- airtrain/integrations/together/vision_models_config.py +49 -0
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +45 -0
- airtrain/tools/command.py +398 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- airtrain-0.1.4.dist-info/METADATA +222 -0
- airtrain-0.1.4.dist-info/RECORD +108 -0
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
- airtrain-0.1.4.dist-info/entry_points.txt +2 -0
- airtrain-0.1.2.dist-info/METADATA +0 -106
- airtrain-0.1.2.dist-info/RECORD +0 -5
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
airtrain/__init__.py
CHANGED
@@ -1,3 +1,149 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
Airtrain: AI Agent Framework
|
2
3
|
|
3
|
-
|
4
|
+
This library provides a flexible framework for building AI agents
|
5
|
+
that can complete complex tasks using AI models, skills, and tools.
|
6
|
+
"""
|
7
|
+
|
8
|
+
__version__ = "0.1.4"
|
9
|
+
|
10
|
+
import sys
|
11
|
+
|
12
|
+
# Core imports
|
13
|
+
from .core import Skill, ProcessingError, InputSchema, OutputSchema, BaseCredentials
|
14
|
+
|
15
|
+
# Integration imports - Credentials
|
16
|
+
from .integrations import (
|
17
|
+
# OpenAI
|
18
|
+
OpenAICredentials,
|
19
|
+
OpenAIChatSkill,
|
20
|
+
OpenAICompletionSkill,
|
21
|
+
# Anthropic
|
22
|
+
AnthropicCredentials,
|
23
|
+
AnthropicChatSkill,
|
24
|
+
# Together.ai
|
25
|
+
TogetherAICredentials,
|
26
|
+
TogetherChatSkill,
|
27
|
+
# Fireworks
|
28
|
+
FireworksCredentials,
|
29
|
+
FireworksChatSkill,
|
30
|
+
# Google
|
31
|
+
GeminiCredentials,
|
32
|
+
GeminiChatSkill,
|
33
|
+
# Search
|
34
|
+
ExaCredentials,
|
35
|
+
ExaSearchSkill,
|
36
|
+
ExaSearchInputSchema,
|
37
|
+
ExaSearchOutputSchema,
|
38
|
+
)
|
39
|
+
|
40
|
+
# Integration imports - Skills
|
41
|
+
from .integrations.aws.skills import AWSBedrockSkill
|
42
|
+
from .integrations.google.skills import GoogleChatSkill
|
43
|
+
from .integrations.groq.skills import GroqChatSkill
|
44
|
+
from .integrations.ollama.skills import OllamaChatSkill
|
45
|
+
from .integrations.sambanova.skills import SambanovaChatSkill
|
46
|
+
from .integrations.cerebras.skills import CerebrasChatSkill
|
47
|
+
|
48
|
+
# Tool imports
|
49
|
+
from .tools import (
|
50
|
+
ToolFactory,
|
51
|
+
register_tool,
|
52
|
+
StatelessTool,
|
53
|
+
StatefulTool,
|
54
|
+
BaseTool,
|
55
|
+
ListDirectoryTool,
|
56
|
+
DirectoryTreeTool,
|
57
|
+
ApiCallTool,
|
58
|
+
ExecuteCommandTool,
|
59
|
+
FindFilesTool,
|
60
|
+
TerminalNavigationTool,
|
61
|
+
SearchTermTool,
|
62
|
+
RunPytestTool,
|
63
|
+
)
|
64
|
+
|
65
|
+
# Agent imports
|
66
|
+
from .agents import (
|
67
|
+
BaseAgent,
|
68
|
+
AgentFactory,
|
69
|
+
register_agent,
|
70
|
+
BaseMemory,
|
71
|
+
ShortTermMemory,
|
72
|
+
LongTermMemory,
|
73
|
+
SharedMemory,
|
74
|
+
)
|
75
|
+
|
76
|
+
# Telemetry import - must be imported after version is defined
|
77
|
+
from .telemetry import telemetry
|
78
|
+
from .telemetry import PackageImportTelemetryEvent
|
79
|
+
|
80
|
+
|
81
|
+
__all__ = [
|
82
|
+
# Core
|
83
|
+
"Skill",
|
84
|
+
"ProcessingError",
|
85
|
+
"InputSchema",
|
86
|
+
"OutputSchema",
|
87
|
+
"BaseCredentials",
|
88
|
+
# OpenAI Integration
|
89
|
+
"OpenAICredentials",
|
90
|
+
"OpenAIChatSkill",
|
91
|
+
"OpenAICompletionSkill",
|
92
|
+
# Anthropic Integration
|
93
|
+
"AnthropicCredentials",
|
94
|
+
"AnthropicChatSkill",
|
95
|
+
# Together Integration
|
96
|
+
"TogetherAICredentials",
|
97
|
+
"TogetherChatSkill",
|
98
|
+
# Fireworks Integration
|
99
|
+
"FireworksCredentials",
|
100
|
+
"FireworksChatSkill",
|
101
|
+
# Google Integration
|
102
|
+
"GeminiCredentials",
|
103
|
+
"GeminiChatSkill",
|
104
|
+
# Search Integration
|
105
|
+
"ExaCredentials",
|
106
|
+
"ExaSearchSkill",
|
107
|
+
"ExaSearchInputSchema",
|
108
|
+
"ExaSearchOutputSchema",
|
109
|
+
# Tools
|
110
|
+
"ToolFactory",
|
111
|
+
"register_tool",
|
112
|
+
"StatelessTool",
|
113
|
+
"StatefulTool",
|
114
|
+
"BaseTool",
|
115
|
+
"ListDirectoryTool",
|
116
|
+
"DirectoryTreeTool",
|
117
|
+
"ApiCallTool",
|
118
|
+
"ExecuteCommandTool",
|
119
|
+
"FindFilesTool",
|
120
|
+
"TerminalNavigationTool",
|
121
|
+
"SearchTermTool",
|
122
|
+
"RunPytestTool",
|
123
|
+
# Agents
|
124
|
+
"BaseAgent",
|
125
|
+
"AgentFactory",
|
126
|
+
"register_agent",
|
127
|
+
"BaseMemory",
|
128
|
+
"ShortTermMemory",
|
129
|
+
"LongTermMemory",
|
130
|
+
"SharedMemory",
|
131
|
+
# Telemetry - not directly exposed to users
|
132
|
+
# but initialized at import time
|
133
|
+
]
|
134
|
+
|
135
|
+
# Capture import telemetry
|
136
|
+
try:
|
137
|
+
telemetry.capture(
|
138
|
+
PackageImportTelemetryEvent(
|
139
|
+
version=__version__,
|
140
|
+
python_version=(
|
141
|
+
f"{sys.version_info.major}."
|
142
|
+
f"{sys.version_info.minor}."
|
143
|
+
f"{sys.version_info.micro}"
|
144
|
+
),
|
145
|
+
)
|
146
|
+
)
|
147
|
+
except Exception:
|
148
|
+
# Silently continue if telemetry fails
|
149
|
+
pass
|
airtrain/__main__.py
ADDED
Binary file
|
@@ -0,0 +1,45 @@
|
|
1
|
+
"""
|
2
|
+
Agents package for AirTrain.
|
3
|
+
|
4
|
+
This package provides a registry of agents that can be used to build AI systems.
|
5
|
+
"""
|
6
|
+
|
7
|
+
# Import registry components
|
8
|
+
from .registry import (
|
9
|
+
BaseAgent,
|
10
|
+
AgentFactory,
|
11
|
+
register_agent,
|
12
|
+
AgentRegistry
|
13
|
+
)
|
14
|
+
|
15
|
+
# Import memory components
|
16
|
+
from .memory import (
|
17
|
+
BaseMemory,
|
18
|
+
ShortTermMemory,
|
19
|
+
LongTermMemory,
|
20
|
+
SharedMemory,
|
21
|
+
AgentMemoryManager
|
22
|
+
)
|
23
|
+
|
24
|
+
# Import agent implementations
|
25
|
+
from .groq_agent import GroqAgent
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
# Base classes
|
29
|
+
"BaseAgent",
|
30
|
+
|
31
|
+
# Registry components
|
32
|
+
"AgentFactory",
|
33
|
+
"register_agent",
|
34
|
+
"AgentRegistry",
|
35
|
+
|
36
|
+
# Memory components
|
37
|
+
"BaseMemory",
|
38
|
+
"ShortTermMemory",
|
39
|
+
"LongTermMemory",
|
40
|
+
"SharedMemory",
|
41
|
+
"AgentMemoryManager",
|
42
|
+
|
43
|
+
# Agent implementations
|
44
|
+
"GroqAgent",
|
45
|
+
]
|
@@ -0,0 +1,348 @@
|
|
1
|
+
"""
|
2
|
+
Example Agent implementation for AirTrain.
|
3
|
+
|
4
|
+
This module provides a simple example agent that demonstrates the use of
|
5
|
+
the AirTrain agent framework with memory and tool integration.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from typing import List, Any, Optional
|
9
|
+
|
10
|
+
from airtrain.agents.registry import BaseAgent, register_agent
|
11
|
+
from airtrain.agents.memory import SharedMemory
|
12
|
+
from airtrain.tools import ToolFactory, execute_tool_call
|
13
|
+
|
14
|
+
try:
|
15
|
+
from airtrain.integrations.groq.skills import GroqChatSkill, GroqInput
|
16
|
+
HAS_GROQ = True
|
17
|
+
except ImportError:
|
18
|
+
HAS_GROQ = False
|
19
|
+
|
20
|
+
try:
|
21
|
+
from airtrain.integrations.fireworks.skills import (
|
22
|
+
FireworksChatSkill,
|
23
|
+
FireworksInput
|
24
|
+
)
|
25
|
+
HAS_FIREWORKS = True
|
26
|
+
except ImportError:
|
27
|
+
HAS_FIREWORKS = False
|
28
|
+
|
29
|
+
|
30
|
+
@register_agent("conversation_agent")
|
31
|
+
class ConversationAgent(BaseAgent):
|
32
|
+
"""Agent specialized for conversation with memory management."""
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
name: str,
|
37
|
+
models: Optional[List[str]] = None,
|
38
|
+
tools: Optional[List[Any]] = None,
|
39
|
+
memory_size: int = 10,
|
40
|
+
temperature: float = 0.2,
|
41
|
+
max_tokens: int = 1024
|
42
|
+
):
|
43
|
+
"""
|
44
|
+
Initialize conversation agent.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
name: Name of the agent
|
48
|
+
models: List of model identifiers
|
49
|
+
tools: List of tools for the agent
|
50
|
+
memory_size: Size of the conversation memory
|
51
|
+
temperature: Temperature for generation
|
52
|
+
max_tokens: Maximum tokens for responses
|
53
|
+
"""
|
54
|
+
super().__init__(name, models, tools)
|
55
|
+
|
56
|
+
# Create specialized memories
|
57
|
+
self.create_memory("dialog", memory_size)
|
58
|
+
self.create_memory("reasoning", 5) # Shorter context for reasoning
|
59
|
+
|
60
|
+
self.temperature = temperature
|
61
|
+
self.max_tokens = max_tokens
|
62
|
+
|
63
|
+
# Initialize model backends
|
64
|
+
self._initialize_backends()
|
65
|
+
|
66
|
+
def _initialize_backends(self):
|
67
|
+
"""Initialize available LLM backends based on installed integrations."""
|
68
|
+
self.backends = {}
|
69
|
+
|
70
|
+
if HAS_GROQ:
|
71
|
+
self.backends["groq"] = GroqChatSkill()
|
72
|
+
|
73
|
+
if HAS_FIREWORKS:
|
74
|
+
self.backends["fireworks"] = FireworksChatSkill()
|
75
|
+
|
76
|
+
if not self.backends:
|
77
|
+
raise ImportError(
|
78
|
+
"No LLM backend available. Please install at least one of: "
|
79
|
+
"airtrain-groq, airtrain-fireworks"
|
80
|
+
)
|
81
|
+
|
82
|
+
def _get_backend_for_model(self, model: str):
|
83
|
+
"""Get the appropriate backend for a model."""
|
84
|
+
if model.startswith("llama-") or model.endswith("-groq"):
|
85
|
+
return self.backends.get("groq")
|
86
|
+
elif "fireworks" in model:
|
87
|
+
return self.backends.get("fireworks")
|
88
|
+
|
89
|
+
# Default to first available backend
|
90
|
+
return next(iter(self.backends.values()))
|
91
|
+
|
92
|
+
def _get_tool_definitions(self):
|
93
|
+
"""Get tool definitions for LLM function calling."""
|
94
|
+
return [tool.to_dict() for tool in self.tools]
|
95
|
+
|
96
|
+
def process(self, user_input: str, memory_name: str = "dialog") -> str:
|
97
|
+
"""
|
98
|
+
Process user input and generate a response.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
user_input: User input to process
|
102
|
+
memory_name: Name of the memory to use
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
Agent's response
|
106
|
+
"""
|
107
|
+
if not self.models:
|
108
|
+
raise ValueError("No models configured for agent")
|
109
|
+
|
110
|
+
# 1. Add user input to memories
|
111
|
+
user_message = {"role": "user", "content": user_input}
|
112
|
+
self.memory.add_to_all(user_message)
|
113
|
+
|
114
|
+
# 2. Get context from memory
|
115
|
+
context = self.memory.get_context(memory_name)
|
116
|
+
|
117
|
+
# 3. Prepare conversation history
|
118
|
+
conversation_history = []
|
119
|
+
for message in context:
|
120
|
+
# Skip messages that aren't relevant to the conversation
|
121
|
+
if "role" not in message:
|
122
|
+
continue
|
123
|
+
|
124
|
+
# Convert to format expected by LLM
|
125
|
+
if message["role"] in ["user", "assistant", "system"]:
|
126
|
+
conversation_history.append({
|
127
|
+
"role": message["role"],
|
128
|
+
"content": message.get("content", "")
|
129
|
+
})
|
130
|
+
|
131
|
+
# Add system message if none present
|
132
|
+
if not any(msg["role"] == "system" for msg in conversation_history):
|
133
|
+
conversation_history.insert(0, {
|
134
|
+
"role": "system",
|
135
|
+
"content": (
|
136
|
+
f"You are {self.name}, a helpful AI assistant. "
|
137
|
+
"Provide accurate and concise responses."
|
138
|
+
)
|
139
|
+
})
|
140
|
+
|
141
|
+
# 4. Prepare tool definitions
|
142
|
+
tool_defs = self._get_tool_definitions() if self.tools else None
|
143
|
+
|
144
|
+
# 5. Call primary model
|
145
|
+
primary_model = self.models[0]
|
146
|
+
backend = self._get_backend_for_model(primary_model)
|
147
|
+
|
148
|
+
if "groq" in str(backend.__class__.__name__).lower():
|
149
|
+
# Groq backend
|
150
|
+
input_data = GroqInput(
|
151
|
+
model=primary_model,
|
152
|
+
conversation_history=conversation_history,
|
153
|
+
tools=tool_defs,
|
154
|
+
temperature=self.temperature,
|
155
|
+
max_tokens=self.max_tokens
|
156
|
+
)
|
157
|
+
elif "fireworks" in str(backend.__class__.__name__).lower():
|
158
|
+
# Fireworks backend
|
159
|
+
input_data = FireworksInput(
|
160
|
+
model=primary_model,
|
161
|
+
conversation_history=conversation_history,
|
162
|
+
tools=tool_defs,
|
163
|
+
temperature=self.temperature,
|
164
|
+
max_tokens=self.max_tokens
|
165
|
+
)
|
166
|
+
else:
|
167
|
+
raise ValueError(f"Unsupported backend for model: {primary_model}")
|
168
|
+
|
169
|
+
# Process the request
|
170
|
+
result = backend.process(input_data)
|
171
|
+
|
172
|
+
# 6. Handle tool calls if any
|
173
|
+
if hasattr(result, "tool_calls") and result.tool_calls:
|
174
|
+
# We have tool calls - execute them and get results
|
175
|
+
tool_results = []
|
176
|
+
|
177
|
+
for tool_call in result.tool_calls:
|
178
|
+
# Execute the tool call
|
179
|
+
tool_result = execute_tool_call(tool_call)
|
180
|
+
tool_results.append((tool_call, tool_result))
|
181
|
+
|
182
|
+
# Add to reasoning memory
|
183
|
+
self.memory.add_to_memory("reasoning", {
|
184
|
+
"role": "function",
|
185
|
+
"name": tool_call.get("function", {}).get("name"),
|
186
|
+
"content": str(tool_result)
|
187
|
+
})
|
188
|
+
|
189
|
+
# Create followup with tool results
|
190
|
+
followup_messages = conversation_history.copy()
|
191
|
+
|
192
|
+
# Add the assistant's response that led to tool calls
|
193
|
+
if result.response:
|
194
|
+
followup_messages.append({
|
195
|
+
"role": "assistant",
|
196
|
+
"content": result.response
|
197
|
+
})
|
198
|
+
|
199
|
+
# Add tool results
|
200
|
+
for tool_call, tool_result in tool_results:
|
201
|
+
followup_messages.append({
|
202
|
+
"role": "tool",
|
203
|
+
"tool_call_id": tool_call.get("id", "unknown"),
|
204
|
+
"content": str(tool_result)
|
205
|
+
})
|
206
|
+
|
207
|
+
# Get completion with tool results
|
208
|
+
if "groq" in str(backend.__class__.__name__).lower():
|
209
|
+
followup_input = GroqInput(
|
210
|
+
model=primary_model,
|
211
|
+
conversation_history=followup_messages,
|
212
|
+
temperature=self.temperature,
|
213
|
+
max_tokens=self.max_tokens
|
214
|
+
)
|
215
|
+
else:
|
216
|
+
followup_input = FireworksInput(
|
217
|
+
model=primary_model,
|
218
|
+
conversation_history=followup_messages,
|
219
|
+
temperature=self.temperature,
|
220
|
+
max_tokens=self.max_tokens
|
221
|
+
)
|
222
|
+
|
223
|
+
followup_result = backend.process(followup_input)
|
224
|
+
response = followup_result.response
|
225
|
+
else:
|
226
|
+
# No tool calls, just use the direct response
|
227
|
+
response = result.response
|
228
|
+
|
229
|
+
# 7. Add response to memory
|
230
|
+
self.memory.add_to_all({"role": "assistant", "content": response})
|
231
|
+
|
232
|
+
# 8. Return final response
|
233
|
+
return response
|
234
|
+
|
235
|
+
|
236
|
+
def create_agent_team(shared_memory_name: str = "team_knowledge"):
|
237
|
+
"""
|
238
|
+
Create a team of agents that share memory.
|
239
|
+
|
240
|
+
Args:
|
241
|
+
shared_memory_name: Name for the shared memory
|
242
|
+
|
243
|
+
Returns:
|
244
|
+
Tuple of agent instances
|
245
|
+
"""
|
246
|
+
# Create shared memory
|
247
|
+
shared_memory = SharedMemory(shared_memory_name)
|
248
|
+
|
249
|
+
# Get tools from tool registry
|
250
|
+
calculator_tool = None
|
251
|
+
memory_tool = None
|
252
|
+
|
253
|
+
try:
|
254
|
+
calculator_tool = ToolFactory.get_tool("calculator")
|
255
|
+
except ValueError:
|
256
|
+
pass
|
257
|
+
|
258
|
+
try:
|
259
|
+
memory_tool = ToolFactory.get_tool("conversation_memory", "stateful")
|
260
|
+
except ValueError:
|
261
|
+
pass
|
262
|
+
|
263
|
+
# Determine available models
|
264
|
+
groq_model = "llama-3.1-8b-instant" if HAS_GROQ else None
|
265
|
+
fireworks_model = None
|
266
|
+
if HAS_FIREWORKS:
|
267
|
+
fireworks_model = "accounts/fireworks/models/firefunction-v1"
|
268
|
+
|
269
|
+
# Create agents
|
270
|
+
agents = []
|
271
|
+
|
272
|
+
if groq_model:
|
273
|
+
# Create Groq-based agent
|
274
|
+
agent1 = ConversationAgent(
|
275
|
+
name="GroqAgent",
|
276
|
+
models=[groq_model],
|
277
|
+
tools=[calculator_tool] if calculator_tool else []
|
278
|
+
)
|
279
|
+
agent1.memory.add_shared_memory(shared_memory)
|
280
|
+
agents.append(agent1)
|
281
|
+
|
282
|
+
if fireworks_model:
|
283
|
+
# Create Fireworks-based agent
|
284
|
+
agent2 = ConversationAgent(
|
285
|
+
name="FireworksAgent",
|
286
|
+
models=[fireworks_model],
|
287
|
+
tools=[memory_tool] if memory_tool else []
|
288
|
+
)
|
289
|
+
agent2.memory.add_shared_memory(shared_memory)
|
290
|
+
agents.append(agent2)
|
291
|
+
|
292
|
+
# Return the created agents
|
293
|
+
return tuple(agents)
|
294
|
+
|
295
|
+
|
296
|
+
# Example usage
|
297
|
+
if __name__ == "__main__":
|
298
|
+
import os
|
299
|
+
import dotenv
|
300
|
+
|
301
|
+
# Load environment variables for API keys
|
302
|
+
dotenv.load_dotenv()
|
303
|
+
|
304
|
+
if not (os.getenv("GROQ_API_KEY") or os.getenv("FIREWORKS_API_KEY")):
|
305
|
+
print("No API keys found. Set GROQ_API_KEY or FIREWORKS_API_KEY in .env file.")
|
306
|
+
exit(1)
|
307
|
+
|
308
|
+
# Create an agent
|
309
|
+
try:
|
310
|
+
# Set model based on available backend
|
311
|
+
groq_model = "llama-3.1-8b-instant"
|
312
|
+
fw_model = "accounts/fireworks/models/firefunction-v1"
|
313
|
+
model = groq_model if HAS_GROQ else fw_model
|
314
|
+
|
315
|
+
agent = ConversationAgent(
|
316
|
+
name="TestAgent",
|
317
|
+
models=[model],
|
318
|
+
memory_size=5
|
319
|
+
)
|
320
|
+
|
321
|
+
# Add calculator tool if available
|
322
|
+
try:
|
323
|
+
calculator = ToolFactory.get_tool("calculator")
|
324
|
+
agent.add_tool(calculator)
|
325
|
+
print(f"Added calculator tool to {agent.name}")
|
326
|
+
except ValueError:
|
327
|
+
pass
|
328
|
+
|
329
|
+
# Test the agent
|
330
|
+
print(f"\n=== Testing {agent.name} ===")
|
331
|
+
|
332
|
+
# Process a few inputs
|
333
|
+
sample_inputs = [
|
334
|
+
"Hello, what can you do?",
|
335
|
+
"Can you help me calculate 23.5 * 17?",
|
336
|
+
"Thank you! Can you remember that result for me?",
|
337
|
+
"What was the calculation result we discussed earlier?"
|
338
|
+
]
|
339
|
+
|
340
|
+
for i, user_input in enumerate(sample_inputs):
|
341
|
+
print(f"\nUser: {user_input}")
|
342
|
+
response = agent.process(user_input)
|
343
|
+
print(f"{agent.name}: {response}")
|
344
|
+
|
345
|
+
except ImportError as e:
|
346
|
+
print(f"Error creating agent: {str(e)}")
|
347
|
+
except Exception as e:
|
348
|
+
print(f"Unexpected error: {str(e)}")
|