airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +148 -2
- airtrain/__main__.py +4 -0
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/builder.py +23 -0
- airtrain/cli/main.py +120 -0
- airtrain/contrib/__init__.py +29 -0
- airtrain/contrib/travel/__init__.py +35 -0
- airtrain/contrib/travel/agents.py +243 -0
- airtrain/contrib/travel/models.py +59 -0
- airtrain/core/__init__.py +7 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +171 -0
- airtrain/core/schemas.py +237 -0
- airtrain/core/skills.py +269 -0
- airtrain/integrations/__init__.py +74 -0
- airtrain/integrations/anthropic/__init__.py +33 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/list_models.py +110 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/anthropic/skills.py +155 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +19 -0
- airtrain/integrations/cerebras/skills.py +127 -0
- airtrain/integrations/combined/__init__.py +21 -0
- airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
- airtrain/integrations/combined/list_models_factory.py +210 -0
- airtrain/integrations/fireworks/__init__.py +21 -0
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/credentials.py +26 -0
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +139 -0
- airtrain/integrations/fireworks/requests_skills.py +207 -0
- airtrain/integrations/fireworks/skills.py +181 -0
- airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/google/__init__.py +7 -0
- airtrain/integrations/google/credentials.py +58 -0
- airtrain/integrations/google/skills.py +122 -0
- airtrain/integrations/groq/__init__.py +23 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +201 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +37 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/list_models.py +112 -0
- airtrain/integrations/openai/models_config.py +224 -0
- airtrain/integrations/openai/skills.py +342 -0
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +129 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +115 -0
- airtrain/integrations/together/__init__.py +33 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +143 -0
- airtrain/integrations/together/list_models.py +76 -0
- airtrain/integrations/together/models.py +95 -0
- airtrain/integrations/together/models_config.py +399 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +305 -0
- airtrain/integrations/together/vision_models_config.py +49 -0
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +45 -0
- airtrain/tools/command.py +398 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- airtrain-0.1.4.dist-info/METADATA +222 -0
- airtrain-0.1.4.dist-info/RECORD +108 -0
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
- airtrain-0.1.4.dist-info/entry_points.txt +2 -0
- airtrain-0.1.2.dist-info/METADATA +0 -106
- airtrain-0.1.2.dist-info/RECORD +0 -5
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,465 @@
|
|
1
|
+
"""
|
2
|
+
Agent Registry System for AirTrain
|
3
|
+
|
4
|
+
This module provides a registry system for agents that can be used to build AI systems.
|
5
|
+
It includes:
|
6
|
+
- Base agent class
|
7
|
+
- Registration decorators
|
8
|
+
- Factory methods for agent creation
|
9
|
+
- Discovery utilities for finding available agents
|
10
|
+
"""
|
11
|
+
|
12
|
+
from abc import ABC, abstractmethod
|
13
|
+
import time
|
14
|
+
import uuid
|
15
|
+
from typing import List, Optional, Type, TypeVar, Union
|
16
|
+
import inspect
|
17
|
+
|
18
|
+
# Import tool registry components
|
19
|
+
from airtrain.tools import ToolFactory, BaseTool
|
20
|
+
from airtrain.telemetry import (
|
21
|
+
telemetry,
|
22
|
+
AgentRunTelemetryEvent,
|
23
|
+
AgentStepTelemetryEvent,
|
24
|
+
AgentEndTelemetryEvent,
|
25
|
+
ModelInvocationTelemetryEvent,
|
26
|
+
ErrorTelemetryEvent
|
27
|
+
)
|
28
|
+
from .memory import AgentMemoryManager
|
29
|
+
|
30
|
+
|
31
|
+
# Type variable for agent classes
|
32
|
+
A = TypeVar('A', bound='BaseAgent')
|
33
|
+
|
34
|
+
# Registry structure for agent classes
|
35
|
+
AGENT_REGISTRY = {}
|
36
|
+
|
37
|
+
|
38
|
+
class BaseAgent(ABC):
|
39
|
+
"""Base class for all agents."""
|
40
|
+
|
41
|
+
# These will be set by the registration decorator
|
42
|
+
agent_name: str = None
|
43
|
+
|
44
|
+
def __init__(
|
45
|
+
self,
|
46
|
+
name: str,
|
47
|
+
models: Optional[List[str]] = None,
|
48
|
+
tools: Optional[List[BaseTool]] = None
|
49
|
+
):
|
50
|
+
"""
|
51
|
+
Initialize an agent.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
name: Name of the agent instance
|
55
|
+
models: List of model identifiers to use
|
56
|
+
tools: List of tools the agent can use
|
57
|
+
"""
|
58
|
+
self.name = name
|
59
|
+
self.models = models or []
|
60
|
+
self.tools = tools or []
|
61
|
+
self.memory = AgentMemoryManager()
|
62
|
+
|
63
|
+
# Generate unique agent ID for telemetry
|
64
|
+
self.agent_id = f"{name}-{uuid.uuid4()}"
|
65
|
+
self.start_time = None
|
66
|
+
self.step_count = 0
|
67
|
+
self.total_tokens = 0
|
68
|
+
self.prompt_tokens = 0
|
69
|
+
self.completion_tokens = 0
|
70
|
+
self.errors = []
|
71
|
+
|
72
|
+
# Initialize default short-term memory
|
73
|
+
self.memory.create_short_term_memory("default")
|
74
|
+
|
75
|
+
def add_tool(self, tool: BaseTool):
|
76
|
+
"""
|
77
|
+
Add a tool to the agent.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
tool: Tool instance to add
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
Self for method chaining
|
84
|
+
"""
|
85
|
+
self.tools.append(tool)
|
86
|
+
return self
|
87
|
+
|
88
|
+
def register_tools(self, tools: List[BaseTool]):
|
89
|
+
"""
|
90
|
+
Register multiple tools with the agent.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
tools: List of tool instances to add
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
Self for method chaining
|
97
|
+
"""
|
98
|
+
self.tools.extend(tools)
|
99
|
+
return self
|
100
|
+
|
101
|
+
def create_memory(self, name: str, max_messages: int = 10):
|
102
|
+
"""
|
103
|
+
Create a new short-term memory.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
name: Name for the memory
|
107
|
+
max_messages: Maximum messages before summarization
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
The created memory instance
|
111
|
+
"""
|
112
|
+
return self.memory.create_short_term_memory(name, max_messages)
|
113
|
+
|
114
|
+
def reset_memory(self, name: str = "default"):
|
115
|
+
"""
|
116
|
+
Reset a specific short-term memory.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
name: Name of the memory to reset
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
Self for method chaining
|
123
|
+
"""
|
124
|
+
self.memory.reset_short_term_memory(name)
|
125
|
+
return self
|
126
|
+
|
127
|
+
def start_run(self, task: str, model_name: str, model_provider: str):
|
128
|
+
"""
|
129
|
+
Start an agent run and send telemetry.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
task: Description of the task
|
133
|
+
model_name: Name of the model being used
|
134
|
+
model_provider: Provider of the model
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
Self for method chaining
|
138
|
+
"""
|
139
|
+
self.start_time = time.time()
|
140
|
+
self.step_count = 0
|
141
|
+
self.total_tokens = 0
|
142
|
+
self.prompt_tokens = 0
|
143
|
+
self.completion_tokens = 0
|
144
|
+
self.errors = []
|
145
|
+
|
146
|
+
# Send run event
|
147
|
+
event = AgentRunTelemetryEvent(
|
148
|
+
agent_id=self.agent_id,
|
149
|
+
task=task,
|
150
|
+
model_name=model_name,
|
151
|
+
model_provider=model_provider,
|
152
|
+
version=self._get_package_version(),
|
153
|
+
source=self.__class__.__name__
|
154
|
+
)
|
155
|
+
telemetry.capture(event)
|
156
|
+
return self
|
157
|
+
|
158
|
+
def record_step(self, actions: List[dict], step_error: List[str] = None):
|
159
|
+
"""
|
160
|
+
Record an agent step and send telemetry.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
actions: List of actions taken in this step
|
164
|
+
step_error: Optional list of errors encountered
|
165
|
+
|
166
|
+
Returns:
|
167
|
+
Self for method chaining
|
168
|
+
"""
|
169
|
+
self.step_count += 1
|
170
|
+
consecutive_failures = len(self.errors)
|
171
|
+
|
172
|
+
# Send step event
|
173
|
+
event = AgentStepTelemetryEvent(
|
174
|
+
agent_id=self.agent_id,
|
175
|
+
step=self.step_count,
|
176
|
+
step_error=step_error or [],
|
177
|
+
consecutive_failures=consecutive_failures,
|
178
|
+
actions=actions
|
179
|
+
)
|
180
|
+
telemetry.capture(event)
|
181
|
+
return self
|
182
|
+
|
183
|
+
def record_model_usage(
|
184
|
+
self,
|
185
|
+
model_name: str,
|
186
|
+
model_provider: str,
|
187
|
+
tokens: int,
|
188
|
+
prompt_tokens: int,
|
189
|
+
completion_tokens: int,
|
190
|
+
duration_seconds: float,
|
191
|
+
error: str = None
|
192
|
+
):
|
193
|
+
"""
|
194
|
+
Record model usage and send telemetry.
|
195
|
+
|
196
|
+
Args:
|
197
|
+
model_name: Name of the model used
|
198
|
+
model_provider: Provider of the model
|
199
|
+
tokens: Total tokens used
|
200
|
+
prompt_tokens: Tokens used in the prompt
|
201
|
+
completion_tokens: Tokens used in the completion
|
202
|
+
duration_seconds: Duration of the model call
|
203
|
+
error: Optional error message
|
204
|
+
|
205
|
+
Returns:
|
206
|
+
Self for method chaining
|
207
|
+
"""
|
208
|
+
self.total_tokens += tokens
|
209
|
+
self.prompt_tokens += prompt_tokens
|
210
|
+
self.completion_tokens += completion_tokens
|
211
|
+
|
212
|
+
# Send model usage event
|
213
|
+
event = ModelInvocationTelemetryEvent(
|
214
|
+
agent_id=self.agent_id,
|
215
|
+
model_name=model_name,
|
216
|
+
model_provider=model_provider,
|
217
|
+
tokens=tokens,
|
218
|
+
prompt_tokens=prompt_tokens,
|
219
|
+
completion_tokens=completion_tokens,
|
220
|
+
duration_seconds=duration_seconds,
|
221
|
+
error=error
|
222
|
+
)
|
223
|
+
telemetry.capture(event)
|
224
|
+
return self
|
225
|
+
|
226
|
+
def end_run(self, is_done: bool, success: bool = None):
|
227
|
+
"""
|
228
|
+
End an agent run and send telemetry.
|
229
|
+
|
230
|
+
Args:
|
231
|
+
is_done: Whether the agent completed its task
|
232
|
+
success: Whether the agent was successful
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
Self for method chaining
|
236
|
+
"""
|
237
|
+
duration = time.time() - self.start_time if self.start_time else 0
|
238
|
+
|
239
|
+
# Send end event
|
240
|
+
event = AgentEndTelemetryEvent(
|
241
|
+
agent_id=self.agent_id,
|
242
|
+
steps=self.step_count,
|
243
|
+
is_done=is_done,
|
244
|
+
success=success,
|
245
|
+
total_tokens=self.total_tokens,
|
246
|
+
prompt_tokens=self.prompt_tokens,
|
247
|
+
completion_tokens=self.completion_tokens,
|
248
|
+
total_duration_seconds=duration,
|
249
|
+
errors=self.errors
|
250
|
+
)
|
251
|
+
telemetry.capture(event)
|
252
|
+
return self
|
253
|
+
|
254
|
+
def record_error(self, error_type: str, error_message: str, component: str):
|
255
|
+
"""
|
256
|
+
Record an error and send telemetry.
|
257
|
+
|
258
|
+
Args:
|
259
|
+
error_type: Type of the error
|
260
|
+
error_message: Error message
|
261
|
+
component: Component where the error occurred
|
262
|
+
|
263
|
+
Returns:
|
264
|
+
Self for method chaining
|
265
|
+
"""
|
266
|
+
self.errors.append(error_message)
|
267
|
+
|
268
|
+
# Send error event
|
269
|
+
event = ErrorTelemetryEvent(
|
270
|
+
error_type=error_type,
|
271
|
+
error_message=error_message,
|
272
|
+
component=component,
|
273
|
+
agent_id=self.agent_id
|
274
|
+
)
|
275
|
+
telemetry.capture(event)
|
276
|
+
return self
|
277
|
+
|
278
|
+
def _get_package_version(self) -> str:
|
279
|
+
"""Get the package version for telemetry."""
|
280
|
+
try:
|
281
|
+
from airtrain import __version__
|
282
|
+
return __version__
|
283
|
+
except ImportError:
|
284
|
+
return "unknown"
|
285
|
+
|
286
|
+
@abstractmethod
|
287
|
+
def process(self, user_input: str, memory_name: str = "default") -> str:
|
288
|
+
"""
|
289
|
+
Process user input using a specific memory context.
|
290
|
+
|
291
|
+
Args:
|
292
|
+
user_input: User input to process
|
293
|
+
memory_name: Name of the memory to use for context
|
294
|
+
|
295
|
+
Returns:
|
296
|
+
Agent's response
|
297
|
+
"""
|
298
|
+
pass
|
299
|
+
|
300
|
+
|
301
|
+
class AgentRegistry:
|
302
|
+
"""Registry for agent classes."""
|
303
|
+
|
304
|
+
@classmethod
|
305
|
+
def register(cls, name: Optional[str] = None):
|
306
|
+
"""
|
307
|
+
Decorator to register an agent class.
|
308
|
+
|
309
|
+
Args:
|
310
|
+
name: Optional name for the agent class
|
311
|
+
|
312
|
+
Returns:
|
313
|
+
Decorator function
|
314
|
+
"""
|
315
|
+
def decorator(agent_class: Type[BaseAgent]) -> Type[BaseAgent]:
|
316
|
+
"""
|
317
|
+
Register an agent class with the registry.
|
318
|
+
|
319
|
+
Args:
|
320
|
+
agent_class: Agent class to register
|
321
|
+
|
322
|
+
Returns:
|
323
|
+
The registered agent class
|
324
|
+
"""
|
325
|
+
# Validate agent class
|
326
|
+
if not issubclass(agent_class, BaseAgent):
|
327
|
+
raise TypeError(
|
328
|
+
f"Agent class {agent_class.__name__} must inherit from BaseAgent"
|
329
|
+
)
|
330
|
+
|
331
|
+
# Check for required methods
|
332
|
+
has_process = hasattr(agent_class, 'process')
|
333
|
+
is_callable = inspect.isfunction(getattr(agent_class, 'process', None))
|
334
|
+
if not has_process or not is_callable:
|
335
|
+
raise TypeError(
|
336
|
+
f"Agent class {agent_class.__name__} must implement process method"
|
337
|
+
)
|
338
|
+
|
339
|
+
# Determine agent name
|
340
|
+
agent_name = name or agent_class.__name__
|
341
|
+
|
342
|
+
# Check for name conflict
|
343
|
+
if agent_name in AGENT_REGISTRY:
|
344
|
+
raise ValueError(f"Agent '{agent_name}' already registered")
|
345
|
+
|
346
|
+
# Register the agent class
|
347
|
+
AGENT_REGISTRY[agent_name] = agent_class
|
348
|
+
|
349
|
+
# Add metadata to the class
|
350
|
+
agent_class.agent_name = agent_name
|
351
|
+
|
352
|
+
return agent_class
|
353
|
+
|
354
|
+
return decorator
|
355
|
+
|
356
|
+
@classmethod
|
357
|
+
def get_agent_class(cls, name: str) -> Type[BaseAgent]:
|
358
|
+
"""
|
359
|
+
Get agent class by name.
|
360
|
+
|
361
|
+
Args:
|
362
|
+
name: Name of the agent class
|
363
|
+
|
364
|
+
Returns:
|
365
|
+
The agent class
|
366
|
+
|
367
|
+
Raises:
|
368
|
+
ValueError: If agent not found
|
369
|
+
"""
|
370
|
+
if name not in AGENT_REGISTRY:
|
371
|
+
raise ValueError(f"Agent '{name}' not found in registry")
|
372
|
+
return AGENT_REGISTRY[name]
|
373
|
+
|
374
|
+
@classmethod
|
375
|
+
def list_agents(cls) -> List[str]:
|
376
|
+
"""
|
377
|
+
List all registered agents.
|
378
|
+
|
379
|
+
Returns:
|
380
|
+
List of agent names
|
381
|
+
"""
|
382
|
+
return list(AGENT_REGISTRY.keys())
|
383
|
+
|
384
|
+
|
385
|
+
class AgentFactory:
|
386
|
+
"""Factory for creating agent instances."""
|
387
|
+
|
388
|
+
@staticmethod
|
389
|
+
def create_agent(
|
390
|
+
agent_type: str,
|
391
|
+
name: Optional[str] = None,
|
392
|
+
models: Optional[List[str]] = None,
|
393
|
+
tools: Optional[List[Union[str, BaseTool]]] = None,
|
394
|
+
**kwargs
|
395
|
+
) -> BaseAgent:
|
396
|
+
"""
|
397
|
+
Create an agent instance.
|
398
|
+
|
399
|
+
Args:
|
400
|
+
agent_type: Type of agent to create
|
401
|
+
name: Name for the agent instance
|
402
|
+
models: List of model identifiers
|
403
|
+
tools: List of tools or tool names
|
404
|
+
**kwargs: Additional arguments for the agent constructor
|
405
|
+
|
406
|
+
Returns:
|
407
|
+
Agent instance
|
408
|
+
"""
|
409
|
+
# Get agent class
|
410
|
+
agent_class = AgentRegistry.get_agent_class(agent_type)
|
411
|
+
|
412
|
+
# Prepare name
|
413
|
+
instance_name = name or f"{agent_type}_{id(agent_class)}"
|
414
|
+
|
415
|
+
# Prepare tools
|
416
|
+
tool_instances = []
|
417
|
+
if tools:
|
418
|
+
for tool in tools:
|
419
|
+
if isinstance(tool, str):
|
420
|
+
# Assume it's a tool name
|
421
|
+
# Try stateless first, then stateful
|
422
|
+
try:
|
423
|
+
tool_instances.append(ToolFactory.get_tool(tool))
|
424
|
+
except ValueError:
|
425
|
+
try:
|
426
|
+
tool_instances.append(
|
427
|
+
ToolFactory.get_tool(tool, "stateful")
|
428
|
+
)
|
429
|
+
except ValueError:
|
430
|
+
raise ValueError(f"Tool '{tool}' not found in registry")
|
431
|
+
else:
|
432
|
+
# Assume it's a tool instance
|
433
|
+
tool_instances.append(tool)
|
434
|
+
|
435
|
+
# Create agent instance
|
436
|
+
return agent_class(
|
437
|
+
name=instance_name,
|
438
|
+
models=models,
|
439
|
+
tools=tool_instances,
|
440
|
+
**kwargs
|
441
|
+
)
|
442
|
+
|
443
|
+
@staticmethod
|
444
|
+
def list_available_agents() -> List[str]:
|
445
|
+
"""
|
446
|
+
List all available agent types.
|
447
|
+
|
448
|
+
Returns:
|
449
|
+
List of agent type names
|
450
|
+
"""
|
451
|
+
return AgentRegistry.list_agents()
|
452
|
+
|
453
|
+
|
454
|
+
# Convenience decorator for registering agents
|
455
|
+
def register_agent(name: Optional[str] = None):
|
456
|
+
"""
|
457
|
+
Decorator to register an agent class.
|
458
|
+
|
459
|
+
Args:
|
460
|
+
name: Optional name for the agent class
|
461
|
+
|
462
|
+
Returns:
|
463
|
+
Decorator function
|
464
|
+
"""
|
465
|
+
return AgentRegistry.register(name)
|
@@ -0,0 +1,122 @@
|
|
1
|
+
from typing import Dict, List, Optional
|
2
|
+
from pydantic import BaseModel, Field
|
3
|
+
from airtrain.integrations.fireworks.skills import FireworksChatSkill, FireworksInput
|
4
|
+
from airtrain.core.skills import ProcessingError
|
5
|
+
import json
|
6
|
+
|
7
|
+
|
8
|
+
class AgentSpecification(BaseModel):
|
9
|
+
"""Model to capture agent specifications"""
|
10
|
+
|
11
|
+
name: str = Field(..., description="Name of the agent")
|
12
|
+
purpose: str = Field(..., description="Primary purpose of the agent")
|
13
|
+
input_type: str = Field(..., description="Type of input the agent accepts")
|
14
|
+
output_type: str = Field(..., description="Type of output the agent produces")
|
15
|
+
required_skills: List[str] = Field(
|
16
|
+
default_factory=list, description="Skills required by the agent"
|
17
|
+
)
|
18
|
+
conversation_style: str = Field(
|
19
|
+
..., description="Style of conversation (formal, casual, technical, etc.)"
|
20
|
+
)
|
21
|
+
safety_constraints: List[str] = Field(
|
22
|
+
default_factory=list, description="Safety constraints for the agent"
|
23
|
+
)
|
24
|
+
reasoning: Optional[str] = Field(
|
25
|
+
None, description="Reasoning behind agent design decisions"
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
class AgentBuilder:
|
30
|
+
"""AI-powered agent builder"""
|
31
|
+
|
32
|
+
def __init__(self):
|
33
|
+
self.skill = FireworksChatSkill()
|
34
|
+
self.system_prompt = """You are an expert AI Agent architect. Your role is to help users build AI agents by:
|
35
|
+
1. Understanding their requirements through targeted questions
|
36
|
+
2. Designing appropriate agent architectures
|
37
|
+
3. Selecting optimal skills and models
|
38
|
+
4. Ensuring safety and ethical constraints
|
39
|
+
5. Providing clear reasoning for all decisions
|
40
|
+
|
41
|
+
Ask one question at a time. Wait for user response before proceeding.
|
42
|
+
Start by asking about the primary purpose of the agent they want to build.
|
43
|
+
|
44
|
+
Your responses must be in this format:
|
45
|
+
QUESTION: [Your question here]
|
46
|
+
CONTEXT: [Brief context about why this question is important]
|
47
|
+
|
48
|
+
When creating the final specification, output valid JSON matching the AgentSpecification schema."""
|
49
|
+
|
50
|
+
def _get_next_question(self, conversation_history: List[Dict[str, str]]) -> str:
|
51
|
+
input_data = FireworksInput(
|
52
|
+
user_input="Based on the conversation so far, what's the next question to ask?",
|
53
|
+
system_prompt=self.system_prompt,
|
54
|
+
model="accounts/fireworks/models/deepseek-r1",
|
55
|
+
temperature=0.7,
|
56
|
+
conversation_history=conversation_history,
|
57
|
+
)
|
58
|
+
|
59
|
+
try:
|
60
|
+
result = self.skill.process(input_data)
|
61
|
+
return result.response
|
62
|
+
except Exception as e:
|
63
|
+
raise ProcessingError(f"Failed to generate next question: {str(e)}")
|
64
|
+
|
65
|
+
def _create_specification(
|
66
|
+
self, conversation_history: List[Dict[str, str]]
|
67
|
+
) -> AgentSpecification:
|
68
|
+
input_data = FireworksInput(
|
69
|
+
user_input="Based on our conversation, create a complete agent specification in valid JSON format.",
|
70
|
+
system_prompt=self.system_prompt,
|
71
|
+
model="accounts/fireworks/models/deepseek-r1",
|
72
|
+
temperature=0.7,
|
73
|
+
conversation_history=conversation_history,
|
74
|
+
)
|
75
|
+
|
76
|
+
result = self.skill.process(input_data)
|
77
|
+
|
78
|
+
try:
|
79
|
+
# Extract JSON from the response (it might be wrapped in markdown or other text)
|
80
|
+
json_str = result.response
|
81
|
+
if "```json" in json_str:
|
82
|
+
json_str = json_str.split("```json")[1].split("```")[0].strip()
|
83
|
+
elif "```" in json_str:
|
84
|
+
json_str = json_str.split("```")[1].split("```")[0].strip()
|
85
|
+
|
86
|
+
return AgentSpecification.model_validate_json(json_str)
|
87
|
+
except Exception as e:
|
88
|
+
raise ProcessingError(f"Failed to parse agent specification: {str(e)}")
|
89
|
+
|
90
|
+
def build_agent(self) -> AgentSpecification:
|
91
|
+
conversation_history = []
|
92
|
+
|
93
|
+
print("\nWelcome to the AI Agent Builder!")
|
94
|
+
print("I'll help you create a custom AI agent through a series of questions.\n")
|
95
|
+
|
96
|
+
while True:
|
97
|
+
next_question = self._get_next_question(conversation_history)
|
98
|
+
print(f"\n{next_question}")
|
99
|
+
|
100
|
+
user_input = input("\nYour response (type 'done' when finished): ").strip()
|
101
|
+
|
102
|
+
if user_input.lower() == "done":
|
103
|
+
if len(conversation_history) < 6: # Minimum questions needed
|
104
|
+
print(
|
105
|
+
"\nPlease answer a few more questions to create a complete specification."
|
106
|
+
)
|
107
|
+
continue
|
108
|
+
try:
|
109
|
+
return self._create_specification(conversation_history)
|
110
|
+
except ProcessingError as e:
|
111
|
+
print(f"\nError creating specification: {str(e)}")
|
112
|
+
print(
|
113
|
+
"Let's continue with a few more questions to gather complete information."
|
114
|
+
)
|
115
|
+
continue
|
116
|
+
|
117
|
+
conversation_history.extend(
|
118
|
+
[
|
119
|
+
{"role": "assistant", "content": next_question},
|
120
|
+
{"role": "user", "content": user_input},
|
121
|
+
]
|
122
|
+
)
|
airtrain/cli/__init__.py
ADDED
File without changes
|
airtrain/cli/builder.py
ADDED
@@ -0,0 +1,23 @@
|
|
1
|
+
import click
|
2
|
+
from airtrain.builder.agent_builder import AgentBuilder
|
3
|
+
import json
|
4
|
+
|
5
|
+
|
6
|
+
@click.command()
|
7
|
+
def build():
|
8
|
+
"""Build a custom AI agent through an interactive process"""
|
9
|
+
try:
|
10
|
+
builder = AgentBuilder()
|
11
|
+
specification = builder.build_agent()
|
12
|
+
|
13
|
+
# Display the final specification
|
14
|
+
click.echo("\n=== Agent Specification ===")
|
15
|
+
click.echo(json.dumps(specification.model_dump(), indent=2))
|
16
|
+
|
17
|
+
click.echo(
|
18
|
+
"\nAgent specification complete! You can now use this specification to initialize your agent."
|
19
|
+
)
|
20
|
+
|
21
|
+
except Exception as e:
|
22
|
+
click.echo(f"\nError building agent: {str(e)}")
|
23
|
+
return 1
|