airtrain 0.1.37__py3-none-any.whl → 0.1.39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +1 -1
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/main.py +9 -0
- {airtrain-0.1.37.dist-info → airtrain-0.1.39.dist-info}/METADATA +1 -1
- {airtrain-0.1.37.dist-info → airtrain-0.1.39.dist-info}/RECORD +9 -7
- {airtrain-0.1.37.dist-info → airtrain-0.1.39.dist-info}/WHEEL +0 -0
- {airtrain-0.1.37.dist-info → airtrain-0.1.39.dist-info}/entry_points.txt +0 -0
- {airtrain-0.1.37.dist-info → airtrain-0.1.39.dist-info}/top_level.txt +0 -0
airtrain/__init__.py
CHANGED
@@ -0,0 +1,122 @@
|
|
1
|
+
from typing import Dict, List, Optional
|
2
|
+
from pydantic import BaseModel, Field
|
3
|
+
from airtrain.integrations.fireworks.skills import FireworksChatSkill, FireworksInput
|
4
|
+
from airtrain.core.skills import ProcessingError
|
5
|
+
import json
|
6
|
+
|
7
|
+
|
8
|
+
class AgentSpecification(BaseModel):
|
9
|
+
"""Model to capture agent specifications"""
|
10
|
+
|
11
|
+
name: str = Field(..., description="Name of the agent")
|
12
|
+
purpose: str = Field(..., description="Primary purpose of the agent")
|
13
|
+
input_type: str = Field(..., description="Type of input the agent accepts")
|
14
|
+
output_type: str = Field(..., description="Type of output the agent produces")
|
15
|
+
required_skills: List[str] = Field(
|
16
|
+
default_factory=list, description="Skills required by the agent"
|
17
|
+
)
|
18
|
+
conversation_style: str = Field(
|
19
|
+
..., description="Style of conversation (formal, casual, technical, etc.)"
|
20
|
+
)
|
21
|
+
safety_constraints: List[str] = Field(
|
22
|
+
default_factory=list, description="Safety constraints for the agent"
|
23
|
+
)
|
24
|
+
reasoning: Optional[str] = Field(
|
25
|
+
None, description="Reasoning behind agent design decisions"
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
class AgentBuilder:
|
30
|
+
"""AI-powered agent builder"""
|
31
|
+
|
32
|
+
def __init__(self):
|
33
|
+
self.skill = FireworksChatSkill()
|
34
|
+
self.system_prompt = """You are an expert AI Agent architect. Your role is to help users build AI agents by:
|
35
|
+
1. Understanding their requirements through targeted questions
|
36
|
+
2. Designing appropriate agent architectures
|
37
|
+
3. Selecting optimal skills and models
|
38
|
+
4. Ensuring safety and ethical constraints
|
39
|
+
5. Providing clear reasoning for all decisions
|
40
|
+
|
41
|
+
Ask one question at a time. Wait for user response before proceeding.
|
42
|
+
Start by asking about the primary purpose of the agent they want to build.
|
43
|
+
|
44
|
+
Your responses must be in this format:
|
45
|
+
QUESTION: [Your question here]
|
46
|
+
CONTEXT: [Brief context about why this question is important]
|
47
|
+
|
48
|
+
When creating the final specification, output valid JSON matching the AgentSpecification schema."""
|
49
|
+
|
50
|
+
def _get_next_question(self, conversation_history: List[Dict[str, str]]) -> str:
|
51
|
+
input_data = FireworksInput(
|
52
|
+
user_input="Based on the conversation so far, what's the next question to ask?",
|
53
|
+
system_prompt=self.system_prompt,
|
54
|
+
model="accounts/fireworks/models/deepseek-r1",
|
55
|
+
temperature=0.7,
|
56
|
+
conversation_history=conversation_history,
|
57
|
+
)
|
58
|
+
|
59
|
+
try:
|
60
|
+
result = self.skill.process(input_data)
|
61
|
+
return result.response
|
62
|
+
except Exception as e:
|
63
|
+
raise ProcessingError(f"Failed to generate next question: {str(e)}")
|
64
|
+
|
65
|
+
def _create_specification(
|
66
|
+
self, conversation_history: List[Dict[str, str]]
|
67
|
+
) -> AgentSpecification:
|
68
|
+
input_data = FireworksInput(
|
69
|
+
user_input="Based on our conversation, create a complete agent specification in valid JSON format.",
|
70
|
+
system_prompt=self.system_prompt,
|
71
|
+
model="accounts/fireworks/models/deepseek-r1",
|
72
|
+
temperature=0.7,
|
73
|
+
conversation_history=conversation_history,
|
74
|
+
)
|
75
|
+
|
76
|
+
result = self.skill.process(input_data)
|
77
|
+
|
78
|
+
try:
|
79
|
+
# Extract JSON from the response (it might be wrapped in markdown or other text)
|
80
|
+
json_str = result.response
|
81
|
+
if "```json" in json_str:
|
82
|
+
json_str = json_str.split("```json")[1].split("```")[0].strip()
|
83
|
+
elif "```" in json_str:
|
84
|
+
json_str = json_str.split("```")[1].split("```")[0].strip()
|
85
|
+
|
86
|
+
return AgentSpecification.model_validate_json(json_str)
|
87
|
+
except Exception as e:
|
88
|
+
raise ProcessingError(f"Failed to parse agent specification: {str(e)}")
|
89
|
+
|
90
|
+
def build_agent(self) -> AgentSpecification:
|
91
|
+
conversation_history = []
|
92
|
+
|
93
|
+
print("\nWelcome to the AI Agent Builder!")
|
94
|
+
print("I'll help you create a custom AI agent through a series of questions.\n")
|
95
|
+
|
96
|
+
while True:
|
97
|
+
next_question = self._get_next_question(conversation_history)
|
98
|
+
print(f"\n{next_question}")
|
99
|
+
|
100
|
+
user_input = input("\nYour response (type 'done' when finished): ").strip()
|
101
|
+
|
102
|
+
if user_input.lower() == "done":
|
103
|
+
if len(conversation_history) < 6: # Minimum questions needed
|
104
|
+
print(
|
105
|
+
"\nPlease answer a few more questions to create a complete specification."
|
106
|
+
)
|
107
|
+
continue
|
108
|
+
try:
|
109
|
+
return self._create_specification(conversation_history)
|
110
|
+
except ProcessingError as e:
|
111
|
+
print(f"\nError creating specification: {str(e)}")
|
112
|
+
print(
|
113
|
+
"Let's continue with a few more questions to gather complete information."
|
114
|
+
)
|
115
|
+
continue
|
116
|
+
|
117
|
+
conversation_history.extend(
|
118
|
+
[
|
119
|
+
{"role": "assistant", "content": next_question},
|
120
|
+
{"role": "user", "content": user_input},
|
121
|
+
]
|
122
|
+
)
|
airtrain/cli/main.py
CHANGED
@@ -109,3 +109,12 @@ def chat(provider: str, temperature: float, system_prompt: str):
|
|
109
109
|
|
110
110
|
# Add to existing cli group
|
111
111
|
cli.add_command(build)
|
112
|
+
|
113
|
+
|
114
|
+
def main():
|
115
|
+
"""Main entry point for the CLI"""
|
116
|
+
cli()
|
117
|
+
|
118
|
+
|
119
|
+
if __name__ == "__main__":
|
120
|
+
main()
|
@@ -1,8 +1,10 @@
|
|
1
|
-
airtrain/__init__.py,sha256=
|
1
|
+
airtrain/__init__.py,sha256=n0L7jCq_aew4QXp1TUM8TzyXHSQwd60t1WdnXkbTgrk,2099
|
2
2
|
airtrain/__main__.py,sha256=EU8ffFmCdC1G-UcHHt0Oo3lB1PGqfC6kwzH39CnYSwU,72
|
3
|
+
airtrain/builder/__init__.py,sha256=D33sr0k_WAe6FAJkk8rUaivEzFaeVqLXkQgyFWEhfPU,110
|
4
|
+
airtrain/builder/agent_builder.py,sha256=3XnGUAcK_6lWoUDtL0TanliQZuh7u0unhNbnrz1z2-I,5018
|
3
5
|
airtrain/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
6
|
airtrain/cli/builder.py,sha256=cI0FZCfRrgXPmjt8lOnHZwCWKrOB2doaOn49kmxVxHs,669
|
5
|
-
airtrain/cli/main.py,sha256=
|
7
|
+
airtrain/cli/main.py,sha256=WGt0WXhfRl7D_UGNtCMRDWiBTBwbXcRbkEZOh9StXOo,3559
|
6
8
|
airtrain/contrib/__init__.py,sha256=pG-7mJ0pBMqp3Q86mIF9bo1PqoBOVSGlnEK1yY1U1ok,641
|
7
9
|
airtrain/contrib/travel/__init__.py,sha256=clmBodw4nkTA-DsgjVGcXfJGPaWxIpCZDtdO-8RzL0M,811
|
8
10
|
airtrain/contrib/travel/agents.py,sha256=tpQtZ0WUiXBuhvZtc2JlEam5TuR5l-Tndi14YyImDBM,8975
|
@@ -61,8 +63,8 @@ airtrain/integrations/together/rerank_skill.py,sha256=gjH24hLWCweWKPyyfKZMG3K_g9
|
|
61
63
|
airtrain/integrations/together/schemas.py,sha256=pBMrbX67oxPCr-sg4K8_Xqu1DWbaC4uLCloVSascROg,1210
|
62
64
|
airtrain/integrations/together/skills.py,sha256=8DwkexMJu1Gm6QmNDfNasYStQ31QsXBbFP99zR-YCf0,7598
|
63
65
|
airtrain/integrations/together/vision_models_config.py,sha256=m28HwYDk2Kup_J-a1FtynIa2ZVcbl37kltfoHnK8zxs,1544
|
64
|
-
airtrain-0.1.
|
65
|
-
airtrain-0.1.
|
66
|
-
airtrain-0.1.
|
67
|
-
airtrain-0.1.
|
68
|
-
airtrain-0.1.
|
66
|
+
airtrain-0.1.39.dist-info/METADATA,sha256=b0g4vq6XCWe55sSnRlt6bcsK_ynvyj7SkDYTxEvrQU8,5375
|
67
|
+
airtrain-0.1.39.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
68
|
+
airtrain-0.1.39.dist-info/entry_points.txt,sha256=rrJ36IUsyq6n1dSfTWXqVAgpQLPRWDfCqwd6_3B-G0U,52
|
69
|
+
airtrain-0.1.39.dist-info/top_level.txt,sha256=cFWW1vY6VMCb3AGVdz6jBDpZ65xxBRSqlsPyySxTkxY,9
|
70
|
+
airtrain-0.1.39.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|