airtrain 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
airtrain/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  """Airtrain - A platform for building and deploying AI agents with structured skills"""
2
2
 
3
- __version__ = "0.1.12"
3
+ __version__ = "0.1.14"
4
4
 
5
5
  # Core imports
6
6
  from .core.skills import Skill, ProcessingError
@@ -4,9 +4,9 @@ from .agents import (
4
4
  TravelAgentBase,
5
5
  ClothingAgent,
6
6
  HikingAgent,
7
- InternetConnectivityAgent,
8
- FoodRecommendationAgent,
9
- PersonalizedRecommendationAgent,
7
+ InternetAgent,
8
+ FoodAgent,
9
+ PersonalizedAgent,
10
10
  )
11
11
  from .models import (
12
12
  ClothingRecommendation,
@@ -14,8 +14,8 @@ from .models import (
14
14
  InternetAvailability,
15
15
  FoodOption,
16
16
  )
17
- from .agents.verification_agent import UserVerificationAgent
18
- from .models.verification import UserTravelInfo, TravelCompanion, HealthCondition
17
+ from .agentlib.verification_agent import UserVerificationAgent
18
+ from .modellib.verification import UserTravelInfo, TravelCompanion, HealthCondition
19
19
 
20
20
  __all__ = [
21
21
  "TravelAgentBase",
@@ -0,0 +1,11 @@
1
+ """Fireworks AI integration module"""
2
+
3
+ from .credentials import FireworksCredentials
4
+ from .skills import FireworksChatSkill, FireworksInput, FireworksOutput
5
+
6
+ __all__ = [
7
+ "FireworksCredentials",
8
+ "FireworksChatSkill",
9
+ "FireworksInput",
10
+ "FireworksOutput",
11
+ ]
@@ -0,0 +1,18 @@
1
+ from pydantic import SecretStr, BaseModel
2
+ from typing import Optional
3
+ import os
4
+
5
+
6
+ class FireworksCredentials(BaseModel):
7
+ """Credentials for Fireworks AI API"""
8
+
9
+ fireworks_api_key: SecretStr
10
+
11
+ @classmethod
12
+ def from_env(cls) -> "FireworksCredentials":
13
+ """Create credentials from environment variables"""
14
+ api_key = os.getenv("FIREWORKS_API_KEY")
15
+ if not api_key:
16
+ raise ValueError("FIREWORKS_API_KEY environment variable not set")
17
+
18
+ return cls(fireworks_api_key=api_key)
@@ -0,0 +1,27 @@
1
+ from typing import List, Optional, Dict, Any
2
+ from pydantic import Field, BaseModel
3
+
4
+
5
+ class FireworksMessage(BaseModel):
6
+ """Schema for Fireworks chat message"""
7
+
8
+ content: str
9
+ role: str = Field(..., pattern="^(system|user|assistant)$")
10
+
11
+
12
+ class FireworksUsage(BaseModel):
13
+ """Schema for Fireworks API usage statistics"""
14
+
15
+ prompt_tokens: int
16
+ completion_tokens: int
17
+ total_tokens: int
18
+
19
+
20
+ class FireworksResponse(BaseModel):
21
+ """Schema for Fireworks API response"""
22
+
23
+ id: str
24
+ choices: List[Dict[str, Any]]
25
+ created: int
26
+ model: str
27
+ usage: FireworksUsage
@@ -0,0 +1,107 @@
1
+ from typing import List, Optional, Dict, Any
2
+ from pydantic import Field
3
+ import requests
4
+ from loguru import logger
5
+
6
+ from airtrain.core.skills import Skill, ProcessingError
7
+ from airtrain.core.schemas import InputSchema, OutputSchema
8
+ from .credentials import FireworksCredentials
9
+ from .models import FireworksMessage, FireworksResponse
10
+
11
+
12
+ class FireworksInput(InputSchema):
13
+ """Schema for Fireworks AI chat input"""
14
+
15
+ user_input: str = Field(..., description="User's input text")
16
+ system_prompt: str = Field(
17
+ default="You are a helpful assistant.",
18
+ description="System prompt to guide the model's behavior",
19
+ )
20
+ model: str = Field(
21
+ default="accounts/fireworks/models/deepseek-r1",
22
+ description="Fireworks AI model to use",
23
+ )
24
+ temperature: float = Field(
25
+ default=0.7, description="Temperature for response generation", ge=0, le=1
26
+ )
27
+ max_tokens: Optional[int] = Field(
28
+ default=None, description="Maximum tokens in response"
29
+ )
30
+ context_length_exceeded_behavior: str = Field(
31
+ default="truncate", description="Behavior when context length is exceeded"
32
+ )
33
+
34
+
35
+ class FireworksOutput(OutputSchema):
36
+ """Schema for Fireworks AI output"""
37
+
38
+ response: str = Field(..., description="Model's response text")
39
+ used_model: str = Field(..., description="Model used for generation")
40
+ usage: Dict[str, int] = Field(default_factory=dict, description="Usage statistics")
41
+
42
+
43
+ class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
44
+ """Skill for interacting with Fireworks AI models"""
45
+
46
+ input_schema = FireworksInput
47
+ output_schema = FireworksOutput
48
+
49
+ def __init__(self, credentials: Optional[FireworksCredentials] = None):
50
+ """Initialize the skill with optional credentials"""
51
+ super().__init__()
52
+ self.credentials = credentials or FireworksCredentials.from_env()
53
+ self.base_url = "https://api.fireworks.ai/inference/v1"
54
+
55
+ def process(self, input_data: FireworksInput) -> FireworksOutput:
56
+ """Process the input using Fireworks AI API"""
57
+ try:
58
+ logger.info(f"Processing request with model {input_data.model}")
59
+
60
+ # Prepare messages
61
+ messages = [
62
+ {"role": "system", "content": input_data.system_prompt},
63
+ {"role": "user", "content": input_data.user_input},
64
+ ]
65
+
66
+ # Prepare request payload
67
+ payload = {
68
+ "messages": messages,
69
+ "model": input_data.model,
70
+ "context_length_exceeded_behavior": input_data.context_length_exceeded_behavior,
71
+ "temperature": input_data.temperature,
72
+ "n": 1,
73
+ "response_format": {"type": "text"},
74
+ "stream": False,
75
+ }
76
+
77
+ if input_data.max_tokens:
78
+ payload["max_tokens"] = input_data.max_tokens
79
+
80
+ # Make API request
81
+ response = requests.post(
82
+ f"{self.base_url}/chat/completions",
83
+ json=payload,
84
+ headers={
85
+ "Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
86
+ "Content-Type": "application/json",
87
+ },
88
+ )
89
+
90
+ response.raise_for_status()
91
+ response_data = FireworksResponse(**response.json())
92
+
93
+ logger.success("Successfully processed Fireworks AI request")
94
+
95
+ return FireworksOutput(
96
+ response=response_data.choices[0]["message"]["content"],
97
+ used_model=response_data.model,
98
+ usage={
99
+ "prompt_tokens": response_data.usage.prompt_tokens,
100
+ "completion_tokens": response_data.usage.completion_tokens,
101
+ "total_tokens": response_data.usage.total_tokens,
102
+ },
103
+ )
104
+
105
+ except Exception as e:
106
+ logger.exception(f"Fireworks AI processing failed: {str(e)}")
107
+ raise ProcessingError(f"Fireworks AI processing failed: {str(e)}")
@@ -0,0 +1,119 @@
1
+ from typing import Dict, NamedTuple, Optional
2
+ from decimal import Decimal
3
+
4
+
5
+ class OpenAIModelConfig(NamedTuple):
6
+ display_name: str
7
+ base_model: str
8
+ input_price: Decimal
9
+ cached_input_price: Optional[Decimal]
10
+ output_price: Decimal
11
+
12
+
13
+ OPENAI_MODELS: Dict[str, OpenAIModelConfig] = {
14
+ "gpt-4o": OpenAIModelConfig(
15
+ display_name="GPT-4 Optimized",
16
+ base_model="gpt-4o",
17
+ input_price=Decimal("2.50"),
18
+ cached_input_price=Decimal("1.25"),
19
+ output_price=Decimal("10.00"),
20
+ ),
21
+ "gpt-4o-2024-08-06": OpenAIModelConfig(
22
+ display_name="GPT-4 Optimized (2024-08-06)",
23
+ base_model="gpt-4o",
24
+ input_price=Decimal("2.50"),
25
+ cached_input_price=Decimal("1.25"),
26
+ output_price=Decimal("10.00"),
27
+ ),
28
+ "gpt-4o-2024-05-13": OpenAIModelConfig(
29
+ display_name="GPT-4 Optimized (2024-05-13)",
30
+ base_model="gpt-4o",
31
+ input_price=Decimal("5.00"),
32
+ cached_input_price=None,
33
+ output_price=Decimal("15.00"),
34
+ ),
35
+ "gpt-4o-audio-preview-2024-12-17": OpenAIModelConfig(
36
+ display_name="GPT-4 Optimized Audio Preview",
37
+ base_model="gpt-4o-audio-preview",
38
+ input_price=Decimal("2.50"),
39
+ cached_input_price=None,
40
+ output_price=Decimal("10.00"),
41
+ ),
42
+ "gpt-4o-realtime-preview-2024-12-17": OpenAIModelConfig(
43
+ display_name="GPT-4 Optimized Realtime Preview",
44
+ base_model="gpt-4o-realtime-preview",
45
+ input_price=Decimal("5.00"),
46
+ cached_input_price=Decimal("2.50"),
47
+ output_price=Decimal("20.00"),
48
+ ),
49
+ "gpt-4o-mini-2024-07-18": OpenAIModelConfig(
50
+ display_name="GPT-4 Optimized Mini",
51
+ base_model="gpt-4o-mini",
52
+ input_price=Decimal("0.15"),
53
+ cached_input_price=Decimal("0.075"),
54
+ output_price=Decimal("0.60"),
55
+ ),
56
+ "gpt-4o-mini-audio-preview-2024-12-17": OpenAIModelConfig(
57
+ display_name="GPT-4 Optimized Mini Audio Preview",
58
+ base_model="gpt-4o-mini-audio-preview",
59
+ input_price=Decimal("0.15"),
60
+ cached_input_price=None,
61
+ output_price=Decimal("0.60"),
62
+ ),
63
+ "gpt-4o-mini-realtime-preview-2024-12-17": OpenAIModelConfig(
64
+ display_name="GPT-4 Optimized Mini Realtime Preview",
65
+ base_model="gpt-4o-mini-realtime-preview",
66
+ input_price=Decimal("0.60"),
67
+ cached_input_price=Decimal("0.30"),
68
+ output_price=Decimal("2.40"),
69
+ ),
70
+ "o1-2024-12-17": OpenAIModelConfig(
71
+ display_name="O1",
72
+ base_model="o1",
73
+ input_price=Decimal("15.00"),
74
+ cached_input_price=Decimal("7.50"),
75
+ output_price=Decimal("60.00"),
76
+ ),
77
+ "o3-mini-2025-01-31": OpenAIModelConfig(
78
+ display_name="O3 Mini",
79
+ base_model="o3-mini",
80
+ input_price=Decimal("1.10"),
81
+ cached_input_price=Decimal("0.55"),
82
+ output_price=Decimal("4.40"),
83
+ ),
84
+ "o1-mini-2024-09-12": OpenAIModelConfig(
85
+ display_name="O1 Mini",
86
+ base_model="o1-mini",
87
+ input_price=Decimal("1.10"),
88
+ cached_input_price=Decimal("0.55"),
89
+ output_price=Decimal("4.40"),
90
+ ),
91
+ }
92
+
93
+
94
+ def get_model_config(model_id: str) -> OpenAIModelConfig:
95
+ """Get model configuration by model ID"""
96
+ if model_id not in OPENAI_MODELS:
97
+ raise ValueError(f"Model {model_id} not found in OpenAI models")
98
+ return OPENAI_MODELS[model_id]
99
+
100
+
101
+ def get_default_model() -> str:
102
+ """Get the default model ID"""
103
+ return "gpt-4o"
104
+
105
+
106
+ def calculate_cost(
107
+ model_id: str, input_tokens: int, output_tokens: int, use_cached: bool = False
108
+ ) -> Decimal:
109
+ """Calculate cost for token usage"""
110
+ config = get_model_config(model_id)
111
+ input_price = (
112
+ config.cached_input_price
113
+ if (use_cached and config.cached_input_price is not None)
114
+ else config.input_price
115
+ )
116
+ return (
117
+ input_price * Decimal(str(input_tokens))
118
+ + config.output_price * Decimal(str(output_tokens))
119
+ ) / Decimal("1000")
@@ -0,0 +1,34 @@
1
+ from typing import Dict, NamedTuple
2
+
3
+
4
+ class AudioModelConfig(NamedTuple):
5
+ organization: str
6
+ display_name: str
7
+
8
+
9
+ TOGETHER_AUDIO_MODELS: Dict[str, AudioModelConfig] = {
10
+ "Cartesia/Sonic": AudioModelConfig(
11
+ organization="Cartesia", display_name="Cartesia/Sonic"
12
+ )
13
+ }
14
+
15
+
16
+ def get_audio_model_config(model_id: str) -> AudioModelConfig:
17
+ """Get audio model configuration by model ID"""
18
+ if model_id not in TOGETHER_AUDIO_MODELS:
19
+ raise ValueError(f"Model {model_id} not found in Together AI audio models")
20
+ return TOGETHER_AUDIO_MODELS[model_id]
21
+
22
+
23
+ def list_audio_models_by_organization(organization: str) -> Dict[str, AudioModelConfig]:
24
+ """Get all audio models for a specific organization"""
25
+ return {
26
+ model_id: config
27
+ for model_id, config in TOGETHER_AUDIO_MODELS.items()
28
+ if config.organization.lower() == organization.lower()
29
+ }
30
+
31
+
32
+ def get_default_audio_model() -> str:
33
+ """Get the default audio model ID"""
34
+ return "Cartesia/Sonic"
@@ -6,14 +6,14 @@ import together
6
6
  class TogetherAICredentials(BaseCredentials):
7
7
  """Together AI credentials"""
8
8
 
9
- api_key: SecretStr = Field(..., description="Together AI API key")
9
+ together_api_key: SecretStr = Field(..., description="Together AI API key")
10
10
 
11
- _required_credentials = {"api_key"}
11
+ _required_credentials = {"together_api_key"}
12
12
 
13
13
  async def validate_credentials(self) -> bool:
14
14
  """Validate Together AI credentials"""
15
15
  try:
16
- together.api_key = self.api_key.get_secret_value()
16
+ together.api_key = self.together_api_key.get_secret_value()
17
17
  await together.Models.list()
18
18
  return True
19
19
  except Exception as e:
@@ -0,0 +1,92 @@
1
+ from typing import Dict, NamedTuple
2
+
3
+
4
+ class EmbeddingModelConfig(NamedTuple):
5
+ organization: str
6
+ display_name: str
7
+ model_size: str
8
+ embedding_dimension: int
9
+ context_window: int
10
+
11
+
12
+ TOGETHER_EMBEDDING_MODELS: Dict[str, EmbeddingModelConfig] = {
13
+ "togethercomputer/m2-bert-80M-2k-retrieval": EmbeddingModelConfig(
14
+ organization="Together",
15
+ display_name="M2-BERT-80M-2K-Retrieval",
16
+ model_size="80M",
17
+ embedding_dimension=768,
18
+ context_window=2048,
19
+ ),
20
+ "togethercomputer/m2-bert-80M-8k-retrieval": EmbeddingModelConfig(
21
+ organization="Together",
22
+ display_name="M2-BERT-80M-8K-Retrieval",
23
+ model_size="80M",
24
+ embedding_dimension=768,
25
+ context_window=8192,
26
+ ),
27
+ "togethercomputer/m2-bert-80M-32k-retrieval": EmbeddingModelConfig(
28
+ organization="Together",
29
+ display_name="M2-BERT-80M-32K-Retrieval",
30
+ model_size="80M",
31
+ embedding_dimension=768,
32
+ context_window=32768,
33
+ ),
34
+ "WhereIsAI/UAE-Large-V1": EmbeddingModelConfig(
35
+ organization="WhereIsAI",
36
+ display_name="UAE-Large-v1",
37
+ model_size="326M",
38
+ embedding_dimension=1024,
39
+ context_window=512,
40
+ ),
41
+ "BAAI/bge-large-en-v1.5": EmbeddingModelConfig(
42
+ organization="BAAI",
43
+ display_name="BGE-Large-EN-v1.5",
44
+ model_size="326M",
45
+ embedding_dimension=1024,
46
+ context_window=512,
47
+ ),
48
+ "BAAI/bge-base-en-v1.5": EmbeddingModelConfig(
49
+ organization="BAAI",
50
+ display_name="BGE-Base-EN-v1.5",
51
+ model_size="102M",
52
+ embedding_dimension=768,
53
+ context_window=512,
54
+ ),
55
+ "sentence-transformers/msmarco-bert-base-dot-v5": EmbeddingModelConfig(
56
+ organization="sentence-transformers",
57
+ display_name="Sentence-BERT",
58
+ model_size="110M",
59
+ embedding_dimension=768,
60
+ context_window=512,
61
+ ),
62
+ "bert-base-uncased": EmbeddingModelConfig(
63
+ organization="Hugging Face",
64
+ display_name="BERT",
65
+ model_size="110M",
66
+ embedding_dimension=768,
67
+ context_window=512,
68
+ ),
69
+ }
70
+
71
+
72
+ def get_embedding_model_config(model_id: str) -> EmbeddingModelConfig:
73
+ """Get embedding model configuration by model ID"""
74
+ if model_id not in TOGETHER_EMBEDDING_MODELS:
75
+ raise ValueError(f"Model {model_id} not found in Together AI embedding models")
76
+ return TOGETHER_EMBEDDING_MODELS[model_id]
77
+
78
+
79
+ def list_embedding_models_by_organization(
80
+ organization: str,
81
+ ) -> Dict[str, EmbeddingModelConfig]:
82
+ """Get all embedding models for a specific organization"""
83
+ return {
84
+ model_id: config
85
+ for model_id, config in TOGETHER_EMBEDDING_MODELS.items()
86
+ if config.organization.lower() == organization.lower()
87
+ }
88
+
89
+
90
+ def get_default_embedding_model() -> str:
91
+ """Get the default embedding model ID"""
92
+ return "togethercomputer/m2-bert-80M-32k-retrieval"
@@ -0,0 +1,69 @@
1
+ from typing import Dict, NamedTuple, Optional
2
+
3
+
4
+ class ImageModelConfig(NamedTuple):
5
+ organization: str
6
+ display_name: str
7
+ default_steps: Optional[int]
8
+
9
+
10
+ TOGETHER_IMAGE_MODELS: Dict[str, ImageModelConfig] = {
11
+ "black-forest-labs/FLUX.1-schnell-Free": ImageModelConfig(
12
+ organization="Black Forest Labs",
13
+ display_name="Flux.1 [schnell] (free)",
14
+ default_steps=None,
15
+ ),
16
+ "black-forest-labs/FLUX.1-schnell": ImageModelConfig(
17
+ organization="Black Forest Labs",
18
+ display_name="Flux.1 [schnell] (Turbo)",
19
+ default_steps=4,
20
+ ),
21
+ "black-forest-labs/FLUX.1-dev": ImageModelConfig(
22
+ organization="Black Forest Labs", display_name="Flux.1 Dev", default_steps=28
23
+ ),
24
+ "black-forest-labs/FLUX.1-canny": ImageModelConfig(
25
+ organization="Black Forest Labs", display_name="Flux.1 Canny", default_steps=28
26
+ ),
27
+ "black-forest-labs/FLUX.1-depth": ImageModelConfig(
28
+ organization="Black Forest Labs", display_name="Flux.1 Depth", default_steps=28
29
+ ),
30
+ "black-forest-labs/FLUX.1-redux": ImageModelConfig(
31
+ organization="Black Forest Labs", display_name="Flux.1 Redux", default_steps=28
32
+ ),
33
+ "black-forest-labs/FLUX.1.1-pro": ImageModelConfig(
34
+ organization="Black Forest Labs",
35
+ display_name="Flux1.1 [pro]",
36
+ default_steps=None,
37
+ ),
38
+ "black-forest-labs/FLUX.1-pro": ImageModelConfig(
39
+ organization="Black Forest Labs",
40
+ display_name="Flux.1 [pro]",
41
+ default_steps=None,
42
+ ),
43
+ "stabilityai/stable-diffusion-xl-base-1.0": ImageModelConfig(
44
+ organization="Stability AI",
45
+ display_name="Stable Diffusion XL 1.0",
46
+ default_steps=None,
47
+ ),
48
+ }
49
+
50
+
51
+ def get_image_model_config(model_id: str) -> ImageModelConfig:
52
+ """Get image model configuration by model ID"""
53
+ if model_id not in TOGETHER_IMAGE_MODELS:
54
+ raise ValueError(f"Model {model_id} not found in Together AI image models")
55
+ return TOGETHER_IMAGE_MODELS[model_id]
56
+
57
+
58
+ def list_image_models_by_organization(organization: str) -> Dict[str, ImageModelConfig]:
59
+ """Get all image models for a specific organization"""
60
+ return {
61
+ model_id: config
62
+ for model_id, config in TOGETHER_IMAGE_MODELS.items()
63
+ if config.organization.lower() == organization.lower()
64
+ }
65
+
66
+
67
+ def get_default_image_model() -> str:
68
+ """Get the default image model ID"""
69
+ return "black-forest-labs/FLUX.1-schnell-Free"