airtrain 0.1.12__tar.gz → 0.1.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. {airtrain-0.1.12/airtrain.egg-info → airtrain-0.1.14}/PKG-INFO +1 -1
  2. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/__init__.py +1 -1
  3. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/contrib/travel/__init__.py +5 -5
  4. airtrain-0.1.14/airtrain/contrib/travel/agentlib/verification_agent.py +96 -0
  5. airtrain-0.1.14/airtrain/contrib/travel/modellib/verification.py +32 -0
  6. airtrain-0.1.14/airtrain/integrations/fireworks/__init__.py +11 -0
  7. airtrain-0.1.14/airtrain/integrations/fireworks/credentials.py +18 -0
  8. airtrain-0.1.14/airtrain/integrations/fireworks/models.py +27 -0
  9. airtrain-0.1.14/airtrain/integrations/fireworks/skills.py +107 -0
  10. airtrain-0.1.14/airtrain/integrations/openai/models_config.py +119 -0
  11. airtrain-0.1.14/airtrain/integrations/together/audio_models_config.py +34 -0
  12. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/together/credentials.py +3 -3
  13. airtrain-0.1.14/airtrain/integrations/together/embedding_models_config.py +92 -0
  14. airtrain-0.1.14/airtrain/integrations/together/image_models_config.py +69 -0
  15. airtrain-0.1.14/airtrain/integrations/together/image_skill.py +171 -0
  16. airtrain-0.1.14/airtrain/integrations/together/models.py +56 -0
  17. airtrain-0.1.14/airtrain/integrations/together/models_config.py +277 -0
  18. airtrain-0.1.14/airtrain/integrations/together/rerank_models_config.py +43 -0
  19. airtrain-0.1.14/airtrain/integrations/together/rerank_skill.py +49 -0
  20. airtrain-0.1.14/airtrain/integrations/together/schemas.py +33 -0
  21. airtrain-0.1.14/airtrain/integrations/together/skills.py +129 -0
  22. airtrain-0.1.14/airtrain/integrations/together/vision_models_config.py +49 -0
  23. {airtrain-0.1.12 → airtrain-0.1.14/airtrain.egg-info}/PKG-INFO +1 -1
  24. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain.egg-info/SOURCES.txt +23 -0
  25. airtrain-0.1.14/examples/creating-skills/fireworks_skills_usage.py +69 -0
  26. airtrain-0.1.14/examples/creating-skills/together_rerank_skills.py +58 -0
  27. airtrain-0.1.14/examples/creating-skills/together_rerank_skills_async.py +1 -0
  28. airtrain-0.1.14/examples/together/image_generation.py +64 -0
  29. airtrain-0.1.14/examples/together/image_generation_example.py +81 -0
  30. airtrain-0.1.14/examples/travel/verification_agent_usage.py +104 -0
  31. airtrain-0.1.12/airtrain/integrations/together/skills.py +0 -43
  32. {airtrain-0.1.12 → airtrain-0.1.14}/.flake8 +0 -0
  33. {airtrain-0.1.12 → airtrain-0.1.14}/.github/workflows/publish.yml +0 -0
  34. {airtrain-0.1.12 → airtrain-0.1.14}/.gitignore +0 -0
  35. {airtrain-0.1.12 → airtrain-0.1.14}/.mypy.ini +0 -0
  36. {airtrain-0.1.12 → airtrain-0.1.14}/.pre-commit-config.yaml +0 -0
  37. {airtrain-0.1.12 → airtrain-0.1.14}/.vscode/extensions.json +0 -0
  38. {airtrain-0.1.12 → airtrain-0.1.14}/.vscode/launch.json +0 -0
  39. {airtrain-0.1.12 → airtrain-0.1.14}/.vscode/settings.json +0 -0
  40. {airtrain-0.1.12 → airtrain-0.1.14}/EXPERIMENTS/integrations_examples/anthropic_with_image.py +0 -0
  41. {airtrain-0.1.12 → airtrain-0.1.14}/EXPERIMENTS/schema_exps/pydantic_schemas.py +0 -0
  42. {airtrain-0.1.12 → airtrain-0.1.14}/README.md +0 -0
  43. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/agents/travel/agents.py +0 -0
  44. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/agents/travel/models.py +0 -0
  45. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/contrib/__init__.py +0 -0
  46. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/contrib/travel/agents.py +0 -0
  47. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/contrib/travel/models.py +0 -0
  48. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/core/__init__.py +0 -0
  49. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/core/__pycache__/credentials.cpython-310.pyc +0 -0
  50. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/core/__pycache__/schemas.cpython-310.pyc +0 -0
  51. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/core/__pycache__/skills.cpython-310.pyc +0 -0
  52. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/core/credentials.py +0 -0
  53. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/core/schemas.py +0 -0
  54. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/core/skills.py +0 -0
  55. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/__init__.py +0 -0
  56. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/anthropic/__init__.py +0 -0
  57. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/anthropic/credentials.py +0 -0
  58. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/anthropic/skills.py +0 -0
  59. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/aws/__init__.py +0 -0
  60. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/aws/credentials.py +0 -0
  61. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/aws/skills.py +0 -0
  62. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/cerebras/__init__.py +0 -0
  63. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/cerebras/credentials.py +0 -0
  64. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/cerebras/skills.py +0 -0
  65. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/google/__init__.py +0 -0
  66. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/google/credentials.py +0 -0
  67. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/google/skills.py +0 -0
  68. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/groq/__init__.py +0 -0
  69. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/groq/credentials.py +0 -0
  70. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/groq/skills.py +0 -0
  71. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/ollama/__init__.py +0 -0
  72. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/ollama/credentials.py +0 -0
  73. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/ollama/skills.py +0 -0
  74. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/openai/__init__.py +0 -0
  75. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/openai/chinese_assistant.py +0 -0
  76. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/openai/credentials.py +0 -0
  77. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/openai/skills.py +0 -0
  78. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/sambanova/__init__.py +0 -0
  79. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/sambanova/credentials.py +0 -0
  80. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/sambanova/skills.py +0 -0
  81. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain/integrations/together/__init__.py +0 -0
  82. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain.egg-info/dependency_links.txt +0 -0
  83. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain.egg-info/requires.txt +0 -0
  84. {airtrain-0.1.12 → airtrain-0.1.14}/airtrain.egg-info/top_level.txt +0 -0
  85. {airtrain-0.1.12 → airtrain-0.1.14}/examples/creating-skills/anthropic_skills_usage.py +0 -0
  86. {airtrain-0.1.12 → airtrain-0.1.14}/examples/creating-skills/chinese_anthropic_assistant.py +0 -0
  87. {airtrain-0.1.12 → airtrain-0.1.14}/examples/creating-skills/chinese_anthropic_usage.py +0 -0
  88. {airtrain-0.1.12 → airtrain-0.1.14}/examples/creating-skills/chinese_assistant_usage.py +0 -0
  89. {airtrain-0.1.12 → airtrain-0.1.14}/examples/creating-skills/icon128.png +0 -0
  90. {airtrain-0.1.12 → airtrain-0.1.14}/examples/creating-skills/icon16.png +0 -0
  91. {airtrain-0.1.12 → airtrain-0.1.14}/examples/creating-skills/image1.jpg +0 -0
  92. {airtrain-0.1.12 → airtrain-0.1.14}/examples/creating-skills/image2.jpg +0 -0
  93. {airtrain-0.1.12 → airtrain-0.1.14}/examples/creating-skills/openai_skills.py +0 -0
  94. {airtrain-0.1.12 → airtrain-0.1.14}/examples/creating-skills/openai_skills_usage.py +0 -0
  95. {airtrain-0.1.12 → airtrain-0.1.14}/examples/creating-skills/openai_structured_skills.py +0 -0
  96. {airtrain-0.1.12 → airtrain-0.1.14}/examples/credentials_usage.py +0 -0
  97. {airtrain-0.1.12 → airtrain-0.1.14}/examples/images/quantum-circuit.png +0 -0
  98. {airtrain-0.1.12 → airtrain-0.1.14}/examples/schema_usage.py +0 -0
  99. {airtrain-0.1.12 → airtrain-0.1.14}/examples/skill_usage.py +0 -0
  100. {airtrain-0.1.12 → airtrain-0.1.14}/pyproject.toml +0 -0
  101. {airtrain-0.1.12 → airtrain-0.1.14}/requirements.txt +0 -0
  102. {airtrain-0.1.12 → airtrain-0.1.14}/scripts/build.sh +0 -0
  103. {airtrain-0.1.12 → airtrain-0.1.14}/scripts/bump_version.py +0 -0
  104. {airtrain-0.1.12 → airtrain-0.1.14}/scripts/publish.sh +0 -0
  105. {airtrain-0.1.12 → airtrain-0.1.14}/scripts/release.py +0 -0
  106. {airtrain-0.1.12 → airtrain-0.1.14}/services/firebase_service.py +0 -0
  107. {airtrain-0.1.12 → airtrain-0.1.14}/services/openai_service.py +0 -0
  108. {airtrain-0.1.12 → airtrain-0.1.14}/setup.cfg +0 -0
  109. {airtrain-0.1.12 → airtrain-0.1.14}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: airtrain
3
- Version: 0.1.12
3
+ Version: 0.1.14
4
4
  Summary: A platform for building and deploying AI agents with structured skills
5
5
  Home-page: https://github.com/rosaboyle/airtrain.dev
6
6
  Author: Dheeraj Pai
@@ -1,6 +1,6 @@
1
1
  """Airtrain - A platform for building and deploying AI agents with structured skills"""
2
2
 
3
- __version__ = "0.1.12"
3
+ __version__ = "0.1.14"
4
4
 
5
5
  # Core imports
6
6
  from .core.skills import Skill, ProcessingError
@@ -4,9 +4,9 @@ from .agents import (
4
4
  TravelAgentBase,
5
5
  ClothingAgent,
6
6
  HikingAgent,
7
- InternetConnectivityAgent,
8
- FoodRecommendationAgent,
9
- PersonalizedRecommendationAgent,
7
+ InternetAgent,
8
+ FoodAgent,
9
+ PersonalizedAgent,
10
10
  )
11
11
  from .models import (
12
12
  ClothingRecommendation,
@@ -14,8 +14,8 @@ from .models import (
14
14
  InternetAvailability,
15
15
  FoodOption,
16
16
  )
17
- from .agents.verification_agent import UserVerificationAgent
18
- from .models.verification import UserTravelInfo, TravelCompanion, HealthCondition
17
+ from .agentlib.verification_agent import UserVerificationAgent
18
+ from .modellib.verification import UserTravelInfo, TravelCompanion, HealthCondition
19
19
 
20
20
  __all__ = [
21
21
  "TravelAgentBase",
@@ -0,0 +1,96 @@
1
+ from typing import Optional, List, Tuple
2
+ from datetime import date
3
+ from pydantic import Field
4
+ from airtrain.core.skills import Skill, ProcessingError
5
+ from airtrain.integrations.openai.skills import OpenAIParserSkill, OpenAIParserInput
6
+ from ..modellib.verification import UserTravelInfo
7
+ from pydantic import BaseModel
8
+
9
+
10
+ class VerificationInput(OpenAIParserInput):
11
+ conversation_history: List[str] = Field(
12
+ default_factory=list, description="History of conversation with user"
13
+ )
14
+ ask_followup: bool = Field(
15
+ default=True, description="Whether to ask follow-up questions"
16
+ )
17
+ followup_question: Optional[str] = Field(
18
+ default=None, description="Specific follow-up question to ask"
19
+ )
20
+
21
+
22
+ class VerificationOutput(BaseModel):
23
+ travel_info: UserTravelInfo
24
+ needs_followup: bool
25
+ next_question: Optional[str] = None
26
+ missing_fields: List[str] = Field(default_factory=list)
27
+
28
+
29
+ class UserVerificationAgent(OpenAIParserSkill):
30
+ """Agent for verifying and collecting user travel information"""
31
+
32
+ input_schema = VerificationInput
33
+ output_schema = VerificationOutput
34
+
35
+ def __init__(self, credentials=None):
36
+ super().__init__(credentials)
37
+ self.model = "gpt-4o"
38
+ self.temperature = 0.2
39
+
40
+ def process(self, input_data: VerificationInput) -> VerificationOutput:
41
+ system_prompt = """
42
+ You are a travel information verification agent. Your role is to:
43
+ 1. Extract travel information from the conversation
44
+ 2. Identify missing required information
45
+ 3. Generate appropriate follow-up questions
46
+ 4. Ensure all necessary details are collected for safe travel planning
47
+
48
+ Required fields:
49
+ - Origin location
50
+ - Destination
51
+ - Travel dates
52
+ - Companions (if any)
53
+ - Preferred outdoor activities
54
+ - Health conditions (if any)
55
+
56
+ Provide structured output and indicate if follow-up questions are needed.
57
+ """
58
+
59
+ # Combine conversation history into a single string
60
+ conversation = "\n".join(input_data.conversation_history)
61
+
62
+ # Add follow-up question if present
63
+ if input_data.followup_question:
64
+ conversation += f"\nFollow-up question: {input_data.followup_question}"
65
+
66
+ input_data = OpenAIParserInput(
67
+ user_input=conversation,
68
+ system_prompt=system_prompt,
69
+ response_model=VerificationOutput,
70
+ model=self.model,
71
+ temperature=self.temperature,
72
+ )
73
+
74
+ try:
75
+ result = self.process(input_data)
76
+ return result.parsed_response
77
+ except Exception as e:
78
+ raise ProcessingError(f"Failed to process verification: {str(e)}")
79
+
80
+ def get_next_question(self, missing_fields: List[str]) -> str:
81
+ """Generate appropriate follow-up question based on missing fields"""
82
+ questions = {
83
+ "origin": "What is your starting location?",
84
+ "destination": "Where would you like to travel to?",
85
+ "start_date": "When do you plan to start your journey?",
86
+ "end_date": "When do you plan to return?",
87
+ "companions": "Will anyone be traveling with you (children, pets, other adults)?",
88
+ "outdoor_activities": "What types of outdoor activities are you interested in?",
89
+ "health_conditions": "Do you or your companions have any health conditions we should be aware of?",
90
+ }
91
+
92
+ for field in missing_fields:
93
+ if field in questions:
94
+ return questions[field]
95
+
96
+ return "Is there anything else you'd like to share about your travel plans?"
@@ -0,0 +1,32 @@
1
+ from typing import Optional, List, Dict
2
+ from pydantic import BaseModel, Field
3
+ from datetime import date
4
+
5
+
6
+ class TravelCompanion(BaseModel):
7
+ type: str = Field(..., description="Type of companion (kid/pet/adult)")
8
+ count: int = Field(..., description="Number of companions of this type")
9
+ details: Optional[Dict[str, str]] = Field(
10
+ default=None, description="Additional details like ages, special needs"
11
+ )
12
+
13
+
14
+ class HealthCondition(BaseModel):
15
+ condition: str = Field(..., description="Name of health condition")
16
+ severity: str = Field(..., description="Severity level (mild/moderate/severe)")
17
+ requirements: List[str] = Field(
18
+ ..., description="Special requirements or precautions"
19
+ )
20
+
21
+
22
+ class UserTravelInfo(BaseModel):
23
+ origin: str = Field(..., description="Starting location")
24
+ destination: str = Field(..., description="Travel destination")
25
+ start_date: date = Field(..., description="Travel start date")
26
+ end_date: date = Field(..., description="Travel end date")
27
+ companions: List[TravelCompanion] = Field(default_factory=list)
28
+ outdoor_activities: List[str] = Field(default_factory=list)
29
+ health_conditions: List[HealthCondition] = Field(default_factory=list)
30
+ complete: bool = Field(
31
+ default=False, description="Whether all required info is collected"
32
+ )
@@ -0,0 +1,11 @@
1
+ """Fireworks AI integration module"""
2
+
3
+ from .credentials import FireworksCredentials
4
+ from .skills import FireworksChatSkill, FireworksInput, FireworksOutput
5
+
6
+ __all__ = [
7
+ "FireworksCredentials",
8
+ "FireworksChatSkill",
9
+ "FireworksInput",
10
+ "FireworksOutput",
11
+ ]
@@ -0,0 +1,18 @@
1
+ from pydantic import SecretStr, BaseModel
2
+ from typing import Optional
3
+ import os
4
+
5
+
6
+ class FireworksCredentials(BaseModel):
7
+ """Credentials for Fireworks AI API"""
8
+
9
+ fireworks_api_key: SecretStr
10
+
11
+ @classmethod
12
+ def from_env(cls) -> "FireworksCredentials":
13
+ """Create credentials from environment variables"""
14
+ api_key = os.getenv("FIREWORKS_API_KEY")
15
+ if not api_key:
16
+ raise ValueError("FIREWORKS_API_KEY environment variable not set")
17
+
18
+ return cls(fireworks_api_key=api_key)
@@ -0,0 +1,27 @@
1
+ from typing import List, Optional, Dict, Any
2
+ from pydantic import Field, BaseModel
3
+
4
+
5
+ class FireworksMessage(BaseModel):
6
+ """Schema for Fireworks chat message"""
7
+
8
+ content: str
9
+ role: str = Field(..., pattern="^(system|user|assistant)$")
10
+
11
+
12
+ class FireworksUsage(BaseModel):
13
+ """Schema for Fireworks API usage statistics"""
14
+
15
+ prompt_tokens: int
16
+ completion_tokens: int
17
+ total_tokens: int
18
+
19
+
20
+ class FireworksResponse(BaseModel):
21
+ """Schema for Fireworks API response"""
22
+
23
+ id: str
24
+ choices: List[Dict[str, Any]]
25
+ created: int
26
+ model: str
27
+ usage: FireworksUsage
@@ -0,0 +1,107 @@
1
+ from typing import List, Optional, Dict, Any
2
+ from pydantic import Field
3
+ import requests
4
+ from loguru import logger
5
+
6
+ from airtrain.core.skills import Skill, ProcessingError
7
+ from airtrain.core.schemas import InputSchema, OutputSchema
8
+ from .credentials import FireworksCredentials
9
+ from .models import FireworksMessage, FireworksResponse
10
+
11
+
12
+ class FireworksInput(InputSchema):
13
+ """Schema for Fireworks AI chat input"""
14
+
15
+ user_input: str = Field(..., description="User's input text")
16
+ system_prompt: str = Field(
17
+ default="You are a helpful assistant.",
18
+ description="System prompt to guide the model's behavior",
19
+ )
20
+ model: str = Field(
21
+ default="accounts/fireworks/models/deepseek-r1",
22
+ description="Fireworks AI model to use",
23
+ )
24
+ temperature: float = Field(
25
+ default=0.7, description="Temperature for response generation", ge=0, le=1
26
+ )
27
+ max_tokens: Optional[int] = Field(
28
+ default=None, description="Maximum tokens in response"
29
+ )
30
+ context_length_exceeded_behavior: str = Field(
31
+ default="truncate", description="Behavior when context length is exceeded"
32
+ )
33
+
34
+
35
+ class FireworksOutput(OutputSchema):
36
+ """Schema for Fireworks AI output"""
37
+
38
+ response: str = Field(..., description="Model's response text")
39
+ used_model: str = Field(..., description="Model used for generation")
40
+ usage: Dict[str, int] = Field(default_factory=dict, description="Usage statistics")
41
+
42
+
43
+ class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
44
+ """Skill for interacting with Fireworks AI models"""
45
+
46
+ input_schema = FireworksInput
47
+ output_schema = FireworksOutput
48
+
49
+ def __init__(self, credentials: Optional[FireworksCredentials] = None):
50
+ """Initialize the skill with optional credentials"""
51
+ super().__init__()
52
+ self.credentials = credentials or FireworksCredentials.from_env()
53
+ self.base_url = "https://api.fireworks.ai/inference/v1"
54
+
55
+ def process(self, input_data: FireworksInput) -> FireworksOutput:
56
+ """Process the input using Fireworks AI API"""
57
+ try:
58
+ logger.info(f"Processing request with model {input_data.model}")
59
+
60
+ # Prepare messages
61
+ messages = [
62
+ {"role": "system", "content": input_data.system_prompt},
63
+ {"role": "user", "content": input_data.user_input},
64
+ ]
65
+
66
+ # Prepare request payload
67
+ payload = {
68
+ "messages": messages,
69
+ "model": input_data.model,
70
+ "context_length_exceeded_behavior": input_data.context_length_exceeded_behavior,
71
+ "temperature": input_data.temperature,
72
+ "n": 1,
73
+ "response_format": {"type": "text"},
74
+ "stream": False,
75
+ }
76
+
77
+ if input_data.max_tokens:
78
+ payload["max_tokens"] = input_data.max_tokens
79
+
80
+ # Make API request
81
+ response = requests.post(
82
+ f"{self.base_url}/chat/completions",
83
+ json=payload,
84
+ headers={
85
+ "Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
86
+ "Content-Type": "application/json",
87
+ },
88
+ )
89
+
90
+ response.raise_for_status()
91
+ response_data = FireworksResponse(**response.json())
92
+
93
+ logger.success("Successfully processed Fireworks AI request")
94
+
95
+ return FireworksOutput(
96
+ response=response_data.choices[0]["message"]["content"],
97
+ used_model=response_data.model,
98
+ usage={
99
+ "prompt_tokens": response_data.usage.prompt_tokens,
100
+ "completion_tokens": response_data.usage.completion_tokens,
101
+ "total_tokens": response_data.usage.total_tokens,
102
+ },
103
+ )
104
+
105
+ except Exception as e:
106
+ logger.exception(f"Fireworks AI processing failed: {str(e)}")
107
+ raise ProcessingError(f"Fireworks AI processing failed: {str(e)}")
@@ -0,0 +1,119 @@
1
+ from typing import Dict, NamedTuple, Optional
2
+ from decimal import Decimal
3
+
4
+
5
+ class OpenAIModelConfig(NamedTuple):
6
+ display_name: str
7
+ base_model: str
8
+ input_price: Decimal
9
+ cached_input_price: Optional[Decimal]
10
+ output_price: Decimal
11
+
12
+
13
+ OPENAI_MODELS: Dict[str, OpenAIModelConfig] = {
14
+ "gpt-4o": OpenAIModelConfig(
15
+ display_name="GPT-4 Optimized",
16
+ base_model="gpt-4o",
17
+ input_price=Decimal("2.50"),
18
+ cached_input_price=Decimal("1.25"),
19
+ output_price=Decimal("10.00"),
20
+ ),
21
+ "gpt-4o-2024-08-06": OpenAIModelConfig(
22
+ display_name="GPT-4 Optimized (2024-08-06)",
23
+ base_model="gpt-4o",
24
+ input_price=Decimal("2.50"),
25
+ cached_input_price=Decimal("1.25"),
26
+ output_price=Decimal("10.00"),
27
+ ),
28
+ "gpt-4o-2024-05-13": OpenAIModelConfig(
29
+ display_name="GPT-4 Optimized (2024-05-13)",
30
+ base_model="gpt-4o",
31
+ input_price=Decimal("5.00"),
32
+ cached_input_price=None,
33
+ output_price=Decimal("15.00"),
34
+ ),
35
+ "gpt-4o-audio-preview-2024-12-17": OpenAIModelConfig(
36
+ display_name="GPT-4 Optimized Audio Preview",
37
+ base_model="gpt-4o-audio-preview",
38
+ input_price=Decimal("2.50"),
39
+ cached_input_price=None,
40
+ output_price=Decimal("10.00"),
41
+ ),
42
+ "gpt-4o-realtime-preview-2024-12-17": OpenAIModelConfig(
43
+ display_name="GPT-4 Optimized Realtime Preview",
44
+ base_model="gpt-4o-realtime-preview",
45
+ input_price=Decimal("5.00"),
46
+ cached_input_price=Decimal("2.50"),
47
+ output_price=Decimal("20.00"),
48
+ ),
49
+ "gpt-4o-mini-2024-07-18": OpenAIModelConfig(
50
+ display_name="GPT-4 Optimized Mini",
51
+ base_model="gpt-4o-mini",
52
+ input_price=Decimal("0.15"),
53
+ cached_input_price=Decimal("0.075"),
54
+ output_price=Decimal("0.60"),
55
+ ),
56
+ "gpt-4o-mini-audio-preview-2024-12-17": OpenAIModelConfig(
57
+ display_name="GPT-4 Optimized Mini Audio Preview",
58
+ base_model="gpt-4o-mini-audio-preview",
59
+ input_price=Decimal("0.15"),
60
+ cached_input_price=None,
61
+ output_price=Decimal("0.60"),
62
+ ),
63
+ "gpt-4o-mini-realtime-preview-2024-12-17": OpenAIModelConfig(
64
+ display_name="GPT-4 Optimized Mini Realtime Preview",
65
+ base_model="gpt-4o-mini-realtime-preview",
66
+ input_price=Decimal("0.60"),
67
+ cached_input_price=Decimal("0.30"),
68
+ output_price=Decimal("2.40"),
69
+ ),
70
+ "o1-2024-12-17": OpenAIModelConfig(
71
+ display_name="O1",
72
+ base_model="o1",
73
+ input_price=Decimal("15.00"),
74
+ cached_input_price=Decimal("7.50"),
75
+ output_price=Decimal("60.00"),
76
+ ),
77
+ "o3-mini-2025-01-31": OpenAIModelConfig(
78
+ display_name="O3 Mini",
79
+ base_model="o3-mini",
80
+ input_price=Decimal("1.10"),
81
+ cached_input_price=Decimal("0.55"),
82
+ output_price=Decimal("4.40"),
83
+ ),
84
+ "o1-mini-2024-09-12": OpenAIModelConfig(
85
+ display_name="O1 Mini",
86
+ base_model="o1-mini",
87
+ input_price=Decimal("1.10"),
88
+ cached_input_price=Decimal("0.55"),
89
+ output_price=Decimal("4.40"),
90
+ ),
91
+ }
92
+
93
+
94
+ def get_model_config(model_id: str) -> OpenAIModelConfig:
95
+ """Get model configuration by model ID"""
96
+ if model_id not in OPENAI_MODELS:
97
+ raise ValueError(f"Model {model_id} not found in OpenAI models")
98
+ return OPENAI_MODELS[model_id]
99
+
100
+
101
+ def get_default_model() -> str:
102
+ """Get the default model ID"""
103
+ return "gpt-4o"
104
+
105
+
106
+ def calculate_cost(
107
+ model_id: str, input_tokens: int, output_tokens: int, use_cached: bool = False
108
+ ) -> Decimal:
109
+ """Calculate cost for token usage"""
110
+ config = get_model_config(model_id)
111
+ input_price = (
112
+ config.cached_input_price
113
+ if (use_cached and config.cached_input_price is not None)
114
+ else config.input_price
115
+ )
116
+ return (
117
+ input_price * Decimal(str(input_tokens))
118
+ + config.output_price * Decimal(str(output_tokens))
119
+ ) / Decimal("1000")
@@ -0,0 +1,34 @@
1
+ from typing import Dict, NamedTuple
2
+
3
+
4
+ class AudioModelConfig(NamedTuple):
5
+ organization: str
6
+ display_name: str
7
+
8
+
9
+ TOGETHER_AUDIO_MODELS: Dict[str, AudioModelConfig] = {
10
+ "Cartesia/Sonic": AudioModelConfig(
11
+ organization="Cartesia", display_name="Cartesia/Sonic"
12
+ )
13
+ }
14
+
15
+
16
+ def get_audio_model_config(model_id: str) -> AudioModelConfig:
17
+ """Get audio model configuration by model ID"""
18
+ if model_id not in TOGETHER_AUDIO_MODELS:
19
+ raise ValueError(f"Model {model_id} not found in Together AI audio models")
20
+ return TOGETHER_AUDIO_MODELS[model_id]
21
+
22
+
23
+ def list_audio_models_by_organization(organization: str) -> Dict[str, AudioModelConfig]:
24
+ """Get all audio models for a specific organization"""
25
+ return {
26
+ model_id: config
27
+ for model_id, config in TOGETHER_AUDIO_MODELS.items()
28
+ if config.organization.lower() == organization.lower()
29
+ }
30
+
31
+
32
+ def get_default_audio_model() -> str:
33
+ """Get the default audio model ID"""
34
+ return "Cartesia/Sonic"
@@ -6,14 +6,14 @@ import together
6
6
  class TogetherAICredentials(BaseCredentials):
7
7
  """Together AI credentials"""
8
8
 
9
- api_key: SecretStr = Field(..., description="Together AI API key")
9
+ together_api_key: SecretStr = Field(..., description="Together AI API key")
10
10
 
11
- _required_credentials = {"api_key"}
11
+ _required_credentials = {"together_api_key"}
12
12
 
13
13
  async def validate_credentials(self) -> bool:
14
14
  """Validate Together AI credentials"""
15
15
  try:
16
- together.api_key = self.api_key.get_secret_value()
16
+ together.api_key = self.together_api_key.get_secret_value()
17
17
  await together.Models.list()
18
18
  return True
19
19
  except Exception as e:
@@ -0,0 +1,92 @@
1
+ from typing import Dict, NamedTuple
2
+
3
+
4
+ class EmbeddingModelConfig(NamedTuple):
5
+ organization: str
6
+ display_name: str
7
+ model_size: str
8
+ embedding_dimension: int
9
+ context_window: int
10
+
11
+
12
+ TOGETHER_EMBEDDING_MODELS: Dict[str, EmbeddingModelConfig] = {
13
+ "togethercomputer/m2-bert-80M-2k-retrieval": EmbeddingModelConfig(
14
+ organization="Together",
15
+ display_name="M2-BERT-80M-2K-Retrieval",
16
+ model_size="80M",
17
+ embedding_dimension=768,
18
+ context_window=2048,
19
+ ),
20
+ "togethercomputer/m2-bert-80M-8k-retrieval": EmbeddingModelConfig(
21
+ organization="Together",
22
+ display_name="M2-BERT-80M-8K-Retrieval",
23
+ model_size="80M",
24
+ embedding_dimension=768,
25
+ context_window=8192,
26
+ ),
27
+ "togethercomputer/m2-bert-80M-32k-retrieval": EmbeddingModelConfig(
28
+ organization="Together",
29
+ display_name="M2-BERT-80M-32K-Retrieval",
30
+ model_size="80M",
31
+ embedding_dimension=768,
32
+ context_window=32768,
33
+ ),
34
+ "WhereIsAI/UAE-Large-V1": EmbeddingModelConfig(
35
+ organization="WhereIsAI",
36
+ display_name="UAE-Large-v1",
37
+ model_size="326M",
38
+ embedding_dimension=1024,
39
+ context_window=512,
40
+ ),
41
+ "BAAI/bge-large-en-v1.5": EmbeddingModelConfig(
42
+ organization="BAAI",
43
+ display_name="BGE-Large-EN-v1.5",
44
+ model_size="326M",
45
+ embedding_dimension=1024,
46
+ context_window=512,
47
+ ),
48
+ "BAAI/bge-base-en-v1.5": EmbeddingModelConfig(
49
+ organization="BAAI",
50
+ display_name="BGE-Base-EN-v1.5",
51
+ model_size="102M",
52
+ embedding_dimension=768,
53
+ context_window=512,
54
+ ),
55
+ "sentence-transformers/msmarco-bert-base-dot-v5": EmbeddingModelConfig(
56
+ organization="sentence-transformers",
57
+ display_name="Sentence-BERT",
58
+ model_size="110M",
59
+ embedding_dimension=768,
60
+ context_window=512,
61
+ ),
62
+ "bert-base-uncased": EmbeddingModelConfig(
63
+ organization="Hugging Face",
64
+ display_name="BERT",
65
+ model_size="110M",
66
+ embedding_dimension=768,
67
+ context_window=512,
68
+ ),
69
+ }
70
+
71
+
72
+ def get_embedding_model_config(model_id: str) -> EmbeddingModelConfig:
73
+ """Get embedding model configuration by model ID"""
74
+ if model_id not in TOGETHER_EMBEDDING_MODELS:
75
+ raise ValueError(f"Model {model_id} not found in Together AI embedding models")
76
+ return TOGETHER_EMBEDDING_MODELS[model_id]
77
+
78
+
79
+ def list_embedding_models_by_organization(
80
+ organization: str,
81
+ ) -> Dict[str, EmbeddingModelConfig]:
82
+ """Get all embedding models for a specific organization"""
83
+ return {
84
+ model_id: config
85
+ for model_id, config in TOGETHER_EMBEDDING_MODELS.items()
86
+ if config.organization.lower() == organization.lower()
87
+ }
88
+
89
+
90
+ def get_default_embedding_model() -> str:
91
+ """Get the default embedding model ID"""
92
+ return "togethercomputer/m2-bert-80M-32k-retrieval"
@@ -0,0 +1,69 @@
1
+ from typing import Dict, NamedTuple, Optional
2
+
3
+
4
+ class ImageModelConfig(NamedTuple):
5
+ organization: str
6
+ display_name: str
7
+ default_steps: Optional[int]
8
+
9
+
10
+ TOGETHER_IMAGE_MODELS: Dict[str, ImageModelConfig] = {
11
+ "black-forest-labs/FLUX.1-schnell-Free": ImageModelConfig(
12
+ organization="Black Forest Labs",
13
+ display_name="Flux.1 [schnell] (free)",
14
+ default_steps=None,
15
+ ),
16
+ "black-forest-labs/FLUX.1-schnell": ImageModelConfig(
17
+ organization="Black Forest Labs",
18
+ display_name="Flux.1 [schnell] (Turbo)",
19
+ default_steps=4,
20
+ ),
21
+ "black-forest-labs/FLUX.1-dev": ImageModelConfig(
22
+ organization="Black Forest Labs", display_name="Flux.1 Dev", default_steps=28
23
+ ),
24
+ "black-forest-labs/FLUX.1-canny": ImageModelConfig(
25
+ organization="Black Forest Labs", display_name="Flux.1 Canny", default_steps=28
26
+ ),
27
+ "black-forest-labs/FLUX.1-depth": ImageModelConfig(
28
+ organization="Black Forest Labs", display_name="Flux.1 Depth", default_steps=28
29
+ ),
30
+ "black-forest-labs/FLUX.1-redux": ImageModelConfig(
31
+ organization="Black Forest Labs", display_name="Flux.1 Redux", default_steps=28
32
+ ),
33
+ "black-forest-labs/FLUX.1.1-pro": ImageModelConfig(
34
+ organization="Black Forest Labs",
35
+ display_name="Flux1.1 [pro]",
36
+ default_steps=None,
37
+ ),
38
+ "black-forest-labs/FLUX.1-pro": ImageModelConfig(
39
+ organization="Black Forest Labs",
40
+ display_name="Flux.1 [pro]",
41
+ default_steps=None,
42
+ ),
43
+ "stabilityai/stable-diffusion-xl-base-1.0": ImageModelConfig(
44
+ organization="Stability AI",
45
+ display_name="Stable Diffusion XL 1.0",
46
+ default_steps=None,
47
+ ),
48
+ }
49
+
50
+
51
+ def get_image_model_config(model_id: str) -> ImageModelConfig:
52
+ """Get image model configuration by model ID"""
53
+ if model_id not in TOGETHER_IMAGE_MODELS:
54
+ raise ValueError(f"Model {model_id} not found in Together AI image models")
55
+ return TOGETHER_IMAGE_MODELS[model_id]
56
+
57
+
58
+ def list_image_models_by_organization(organization: str) -> Dict[str, ImageModelConfig]:
59
+ """Get all image models for a specific organization"""
60
+ return {
61
+ model_id: config
62
+ for model_id, config in TOGETHER_IMAGE_MODELS.items()
63
+ if config.organization.lower() == organization.lower()
64
+ }
65
+
66
+
67
+ def get_default_image_model() -> str:
68
+ """Get the default image model ID"""
69
+ return "black-forest-labs/FLUX.1-schnell-Free"