airtrain 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. airtrain/__init__.py +146 -6
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  19. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  21. airtrain/core/credentials.py +62 -44
  22. airtrain/core/skills.py +102 -0
  23. airtrain/integrations/__init__.py +74 -0
  24. airtrain/integrations/anthropic/__init__.py +33 -0
  25. airtrain/integrations/anthropic/credentials.py +32 -0
  26. airtrain/integrations/anthropic/list_models.py +110 -0
  27. airtrain/integrations/anthropic/models_config.py +100 -0
  28. airtrain/integrations/anthropic/skills.py +155 -0
  29. airtrain/integrations/aws/__init__.py +6 -0
  30. airtrain/integrations/aws/credentials.py +36 -0
  31. airtrain/integrations/aws/skills.py +98 -0
  32. airtrain/integrations/cerebras/__init__.py +6 -0
  33. airtrain/integrations/cerebras/credentials.py +19 -0
  34. airtrain/integrations/cerebras/skills.py +127 -0
  35. airtrain/integrations/combined/__init__.py +21 -0
  36. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  37. airtrain/integrations/combined/list_models_factory.py +210 -0
  38. airtrain/integrations/fireworks/__init__.py +21 -0
  39. airtrain/integrations/fireworks/completion_skills.py +147 -0
  40. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  41. airtrain/integrations/fireworks/credentials.py +26 -0
  42. airtrain/integrations/fireworks/list_models.py +128 -0
  43. airtrain/integrations/fireworks/models.py +139 -0
  44. airtrain/integrations/fireworks/requests_skills.py +207 -0
  45. airtrain/integrations/fireworks/skills.py +181 -0
  46. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  47. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  48. airtrain/integrations/fireworks/structured_skills.py +102 -0
  49. airtrain/integrations/google/__init__.py +7 -0
  50. airtrain/integrations/google/credentials.py +58 -0
  51. airtrain/integrations/google/skills.py +122 -0
  52. airtrain/integrations/groq/__init__.py +23 -0
  53. airtrain/integrations/groq/credentials.py +24 -0
  54. airtrain/integrations/groq/models_config.py +162 -0
  55. airtrain/integrations/groq/skills.py +201 -0
  56. airtrain/integrations/ollama/__init__.py +6 -0
  57. airtrain/integrations/ollama/credentials.py +26 -0
  58. airtrain/integrations/ollama/skills.py +41 -0
  59. airtrain/integrations/openai/__init__.py +37 -0
  60. airtrain/integrations/openai/chinese_assistant.py +42 -0
  61. airtrain/integrations/openai/credentials.py +39 -0
  62. airtrain/integrations/openai/list_models.py +112 -0
  63. airtrain/integrations/openai/models_config.py +224 -0
  64. airtrain/integrations/openai/skills.py +342 -0
  65. airtrain/integrations/perplexity/__init__.py +49 -0
  66. airtrain/integrations/perplexity/credentials.py +43 -0
  67. airtrain/integrations/perplexity/list_models.py +112 -0
  68. airtrain/integrations/perplexity/models_config.py +128 -0
  69. airtrain/integrations/perplexity/skills.py +279 -0
  70. airtrain/integrations/sambanova/__init__.py +6 -0
  71. airtrain/integrations/sambanova/credentials.py +20 -0
  72. airtrain/integrations/sambanova/skills.py +129 -0
  73. airtrain/integrations/search/__init__.py +21 -0
  74. airtrain/integrations/search/exa/__init__.py +23 -0
  75. airtrain/integrations/search/exa/credentials.py +30 -0
  76. airtrain/integrations/search/exa/schemas.py +114 -0
  77. airtrain/integrations/search/exa/skills.py +115 -0
  78. airtrain/integrations/together/__init__.py +33 -0
  79. airtrain/integrations/together/audio_models_config.py +34 -0
  80. airtrain/integrations/together/credentials.py +22 -0
  81. airtrain/integrations/together/embedding_models_config.py +92 -0
  82. airtrain/integrations/together/image_models_config.py +69 -0
  83. airtrain/integrations/together/image_skill.py +143 -0
  84. airtrain/integrations/together/list_models.py +76 -0
  85. airtrain/integrations/together/models.py +95 -0
  86. airtrain/integrations/together/models_config.py +399 -0
  87. airtrain/integrations/together/rerank_models_config.py +43 -0
  88. airtrain/integrations/together/rerank_skill.py +49 -0
  89. airtrain/integrations/together/schemas.py +33 -0
  90. airtrain/integrations/together/skills.py +305 -0
  91. airtrain/integrations/together/vision_models_config.py +49 -0
  92. airtrain/telemetry/__init__.py +38 -0
  93. airtrain/telemetry/service.py +167 -0
  94. airtrain/telemetry/views.py +237 -0
  95. airtrain/tools/__init__.py +45 -0
  96. airtrain/tools/command.py +398 -0
  97. airtrain/tools/filesystem.py +166 -0
  98. airtrain/tools/network.py +111 -0
  99. airtrain/tools/registry.py +320 -0
  100. airtrain/tools/search.py +450 -0
  101. airtrain/tools/testing.py +135 -0
  102. airtrain-0.1.4.dist-info/METADATA +222 -0
  103. airtrain-0.1.4.dist-info/RECORD +108 -0
  104. {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  105. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  106. airtrain-0.1.3.dist-info/METADATA +0 -106
  107. airtrain-0.1.3.dist-info/RECORD +0 -9
  108. {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
airtrain/cli/main.py ADDED
@@ -0,0 +1,120 @@
1
+ import click
2
+ from typing import Optional
3
+ from airtrain.integrations.openai.skills import OpenAIChatSkill, OpenAIInput
4
+ from airtrain.integrations.anthropic.skills import AnthropicChatSkill, AnthropicInput
5
+ import os
6
+ from dotenv import load_dotenv
7
+ import sys
8
+ from .builder import build
9
+
10
+ # Load environment variables
11
+ load_dotenv()
12
+
13
+
14
+ def initialize_chat(provider: str = "openai"):
15
+ """Initialize chat skill based on provider"""
16
+ if provider == "openai":
17
+ return OpenAIChatSkill()
18
+ elif provider == "anthropic":
19
+ return AnthropicChatSkill()
20
+ else:
21
+ raise ValueError(f"Unsupported provider: {provider}")
22
+
23
+
24
+ @click.group()
25
+ def cli():
26
+ """Airtrain CLI - Your AI Agent Building Assistant"""
27
+ pass
28
+
29
+
30
+ @cli.command()
31
+ @click.option(
32
+ "--provider",
33
+ type=click.Choice(["openai", "anthropic"]),
34
+ default="openai",
35
+ help="The AI provider to use",
36
+ )
37
+ @click.option(
38
+ "--temperature",
39
+ type=float,
40
+ default=0.7,
41
+ help="Temperature for response generation (0.0-1.0)",
42
+ )
43
+ @click.option(
44
+ "--system-prompt",
45
+ type=str,
46
+ default="You are a helpful AI assistant that helps users build their own AI agents. Be helpful and provide clear explanations.",
47
+ help="System prompt to guide the model",
48
+ )
49
+ def chat(provider: str, temperature: float, system_prompt: str):
50
+ """Start an interactive chat session with Airtrain"""
51
+ try:
52
+ skill = initialize_chat(provider)
53
+
54
+ click.echo(f"\nWelcome to Airtrain! Using {provider.upper()} as the provider.")
55
+ click.echo("Type 'exit' to end the conversation.")
56
+ click.echo("Type 'clear' to clear the conversation history.\n")
57
+
58
+ conversation_history = []
59
+
60
+ while True:
61
+ user_input = click.prompt("You", type=str)
62
+
63
+ if user_input.lower() == "exit":
64
+ click.echo("\nGoodbye! Have a great day!")
65
+ break
66
+
67
+ if user_input.lower() == "clear":
68
+ conversation_history = []
69
+ click.echo("\nConversation history cleared!")
70
+ continue
71
+
72
+ try:
73
+ if provider == "openai":
74
+ input_data = OpenAIInput(
75
+ user_input=user_input,
76
+ system_prompt=system_prompt,
77
+ conversation_history=conversation_history,
78
+ model="gpt-4o",
79
+ temperature=temperature,
80
+ )
81
+ else:
82
+ input_data = AnthropicInput(
83
+ user_input=user_input,
84
+ system_prompt=system_prompt,
85
+ conversation_history=conversation_history,
86
+ model="claude-3-opus-20240229",
87
+ temperature=temperature,
88
+ )
89
+
90
+ result = skill.process(input_data)
91
+
92
+ # Update conversation history
93
+ conversation_history.extend(
94
+ [
95
+ {"role": "user", "content": user_input},
96
+ {"role": "assistant", "content": result.response},
97
+ ]
98
+ )
99
+
100
+ click.echo(f"\nAirtrain: {result.response}\n")
101
+
102
+ except Exception as e:
103
+ click.echo(f"\nError: {str(e)}\n")
104
+
105
+ except Exception as e:
106
+ click.echo(f"Failed to initialize chat: {str(e)}")
107
+ sys.exit(1)
108
+
109
+
110
+ # Add to existing cli group
111
+ cli.add_command(build)
112
+
113
+
114
+ def main():
115
+ """Main entry point for the CLI"""
116
+ cli()
117
+
118
+
119
+ if __name__ == "__main__":
120
+ main()
@@ -0,0 +1,29 @@
1
+ """Airtrain contrib package for community contributions"""
2
+
3
+ from .travel.agents import (
4
+ TravelAgentBase,
5
+ ClothingAgent,
6
+ HikingAgent,
7
+ InternetConnectivityAgent,
8
+ FoodRecommendationAgent,
9
+ PersonalizedRecommendationAgent,
10
+ )
11
+ from .travel.models import (
12
+ ClothingRecommendation,
13
+ HikingOption,
14
+ InternetAvailability,
15
+ FoodOption,
16
+ )
17
+
18
+ __all__ = [
19
+ "TravelAgentBase",
20
+ "ClothingAgent",
21
+ "HikingAgent",
22
+ "InternetConnectivityAgent",
23
+ "FoodRecommendationAgent",
24
+ "PersonalizedRecommendationAgent",
25
+ "ClothingRecommendation",
26
+ "HikingOption",
27
+ "InternetAvailability",
28
+ "FoodOption",
29
+ ]
@@ -0,0 +1,35 @@
1
+ """Travel related agents and models"""
2
+
3
+ from .agents import (
4
+ TravelAgentBase,
5
+ ClothingAgent,
6
+ HikingAgent,
7
+ InternetAgent,
8
+ FoodAgent,
9
+ PersonalizedAgent,
10
+ )
11
+ from .models import (
12
+ ClothingRecommendation,
13
+ HikingOption,
14
+ InternetAvailability,
15
+ FoodOption,
16
+ )
17
+ from .agentlib.verification_agent import UserVerificationAgent
18
+ from .modellib.verification import UserTravelInfo, TravelCompanion, HealthCondition
19
+
20
+ __all__ = [
21
+ "TravelAgentBase",
22
+ "ClothingAgent",
23
+ "HikingAgent",
24
+ "InternetConnectivityAgent",
25
+ "FoodRecommendationAgent",
26
+ "PersonalizedRecommendationAgent",
27
+ "ClothingRecommendation",
28
+ "HikingOption",
29
+ "InternetAvailability",
30
+ "FoodOption",
31
+ "UserVerificationAgent",
32
+ "UserTravelInfo",
33
+ "TravelCompanion",
34
+ "HealthCondition",
35
+ ]
@@ -0,0 +1,243 @@
1
+ from typing import Optional, List, Any
2
+ from airtrain.core.skills import Skill, ProcessingError
3
+ from airtrain.integrations.openai.skills import OpenAIParserSkill, OpenAIParserInput
4
+ from .models import (
5
+ ClothingRecommendation,
6
+ HikingOption,
7
+ InternetAvailability,
8
+ FoodOption,
9
+ PersonalizedRecommendation,
10
+ )
11
+
12
+
13
+ class TravelAgentBase(OpenAIParserSkill):
14
+ def __init__(self, credentials=None):
15
+ super().__init__(credentials)
16
+ self.model = "gpt-4o"
17
+ self.temperature = 0.0
18
+
19
+
20
+ class ClothingAgent(TravelAgentBase):
21
+ """Agent for clothing recommendations"""
22
+
23
+ def get_recommendations(
24
+ self, location: str, duration: int, activities: List[str], season: str
25
+ ) -> ClothingRecommendation:
26
+ prompt = f"""
27
+ Provide clothing recommendations for a {duration}-day trip to {location} during {season}.
28
+ Activities planned: {', '.join(activities)}.
29
+ Include essential items, weather-specific clothing, and cultural considerations.
30
+ """
31
+
32
+ input_data = OpenAIParserInput(
33
+ user_input=prompt,
34
+ system_prompt="You are a travel clothing expert. Provide detailed packing recommendations.",
35
+ response_model=ClothingRecommendation,
36
+ model=self.model,
37
+ temperature=self.temperature,
38
+ )
39
+
40
+ result = self.process(input_data)
41
+ return result.parsed_response
42
+
43
+
44
+ class HikingAgent(TravelAgentBase):
45
+ """Agent for hiking recommendations"""
46
+
47
+ def get_hiking_options(
48
+ self, location: str, difficulty: str, duration_preference: float
49
+ ) -> List[HikingOption]:
50
+ prompt = f"""
51
+ Find hiking trails in {location} that match:
52
+ - Difficulty level: {difficulty}
53
+ - Preferred duration: around {duration_preference} hours
54
+ Provide detailed trail information and safety tips.
55
+ """
56
+
57
+ input_data = OpenAIParserInput(
58
+ user_input=prompt,
59
+ system_prompt="You are a hiking expert. Recommend suitable trails with safety considerations.",
60
+ response_model=List[HikingOption],
61
+ model=self.model,
62
+ )
63
+
64
+ result = self.process(input_data)
65
+ return result.parsed_response
66
+
67
+
68
+ class InternetAgent(TravelAgentBase):
69
+ """Agent for internet availability information"""
70
+
71
+ def get_connectivity_info(
72
+ self,
73
+ location: str,
74
+ duration: int,
75
+ work_requirements: Optional[List[str]] = None,
76
+ ) -> InternetAvailability:
77
+ prompt = f"""
78
+ Provide detailed internet connectivity information for {location} for a {duration}-day stay.
79
+ {f'Work requirements: {", ".join(work_requirements)}' if work_requirements else ''}
80
+ Include public WiFi spots, recommended providers, and connectivity tips.
81
+ """
82
+
83
+ input_data = OpenAIParserInput(
84
+ user_input=prompt,
85
+ system_prompt="You are a connectivity expert. Provide detailed internet availability information.",
86
+ response_model=InternetAvailability,
87
+ model=self.model,
88
+ temperature=self.temperature,
89
+ )
90
+
91
+ result = self.process(input_data)
92
+ return result.parsed_response
93
+
94
+
95
+ class FoodAgent(TravelAgentBase):
96
+ """Agent for food recommendations"""
97
+
98
+ def get_food_recommendations(
99
+ self,
100
+ location: str,
101
+ dietary_restrictions: Optional[List[str]] = None,
102
+ preferences: Optional[List[str]] = None,
103
+ budget_level: str = "medium",
104
+ ) -> FoodOption:
105
+ prompt = f"""
106
+ Provide food recommendations for {location}.
107
+ {f'Dietary restrictions: {", ".join(dietary_restrictions)}' if dietary_restrictions else ''}
108
+ {f'Food preferences: {", ".join(preferences)}' if preferences else ''}
109
+ Budget level: {budget_level}
110
+ Include local specialties, restaurant recommendations, and food safety tips.
111
+ """
112
+
113
+ input_data = OpenAIParserInput(
114
+ user_input=prompt,
115
+ system_prompt="You are a culinary expert. Provide detailed food recommendations.",
116
+ response_model=FoodOption,
117
+ model=self.model,
118
+ temperature=self.temperature,
119
+ )
120
+
121
+ result = self.process(input_data)
122
+ return result.parsed_response
123
+
124
+
125
+ class PersonalizedAgent(TravelAgentBase):
126
+ """Agent for personalized travel recommendations"""
127
+
128
+ def get_personalized_recommendations(
129
+ self,
130
+ location: str,
131
+ duration: int,
132
+ interests: List[str],
133
+ budget_level: str,
134
+ travel_style: str,
135
+ previous_destinations: Optional[List[str]] = None,
136
+ ) -> PersonalizedRecommendation:
137
+ prompt = f"""
138
+ Create personalized travel recommendations for {location} for {duration} days.
139
+ Interests: {', '.join(interests)}
140
+ Travel style: {travel_style}
141
+ Budget level: {budget_level}
142
+ {f'Previous destinations: {", ".join(previous_destinations)}' if previous_destinations else ''}
143
+ Include hidden gems, local events, and a custom itinerary.
144
+ """
145
+
146
+ input_data = OpenAIParserInput(
147
+ user_input=prompt,
148
+ system_prompt="You are a personal travel consultant. Provide tailored recommendations.",
149
+ response_model=PersonalizedRecommendation,
150
+ model=self.model,
151
+ temperature=self.temperature,
152
+ )
153
+
154
+ result = self.process(input_data)
155
+ return result.parsed_response
156
+
157
+
158
+ # Similar pattern for other agents...
159
+
160
+ if __name__ == "__main__":
161
+ # Initialize all agents
162
+ clothing_agent = ClothingAgent()
163
+ hiking_agent = HikingAgent()
164
+ internet_agent = InternetAgent()
165
+ food_agent = FoodAgent()
166
+ personalized_agent = PersonalizedAgent()
167
+
168
+ # Example location and common parameters
169
+ location = "Kyoto, Japan"
170
+ duration = 7
171
+ season = "spring"
172
+
173
+ try:
174
+ # Get clothing recommendations
175
+ clothing_result = clothing_agent.get_recommendations(
176
+ location=location,
177
+ duration=duration,
178
+ activities=["hiking", "temple visits", "city walking", "photography"],
179
+ season=season,
180
+ )
181
+ print("\n=== Clothing Recommendations ===")
182
+ print(f"Essential items: {', '.join(clothing_result.essential_items)}")
183
+ print(f"Weather specific: {', '.join(clothing_result.weather_specific)}")
184
+ print(
185
+ f"Cultural considerations: {', '.join(clothing_result.cultural_considerations)}"
186
+ )
187
+
188
+ # Get hiking options
189
+ hiking_result = hiking_agent.get_hiking_options(
190
+ location=location, difficulty="moderate", duration_preference=4.0
191
+ )
192
+ print("\n=== Hiking Options ===")
193
+ for trail in hiking_result:
194
+ print(f"\nTrail: {trail.trail_name}")
195
+ print(f"Difficulty: {trail.difficulty}")
196
+ print(f"Duration: {trail.duration_hours} hours")
197
+ print(f"Distance: {trail.distance_km} km")
198
+
199
+ # Get internet availability
200
+ internet_result = internet_agent.get_connectivity_info(
201
+ location=location,
202
+ duration=duration,
203
+ work_requirements=["video calls", "cloud storage access"],
204
+ )
205
+ print("\n=== Internet Availability ===")
206
+ print(f"General availability: {internet_result.general_availability}")
207
+ print(f"Average speed: {internet_result.average_speed_mbps} Mbps")
208
+ print(
209
+ f"Recommended providers: {', '.join(internet_result.recommended_providers)}"
210
+ )
211
+
212
+ # Get food recommendations
213
+ food_result = food_agent.get_food_recommendations(
214
+ location=location,
215
+ dietary_restrictions=["vegetarian"],
216
+ preferences=["traditional", "local specialties"],
217
+ budget_level="medium",
218
+ )
219
+ print("\n=== Food Recommendations ===")
220
+ print("Local specialties:", ", ".join(food_result.local_specialties))
221
+ print("Must-try dishes:", ", ".join(food_result.must_try_dishes))
222
+ for restaurant in food_result.recommended_restaurants[:3]: # Show top 3
223
+ print(f"Restaurant: {restaurant['name']} - {restaurant['type']}")
224
+
225
+ # Get personalized recommendations
226
+ personal_result = personalized_agent.get_personalized_recommendations(
227
+ location=location,
228
+ duration=duration,
229
+ interests=["photography", "culture", "nature", "food"],
230
+ budget_level="medium",
231
+ travel_style="balanced",
232
+ previous_destinations=["Tokyo", "Seoul"],
233
+ )
234
+ print("\n=== Personalized Recommendations ===")
235
+ print("Hidden gems:", ", ".join(personal_result.hidden_gems))
236
+ print("\nCustom Itinerary:")
237
+ for day in personal_result.custom_itinerary:
238
+ print(f"Day {day['day']}: {day['activities']}")
239
+
240
+ except ProcessingError as e:
241
+ print(f"Error processing travel recommendations: {str(e)}")
242
+ except Exception as e:
243
+ print(f"Unexpected error: {str(e)}")
@@ -0,0 +1,59 @@
1
+ from typing import List, Optional, Dict, Any
2
+ from pydantic import BaseModel, Field, validator
3
+
4
+
5
+ class ClothingRecommendation(BaseModel):
6
+ """Model for clothing recommendations"""
7
+
8
+ essential_items: List[str]
9
+ weather_specific: List[str]
10
+ activity_specific: List[str]
11
+ cultural_considerations: List[str]
12
+ packing_tips: List[str]
13
+
14
+
15
+ class HikingOption(BaseModel):
16
+ """Model for hiking recommendations"""
17
+
18
+ trail_name: str
19
+ difficulty: str
20
+ duration_hours: float
21
+ distance_km: float
22
+ elevation_gain_m: float
23
+ best_season: List[str]
24
+ required_gear: List[str]
25
+ safety_tips: List[str]
26
+ highlights: List[str]
27
+
28
+
29
+ class InternetAvailability(BaseModel):
30
+ """Model for internet availability information"""
31
+
32
+ general_availability: str
33
+ average_speed_mbps: float
34
+ public_wifi_spots: List[str]
35
+ recommended_providers: List[str]
36
+ connectivity_tips: List[str]
37
+ offline_alternatives: List[str]
38
+
39
+
40
+ class FoodOption(BaseModel):
41
+ """Model for food recommendations"""
42
+
43
+ local_specialties: List[str]
44
+ recommended_restaurants: List[Dict[str, str]]
45
+ dietary_considerations: List[str]
46
+ food_safety_tips: List[str]
47
+ price_ranges: Dict[str, str]
48
+ must_try_dishes: List[str]
49
+
50
+
51
+ class PersonalizedRecommendation(BaseModel):
52
+ """Model for personalized recommendations"""
53
+
54
+ activities: List[Dict[str, str]]
55
+ hidden_gems: List[str]
56
+ local_events: List[Dict[str, str]]
57
+ custom_itinerary: List[Dict[str, Any]]
58
+ safety_tips: List[str]
59
+ budget_recommendations: Dict[str, str]
@@ -1,11 +1,11 @@
1
- from typing import Dict, List, Optional, Set
1
+ from typing import Dict, List, Optional, Set, Union
2
2
  import os
3
3
  import json
4
4
  from pathlib import Path
5
5
  from abc import ABC, abstractmethod
6
6
  import dotenv
7
7
  from pydantic import BaseModel, Field, SecretStr
8
- import yaml # type: ignore
8
+ import yaml
9
9
 
10
10
 
11
11
  class CredentialError(Exception):
@@ -53,30 +53,70 @@ class BaseCredentials(BaseModel):
53
53
  return cls(**field_values)
54
54
 
55
55
  @classmethod
56
- def from_file(cls, file_path: Path) -> "BaseCredentials":
57
- """Load credentials from a file (supports .env, .json, .yaml)"""
56
+ def from_file(cls, file_path: str | Path) -> "BaseCredentials":
57
+ """Load credentials from a file (supports .env, .json, .yaml).
58
+
59
+ Args:
60
+ file_path: Path to load credentials from. Can be a string or Path object.
61
+ Supported formats: .env, .json, .yaml/.yml
62
+
63
+ Returns:
64
+ BaseCredentials: Initialized credentials object
65
+
66
+ Raises:
67
+ FileNotFoundError: If the credentials file does not exist
68
+ ValueError: If the file format is not supported
69
+ """
70
+ # Convert to Path object if string
71
+ if isinstance(file_path, str):
72
+ file_path = Path(file_path)
73
+
58
74
  if not file_path.exists():
59
75
  raise FileNotFoundError(f"Credentials file not found: {file_path}")
60
76
 
61
- if file_path.suffix == ".env":
77
+ # Get file extension, default to .env if none provided
78
+ suffix = file_path.suffix
79
+ if not suffix:
80
+ # Try to find a file with the same name but different extension
81
+ for ext in [".env", ".json", ".yaml", ".yml"]:
82
+ potential_path = file_path.with_suffix(ext)
83
+ if potential_path.exists():
84
+ file_path = potential_path
85
+ suffix = ext
86
+ break
87
+ # If no file was found, default to .env
88
+ if not suffix:
89
+ file_path = file_path.with_suffix(".env")
90
+ suffix = ".env"
91
+
92
+ if suffix == ".env":
62
93
  dotenv.load_dotenv(file_path)
63
94
  return cls.from_env()
64
95
 
65
- elif file_path.suffix == ".json":
96
+ elif suffix == ".json":
66
97
  with open(file_path) as f:
67
98
  data = json.load(f)
68
99
  return cls(**data)
69
100
 
70
- elif file_path.suffix in {".yaml", ".yml"}:
101
+ elif suffix in {".yaml", ".yml"}:
71
102
  with open(file_path) as f:
72
103
  data = yaml.safe_load(f)
73
104
  return cls(**data)
74
105
 
75
106
  else:
76
- raise ValueError(f"Unsupported file format: {file_path.suffix}")
107
+ raise ValueError(f"Unsupported file format: {suffix}")
108
+
109
+ def save_to_file(self, file_path: str | Path) -> None:
110
+ """Save credentials to a file.
111
+
112
+ Args:
113
+ file_path: Path to save credentials to. Can be a string or Path object.
114
+ Supported formats: .env, .json, .yaml/.yml
115
+ """
116
+ # Convert to Path object if string
117
+ if isinstance(file_path, str):
118
+ file_path = Path(file_path)
77
119
 
78
- def save_to_file(self, file_path: Path) -> None:
79
- """Save credentials to a file"""
80
120
  data = self.model_dump(exclude={"_loaded"})
81
121
 
82
122
  # Convert SecretStr to plain strings for saving
@@ -84,23 +124,29 @@ class BaseCredentials(BaseModel):
84
124
  if isinstance(value, SecretStr):
85
125
  data[key] = value.get_secret_value()
86
126
 
87
- if file_path.suffix == ".env":
127
+ # Get file extension, default to .env if none provided
128
+ suffix = file_path.suffix
129
+ if not suffix:
130
+ file_path = file_path.with_suffix(".env")
131
+ suffix = ".env"
132
+
133
+ if suffix == ".env":
88
134
  with open(file_path, "w") as f:
89
135
  for key, value in data.items():
90
136
  f.write(f"{key.upper()}={value}\n")
91
137
 
92
- elif file_path.suffix == ".json":
138
+ elif suffix == ".json":
93
139
  with open(file_path, "w") as f:
94
140
  json.dump(data, f, indent=2)
95
141
 
96
- elif file_path.suffix in {".yaml", ".yml"}:
142
+ elif suffix in {".yaml", ".yml"}:
97
143
  with open(file_path, "w") as f:
98
144
  yaml.dump(data, f)
99
145
 
100
146
  else:
101
- raise ValueError(f"Unsupported file format: {file_path.suffix}")
147
+ raise ValueError(f"Unsupported file format: {suffix}")
102
148
 
103
- def validate_credentials(self) -> None:
149
+ async def validate_credentials(self) -> bool:
104
150
  """Validate that all required credentials are present"""
105
151
  missing = []
106
152
  for field_name in self._required_credentials:
@@ -114,6 +160,7 @@ class BaseCredentials(BaseModel):
114
160
  raise CredentialValidationError(
115
161
  f"Missing required credentials: {', '.join(missing)}"
116
162
  )
163
+ return True
117
164
 
118
165
  def clear_from_env(self) -> None:
119
166
  """Remove credentials from environment variables"""
@@ -122,32 +169,3 @@ class BaseCredentials(BaseModel):
122
169
  if env_key in os.environ:
123
170
  del os.environ[env_key]
124
171
  self._loaded = False
125
-
126
-
127
- class OpenAICredentials(BaseCredentials):
128
- """OpenAI API credentials"""
129
-
130
- api_key: SecretStr = Field(..., description="OpenAI API key")
131
- organization_id: Optional[str] = Field(None, description="OpenAI organization ID")
132
-
133
- _required_credentials = {"api_key"}
134
-
135
-
136
- class AWSCredentials(BaseCredentials):
137
- """AWS credentials"""
138
-
139
- aws_access_key_id: SecretStr
140
- aws_secret_access_key: SecretStr
141
- aws_region: str = "us-east-1"
142
- aws_session_token: Optional[SecretStr] = None
143
-
144
- _required_credentials = {"aws_access_key_id", "aws_secret_access_key"}
145
-
146
-
147
- class GoogleCloudCredentials(BaseCredentials):
148
- """Google Cloud credentials"""
149
-
150
- project_id: str
151
- service_account_key: SecretStr
152
-
153
- _required_credentials = {"project_id", "service_account_key"}