airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. airtrain/__init__.py +148 -2
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__init__.py +7 -0
  19. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  21. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  22. airtrain/core/credentials.py +171 -0
  23. airtrain/core/schemas.py +237 -0
  24. airtrain/core/skills.py +269 -0
  25. airtrain/integrations/__init__.py +74 -0
  26. airtrain/integrations/anthropic/__init__.py +33 -0
  27. airtrain/integrations/anthropic/credentials.py +32 -0
  28. airtrain/integrations/anthropic/list_models.py +110 -0
  29. airtrain/integrations/anthropic/models_config.py +100 -0
  30. airtrain/integrations/anthropic/skills.py +155 -0
  31. airtrain/integrations/aws/__init__.py +6 -0
  32. airtrain/integrations/aws/credentials.py +36 -0
  33. airtrain/integrations/aws/skills.py +98 -0
  34. airtrain/integrations/cerebras/__init__.py +6 -0
  35. airtrain/integrations/cerebras/credentials.py +19 -0
  36. airtrain/integrations/cerebras/skills.py +127 -0
  37. airtrain/integrations/combined/__init__.py +21 -0
  38. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  39. airtrain/integrations/combined/list_models_factory.py +210 -0
  40. airtrain/integrations/fireworks/__init__.py +21 -0
  41. airtrain/integrations/fireworks/completion_skills.py +147 -0
  42. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  43. airtrain/integrations/fireworks/credentials.py +26 -0
  44. airtrain/integrations/fireworks/list_models.py +128 -0
  45. airtrain/integrations/fireworks/models.py +139 -0
  46. airtrain/integrations/fireworks/requests_skills.py +207 -0
  47. airtrain/integrations/fireworks/skills.py +181 -0
  48. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  49. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  50. airtrain/integrations/fireworks/structured_skills.py +102 -0
  51. airtrain/integrations/google/__init__.py +7 -0
  52. airtrain/integrations/google/credentials.py +58 -0
  53. airtrain/integrations/google/skills.py +122 -0
  54. airtrain/integrations/groq/__init__.py +23 -0
  55. airtrain/integrations/groq/credentials.py +24 -0
  56. airtrain/integrations/groq/models_config.py +162 -0
  57. airtrain/integrations/groq/skills.py +201 -0
  58. airtrain/integrations/ollama/__init__.py +6 -0
  59. airtrain/integrations/ollama/credentials.py +26 -0
  60. airtrain/integrations/ollama/skills.py +41 -0
  61. airtrain/integrations/openai/__init__.py +37 -0
  62. airtrain/integrations/openai/chinese_assistant.py +42 -0
  63. airtrain/integrations/openai/credentials.py +39 -0
  64. airtrain/integrations/openai/list_models.py +112 -0
  65. airtrain/integrations/openai/models_config.py +224 -0
  66. airtrain/integrations/openai/skills.py +342 -0
  67. airtrain/integrations/perplexity/__init__.py +49 -0
  68. airtrain/integrations/perplexity/credentials.py +43 -0
  69. airtrain/integrations/perplexity/list_models.py +112 -0
  70. airtrain/integrations/perplexity/models_config.py +128 -0
  71. airtrain/integrations/perplexity/skills.py +279 -0
  72. airtrain/integrations/sambanova/__init__.py +6 -0
  73. airtrain/integrations/sambanova/credentials.py +20 -0
  74. airtrain/integrations/sambanova/skills.py +129 -0
  75. airtrain/integrations/search/__init__.py +21 -0
  76. airtrain/integrations/search/exa/__init__.py +23 -0
  77. airtrain/integrations/search/exa/credentials.py +30 -0
  78. airtrain/integrations/search/exa/schemas.py +114 -0
  79. airtrain/integrations/search/exa/skills.py +115 -0
  80. airtrain/integrations/together/__init__.py +33 -0
  81. airtrain/integrations/together/audio_models_config.py +34 -0
  82. airtrain/integrations/together/credentials.py +22 -0
  83. airtrain/integrations/together/embedding_models_config.py +92 -0
  84. airtrain/integrations/together/image_models_config.py +69 -0
  85. airtrain/integrations/together/image_skill.py +143 -0
  86. airtrain/integrations/together/list_models.py +76 -0
  87. airtrain/integrations/together/models.py +95 -0
  88. airtrain/integrations/together/models_config.py +399 -0
  89. airtrain/integrations/together/rerank_models_config.py +43 -0
  90. airtrain/integrations/together/rerank_skill.py +49 -0
  91. airtrain/integrations/together/schemas.py +33 -0
  92. airtrain/integrations/together/skills.py +305 -0
  93. airtrain/integrations/together/vision_models_config.py +49 -0
  94. airtrain/telemetry/__init__.py +38 -0
  95. airtrain/telemetry/service.py +167 -0
  96. airtrain/telemetry/views.py +237 -0
  97. airtrain/tools/__init__.py +45 -0
  98. airtrain/tools/command.py +398 -0
  99. airtrain/tools/filesystem.py +166 -0
  100. airtrain/tools/network.py +111 -0
  101. airtrain/tools/registry.py +320 -0
  102. airtrain/tools/search.py +450 -0
  103. airtrain/tools/testing.py +135 -0
  104. airtrain-0.1.4.dist-info/METADATA +222 -0
  105. airtrain-0.1.4.dist-info/RECORD +108 -0
  106. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  107. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  108. airtrain-0.1.2.dist-info/METADATA +0 -106
  109. airtrain-0.1.2.dist-info/RECORD +0 -5
  110. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
airtrain/cli/main.py ADDED
@@ -0,0 +1,120 @@
1
+ import click
2
+ from typing import Optional
3
+ from airtrain.integrations.openai.skills import OpenAIChatSkill, OpenAIInput
4
+ from airtrain.integrations.anthropic.skills import AnthropicChatSkill, AnthropicInput
5
+ import os
6
+ from dotenv import load_dotenv
7
+ import sys
8
+ from .builder import build
9
+
10
+ # Load environment variables
11
+ load_dotenv()
12
+
13
+
14
+ def initialize_chat(provider: str = "openai"):
15
+ """Initialize chat skill based on provider"""
16
+ if provider == "openai":
17
+ return OpenAIChatSkill()
18
+ elif provider == "anthropic":
19
+ return AnthropicChatSkill()
20
+ else:
21
+ raise ValueError(f"Unsupported provider: {provider}")
22
+
23
+
24
+ @click.group()
25
+ def cli():
26
+ """Airtrain CLI - Your AI Agent Building Assistant"""
27
+ pass
28
+
29
+
30
+ @cli.command()
31
+ @click.option(
32
+ "--provider",
33
+ type=click.Choice(["openai", "anthropic"]),
34
+ default="openai",
35
+ help="The AI provider to use",
36
+ )
37
+ @click.option(
38
+ "--temperature",
39
+ type=float,
40
+ default=0.7,
41
+ help="Temperature for response generation (0.0-1.0)",
42
+ )
43
+ @click.option(
44
+ "--system-prompt",
45
+ type=str,
46
+ default="You are a helpful AI assistant that helps users build their own AI agents. Be helpful and provide clear explanations.",
47
+ help="System prompt to guide the model",
48
+ )
49
+ def chat(provider: str, temperature: float, system_prompt: str):
50
+ """Start an interactive chat session with Airtrain"""
51
+ try:
52
+ skill = initialize_chat(provider)
53
+
54
+ click.echo(f"\nWelcome to Airtrain! Using {provider.upper()} as the provider.")
55
+ click.echo("Type 'exit' to end the conversation.")
56
+ click.echo("Type 'clear' to clear the conversation history.\n")
57
+
58
+ conversation_history = []
59
+
60
+ while True:
61
+ user_input = click.prompt("You", type=str)
62
+
63
+ if user_input.lower() == "exit":
64
+ click.echo("\nGoodbye! Have a great day!")
65
+ break
66
+
67
+ if user_input.lower() == "clear":
68
+ conversation_history = []
69
+ click.echo("\nConversation history cleared!")
70
+ continue
71
+
72
+ try:
73
+ if provider == "openai":
74
+ input_data = OpenAIInput(
75
+ user_input=user_input,
76
+ system_prompt=system_prompt,
77
+ conversation_history=conversation_history,
78
+ model="gpt-4o",
79
+ temperature=temperature,
80
+ )
81
+ else:
82
+ input_data = AnthropicInput(
83
+ user_input=user_input,
84
+ system_prompt=system_prompt,
85
+ conversation_history=conversation_history,
86
+ model="claude-3-opus-20240229",
87
+ temperature=temperature,
88
+ )
89
+
90
+ result = skill.process(input_data)
91
+
92
+ # Update conversation history
93
+ conversation_history.extend(
94
+ [
95
+ {"role": "user", "content": user_input},
96
+ {"role": "assistant", "content": result.response},
97
+ ]
98
+ )
99
+
100
+ click.echo(f"\nAirtrain: {result.response}\n")
101
+
102
+ except Exception as e:
103
+ click.echo(f"\nError: {str(e)}\n")
104
+
105
+ except Exception as e:
106
+ click.echo(f"Failed to initialize chat: {str(e)}")
107
+ sys.exit(1)
108
+
109
+
110
+ # Add to existing cli group
111
+ cli.add_command(build)
112
+
113
+
114
+ def main():
115
+ """Main entry point for the CLI"""
116
+ cli()
117
+
118
+
119
+ if __name__ == "__main__":
120
+ main()
@@ -0,0 +1,29 @@
1
+ """Airtrain contrib package for community contributions"""
2
+
3
+ from .travel.agents import (
4
+ TravelAgentBase,
5
+ ClothingAgent,
6
+ HikingAgent,
7
+ InternetConnectivityAgent,
8
+ FoodRecommendationAgent,
9
+ PersonalizedRecommendationAgent,
10
+ )
11
+ from .travel.models import (
12
+ ClothingRecommendation,
13
+ HikingOption,
14
+ InternetAvailability,
15
+ FoodOption,
16
+ )
17
+
18
+ __all__ = [
19
+ "TravelAgentBase",
20
+ "ClothingAgent",
21
+ "HikingAgent",
22
+ "InternetConnectivityAgent",
23
+ "FoodRecommendationAgent",
24
+ "PersonalizedRecommendationAgent",
25
+ "ClothingRecommendation",
26
+ "HikingOption",
27
+ "InternetAvailability",
28
+ "FoodOption",
29
+ ]
@@ -0,0 +1,35 @@
1
+ """Travel related agents and models"""
2
+
3
+ from .agents import (
4
+ TravelAgentBase,
5
+ ClothingAgent,
6
+ HikingAgent,
7
+ InternetAgent,
8
+ FoodAgent,
9
+ PersonalizedAgent,
10
+ )
11
+ from .models import (
12
+ ClothingRecommendation,
13
+ HikingOption,
14
+ InternetAvailability,
15
+ FoodOption,
16
+ )
17
+ from .agentlib.verification_agent import UserVerificationAgent
18
+ from .modellib.verification import UserTravelInfo, TravelCompanion, HealthCondition
19
+
20
+ __all__ = [
21
+ "TravelAgentBase",
22
+ "ClothingAgent",
23
+ "HikingAgent",
24
+ "InternetConnectivityAgent",
25
+ "FoodRecommendationAgent",
26
+ "PersonalizedRecommendationAgent",
27
+ "ClothingRecommendation",
28
+ "HikingOption",
29
+ "InternetAvailability",
30
+ "FoodOption",
31
+ "UserVerificationAgent",
32
+ "UserTravelInfo",
33
+ "TravelCompanion",
34
+ "HealthCondition",
35
+ ]
@@ -0,0 +1,243 @@
1
+ from typing import Optional, List, Any
2
+ from airtrain.core.skills import Skill, ProcessingError
3
+ from airtrain.integrations.openai.skills import OpenAIParserSkill, OpenAIParserInput
4
+ from .models import (
5
+ ClothingRecommendation,
6
+ HikingOption,
7
+ InternetAvailability,
8
+ FoodOption,
9
+ PersonalizedRecommendation,
10
+ )
11
+
12
+
13
+ class TravelAgentBase(OpenAIParserSkill):
14
+ def __init__(self, credentials=None):
15
+ super().__init__(credentials)
16
+ self.model = "gpt-4o"
17
+ self.temperature = 0.0
18
+
19
+
20
+ class ClothingAgent(TravelAgentBase):
21
+ """Agent for clothing recommendations"""
22
+
23
+ def get_recommendations(
24
+ self, location: str, duration: int, activities: List[str], season: str
25
+ ) -> ClothingRecommendation:
26
+ prompt = f"""
27
+ Provide clothing recommendations for a {duration}-day trip to {location} during {season}.
28
+ Activities planned: {', '.join(activities)}.
29
+ Include essential items, weather-specific clothing, and cultural considerations.
30
+ """
31
+
32
+ input_data = OpenAIParserInput(
33
+ user_input=prompt,
34
+ system_prompt="You are a travel clothing expert. Provide detailed packing recommendations.",
35
+ response_model=ClothingRecommendation,
36
+ model=self.model,
37
+ temperature=self.temperature,
38
+ )
39
+
40
+ result = self.process(input_data)
41
+ return result.parsed_response
42
+
43
+
44
+ class HikingAgent(TravelAgentBase):
45
+ """Agent for hiking recommendations"""
46
+
47
+ def get_hiking_options(
48
+ self, location: str, difficulty: str, duration_preference: float
49
+ ) -> List[HikingOption]:
50
+ prompt = f"""
51
+ Find hiking trails in {location} that match:
52
+ - Difficulty level: {difficulty}
53
+ - Preferred duration: around {duration_preference} hours
54
+ Provide detailed trail information and safety tips.
55
+ """
56
+
57
+ input_data = OpenAIParserInput(
58
+ user_input=prompt,
59
+ system_prompt="You are a hiking expert. Recommend suitable trails with safety considerations.",
60
+ response_model=List[HikingOption],
61
+ model=self.model,
62
+ )
63
+
64
+ result = self.process(input_data)
65
+ return result.parsed_response
66
+
67
+
68
+ class InternetAgent(TravelAgentBase):
69
+ """Agent for internet availability information"""
70
+
71
+ def get_connectivity_info(
72
+ self,
73
+ location: str,
74
+ duration: int,
75
+ work_requirements: Optional[List[str]] = None,
76
+ ) -> InternetAvailability:
77
+ prompt = f"""
78
+ Provide detailed internet connectivity information for {location} for a {duration}-day stay.
79
+ {f'Work requirements: {", ".join(work_requirements)}' if work_requirements else ''}
80
+ Include public WiFi spots, recommended providers, and connectivity tips.
81
+ """
82
+
83
+ input_data = OpenAIParserInput(
84
+ user_input=prompt,
85
+ system_prompt="You are a connectivity expert. Provide detailed internet availability information.",
86
+ response_model=InternetAvailability,
87
+ model=self.model,
88
+ temperature=self.temperature,
89
+ )
90
+
91
+ result = self.process(input_data)
92
+ return result.parsed_response
93
+
94
+
95
+ class FoodAgent(TravelAgentBase):
96
+ """Agent for food recommendations"""
97
+
98
+ def get_food_recommendations(
99
+ self,
100
+ location: str,
101
+ dietary_restrictions: Optional[List[str]] = None,
102
+ preferences: Optional[List[str]] = None,
103
+ budget_level: str = "medium",
104
+ ) -> FoodOption:
105
+ prompt = f"""
106
+ Provide food recommendations for {location}.
107
+ {f'Dietary restrictions: {", ".join(dietary_restrictions)}' if dietary_restrictions else ''}
108
+ {f'Food preferences: {", ".join(preferences)}' if preferences else ''}
109
+ Budget level: {budget_level}
110
+ Include local specialties, restaurant recommendations, and food safety tips.
111
+ """
112
+
113
+ input_data = OpenAIParserInput(
114
+ user_input=prompt,
115
+ system_prompt="You are a culinary expert. Provide detailed food recommendations.",
116
+ response_model=FoodOption,
117
+ model=self.model,
118
+ temperature=self.temperature,
119
+ )
120
+
121
+ result = self.process(input_data)
122
+ return result.parsed_response
123
+
124
+
125
+ class PersonalizedAgent(TravelAgentBase):
126
+ """Agent for personalized travel recommendations"""
127
+
128
+ def get_personalized_recommendations(
129
+ self,
130
+ location: str,
131
+ duration: int,
132
+ interests: List[str],
133
+ budget_level: str,
134
+ travel_style: str,
135
+ previous_destinations: Optional[List[str]] = None,
136
+ ) -> PersonalizedRecommendation:
137
+ prompt = f"""
138
+ Create personalized travel recommendations for {location} for {duration} days.
139
+ Interests: {', '.join(interests)}
140
+ Travel style: {travel_style}
141
+ Budget level: {budget_level}
142
+ {f'Previous destinations: {", ".join(previous_destinations)}' if previous_destinations else ''}
143
+ Include hidden gems, local events, and a custom itinerary.
144
+ """
145
+
146
+ input_data = OpenAIParserInput(
147
+ user_input=prompt,
148
+ system_prompt="You are a personal travel consultant. Provide tailored recommendations.",
149
+ response_model=PersonalizedRecommendation,
150
+ model=self.model,
151
+ temperature=self.temperature,
152
+ )
153
+
154
+ result = self.process(input_data)
155
+ return result.parsed_response
156
+
157
+
158
+ # Similar pattern for other agents...
159
+
160
+ if __name__ == "__main__":
161
+ # Initialize all agents
162
+ clothing_agent = ClothingAgent()
163
+ hiking_agent = HikingAgent()
164
+ internet_agent = InternetAgent()
165
+ food_agent = FoodAgent()
166
+ personalized_agent = PersonalizedAgent()
167
+
168
+ # Example location and common parameters
169
+ location = "Kyoto, Japan"
170
+ duration = 7
171
+ season = "spring"
172
+
173
+ try:
174
+ # Get clothing recommendations
175
+ clothing_result = clothing_agent.get_recommendations(
176
+ location=location,
177
+ duration=duration,
178
+ activities=["hiking", "temple visits", "city walking", "photography"],
179
+ season=season,
180
+ )
181
+ print("\n=== Clothing Recommendations ===")
182
+ print(f"Essential items: {', '.join(clothing_result.essential_items)}")
183
+ print(f"Weather specific: {', '.join(clothing_result.weather_specific)}")
184
+ print(
185
+ f"Cultural considerations: {', '.join(clothing_result.cultural_considerations)}"
186
+ )
187
+
188
+ # Get hiking options
189
+ hiking_result = hiking_agent.get_hiking_options(
190
+ location=location, difficulty="moderate", duration_preference=4.0
191
+ )
192
+ print("\n=== Hiking Options ===")
193
+ for trail in hiking_result:
194
+ print(f"\nTrail: {trail.trail_name}")
195
+ print(f"Difficulty: {trail.difficulty}")
196
+ print(f"Duration: {trail.duration_hours} hours")
197
+ print(f"Distance: {trail.distance_km} km")
198
+
199
+ # Get internet availability
200
+ internet_result = internet_agent.get_connectivity_info(
201
+ location=location,
202
+ duration=duration,
203
+ work_requirements=["video calls", "cloud storage access"],
204
+ )
205
+ print("\n=== Internet Availability ===")
206
+ print(f"General availability: {internet_result.general_availability}")
207
+ print(f"Average speed: {internet_result.average_speed_mbps} Mbps")
208
+ print(
209
+ f"Recommended providers: {', '.join(internet_result.recommended_providers)}"
210
+ )
211
+
212
+ # Get food recommendations
213
+ food_result = food_agent.get_food_recommendations(
214
+ location=location,
215
+ dietary_restrictions=["vegetarian"],
216
+ preferences=["traditional", "local specialties"],
217
+ budget_level="medium",
218
+ )
219
+ print("\n=== Food Recommendations ===")
220
+ print("Local specialties:", ", ".join(food_result.local_specialties))
221
+ print("Must-try dishes:", ", ".join(food_result.must_try_dishes))
222
+ for restaurant in food_result.recommended_restaurants[:3]: # Show top 3
223
+ print(f"Restaurant: {restaurant['name']} - {restaurant['type']}")
224
+
225
+ # Get personalized recommendations
226
+ personal_result = personalized_agent.get_personalized_recommendations(
227
+ location=location,
228
+ duration=duration,
229
+ interests=["photography", "culture", "nature", "food"],
230
+ budget_level="medium",
231
+ travel_style="balanced",
232
+ previous_destinations=["Tokyo", "Seoul"],
233
+ )
234
+ print("\n=== Personalized Recommendations ===")
235
+ print("Hidden gems:", ", ".join(personal_result.hidden_gems))
236
+ print("\nCustom Itinerary:")
237
+ for day in personal_result.custom_itinerary:
238
+ print(f"Day {day['day']}: {day['activities']}")
239
+
240
+ except ProcessingError as e:
241
+ print(f"Error processing travel recommendations: {str(e)}")
242
+ except Exception as e:
243
+ print(f"Unexpected error: {str(e)}")
@@ -0,0 +1,59 @@
1
+ from typing import List, Optional, Dict, Any
2
+ from pydantic import BaseModel, Field, validator
3
+
4
+
5
+ class ClothingRecommendation(BaseModel):
6
+ """Model for clothing recommendations"""
7
+
8
+ essential_items: List[str]
9
+ weather_specific: List[str]
10
+ activity_specific: List[str]
11
+ cultural_considerations: List[str]
12
+ packing_tips: List[str]
13
+
14
+
15
+ class HikingOption(BaseModel):
16
+ """Model for hiking recommendations"""
17
+
18
+ trail_name: str
19
+ difficulty: str
20
+ duration_hours: float
21
+ distance_km: float
22
+ elevation_gain_m: float
23
+ best_season: List[str]
24
+ required_gear: List[str]
25
+ safety_tips: List[str]
26
+ highlights: List[str]
27
+
28
+
29
+ class InternetAvailability(BaseModel):
30
+ """Model for internet availability information"""
31
+
32
+ general_availability: str
33
+ average_speed_mbps: float
34
+ public_wifi_spots: List[str]
35
+ recommended_providers: List[str]
36
+ connectivity_tips: List[str]
37
+ offline_alternatives: List[str]
38
+
39
+
40
+ class FoodOption(BaseModel):
41
+ """Model for food recommendations"""
42
+
43
+ local_specialties: List[str]
44
+ recommended_restaurants: List[Dict[str, str]]
45
+ dietary_considerations: List[str]
46
+ food_safety_tips: List[str]
47
+ price_ranges: Dict[str, str]
48
+ must_try_dishes: List[str]
49
+
50
+
51
+ class PersonalizedRecommendation(BaseModel):
52
+ """Model for personalized recommendations"""
53
+
54
+ activities: List[Dict[str, str]]
55
+ hidden_gems: List[str]
56
+ local_events: List[Dict[str, str]]
57
+ custom_itinerary: List[Dict[str, Any]]
58
+ safety_tips: List[str]
59
+ budget_recommendations: Dict[str, str]
@@ -0,0 +1,7 @@
1
+ """Core modules for Airtrain"""
2
+
3
+ from .skills import Skill, ProcessingError
4
+ from .schemas import InputSchema, OutputSchema
5
+ from .credentials import BaseCredentials
6
+
7
+ __all__ = ["Skill", "ProcessingError", "InputSchema", "OutputSchema", "BaseCredentials"]
@@ -0,0 +1,171 @@
1
+ from typing import Dict, List, Optional, Set, Union
2
+ import os
3
+ import json
4
+ from pathlib import Path
5
+ from abc import ABC, abstractmethod
6
+ import dotenv
7
+ from pydantic import BaseModel, Field, SecretStr
8
+ import yaml
9
+
10
+
11
+ class CredentialError(Exception):
12
+ """Base exception for credential-related errors"""
13
+
14
+ pass
15
+
16
+
17
+ class CredentialNotFoundError(CredentialError):
18
+ """Raised when a required credential is not found"""
19
+
20
+ pass
21
+
22
+
23
+ class CredentialValidationError(CredentialError):
24
+ """Raised when credentials fail validation"""
25
+
26
+ pass
27
+
28
+
29
+ class BaseCredentials(BaseModel):
30
+ """Base class for all credential configurations"""
31
+
32
+ _loaded: bool = False
33
+ _required_credentials: Set[str] = set()
34
+
35
+ def load_to_env(self) -> None:
36
+ """Load credentials into environment variables"""
37
+ for field_name, field_value in self:
38
+ if isinstance(field_value, SecretStr):
39
+ value = field_value.get_secret_value()
40
+ else:
41
+ value = str(field_value)
42
+ os.environ[field_name.upper()] = value
43
+ self._loaded = True
44
+
45
+ @classmethod
46
+ def from_env(cls) -> "BaseCredentials":
47
+ """Create credentials instance from environment variables"""
48
+ field_values = {}
49
+ for field_name in cls.model_fields:
50
+ env_key = field_name.upper()
51
+ if env_value := os.getenv(env_key):
52
+ field_values[field_name] = env_value
53
+ return cls(**field_values)
54
+
55
+ @classmethod
56
+ def from_file(cls, file_path: str | Path) -> "BaseCredentials":
57
+ """Load credentials from a file (supports .env, .json, .yaml).
58
+
59
+ Args:
60
+ file_path: Path to load credentials from. Can be a string or Path object.
61
+ Supported formats: .env, .json, .yaml/.yml
62
+
63
+ Returns:
64
+ BaseCredentials: Initialized credentials object
65
+
66
+ Raises:
67
+ FileNotFoundError: If the credentials file does not exist
68
+ ValueError: If the file format is not supported
69
+ """
70
+ # Convert to Path object if string
71
+ if isinstance(file_path, str):
72
+ file_path = Path(file_path)
73
+
74
+ if not file_path.exists():
75
+ raise FileNotFoundError(f"Credentials file not found: {file_path}")
76
+
77
+ # Get file extension, default to .env if none provided
78
+ suffix = file_path.suffix
79
+ if not suffix:
80
+ # Try to find a file with the same name but different extension
81
+ for ext in [".env", ".json", ".yaml", ".yml"]:
82
+ potential_path = file_path.with_suffix(ext)
83
+ if potential_path.exists():
84
+ file_path = potential_path
85
+ suffix = ext
86
+ break
87
+ # If no file was found, default to .env
88
+ if not suffix:
89
+ file_path = file_path.with_suffix(".env")
90
+ suffix = ".env"
91
+
92
+ if suffix == ".env":
93
+ dotenv.load_dotenv(file_path)
94
+ return cls.from_env()
95
+
96
+ elif suffix == ".json":
97
+ with open(file_path) as f:
98
+ data = json.load(f)
99
+ return cls(**data)
100
+
101
+ elif suffix in {".yaml", ".yml"}:
102
+ with open(file_path) as f:
103
+ data = yaml.safe_load(f)
104
+ return cls(**data)
105
+
106
+ else:
107
+ raise ValueError(f"Unsupported file format: {suffix}")
108
+
109
+ def save_to_file(self, file_path: str | Path) -> None:
110
+ """Save credentials to a file.
111
+
112
+ Args:
113
+ file_path: Path to save credentials to. Can be a string or Path object.
114
+ Supported formats: .env, .json, .yaml/.yml
115
+ """
116
+ # Convert to Path object if string
117
+ if isinstance(file_path, str):
118
+ file_path = Path(file_path)
119
+
120
+ data = self.model_dump(exclude={"_loaded"})
121
+
122
+ # Convert SecretStr to plain strings for saving
123
+ for key, value in data.items():
124
+ if isinstance(value, SecretStr):
125
+ data[key] = value.get_secret_value()
126
+
127
+ # Get file extension, default to .env if none provided
128
+ suffix = file_path.suffix
129
+ if not suffix:
130
+ file_path = file_path.with_suffix(".env")
131
+ suffix = ".env"
132
+
133
+ if suffix == ".env":
134
+ with open(file_path, "w") as f:
135
+ for key, value in data.items():
136
+ f.write(f"{key.upper()}={value}\n")
137
+
138
+ elif suffix == ".json":
139
+ with open(file_path, "w") as f:
140
+ json.dump(data, f, indent=2)
141
+
142
+ elif suffix in {".yaml", ".yml"}:
143
+ with open(file_path, "w") as f:
144
+ yaml.dump(data, f)
145
+
146
+ else:
147
+ raise ValueError(f"Unsupported file format: {suffix}")
148
+
149
+ async def validate_credentials(self) -> bool:
150
+ """Validate that all required credentials are present"""
151
+ missing = []
152
+ for field_name in self._required_credentials:
153
+ value = getattr(self, field_name, None)
154
+ if value is None or (
155
+ isinstance(value, SecretStr) and not value.get_secret_value()
156
+ ):
157
+ missing.append(field_name)
158
+
159
+ if missing:
160
+ raise CredentialValidationError(
161
+ f"Missing required credentials: {', '.join(missing)}"
162
+ )
163
+ return True
164
+
165
+ def clear_from_env(self) -> None:
166
+ """Remove credentials from environment variables"""
167
+ for field_name in self.model_fields:
168
+ env_key = field_name.upper()
169
+ if env_key in os.environ:
170
+ del os.environ[env_key]
171
+ self._loaded = False