airtrain 0.1.44__py3-none-any.whl → 0.1.45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
airtrain/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  """Airtrain - A platform for building and deploying AI agents with structured skills"""
2
2
 
3
- __version__ = "0.1.44"
3
+ __version__ = "0.1.45"
4
4
 
5
5
  # Core imports
6
6
  from .core.skills import Skill, ProcessingError
@@ -26,6 +26,13 @@ from .cerebras.skills import CerebrasChatSkill
26
26
  from .openai.models_config import OPENAI_MODELS, OpenAIModelConfig
27
27
  from .anthropic.models_config import ANTHROPIC_MODELS, AnthropicModelConfig
28
28
 
29
+ # Combined modules
30
+ from .combined.list_models_factory import (
31
+ ListModelsSkillFactory,
32
+ GenericListModelsInput,
33
+ GenericListModelsOutput
34
+ )
35
+
29
36
  __all__ = [
30
37
  # Credentials
31
38
  "OpenAICredentials",
@@ -53,4 +60,8 @@ __all__ = [
53
60
  "OpenAIModelConfig",
54
61
  "ANTHROPIC_MODELS",
55
62
  "AnthropicModelConfig",
63
+ # Combined modules
64
+ "ListModelsSkillFactory",
65
+ "GenericListModelsInput",
66
+ "GenericListModelsOutput",
56
67
  ]
@@ -0,0 +1,21 @@
1
+ """Combined integration modules for Airtrain"""
2
+
3
+ from .groq_fireworks_skills import (
4
+ GroqFireworksSkill,
5
+ GroqFireworksInput,
6
+ GroqFireworksOutput
7
+ )
8
+ from .list_models_factory import (
9
+ ListModelsSkillFactory,
10
+ GenericListModelsInput,
11
+ GenericListModelsOutput
12
+ )
13
+
14
+ __all__ = [
15
+ "GroqFireworksSkill",
16
+ "GroqFireworksInput",
17
+ "GroqFireworksOutput",
18
+ "ListModelsSkillFactory",
19
+ "GenericListModelsInput",
20
+ "GenericListModelsOutput"
21
+ ]
@@ -0,0 +1,126 @@
1
+ from typing import Optional, Dict, Any, List
2
+ from pydantic import Field
3
+ import requests
4
+ from groq import Groq
5
+
6
+ from airtrain.core.skills import Skill, ProcessingError
7
+ from airtrain.core.schemas import InputSchema, OutputSchema
8
+ from airtrain.integrations.fireworks.completion_skills import (
9
+ FireworksCompletionSkill,
10
+ FireworksCompletionInput,
11
+ )
12
+
13
+
14
+ class GroqFireworksInput(InputSchema):
15
+ """Schema for combined Groq and Fireworks input"""
16
+
17
+ user_input: str = Field(..., description="User's input text")
18
+ groq_model: str = Field(
19
+ default="mixtral-8x7b-32768", description="Groq model to use"
20
+ )
21
+ fireworks_model: str = Field(
22
+ default="accounts/fireworks/models/deepseek-r1",
23
+ description="Fireworks model to use",
24
+ )
25
+ temperature: float = Field(
26
+ default=0.7, description="Temperature for response generation"
27
+ )
28
+ max_tokens: int = Field(default=131072, description="Maximum tokens in response")
29
+
30
+
31
+ class GroqFireworksOutput(OutputSchema):
32
+ """Schema for combined Groq and Fireworks output"""
33
+
34
+ combined_response: str
35
+ groq_response: str
36
+ fireworks_response: str
37
+ used_models: Dict[str, str]
38
+ usage: Dict[str, Dict[str, int]]
39
+
40
+
41
+ class GroqFireworksSkill(Skill[GroqFireworksInput, GroqFireworksOutput]):
42
+ """Skill combining Groq and Fireworks responses"""
43
+
44
+ input_schema = GroqFireworksInput
45
+ output_schema = GroqFireworksOutput
46
+
47
+ def __init__(
48
+ self,
49
+ groq_api_key: Optional[str] = None,
50
+ fireworks_skill: Optional[FireworksCompletionSkill] = None,
51
+ ):
52
+ """Initialize the skill with optional API keys"""
53
+ super().__init__()
54
+ self.groq_client = Groq(api_key=groq_api_key)
55
+ self.fireworks_skill = fireworks_skill or FireworksCompletionSkill()
56
+
57
+ def _get_groq_response(self, input_data: GroqFireworksInput) -> Dict[str, Any]:
58
+ """Get response from Groq"""
59
+ try:
60
+ completion = self.groq_client.chat.completions.create(
61
+ model=input_data.groq_model,
62
+ messages=[{"role": "user", "content": input_data.user_input}],
63
+ temperature=input_data.temperature,
64
+ max_tokens=input_data.max_tokens,
65
+ )
66
+ return {
67
+ "response": completion.choices[0].message.content,
68
+ "usage": completion.usage.model_dump(),
69
+ }
70
+ except Exception as e:
71
+ raise ProcessingError(f"Groq request failed: {str(e)}")
72
+
73
+ def _get_fireworks_response(
74
+ self, groq_response: str, input_data: GroqFireworksInput
75
+ ) -> Dict[str, Any]:
76
+ """Get response from Fireworks"""
77
+ try:
78
+ formatted_prompt = (
79
+ f"<USER>{input_data.user_input}</USER>\n<ASSISTANT>{groq_response}"
80
+ )
81
+
82
+ fireworks_input = FireworksCompletionInput(
83
+ prompt=formatted_prompt,
84
+ model=input_data.fireworks_model,
85
+ temperature=input_data.temperature,
86
+ max_tokens=input_data.max_tokens,
87
+ )
88
+
89
+ result = self.fireworks_skill.process(fireworks_input)
90
+ return {"response": result.response, "usage": result.usage}
91
+ except Exception as e:
92
+ raise ProcessingError(f"Fireworks request failed: {str(e)}")
93
+
94
+ def process(self, input_data: GroqFireworksInput) -> GroqFireworksOutput:
95
+ """Process the input using both Groq and Fireworks"""
96
+ try:
97
+ # Get Groq response
98
+ groq_result = self._get_groq_response(input_data)
99
+
100
+ # Get Fireworks response
101
+ fireworks_result = self._get_fireworks_response(
102
+ groq_result["response"], input_data
103
+ )
104
+
105
+ # Combine responses in the required format
106
+ combined_response = (
107
+ f"<USER>{input_data.user_input}</USER>\n"
108
+ f"<ASSISTANT>{groq_result['response']} {fireworks_result['response']}"
109
+ )
110
+
111
+ return GroqFireworksOutput(
112
+ combined_response=combined_response,
113
+ groq_response=groq_result["response"],
114
+ fireworks_response=fireworks_result["response"],
115
+ used_models={
116
+ "groq": input_data.groq_model,
117
+ "fireworks": input_data.fireworks_model,
118
+ },
119
+ usage={
120
+ "groq": groq_result["usage"],
121
+ "fireworks": fireworks_result["usage"],
122
+ },
123
+ )
124
+
125
+ except Exception as e:
126
+ raise ProcessingError(f"Combined processing failed: {str(e)}")
@@ -0,0 +1,172 @@
1
+ from typing import Optional, Dict, Any, List
2
+ from pydantic import Field
3
+
4
+ from airtrain.core.skills import Skill, ProcessingError
5
+ from airtrain.core.schemas import InputSchema, OutputSchema
6
+ from airtrain.core.credentials import BaseCredentials
7
+
8
+ # Import existing list models skills
9
+ from airtrain.integrations.openai.list_models import OpenAIListModelsSkill
10
+ from airtrain.integrations.anthropic.list_models import AnthropicListModelsSkill
11
+ from airtrain.integrations.together.list_models import TogetherListModelsSkill
12
+ from airtrain.integrations.fireworks.list_models import FireworksListModelsSkill
13
+
14
+ # Import credentials
15
+ from airtrain.integrations.groq.credentials import GroqCredentials
16
+ from airtrain.integrations.cerebras.credentials import CerebrasCredentials
17
+ from airtrain.integrations.sambanova.credentials import SambanovaCredentials
18
+
19
+
20
+ # Generic list models input schema
21
+ class GenericListModelsInput(InputSchema):
22
+ """Generic schema for listing models from any provider"""
23
+
24
+ api_models_only: bool = Field(
25
+ default=False,
26
+ description=(
27
+ "If True, fetch models from the API only. If False, use local config."
28
+ )
29
+ )
30
+
31
+
32
+ # Generic list models output schema
33
+ class GenericListModelsOutput(OutputSchema):
34
+ """Generic schema for list models output from any provider"""
35
+
36
+ models: List[Dict[str, Any]] = Field(
37
+ default_factory=list,
38
+ description="List of models"
39
+ )
40
+ provider: str = Field(
41
+ ...,
42
+ description="Provider name"
43
+ )
44
+
45
+
46
+ # Base class for stub implementations
47
+ class BaseListModelsSkill(Skill[GenericListModelsInput, GenericListModelsOutput]):
48
+ """Base skill for listing models"""
49
+
50
+ input_schema = GenericListModelsInput
51
+ output_schema = GenericListModelsOutput
52
+
53
+ def __init__(self, provider: str, credentials: Optional[BaseCredentials] = None):
54
+ """Initialize the skill with provider name and optional credentials"""
55
+ super().__init__()
56
+ self.provider = provider
57
+ self.credentials = credentials
58
+
59
+ def get_models(self) -> List[Dict[str, Any]]:
60
+ """Return list of models. To be implemented by subclasses."""
61
+ raise NotImplementedError("Subclasses must implement get_models()")
62
+
63
+ def process(self, input_data: GenericListModelsInput) -> GenericListModelsOutput:
64
+ """Process the input and return a list of models."""
65
+ try:
66
+ models = self.get_models()
67
+ return GenericListModelsOutput(models=models, provider=self.provider)
68
+ except Exception as e:
69
+ raise ProcessingError(f"Failed to list {self.provider} models: {str(e)}")
70
+
71
+
72
+ # Groq implementation
73
+ class GroqListModelsSkill(BaseListModelsSkill):
74
+ """Skill for listing Groq models"""
75
+
76
+ def __init__(self, credentials: Optional[GroqCredentials] = None):
77
+ """Initialize the skill with optional credentials"""
78
+ super().__init__(provider="groq", credentials=credentials)
79
+
80
+ def get_models(self) -> List[Dict[str, Any]]:
81
+ """Return list of Groq models."""
82
+ # Default Groq models from trmx_agent config
83
+ models = [
84
+ {"id": "llama-3-70b-8192", "display_name": "Llama 3 70B (8K)"},
85
+ {"id": "mixtral-8x7b-32768", "display_name": "Mixtral 8x7B (32K)"},
86
+ {"id": "gemma-7b-it", "display_name": "Gemma 7B Instruct"}
87
+ ]
88
+ return models
89
+
90
+
91
+ # Cerebras implementation
92
+ class CerebrasListModelsSkill(BaseListModelsSkill):
93
+ """Skill for listing Cerebras models"""
94
+
95
+ def __init__(self, credentials: Optional[CerebrasCredentials] = None):
96
+ """Initialize the skill with optional credentials"""
97
+ super().__init__(provider="cerebras", credentials=credentials)
98
+
99
+ def get_models(self) -> List[Dict[str, Any]]:
100
+ """Return list of Cerebras models."""
101
+ # Default Cerebras models from trmx_agent config
102
+ models = [
103
+ {"id": "cerebras/Cerebras-GPT-13B-v0.1", "display_name": "Cerebras GPT 13B v0.1"},
104
+ {"id": "cerebras/Cerebras-GPT-111M-v0.9", "display_name": "Cerebras GPT 111M v0.9"},
105
+ {"id": "cerebras/Cerebras-GPT-590M-v0.7", "display_name": "Cerebras GPT 590M v0.7"}
106
+ ]
107
+ return models
108
+
109
+
110
+ # Sambanova implementation
111
+ class SambanovaListModelsSkill(BaseListModelsSkill):
112
+ """Skill for listing Sambanova models"""
113
+
114
+ def __init__(self, credentials: Optional[SambanovaCredentials] = None):
115
+ """Initialize the skill with optional credentials"""
116
+ super().__init__(provider="sambanova", credentials=credentials)
117
+
118
+ def get_models(self) -> List[Dict[str, Any]]:
119
+ """Return list of Sambanova models."""
120
+ # Limited Sambanova model information
121
+ models = [
122
+ {"id": "sambanova/samba-1", "display_name": "Samba-1"},
123
+ {"id": "sambanova/samba-2", "display_name": "Samba-2"}
124
+ ]
125
+ return models
126
+
127
+
128
+ # Factory class
129
+ class ListModelsSkillFactory:
130
+ """Factory for creating list models skills for different providers"""
131
+
132
+ # Map provider names to their corresponding list models skills
133
+ _PROVIDER_MAP = {
134
+ "openai": OpenAIListModelsSkill,
135
+ "anthropic": AnthropicListModelsSkill,
136
+ "together": TogetherListModelsSkill,
137
+ "fireworks": FireworksListModelsSkill,
138
+ "groq": GroqListModelsSkill,
139
+ "cerebras": CerebrasListModelsSkill,
140
+ "sambanova": SambanovaListModelsSkill
141
+ }
142
+
143
+ @classmethod
144
+ def get_skill(cls, provider: str, credentials=None):
145
+ """Return a list models skill for the specified provider
146
+
147
+ Args:
148
+ provider (str): The provider name (case-insensitive)
149
+ credentials: Optional credentials for the provider
150
+
151
+ Returns:
152
+ A ListModelsSkill instance for the specified provider
153
+
154
+ Raises:
155
+ ValueError: If the provider is not supported
156
+ """
157
+ provider = provider.lower()
158
+
159
+ if provider not in cls._PROVIDER_MAP:
160
+ supported = ", ".join(cls.get_supported_providers())
161
+ raise ValueError(
162
+ f"Unsupported provider: {provider}. "
163
+ f"Supported providers are: {supported}"
164
+ )
165
+
166
+ skill_class = cls._PROVIDER_MAP[provider]
167
+ return skill_class(credentials=credentials)
168
+
169
+ @classmethod
170
+ def get_supported_providers(cls):
171
+ """Return a list of supported provider names"""
172
+ return list(cls._PROVIDER_MAP.keys())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: airtrain
3
- Version: 0.1.44
3
+ Version: 0.1.45
4
4
  Summary: A platform for building and deploying AI agents with structured skills
5
5
  Home-page: https://github.com/rosaboyle/airtrain.dev
6
6
  Author: Dheeraj Pai
@@ -1,4 +1,4 @@
1
- airtrain/__init__.py,sha256=_i8k5dricoc_30LRe5yFCrObuhrmuImelQnbT_Kkmmw,2099
1
+ airtrain/__init__.py,sha256=r8knBRelQBR1dv6YknyXNm00o3Eo7UmfPDuHYDxLtL4,2099
2
2
  airtrain/__main__.py,sha256=EU8ffFmCdC1G-UcHHt0Oo3lB1PGqfC6kwzH39CnYSwU,72
3
3
  airtrain/builder/__init__.py,sha256=D33sr0k_WAe6FAJkk8rUaivEzFaeVqLXkQgyFWEhfPU,110
4
4
  airtrain/builder/agent_builder.py,sha256=3XnGUAcK_6lWoUDtL0TanliQZuh7u0unhNbnrz1z2-I,5018
@@ -13,7 +13,7 @@ airtrain/core/__init__.py,sha256=9h7iKwTzZocCPc9bU6j8bA02BokteWIOcO1uaqGMcrk,254
13
13
  airtrain/core/credentials.py,sha256=PgQotrQc46J5djidKnkK1znUv3fyNkUFDO-m2Kn_Gzo,4006
14
14
  airtrain/core/schemas.py,sha256=MMXrDviC4gRea_QaPpbjgO--B_UKxnD7YrxqZOLJZZU,7003
15
15
  airtrain/core/skills.py,sha256=LljalzeSHK5eQPTAOEAYc5D8Qn1kVSfiz9WgziTD5UM,4688
16
- airtrain/integrations/__init__.py,sha256=rk9QFl0Dd7Qp4rULhi_u4smwsJwk69Kg_-fv0GQ43iw,1782
16
+ airtrain/integrations/__init__.py,sha256=Y0yJt6c6uKOkMXdGPELSc6NYiLrEYr_38acViNPXWNk,2046
17
17
  airtrain/integrations/anthropic/__init__.py,sha256=K741w3v7fWsCknTo38ARqDL0D3HPlwDIvDuuBao9Tto,800
18
18
  airtrain/integrations/anthropic/credentials.py,sha256=hlTSw9HX66kYNaeQUtn0JjdZQBMNkzzFOJOoLOOzvcY,1246
19
19
  airtrain/integrations/anthropic/list_models.py,sha256=o7FABp0Cq3gs76zOF-CM9ohmYslWT6vK9qtQabV9XzI,3973
@@ -25,6 +25,9 @@ airtrain/integrations/aws/skills.py,sha256=2l16Y5zYeNd9trrPca6Rbhvl6a-GJBuCQMu7R
25
25
  airtrain/integrations/cerebras/__init__.py,sha256=zAD-qV38OzHhMCz1z-NvjjqcYEhURbm8RWTOKHNqbew,174
26
26
  airtrain/integrations/cerebras/credentials.py,sha256=KDEH4r8FGT68L9p34MLZWK65wq_a703pqIF3ODaSbts,694
27
27
  airtrain/integrations/cerebras/skills.py,sha256=BJEb_7TglCYAukD3kcx37R8ibnJWdxVrBrwf3ZTYP-4,4924
28
+ airtrain/integrations/combined/__init__.py,sha256=EL_uZs8436abeZA6NbM-gLdur7kJr4YfGGOYybKOhjc,467
29
+ airtrain/integrations/combined/groq_fireworks_skills.py,sha256=Kz8UDU4-Rl71znz3ml9qVMRz669iWx4sUl37iafW0NU,4612
30
+ airtrain/integrations/combined/list_models_factory.py,sha256=tDhfu7-MdAgPXCDkASNezSu4gfNjIC7Zf4XLw_Rx0sk,6500
28
31
  airtrain/integrations/fireworks/__init__.py,sha256=GstUg0rYC-7Pg0DVbDXwL5eO1hp3WCSfroWazbGpfi0,545
29
32
  airtrain/integrations/fireworks/completion_skills.py,sha256=zxx7aNlum9scQMri5Ek0qN8VfAomhyUp3u8JJo_AFWM,5615
30
33
  airtrain/integrations/fireworks/conversation_manager.py,sha256=ifscKHYKWM_NDElin-oTzpRhyoh6pzBnklmMuH5geOY,3706
@@ -68,8 +71,8 @@ airtrain/integrations/together/rerank_skill.py,sha256=gjH24hLWCweWKPyyfKZMG3K_g9
68
71
  airtrain/integrations/together/schemas.py,sha256=pBMrbX67oxPCr-sg4K8_Xqu1DWbaC4uLCloVSascROg,1210
69
72
  airtrain/integrations/together/skills.py,sha256=8DwkexMJu1Gm6QmNDfNasYStQ31QsXBbFP99zR-YCf0,7598
70
73
  airtrain/integrations/together/vision_models_config.py,sha256=m28HwYDk2Kup_J-a1FtynIa2ZVcbl37kltfoHnK8zxs,1544
71
- airtrain-0.1.44.dist-info/METADATA,sha256=EV6dwjtrrrN_YvL7KpCP9FNGpjZYNGWVwXN7tzP-czw,5375
72
- airtrain-0.1.44.dist-info/WHEEL,sha256=tTnHoFhvKQHCh4jz3yCn0WPTYIy7wXx3CJtJ7SJGV7c,91
73
- airtrain-0.1.44.dist-info/entry_points.txt,sha256=rrJ36IUsyq6n1dSfTWXqVAgpQLPRWDfCqwd6_3B-G0U,52
74
- airtrain-0.1.44.dist-info/top_level.txt,sha256=cFWW1vY6VMCb3AGVdz6jBDpZ65xxBRSqlsPyySxTkxY,9
75
- airtrain-0.1.44.dist-info/RECORD,,
74
+ airtrain-0.1.45.dist-info/METADATA,sha256=PXbwhd_qiTdj56DAAG1iwZ9CBgThiTRaaf4xw6A_G7o,5375
75
+ airtrain-0.1.45.dist-info/WHEEL,sha256=tTnHoFhvKQHCh4jz3yCn0WPTYIy7wXx3CJtJ7SJGV7c,91
76
+ airtrain-0.1.45.dist-info/entry_points.txt,sha256=rrJ36IUsyq6n1dSfTWXqVAgpQLPRWDfCqwd6_3B-G0U,52
77
+ airtrain-0.1.45.dist-info/top_level.txt,sha256=cFWW1vY6VMCb3AGVdz6jBDpZ65xxBRSqlsPyySxTkxY,9
78
+ airtrain-0.1.45.dist-info/RECORD,,