airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. airtrain/__init__.py +148 -2
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__init__.py +7 -0
  19. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  21. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  22. airtrain/core/credentials.py +171 -0
  23. airtrain/core/schemas.py +237 -0
  24. airtrain/core/skills.py +269 -0
  25. airtrain/integrations/__init__.py +74 -0
  26. airtrain/integrations/anthropic/__init__.py +33 -0
  27. airtrain/integrations/anthropic/credentials.py +32 -0
  28. airtrain/integrations/anthropic/list_models.py +110 -0
  29. airtrain/integrations/anthropic/models_config.py +100 -0
  30. airtrain/integrations/anthropic/skills.py +155 -0
  31. airtrain/integrations/aws/__init__.py +6 -0
  32. airtrain/integrations/aws/credentials.py +36 -0
  33. airtrain/integrations/aws/skills.py +98 -0
  34. airtrain/integrations/cerebras/__init__.py +6 -0
  35. airtrain/integrations/cerebras/credentials.py +19 -0
  36. airtrain/integrations/cerebras/skills.py +127 -0
  37. airtrain/integrations/combined/__init__.py +21 -0
  38. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  39. airtrain/integrations/combined/list_models_factory.py +210 -0
  40. airtrain/integrations/fireworks/__init__.py +21 -0
  41. airtrain/integrations/fireworks/completion_skills.py +147 -0
  42. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  43. airtrain/integrations/fireworks/credentials.py +26 -0
  44. airtrain/integrations/fireworks/list_models.py +128 -0
  45. airtrain/integrations/fireworks/models.py +139 -0
  46. airtrain/integrations/fireworks/requests_skills.py +207 -0
  47. airtrain/integrations/fireworks/skills.py +181 -0
  48. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  49. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  50. airtrain/integrations/fireworks/structured_skills.py +102 -0
  51. airtrain/integrations/google/__init__.py +7 -0
  52. airtrain/integrations/google/credentials.py +58 -0
  53. airtrain/integrations/google/skills.py +122 -0
  54. airtrain/integrations/groq/__init__.py +23 -0
  55. airtrain/integrations/groq/credentials.py +24 -0
  56. airtrain/integrations/groq/models_config.py +162 -0
  57. airtrain/integrations/groq/skills.py +201 -0
  58. airtrain/integrations/ollama/__init__.py +6 -0
  59. airtrain/integrations/ollama/credentials.py +26 -0
  60. airtrain/integrations/ollama/skills.py +41 -0
  61. airtrain/integrations/openai/__init__.py +37 -0
  62. airtrain/integrations/openai/chinese_assistant.py +42 -0
  63. airtrain/integrations/openai/credentials.py +39 -0
  64. airtrain/integrations/openai/list_models.py +112 -0
  65. airtrain/integrations/openai/models_config.py +224 -0
  66. airtrain/integrations/openai/skills.py +342 -0
  67. airtrain/integrations/perplexity/__init__.py +49 -0
  68. airtrain/integrations/perplexity/credentials.py +43 -0
  69. airtrain/integrations/perplexity/list_models.py +112 -0
  70. airtrain/integrations/perplexity/models_config.py +128 -0
  71. airtrain/integrations/perplexity/skills.py +279 -0
  72. airtrain/integrations/sambanova/__init__.py +6 -0
  73. airtrain/integrations/sambanova/credentials.py +20 -0
  74. airtrain/integrations/sambanova/skills.py +129 -0
  75. airtrain/integrations/search/__init__.py +21 -0
  76. airtrain/integrations/search/exa/__init__.py +23 -0
  77. airtrain/integrations/search/exa/credentials.py +30 -0
  78. airtrain/integrations/search/exa/schemas.py +114 -0
  79. airtrain/integrations/search/exa/skills.py +115 -0
  80. airtrain/integrations/together/__init__.py +33 -0
  81. airtrain/integrations/together/audio_models_config.py +34 -0
  82. airtrain/integrations/together/credentials.py +22 -0
  83. airtrain/integrations/together/embedding_models_config.py +92 -0
  84. airtrain/integrations/together/image_models_config.py +69 -0
  85. airtrain/integrations/together/image_skill.py +143 -0
  86. airtrain/integrations/together/list_models.py +76 -0
  87. airtrain/integrations/together/models.py +95 -0
  88. airtrain/integrations/together/models_config.py +399 -0
  89. airtrain/integrations/together/rerank_models_config.py +43 -0
  90. airtrain/integrations/together/rerank_skill.py +49 -0
  91. airtrain/integrations/together/schemas.py +33 -0
  92. airtrain/integrations/together/skills.py +305 -0
  93. airtrain/integrations/together/vision_models_config.py +49 -0
  94. airtrain/telemetry/__init__.py +38 -0
  95. airtrain/telemetry/service.py +167 -0
  96. airtrain/telemetry/views.py +237 -0
  97. airtrain/tools/__init__.py +45 -0
  98. airtrain/tools/command.py +398 -0
  99. airtrain/tools/filesystem.py +166 -0
  100. airtrain/tools/network.py +111 -0
  101. airtrain/tools/registry.py +320 -0
  102. airtrain/tools/search.py +450 -0
  103. airtrain/tools/testing.py +135 -0
  104. airtrain-0.1.4.dist-info/METADATA +222 -0
  105. airtrain-0.1.4.dist-info/RECORD +108 -0
  106. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  107. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  108. airtrain-0.1.2.dist-info/METADATA +0 -106
  109. airtrain-0.1.2.dist-info/RECORD +0 -5
  110. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,465 @@
1
+ """
2
+ Agent Registry System for AirTrain
3
+
4
+ This module provides a registry system for agents that can be used to build AI systems.
5
+ It includes:
6
+ - Base agent class
7
+ - Registration decorators
8
+ - Factory methods for agent creation
9
+ - Discovery utilities for finding available agents
10
+ """
11
+
12
+ from abc import ABC, abstractmethod
13
+ import time
14
+ import uuid
15
+ from typing import List, Optional, Type, TypeVar, Union
16
+ import inspect
17
+
18
+ # Import tool registry components
19
+ from airtrain.tools import ToolFactory, BaseTool
20
+ from airtrain.telemetry import (
21
+ telemetry,
22
+ AgentRunTelemetryEvent,
23
+ AgentStepTelemetryEvent,
24
+ AgentEndTelemetryEvent,
25
+ ModelInvocationTelemetryEvent,
26
+ ErrorTelemetryEvent
27
+ )
28
+ from .memory import AgentMemoryManager
29
+
30
+
31
+ # Type variable for agent classes
32
+ A = TypeVar('A', bound='BaseAgent')
33
+
34
+ # Registry structure for agent classes
35
+ AGENT_REGISTRY = {}
36
+
37
+
38
+ class BaseAgent(ABC):
39
+ """Base class for all agents."""
40
+
41
+ # These will be set by the registration decorator
42
+ agent_name: str = None
43
+
44
+ def __init__(
45
+ self,
46
+ name: str,
47
+ models: Optional[List[str]] = None,
48
+ tools: Optional[List[BaseTool]] = None
49
+ ):
50
+ """
51
+ Initialize an agent.
52
+
53
+ Args:
54
+ name: Name of the agent instance
55
+ models: List of model identifiers to use
56
+ tools: List of tools the agent can use
57
+ """
58
+ self.name = name
59
+ self.models = models or []
60
+ self.tools = tools or []
61
+ self.memory = AgentMemoryManager()
62
+
63
+ # Generate unique agent ID for telemetry
64
+ self.agent_id = f"{name}-{uuid.uuid4()}"
65
+ self.start_time = None
66
+ self.step_count = 0
67
+ self.total_tokens = 0
68
+ self.prompt_tokens = 0
69
+ self.completion_tokens = 0
70
+ self.errors = []
71
+
72
+ # Initialize default short-term memory
73
+ self.memory.create_short_term_memory("default")
74
+
75
+ def add_tool(self, tool: BaseTool):
76
+ """
77
+ Add a tool to the agent.
78
+
79
+ Args:
80
+ tool: Tool instance to add
81
+
82
+ Returns:
83
+ Self for method chaining
84
+ """
85
+ self.tools.append(tool)
86
+ return self
87
+
88
+ def register_tools(self, tools: List[BaseTool]):
89
+ """
90
+ Register multiple tools with the agent.
91
+
92
+ Args:
93
+ tools: List of tool instances to add
94
+
95
+ Returns:
96
+ Self for method chaining
97
+ """
98
+ self.tools.extend(tools)
99
+ return self
100
+
101
+ def create_memory(self, name: str, max_messages: int = 10):
102
+ """
103
+ Create a new short-term memory.
104
+
105
+ Args:
106
+ name: Name for the memory
107
+ max_messages: Maximum messages before summarization
108
+
109
+ Returns:
110
+ The created memory instance
111
+ """
112
+ return self.memory.create_short_term_memory(name, max_messages)
113
+
114
+ def reset_memory(self, name: str = "default"):
115
+ """
116
+ Reset a specific short-term memory.
117
+
118
+ Args:
119
+ name: Name of the memory to reset
120
+
121
+ Returns:
122
+ Self for method chaining
123
+ """
124
+ self.memory.reset_short_term_memory(name)
125
+ return self
126
+
127
+ def start_run(self, task: str, model_name: str, model_provider: str):
128
+ """
129
+ Start an agent run and send telemetry.
130
+
131
+ Args:
132
+ task: Description of the task
133
+ model_name: Name of the model being used
134
+ model_provider: Provider of the model
135
+
136
+ Returns:
137
+ Self for method chaining
138
+ """
139
+ self.start_time = time.time()
140
+ self.step_count = 0
141
+ self.total_tokens = 0
142
+ self.prompt_tokens = 0
143
+ self.completion_tokens = 0
144
+ self.errors = []
145
+
146
+ # Send run event
147
+ event = AgentRunTelemetryEvent(
148
+ agent_id=self.agent_id,
149
+ task=task,
150
+ model_name=model_name,
151
+ model_provider=model_provider,
152
+ version=self._get_package_version(),
153
+ source=self.__class__.__name__
154
+ )
155
+ telemetry.capture(event)
156
+ return self
157
+
158
+ def record_step(self, actions: List[dict], step_error: List[str] = None):
159
+ """
160
+ Record an agent step and send telemetry.
161
+
162
+ Args:
163
+ actions: List of actions taken in this step
164
+ step_error: Optional list of errors encountered
165
+
166
+ Returns:
167
+ Self for method chaining
168
+ """
169
+ self.step_count += 1
170
+ consecutive_failures = len(self.errors)
171
+
172
+ # Send step event
173
+ event = AgentStepTelemetryEvent(
174
+ agent_id=self.agent_id,
175
+ step=self.step_count,
176
+ step_error=step_error or [],
177
+ consecutive_failures=consecutive_failures,
178
+ actions=actions
179
+ )
180
+ telemetry.capture(event)
181
+ return self
182
+
183
+ def record_model_usage(
184
+ self,
185
+ model_name: str,
186
+ model_provider: str,
187
+ tokens: int,
188
+ prompt_tokens: int,
189
+ completion_tokens: int,
190
+ duration_seconds: float,
191
+ error: str = None
192
+ ):
193
+ """
194
+ Record model usage and send telemetry.
195
+
196
+ Args:
197
+ model_name: Name of the model used
198
+ model_provider: Provider of the model
199
+ tokens: Total tokens used
200
+ prompt_tokens: Tokens used in the prompt
201
+ completion_tokens: Tokens used in the completion
202
+ duration_seconds: Duration of the model call
203
+ error: Optional error message
204
+
205
+ Returns:
206
+ Self for method chaining
207
+ """
208
+ self.total_tokens += tokens
209
+ self.prompt_tokens += prompt_tokens
210
+ self.completion_tokens += completion_tokens
211
+
212
+ # Send model usage event
213
+ event = ModelInvocationTelemetryEvent(
214
+ agent_id=self.agent_id,
215
+ model_name=model_name,
216
+ model_provider=model_provider,
217
+ tokens=tokens,
218
+ prompt_tokens=prompt_tokens,
219
+ completion_tokens=completion_tokens,
220
+ duration_seconds=duration_seconds,
221
+ error=error
222
+ )
223
+ telemetry.capture(event)
224
+ return self
225
+
226
+ def end_run(self, is_done: bool, success: bool = None):
227
+ """
228
+ End an agent run and send telemetry.
229
+
230
+ Args:
231
+ is_done: Whether the agent completed its task
232
+ success: Whether the agent was successful
233
+
234
+ Returns:
235
+ Self for method chaining
236
+ """
237
+ duration = time.time() - self.start_time if self.start_time else 0
238
+
239
+ # Send end event
240
+ event = AgentEndTelemetryEvent(
241
+ agent_id=self.agent_id,
242
+ steps=self.step_count,
243
+ is_done=is_done,
244
+ success=success,
245
+ total_tokens=self.total_tokens,
246
+ prompt_tokens=self.prompt_tokens,
247
+ completion_tokens=self.completion_tokens,
248
+ total_duration_seconds=duration,
249
+ errors=self.errors
250
+ )
251
+ telemetry.capture(event)
252
+ return self
253
+
254
+ def record_error(self, error_type: str, error_message: str, component: str):
255
+ """
256
+ Record an error and send telemetry.
257
+
258
+ Args:
259
+ error_type: Type of the error
260
+ error_message: Error message
261
+ component: Component where the error occurred
262
+
263
+ Returns:
264
+ Self for method chaining
265
+ """
266
+ self.errors.append(error_message)
267
+
268
+ # Send error event
269
+ event = ErrorTelemetryEvent(
270
+ error_type=error_type,
271
+ error_message=error_message,
272
+ component=component,
273
+ agent_id=self.agent_id
274
+ )
275
+ telemetry.capture(event)
276
+ return self
277
+
278
+ def _get_package_version(self) -> str:
279
+ """Get the package version for telemetry."""
280
+ try:
281
+ from airtrain import __version__
282
+ return __version__
283
+ except ImportError:
284
+ return "unknown"
285
+
286
+ @abstractmethod
287
+ def process(self, user_input: str, memory_name: str = "default") -> str:
288
+ """
289
+ Process user input using a specific memory context.
290
+
291
+ Args:
292
+ user_input: User input to process
293
+ memory_name: Name of the memory to use for context
294
+
295
+ Returns:
296
+ Agent's response
297
+ """
298
+ pass
299
+
300
+
301
+ class AgentRegistry:
302
+ """Registry for agent classes."""
303
+
304
+ @classmethod
305
+ def register(cls, name: Optional[str] = None):
306
+ """
307
+ Decorator to register an agent class.
308
+
309
+ Args:
310
+ name: Optional name for the agent class
311
+
312
+ Returns:
313
+ Decorator function
314
+ """
315
+ def decorator(agent_class: Type[BaseAgent]) -> Type[BaseAgent]:
316
+ """
317
+ Register an agent class with the registry.
318
+
319
+ Args:
320
+ agent_class: Agent class to register
321
+
322
+ Returns:
323
+ The registered agent class
324
+ """
325
+ # Validate agent class
326
+ if not issubclass(agent_class, BaseAgent):
327
+ raise TypeError(
328
+ f"Agent class {agent_class.__name__} must inherit from BaseAgent"
329
+ )
330
+
331
+ # Check for required methods
332
+ has_process = hasattr(agent_class, 'process')
333
+ is_callable = inspect.isfunction(getattr(agent_class, 'process', None))
334
+ if not has_process or not is_callable:
335
+ raise TypeError(
336
+ f"Agent class {agent_class.__name__} must implement process method"
337
+ )
338
+
339
+ # Determine agent name
340
+ agent_name = name or agent_class.__name__
341
+
342
+ # Check for name conflict
343
+ if agent_name in AGENT_REGISTRY:
344
+ raise ValueError(f"Agent '{agent_name}' already registered")
345
+
346
+ # Register the agent class
347
+ AGENT_REGISTRY[agent_name] = agent_class
348
+
349
+ # Add metadata to the class
350
+ agent_class.agent_name = agent_name
351
+
352
+ return agent_class
353
+
354
+ return decorator
355
+
356
+ @classmethod
357
+ def get_agent_class(cls, name: str) -> Type[BaseAgent]:
358
+ """
359
+ Get agent class by name.
360
+
361
+ Args:
362
+ name: Name of the agent class
363
+
364
+ Returns:
365
+ The agent class
366
+
367
+ Raises:
368
+ ValueError: If agent not found
369
+ """
370
+ if name not in AGENT_REGISTRY:
371
+ raise ValueError(f"Agent '{name}' not found in registry")
372
+ return AGENT_REGISTRY[name]
373
+
374
+ @classmethod
375
+ def list_agents(cls) -> List[str]:
376
+ """
377
+ List all registered agents.
378
+
379
+ Returns:
380
+ List of agent names
381
+ """
382
+ return list(AGENT_REGISTRY.keys())
383
+
384
+
385
+ class AgentFactory:
386
+ """Factory for creating agent instances."""
387
+
388
+ @staticmethod
389
+ def create_agent(
390
+ agent_type: str,
391
+ name: Optional[str] = None,
392
+ models: Optional[List[str]] = None,
393
+ tools: Optional[List[Union[str, BaseTool]]] = None,
394
+ **kwargs
395
+ ) -> BaseAgent:
396
+ """
397
+ Create an agent instance.
398
+
399
+ Args:
400
+ agent_type: Type of agent to create
401
+ name: Name for the agent instance
402
+ models: List of model identifiers
403
+ tools: List of tools or tool names
404
+ **kwargs: Additional arguments for the agent constructor
405
+
406
+ Returns:
407
+ Agent instance
408
+ """
409
+ # Get agent class
410
+ agent_class = AgentRegistry.get_agent_class(agent_type)
411
+
412
+ # Prepare name
413
+ instance_name = name or f"{agent_type}_{id(agent_class)}"
414
+
415
+ # Prepare tools
416
+ tool_instances = []
417
+ if tools:
418
+ for tool in tools:
419
+ if isinstance(tool, str):
420
+ # Assume it's a tool name
421
+ # Try stateless first, then stateful
422
+ try:
423
+ tool_instances.append(ToolFactory.get_tool(tool))
424
+ except ValueError:
425
+ try:
426
+ tool_instances.append(
427
+ ToolFactory.get_tool(tool, "stateful")
428
+ )
429
+ except ValueError:
430
+ raise ValueError(f"Tool '{tool}' not found in registry")
431
+ else:
432
+ # Assume it's a tool instance
433
+ tool_instances.append(tool)
434
+
435
+ # Create agent instance
436
+ return agent_class(
437
+ name=instance_name,
438
+ models=models,
439
+ tools=tool_instances,
440
+ **kwargs
441
+ )
442
+
443
+ @staticmethod
444
+ def list_available_agents() -> List[str]:
445
+ """
446
+ List all available agent types.
447
+
448
+ Returns:
449
+ List of agent type names
450
+ """
451
+ return AgentRegistry.list_agents()
452
+
453
+
454
+ # Convenience decorator for registering agents
455
+ def register_agent(name: Optional[str] = None):
456
+ """
457
+ Decorator to register an agent class.
458
+
459
+ Args:
460
+ name: Optional name for the agent class
461
+
462
+ Returns:
463
+ Decorator function
464
+ """
465
+ return AgentRegistry.register(name)
@@ -0,0 +1,3 @@
1
+ from .agent_builder import AgentBuilder, AgentSpecification
2
+
3
+ __all__ = ["AgentBuilder", "AgentSpecification"]
@@ -0,0 +1,122 @@
1
+ from typing import Dict, List, Optional
2
+ from pydantic import BaseModel, Field
3
+ from airtrain.integrations.fireworks.skills import FireworksChatSkill, FireworksInput
4
+ from airtrain.core.skills import ProcessingError
5
+ import json
6
+
7
+
8
+ class AgentSpecification(BaseModel):
9
+ """Model to capture agent specifications"""
10
+
11
+ name: str = Field(..., description="Name of the agent")
12
+ purpose: str = Field(..., description="Primary purpose of the agent")
13
+ input_type: str = Field(..., description="Type of input the agent accepts")
14
+ output_type: str = Field(..., description="Type of output the agent produces")
15
+ required_skills: List[str] = Field(
16
+ default_factory=list, description="Skills required by the agent"
17
+ )
18
+ conversation_style: str = Field(
19
+ ..., description="Style of conversation (formal, casual, technical, etc.)"
20
+ )
21
+ safety_constraints: List[str] = Field(
22
+ default_factory=list, description="Safety constraints for the agent"
23
+ )
24
+ reasoning: Optional[str] = Field(
25
+ None, description="Reasoning behind agent design decisions"
26
+ )
27
+
28
+
29
+ class AgentBuilder:
30
+ """AI-powered agent builder"""
31
+
32
+ def __init__(self):
33
+ self.skill = FireworksChatSkill()
34
+ self.system_prompt = """You are an expert AI Agent architect. Your role is to help users build AI agents by:
35
+ 1. Understanding their requirements through targeted questions
36
+ 2. Designing appropriate agent architectures
37
+ 3. Selecting optimal skills and models
38
+ 4. Ensuring safety and ethical constraints
39
+ 5. Providing clear reasoning for all decisions
40
+
41
+ Ask one question at a time. Wait for user response before proceeding.
42
+ Start by asking about the primary purpose of the agent they want to build.
43
+
44
+ Your responses must be in this format:
45
+ QUESTION: [Your question here]
46
+ CONTEXT: [Brief context about why this question is important]
47
+
48
+ When creating the final specification, output valid JSON matching the AgentSpecification schema."""
49
+
50
+ def _get_next_question(self, conversation_history: List[Dict[str, str]]) -> str:
51
+ input_data = FireworksInput(
52
+ user_input="Based on the conversation so far, what's the next question to ask?",
53
+ system_prompt=self.system_prompt,
54
+ model="accounts/fireworks/models/deepseek-r1",
55
+ temperature=0.7,
56
+ conversation_history=conversation_history,
57
+ )
58
+
59
+ try:
60
+ result = self.skill.process(input_data)
61
+ return result.response
62
+ except Exception as e:
63
+ raise ProcessingError(f"Failed to generate next question: {str(e)}")
64
+
65
+ def _create_specification(
66
+ self, conversation_history: List[Dict[str, str]]
67
+ ) -> AgentSpecification:
68
+ input_data = FireworksInput(
69
+ user_input="Based on our conversation, create a complete agent specification in valid JSON format.",
70
+ system_prompt=self.system_prompt,
71
+ model="accounts/fireworks/models/deepseek-r1",
72
+ temperature=0.7,
73
+ conversation_history=conversation_history,
74
+ )
75
+
76
+ result = self.skill.process(input_data)
77
+
78
+ try:
79
+ # Extract JSON from the response (it might be wrapped in markdown or other text)
80
+ json_str = result.response
81
+ if "```json" in json_str:
82
+ json_str = json_str.split("```json")[1].split("```")[0].strip()
83
+ elif "```" in json_str:
84
+ json_str = json_str.split("```")[1].split("```")[0].strip()
85
+
86
+ return AgentSpecification.model_validate_json(json_str)
87
+ except Exception as e:
88
+ raise ProcessingError(f"Failed to parse agent specification: {str(e)}")
89
+
90
+ def build_agent(self) -> AgentSpecification:
91
+ conversation_history = []
92
+
93
+ print("\nWelcome to the AI Agent Builder!")
94
+ print("I'll help you create a custom AI agent through a series of questions.\n")
95
+
96
+ while True:
97
+ next_question = self._get_next_question(conversation_history)
98
+ print(f"\n{next_question}")
99
+
100
+ user_input = input("\nYour response (type 'done' when finished): ").strip()
101
+
102
+ if user_input.lower() == "done":
103
+ if len(conversation_history) < 6: # Minimum questions needed
104
+ print(
105
+ "\nPlease answer a few more questions to create a complete specification."
106
+ )
107
+ continue
108
+ try:
109
+ return self._create_specification(conversation_history)
110
+ except ProcessingError as e:
111
+ print(f"\nError creating specification: {str(e)}")
112
+ print(
113
+ "Let's continue with a few more questions to gather complete information."
114
+ )
115
+ continue
116
+
117
+ conversation_history.extend(
118
+ [
119
+ {"role": "assistant", "content": next_question},
120
+ {"role": "user", "content": user_input},
121
+ ]
122
+ )
File without changes
@@ -0,0 +1,23 @@
1
+ import click
2
+ from airtrain.builder.agent_builder import AgentBuilder
3
+ import json
4
+
5
+
6
+ @click.command()
7
+ def build():
8
+ """Build a custom AI agent through an interactive process"""
9
+ try:
10
+ builder = AgentBuilder()
11
+ specification = builder.build_agent()
12
+
13
+ # Display the final specification
14
+ click.echo("\n=== Agent Specification ===")
15
+ click.echo(json.dumps(specification.model_dump(), indent=2))
16
+
17
+ click.echo(
18
+ "\nAgent specification complete! You can now use this specification to initialize your agent."
19
+ )
20
+
21
+ except Exception as e:
22
+ click.echo(f"\nError building agent: {str(e)}")
23
+ return 1