airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. airtrain/__init__.py +148 -2
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__init__.py +7 -0
  19. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  21. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  22. airtrain/core/credentials.py +171 -0
  23. airtrain/core/schemas.py +237 -0
  24. airtrain/core/skills.py +269 -0
  25. airtrain/integrations/__init__.py +74 -0
  26. airtrain/integrations/anthropic/__init__.py +33 -0
  27. airtrain/integrations/anthropic/credentials.py +32 -0
  28. airtrain/integrations/anthropic/list_models.py +110 -0
  29. airtrain/integrations/anthropic/models_config.py +100 -0
  30. airtrain/integrations/anthropic/skills.py +155 -0
  31. airtrain/integrations/aws/__init__.py +6 -0
  32. airtrain/integrations/aws/credentials.py +36 -0
  33. airtrain/integrations/aws/skills.py +98 -0
  34. airtrain/integrations/cerebras/__init__.py +6 -0
  35. airtrain/integrations/cerebras/credentials.py +19 -0
  36. airtrain/integrations/cerebras/skills.py +127 -0
  37. airtrain/integrations/combined/__init__.py +21 -0
  38. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  39. airtrain/integrations/combined/list_models_factory.py +210 -0
  40. airtrain/integrations/fireworks/__init__.py +21 -0
  41. airtrain/integrations/fireworks/completion_skills.py +147 -0
  42. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  43. airtrain/integrations/fireworks/credentials.py +26 -0
  44. airtrain/integrations/fireworks/list_models.py +128 -0
  45. airtrain/integrations/fireworks/models.py +139 -0
  46. airtrain/integrations/fireworks/requests_skills.py +207 -0
  47. airtrain/integrations/fireworks/skills.py +181 -0
  48. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  49. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  50. airtrain/integrations/fireworks/structured_skills.py +102 -0
  51. airtrain/integrations/google/__init__.py +7 -0
  52. airtrain/integrations/google/credentials.py +58 -0
  53. airtrain/integrations/google/skills.py +122 -0
  54. airtrain/integrations/groq/__init__.py +23 -0
  55. airtrain/integrations/groq/credentials.py +24 -0
  56. airtrain/integrations/groq/models_config.py +162 -0
  57. airtrain/integrations/groq/skills.py +201 -0
  58. airtrain/integrations/ollama/__init__.py +6 -0
  59. airtrain/integrations/ollama/credentials.py +26 -0
  60. airtrain/integrations/ollama/skills.py +41 -0
  61. airtrain/integrations/openai/__init__.py +37 -0
  62. airtrain/integrations/openai/chinese_assistant.py +42 -0
  63. airtrain/integrations/openai/credentials.py +39 -0
  64. airtrain/integrations/openai/list_models.py +112 -0
  65. airtrain/integrations/openai/models_config.py +224 -0
  66. airtrain/integrations/openai/skills.py +342 -0
  67. airtrain/integrations/perplexity/__init__.py +49 -0
  68. airtrain/integrations/perplexity/credentials.py +43 -0
  69. airtrain/integrations/perplexity/list_models.py +112 -0
  70. airtrain/integrations/perplexity/models_config.py +128 -0
  71. airtrain/integrations/perplexity/skills.py +279 -0
  72. airtrain/integrations/sambanova/__init__.py +6 -0
  73. airtrain/integrations/sambanova/credentials.py +20 -0
  74. airtrain/integrations/sambanova/skills.py +129 -0
  75. airtrain/integrations/search/__init__.py +21 -0
  76. airtrain/integrations/search/exa/__init__.py +23 -0
  77. airtrain/integrations/search/exa/credentials.py +30 -0
  78. airtrain/integrations/search/exa/schemas.py +114 -0
  79. airtrain/integrations/search/exa/skills.py +115 -0
  80. airtrain/integrations/together/__init__.py +33 -0
  81. airtrain/integrations/together/audio_models_config.py +34 -0
  82. airtrain/integrations/together/credentials.py +22 -0
  83. airtrain/integrations/together/embedding_models_config.py +92 -0
  84. airtrain/integrations/together/image_models_config.py +69 -0
  85. airtrain/integrations/together/image_skill.py +143 -0
  86. airtrain/integrations/together/list_models.py +76 -0
  87. airtrain/integrations/together/models.py +95 -0
  88. airtrain/integrations/together/models_config.py +399 -0
  89. airtrain/integrations/together/rerank_models_config.py +43 -0
  90. airtrain/integrations/together/rerank_skill.py +49 -0
  91. airtrain/integrations/together/schemas.py +33 -0
  92. airtrain/integrations/together/skills.py +305 -0
  93. airtrain/integrations/together/vision_models_config.py +49 -0
  94. airtrain/telemetry/__init__.py +38 -0
  95. airtrain/telemetry/service.py +167 -0
  96. airtrain/telemetry/views.py +237 -0
  97. airtrain/tools/__init__.py +45 -0
  98. airtrain/tools/command.py +398 -0
  99. airtrain/tools/filesystem.py +166 -0
  100. airtrain/tools/network.py +111 -0
  101. airtrain/tools/registry.py +320 -0
  102. airtrain/tools/search.py +450 -0
  103. airtrain/tools/testing.py +135 -0
  104. airtrain-0.1.4.dist-info/METADATA +222 -0
  105. airtrain-0.1.4.dist-info/RECORD +108 -0
  106. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  107. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  108. airtrain-0.1.2.dist-info/METADATA +0 -106
  109. airtrain-0.1.2.dist-info/RECORD +0 -5
  110. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,39 @@
1
+ from datetime import datetime, timedelta
2
+ from typing import Optional
3
+ from pydantic import Field, SecretStr, validator
4
+ from openai import OpenAI
5
+
6
+ from airtrain.core.credentials import BaseCredentials, CredentialValidationError
7
+
8
+
9
+ class OpenAICredentials(BaseCredentials):
10
+ """OpenAI API credentials with enhanced validation"""
11
+
12
+ openai_api_key: SecretStr = Field(..., description="OpenAI API key")
13
+ openai_organization_id: Optional[str] = Field(
14
+ None, description="OpenAI organization ID", pattern="^org-[A-Za-z0-9]{24}$"
15
+ )
16
+
17
+ _required_credentials = {"openai_api_key"}
18
+
19
+ @validator("openai_api_key")
20
+ def validate_api_key_format(cls, v: SecretStr) -> SecretStr:
21
+ key = v.get_secret_value()
22
+ if not key.startswith("sk-"):
23
+ raise ValueError("OpenAI API key must start with 'sk-'")
24
+ if len(key) < 40:
25
+ raise ValueError("OpenAI API key appears to be too short")
26
+ return v
27
+
28
+ async def validate_credentials(self) -> bool:
29
+ """Validate credentials by making a test API call"""
30
+ try:
31
+ client = OpenAI(
32
+ api_key=self.openai_api_key.get_secret_value(),
33
+ organization=self.openai_organization_id,
34
+ )
35
+ # Make minimal API call to validate
36
+ await client.models.list(limit=1)
37
+ return True
38
+ except Exception as e:
39
+ raise CredentialValidationError(f"Invalid OpenAI credentials: {str(e)}")
@@ -0,0 +1,112 @@
1
+ from typing import Optional, List, Dict, Any
2
+ from pydantic import Field
3
+
4
+ from airtrain.core.skills import Skill, ProcessingError
5
+ from airtrain.core.schemas import InputSchema, OutputSchema
6
+ from .credentials import OpenAICredentials
7
+ from .models_config import OPENAI_MODELS, OpenAIModelConfig
8
+
9
+
10
+ class OpenAIModel:
11
+ """Class to represent an OpenAI model."""
12
+
13
+ def __init__(self, model_id: str, config: OpenAIModelConfig):
14
+ """Initialize the OpenAI model."""
15
+ self.id = model_id
16
+ self.display_name = config.display_name
17
+ self.base_model = config.base_model
18
+ self.input_price = config.input_price
19
+ self.cached_input_price = config.cached_input_price
20
+ self.output_price = config.output_price
21
+
22
+ def dict(self, exclude_none=False):
23
+ """Convert the model to a dictionary."""
24
+ result = {
25
+ "id": self.id,
26
+ "display_name": self.display_name,
27
+ "base_model": self.base_model,
28
+ "input_price": float(self.input_price),
29
+ "output_price": float(self.output_price),
30
+ }
31
+ if self.cached_input_price is not None:
32
+ result["cached_input_price"] = float(self.cached_input_price)
33
+ elif not exclude_none:
34
+ result["cached_input_price"] = None
35
+ return result
36
+
37
+
38
+ class OpenAIListModelsInput(InputSchema):
39
+ """Schema for OpenAI list models input"""
40
+
41
+ api_models_only: bool = Field(
42
+ default=False,
43
+ description=(
44
+ "If True, fetch models from the API only. If False, use local config."
45
+ )
46
+ )
47
+
48
+
49
+ class OpenAIListModelsOutput(OutputSchema):
50
+ """Schema for OpenAI list models output"""
51
+
52
+ models: List[Dict[str, Any]] = Field(
53
+ default_factory=list,
54
+ description="List of OpenAI models"
55
+ )
56
+
57
+
58
+ class OpenAIListModelsSkill(Skill[OpenAIListModelsInput, OpenAIListModelsOutput]):
59
+ """Skill for listing OpenAI models"""
60
+
61
+ input_schema = OpenAIListModelsInput
62
+ output_schema = OpenAIListModelsOutput
63
+
64
+ def __init__(self, credentials: Optional[OpenAICredentials] = None):
65
+ """Initialize the skill with optional credentials"""
66
+ super().__init__()
67
+ self.credentials = credentials
68
+
69
+ def process(
70
+ self, input_data: OpenAIListModelsInput
71
+ ) -> OpenAIListModelsOutput:
72
+ """Process the input and return a list of models."""
73
+ try:
74
+ models = []
75
+
76
+ if input_data.api_models_only:
77
+ # Fetch models from OpenAI API - requires credentials
78
+ if not self.credentials:
79
+ raise ProcessingError(
80
+ "OpenAI credentials required for API models"
81
+ )
82
+
83
+ from openai import OpenAI
84
+ client = OpenAI(
85
+ api_key=self.credentials.openai_api_key.get_secret_value(),
86
+ organization=self.credentials.openai_organization_id,
87
+ )
88
+
89
+ # Make API call to get models
90
+ response = client.models.list()
91
+
92
+ # Convert response to our format
93
+ for model in response.data:
94
+ models.append({
95
+ "id": model.id,
96
+ "display_name": model.id, # API doesn't provide display_name
97
+ "base_model": model.id, # API doesn't provide base_model
98
+ "created": model.created,
99
+ "owned_by": model.owned_by,
100
+ # Pricing info not available from API
101
+ })
102
+ else:
103
+ # Use local model config - no credentials needed
104
+ for model_id, config in OPENAI_MODELS.items():
105
+ model = OpenAIModel(model_id, config)
106
+ models.append(model.dict())
107
+
108
+ # Return the output
109
+ return OpenAIListModelsOutput(models=models)
110
+
111
+ except Exception as e:
112
+ raise ProcessingError(f"Failed to list OpenAI models: {str(e)}")
@@ -0,0 +1,224 @@
1
+ from typing import Dict, NamedTuple, Optional
2
+ from decimal import Decimal
3
+
4
+
5
+ class OpenAIModelConfig(NamedTuple):
6
+ display_name: str
7
+ base_model: str
8
+ input_price: Decimal
9
+ cached_input_price: Optional[Decimal]
10
+ output_price: Decimal
11
+
12
+
13
+ OPENAI_MODELS: Dict[str, OpenAIModelConfig] = {
14
+ "gpt-4.5-preview": OpenAIModelConfig(
15
+ display_name="GPT-4.5 Preview",
16
+ base_model="gpt-4.5-preview",
17
+ input_price=Decimal("75.00"),
18
+ cached_input_price=Decimal("37.50"),
19
+ output_price=Decimal("150.00"),
20
+ ),
21
+ "gpt-4.5-preview-2025-02-27": OpenAIModelConfig(
22
+ display_name="GPT-4.5 Preview (2025-02-27)",
23
+ base_model="gpt-4.5-preview",
24
+ input_price=Decimal("75.00"),
25
+ cached_input_price=Decimal("37.50"),
26
+ output_price=Decimal("150.00"),
27
+ ),
28
+ "gpt-4o": OpenAIModelConfig(
29
+ display_name="GPT-4 Optimized",
30
+ base_model="gpt-4o",
31
+ input_price=Decimal("2.50"),
32
+ cached_input_price=Decimal("1.25"),
33
+ output_price=Decimal("10.00"),
34
+ ),
35
+ "gpt-4o-2024-08-06": OpenAIModelConfig(
36
+ display_name="GPT-4 Optimized (2024-08-06)",
37
+ base_model="gpt-4o",
38
+ input_price=Decimal("2.50"),
39
+ cached_input_price=Decimal("1.25"),
40
+ output_price=Decimal("10.00"),
41
+ ),
42
+ "gpt-4o-audio-preview": OpenAIModelConfig(
43
+ display_name="GPT-4 Optimized Audio Preview",
44
+ base_model="gpt-4o-audio-preview",
45
+ input_price=Decimal("2.50"),
46
+ cached_input_price=None,
47
+ output_price=Decimal("10.00"),
48
+ ),
49
+ "gpt-4o-audio-preview-2024-12-17": OpenAIModelConfig(
50
+ display_name="GPT-4 Optimized Audio Preview (2024-12-17)",
51
+ base_model="gpt-4o-audio-preview",
52
+ input_price=Decimal("2.50"),
53
+ cached_input_price=None,
54
+ output_price=Decimal("10.00"),
55
+ ),
56
+ "gpt-4o-realtime-preview": OpenAIModelConfig(
57
+ display_name="GPT-4 Optimized Realtime Preview",
58
+ base_model="gpt-4o-realtime-preview",
59
+ input_price=Decimal("5.00"),
60
+ cached_input_price=Decimal("2.50"),
61
+ output_price=Decimal("20.00"),
62
+ ),
63
+ "gpt-4o-realtime-preview-2024-12-17": OpenAIModelConfig(
64
+ display_name="GPT-4 Optimized Realtime Preview (2024-12-17)",
65
+ base_model="gpt-4o-realtime-preview",
66
+ input_price=Decimal("5.00"),
67
+ cached_input_price=Decimal("2.50"),
68
+ output_price=Decimal("20.00"),
69
+ ),
70
+ "gpt-4o-mini": OpenAIModelConfig(
71
+ display_name="GPT-4 Optimized Mini",
72
+ base_model="gpt-4o-mini",
73
+ input_price=Decimal("0.15"),
74
+ cached_input_price=Decimal("0.075"),
75
+ output_price=Decimal("0.60"),
76
+ ),
77
+ "gpt-4o-mini-2024-07-18": OpenAIModelConfig(
78
+ display_name="GPT-4 Optimized Mini (2024-07-18)",
79
+ base_model="gpt-4o-mini",
80
+ input_price=Decimal("0.15"),
81
+ cached_input_price=Decimal("0.075"),
82
+ output_price=Decimal("0.60"),
83
+ ),
84
+ "gpt-4o-mini-audio-preview": OpenAIModelConfig(
85
+ display_name="GPT-4 Optimized Mini Audio Preview",
86
+ base_model="gpt-4o-mini-audio-preview",
87
+ input_price=Decimal("0.15"),
88
+ cached_input_price=None,
89
+ output_price=Decimal("0.60"),
90
+ ),
91
+ "gpt-4o-mini-audio-preview-2024-12-17": OpenAIModelConfig(
92
+ display_name="GPT-4 Optimized Mini Audio Preview (2024-12-17)",
93
+ base_model="gpt-4o-mini-audio-preview",
94
+ input_price=Decimal("0.15"),
95
+ cached_input_price=None,
96
+ output_price=Decimal("0.60"),
97
+ ),
98
+ "gpt-4o-mini-realtime-preview": OpenAIModelConfig(
99
+ display_name="GPT-4 Optimized Mini Realtime Preview",
100
+ base_model="gpt-4o-mini-realtime-preview",
101
+ input_price=Decimal("0.60"),
102
+ cached_input_price=Decimal("0.30"),
103
+ output_price=Decimal("2.40"),
104
+ ),
105
+ "gpt-4o-mini-realtime-preview-2024-12-17": OpenAIModelConfig(
106
+ display_name="GPT-4 Optimized Mini Realtime Preview (2024-12-17)",
107
+ base_model="gpt-4o-mini-realtime-preview",
108
+ input_price=Decimal("0.60"),
109
+ cached_input_price=Decimal("0.30"),
110
+ output_price=Decimal("2.40"),
111
+ ),
112
+ "o1": OpenAIModelConfig(
113
+ display_name="O1",
114
+ base_model="o1",
115
+ input_price=Decimal("15.00"),
116
+ cached_input_price=Decimal("7.50"),
117
+ output_price=Decimal("60.00"),
118
+ ),
119
+ "o1-2024-12-17": OpenAIModelConfig(
120
+ display_name="O1 (2024-12-17)",
121
+ base_model="o1",
122
+ input_price=Decimal("15.00"),
123
+ cached_input_price=Decimal("7.50"),
124
+ output_price=Decimal("60.00"),
125
+ ),
126
+ "o3-mini": OpenAIModelConfig(
127
+ display_name="O3 Mini",
128
+ base_model="o3-mini",
129
+ input_price=Decimal("1.10"),
130
+ cached_input_price=Decimal("0.55"),
131
+ output_price=Decimal("4.40"),
132
+ ),
133
+ "o3-mini-2025-01-31": OpenAIModelConfig(
134
+ display_name="O3 Mini (2025-01-31)",
135
+ base_model="o3-mini",
136
+ input_price=Decimal("1.10"),
137
+ cached_input_price=Decimal("0.55"),
138
+ output_price=Decimal("4.40"),
139
+ ),
140
+ "o1-mini": OpenAIModelConfig(
141
+ display_name="O1 Mini",
142
+ base_model="o1-mini",
143
+ input_price=Decimal("1.10"),
144
+ cached_input_price=Decimal("0.55"),
145
+ output_price=Decimal("4.40"),
146
+ ),
147
+ "o1-mini-2024-09-12": OpenAIModelConfig(
148
+ display_name="O1 Mini (2024-09-12)",
149
+ base_model="o1-mini",
150
+ input_price=Decimal("1.10"),
151
+ cached_input_price=Decimal("0.55"),
152
+ output_price=Decimal("4.40"),
153
+ ),
154
+ "gpt-4o-mini-search-preview": OpenAIModelConfig(
155
+ display_name="GPT-4 Optimized Mini Search Preview",
156
+ base_model="gpt-4o-mini-search-preview",
157
+ input_price=Decimal("0.15"),
158
+ cached_input_price=None,
159
+ output_price=Decimal("0.60"),
160
+ ),
161
+ "gpt-4o-mini-search-preview-2025-03-11": OpenAIModelConfig(
162
+ display_name="GPT-4 Optimized Mini Search Preview (2025-03-11)",
163
+ base_model="gpt-4o-mini-search-preview",
164
+ input_price=Decimal("0.15"),
165
+ cached_input_price=None,
166
+ output_price=Decimal("0.60"),
167
+ ),
168
+ "gpt-4o-search-preview": OpenAIModelConfig(
169
+ display_name="GPT-4 Optimized Search Preview",
170
+ base_model="gpt-4o-search-preview",
171
+ input_price=Decimal("2.50"),
172
+ cached_input_price=None,
173
+ output_price=Decimal("10.00"),
174
+ ),
175
+ "gpt-4o-search-preview-2025-03-11": OpenAIModelConfig(
176
+ display_name="GPT-4 Optimized Search Preview (2025-03-11)",
177
+ base_model="gpt-4o-search-preview",
178
+ input_price=Decimal("2.50"),
179
+ cached_input_price=None,
180
+ output_price=Decimal("10.00"),
181
+ ),
182
+ "computer-use-preview": OpenAIModelConfig(
183
+ display_name="Computer Use Preview",
184
+ base_model="computer-use-preview",
185
+ input_price=Decimal("3.00"),
186
+ cached_input_price=None,
187
+ output_price=Decimal("12.00"),
188
+ ),
189
+ "computer-use-preview-2025-03-11": OpenAIModelConfig(
190
+ display_name="Computer Use Preview (2025-03-11)",
191
+ base_model="computer-use-preview",
192
+ input_price=Decimal("3.00"),
193
+ cached_input_price=None,
194
+ output_price=Decimal("12.00"),
195
+ ),
196
+ }
197
+
198
+
199
+ def get_model_config(model_id: str) -> OpenAIModelConfig:
200
+ """Get model configuration by model ID"""
201
+ if model_id not in OPENAI_MODELS:
202
+ raise ValueError(f"Model {model_id} not found in OpenAI models")
203
+ return OPENAI_MODELS[model_id]
204
+
205
+
206
+ def get_default_model() -> str:
207
+ """Get the default model ID"""
208
+ return "gpt-4o"
209
+
210
+
211
+ def calculate_cost(
212
+ model_id: str, input_tokens: int, output_tokens: int, use_cached: bool = False
213
+ ) -> Decimal:
214
+ """Calculate cost for token usage"""
215
+ config = get_model_config(model_id)
216
+ input_price = (
217
+ config.cached_input_price
218
+ if (use_cached and config.cached_input_price is not None)
219
+ else config.input_price
220
+ )
221
+ return (
222
+ input_price * Decimal(str(input_tokens))
223
+ + config.output_price * Decimal(str(output_tokens))
224
+ ) / Decimal("1000")