airtrain 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. airtrain/__init__.py +146 -6
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  19. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  21. airtrain/core/credentials.py +62 -44
  22. airtrain/core/skills.py +102 -0
  23. airtrain/integrations/__init__.py +74 -0
  24. airtrain/integrations/anthropic/__init__.py +33 -0
  25. airtrain/integrations/anthropic/credentials.py +32 -0
  26. airtrain/integrations/anthropic/list_models.py +110 -0
  27. airtrain/integrations/anthropic/models_config.py +100 -0
  28. airtrain/integrations/anthropic/skills.py +155 -0
  29. airtrain/integrations/aws/__init__.py +6 -0
  30. airtrain/integrations/aws/credentials.py +36 -0
  31. airtrain/integrations/aws/skills.py +98 -0
  32. airtrain/integrations/cerebras/__init__.py +6 -0
  33. airtrain/integrations/cerebras/credentials.py +19 -0
  34. airtrain/integrations/cerebras/skills.py +127 -0
  35. airtrain/integrations/combined/__init__.py +21 -0
  36. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  37. airtrain/integrations/combined/list_models_factory.py +210 -0
  38. airtrain/integrations/fireworks/__init__.py +21 -0
  39. airtrain/integrations/fireworks/completion_skills.py +147 -0
  40. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  41. airtrain/integrations/fireworks/credentials.py +26 -0
  42. airtrain/integrations/fireworks/list_models.py +128 -0
  43. airtrain/integrations/fireworks/models.py +139 -0
  44. airtrain/integrations/fireworks/requests_skills.py +207 -0
  45. airtrain/integrations/fireworks/skills.py +181 -0
  46. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  47. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  48. airtrain/integrations/fireworks/structured_skills.py +102 -0
  49. airtrain/integrations/google/__init__.py +7 -0
  50. airtrain/integrations/google/credentials.py +58 -0
  51. airtrain/integrations/google/skills.py +122 -0
  52. airtrain/integrations/groq/__init__.py +23 -0
  53. airtrain/integrations/groq/credentials.py +24 -0
  54. airtrain/integrations/groq/models_config.py +162 -0
  55. airtrain/integrations/groq/skills.py +201 -0
  56. airtrain/integrations/ollama/__init__.py +6 -0
  57. airtrain/integrations/ollama/credentials.py +26 -0
  58. airtrain/integrations/ollama/skills.py +41 -0
  59. airtrain/integrations/openai/__init__.py +37 -0
  60. airtrain/integrations/openai/chinese_assistant.py +42 -0
  61. airtrain/integrations/openai/credentials.py +39 -0
  62. airtrain/integrations/openai/list_models.py +112 -0
  63. airtrain/integrations/openai/models_config.py +224 -0
  64. airtrain/integrations/openai/skills.py +342 -0
  65. airtrain/integrations/perplexity/__init__.py +49 -0
  66. airtrain/integrations/perplexity/credentials.py +43 -0
  67. airtrain/integrations/perplexity/list_models.py +112 -0
  68. airtrain/integrations/perplexity/models_config.py +128 -0
  69. airtrain/integrations/perplexity/skills.py +279 -0
  70. airtrain/integrations/sambanova/__init__.py +6 -0
  71. airtrain/integrations/sambanova/credentials.py +20 -0
  72. airtrain/integrations/sambanova/skills.py +129 -0
  73. airtrain/integrations/search/__init__.py +21 -0
  74. airtrain/integrations/search/exa/__init__.py +23 -0
  75. airtrain/integrations/search/exa/credentials.py +30 -0
  76. airtrain/integrations/search/exa/schemas.py +114 -0
  77. airtrain/integrations/search/exa/skills.py +115 -0
  78. airtrain/integrations/together/__init__.py +33 -0
  79. airtrain/integrations/together/audio_models_config.py +34 -0
  80. airtrain/integrations/together/credentials.py +22 -0
  81. airtrain/integrations/together/embedding_models_config.py +92 -0
  82. airtrain/integrations/together/image_models_config.py +69 -0
  83. airtrain/integrations/together/image_skill.py +143 -0
  84. airtrain/integrations/together/list_models.py +76 -0
  85. airtrain/integrations/together/models.py +95 -0
  86. airtrain/integrations/together/models_config.py +399 -0
  87. airtrain/integrations/together/rerank_models_config.py +43 -0
  88. airtrain/integrations/together/rerank_skill.py +49 -0
  89. airtrain/integrations/together/schemas.py +33 -0
  90. airtrain/integrations/together/skills.py +305 -0
  91. airtrain/integrations/together/vision_models_config.py +49 -0
  92. airtrain/telemetry/__init__.py +38 -0
  93. airtrain/telemetry/service.py +167 -0
  94. airtrain/telemetry/views.py +237 -0
  95. airtrain/tools/__init__.py +45 -0
  96. airtrain/tools/command.py +398 -0
  97. airtrain/tools/filesystem.py +166 -0
  98. airtrain/tools/network.py +111 -0
  99. airtrain/tools/registry.py +320 -0
  100. airtrain/tools/search.py +450 -0
  101. airtrain/tools/testing.py +135 -0
  102. airtrain-0.1.4.dist-info/METADATA +222 -0
  103. airtrain-0.1.4.dist-info/RECORD +108 -0
  104. {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  105. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  106. airtrain-0.1.3.dist-info/METADATA +0 -106
  107. airtrain-0.1.3.dist-info/RECORD +0 -9
  108. {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,98 @@
1
+ from typing import List, Optional, Dict, Any
2
+ from pydantic import Field
3
+ import boto3
4
+ from pathlib import Path
5
+ from loguru import logger
6
+
7
+ from airtrain.core.skills import Skill, ProcessingError
8
+ from airtrain.core.schemas import InputSchema, OutputSchema
9
+ from .credentials import AWSCredentials
10
+
11
+
12
+ class AWSBedrockInput(InputSchema):
13
+ """Schema for AWS Bedrock chat input"""
14
+
15
+ user_input: str = Field(..., description="User's input text")
16
+ system_prompt: str = Field(
17
+ default="You are a helpful assistant.",
18
+ description="System prompt to guide the model's behavior",
19
+ )
20
+ model: str = Field(
21
+ default="anthropic.claude-3-sonnet-20240229-v1:0",
22
+ description="AWS Bedrock model to use",
23
+ )
24
+ max_tokens: int = Field(default=131072, description="Maximum tokens in response")
25
+ temperature: float = Field(
26
+ default=0.7, description="Temperature for response generation", ge=0, le=1
27
+ )
28
+ images: Optional[List[Path]] = Field(
29
+ default=None,
30
+ description="Optional list of image paths to include in the message",
31
+ )
32
+
33
+
34
+ class AWSBedrockOutput(OutputSchema):
35
+ """Schema for AWS Bedrock chat output"""
36
+
37
+ response: str = Field(..., description="Model's response text")
38
+ used_model: str = Field(..., description="Model used for generation")
39
+ usage: Dict[str, Any] = Field(
40
+ default_factory=dict, description="Usage statistics from the API"
41
+ )
42
+
43
+
44
+ class AWSBedrockSkill(Skill[AWSBedrockInput, AWSBedrockOutput]):
45
+ """Skill for interacting with AWS Bedrock models"""
46
+
47
+ input_schema = AWSBedrockInput
48
+ output_schema = AWSBedrockOutput
49
+
50
+ def __init__(self, credentials: Optional[AWSCredentials] = None):
51
+ """Initialize the skill with optional credentials"""
52
+ super().__init__()
53
+ self.credentials = credentials or AWSCredentials.from_env()
54
+ self.client = boto3.client(
55
+ "bedrock-runtime",
56
+ aws_access_key_id=self.credentials.aws_access_key_id.get_secret_value(),
57
+ aws_secret_access_key=self.credentials.aws_secret_access_key.get_secret_value(),
58
+ region_name=self.credentials.aws_region,
59
+ )
60
+
61
+ def process(self, input_data: AWSBedrockInput) -> AWSBedrockOutput:
62
+ """Process the input using AWS Bedrock API"""
63
+ try:
64
+ logger.info(f"Processing request with model {input_data.model}")
65
+
66
+ # Prepare request body based on model provider
67
+ if "anthropic" in input_data.model:
68
+ request_body = {
69
+ "anthropic_version": "bedrock-2023-05-31",
70
+ "max_tokens": input_data.max_tokens,
71
+ "temperature": input_data.temperature,
72
+ "system": input_data.system_prompt,
73
+ "messages": [{"role": "user", "content": input_data.user_input}],
74
+ }
75
+ else:
76
+ raise ProcessingError(f"Unsupported model: {input_data.model}")
77
+
78
+ response = self.client.invoke_model(
79
+ modelId=input_data.model, body=request_body
80
+ )
81
+
82
+ # Parse response based on model provider
83
+ if "anthropic" in input_data.model:
84
+ response_data = response["body"]["completion"]
85
+ usage = {
86
+ "input_tokens": response["body"]["usage"]["input_tokens"],
87
+ "output_tokens": response["body"]["usage"]["output_tokens"],
88
+ }
89
+ else:
90
+ raise ProcessingError(f"Unsupported model response: {input_data.model}")
91
+
92
+ return AWSBedrockOutput(
93
+ response=response_data, used_model=input_data.model, usage=usage
94
+ )
95
+
96
+ except Exception as e:
97
+ logger.exception(f"AWS Bedrock processing failed: {str(e)}")
98
+ raise ProcessingError(f"AWS Bedrock processing failed: {str(e)}")
@@ -0,0 +1,6 @@
1
+ """Cerebras integration module"""
2
+
3
+ from .credentials import CerebrasCredentials
4
+ from .skills import CerebrasChatSkill
5
+
6
+ __all__ = ["CerebrasCredentials", "CerebrasChatSkill"]
@@ -0,0 +1,19 @@
1
+ from pydantic import Field, SecretStr
2
+ from airtrain.core.credentials import BaseCredentials, CredentialValidationError
3
+
4
+
5
+ class CerebrasCredentials(BaseCredentials):
6
+ """Cerebras credentials"""
7
+
8
+ cerebras_api_key: SecretStr = Field(..., description="Cerebras API key")
9
+
10
+ _required_credentials = {"cerebras_api_key"}
11
+
12
+ async def validate_credentials(self) -> bool:
13
+ """Validate Cerebras credentials"""
14
+ try:
15
+ # Implement Cerebras-specific validation
16
+ # This would depend on their API client implementation
17
+ return True
18
+ except Exception as e:
19
+ raise CredentialValidationError(f"Invalid Cerebras credentials: {str(e)}")
@@ -0,0 +1,127 @@
1
+ from typing import List, Optional, Dict, Any, Generator
2
+ from pydantic import Field
3
+ from cerebras.cloud.sdk import Cerebras
4
+ from loguru import logger
5
+
6
+ from airtrain.core.skills import Skill, ProcessingError
7
+ from airtrain.core.schemas import InputSchema, OutputSchema
8
+ from .credentials import CerebrasCredentials
9
+
10
+
11
+ class CerebrasInput(InputSchema):
12
+ """Schema for Cerebras chat input"""
13
+
14
+ user_input: str = Field(..., description="User's input text")
15
+ system_prompt: str = Field(
16
+ default="You are a helpful assistant.",
17
+ description="System prompt to guide the model's behavior",
18
+ )
19
+ conversation_history: List[Dict[str, str]] = Field(
20
+ default_factory=list,
21
+ description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
22
+ )
23
+ model: str = Field(default="llama3.1-8b", description="Cerebras model to use")
24
+ max_tokens: Optional[int] = Field(
25
+ default=131072, description="Maximum tokens in response"
26
+ )
27
+ temperature: float = Field(
28
+ default=0.7, description="Temperature for response generation", ge=0, le=1
29
+ )
30
+ stream: bool = Field(
31
+ default=False, description="Whether to stream the response progressively"
32
+ )
33
+
34
+
35
+ class CerebrasOutput(OutputSchema):
36
+ """Schema for Cerebras chat output"""
37
+
38
+ response: str = Field(..., description="Model's response text")
39
+ used_model: str = Field(..., description="Model used for generation")
40
+ usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
41
+
42
+
43
+ class CerebrasChatSkill(Skill[CerebrasInput, CerebrasOutput]):
44
+ """Skill for Cerebras chat"""
45
+
46
+ input_schema = CerebrasInput
47
+ output_schema = CerebrasOutput
48
+
49
+ def __init__(self, credentials: Optional[CerebrasCredentials] = None):
50
+ super().__init__()
51
+ self.credentials = credentials or CerebrasCredentials.from_env()
52
+ self.client = Cerebras(
53
+ api_key=self.credentials.cerebras_api_key.get_secret_value()
54
+ )
55
+
56
+ def _build_messages(self, input_data: CerebrasInput) -> List[Dict[str, str]]:
57
+ """
58
+ Build messages list from input data including conversation history.
59
+
60
+ Args:
61
+ input_data: The input data containing system prompt, conversation history, and user input
62
+
63
+ Returns:
64
+ List[Dict[str, str]]: List of messages in the format required by Cerebras
65
+ """
66
+ messages = [{"role": "system", "content": input_data.system_prompt}]
67
+
68
+ # Add conversation history if present
69
+ if input_data.conversation_history:
70
+ messages.extend(input_data.conversation_history)
71
+
72
+ # Add current user input
73
+ messages.append({"role": "user", "content": input_data.user_input})
74
+
75
+ return messages
76
+
77
+ def process_stream(self, input_data: CerebrasInput) -> Generator[str, None, None]:
78
+ """Process the input and stream the response token by token."""
79
+ try:
80
+ messages = self._build_messages(input_data)
81
+
82
+ stream = self.client.chat.completions.create(
83
+ model=input_data.model,
84
+ messages=messages,
85
+ temperature=input_data.temperature,
86
+ max_tokens=input_data.max_tokens,
87
+ stream=True,
88
+ )
89
+
90
+ for chunk in stream:
91
+ if chunk.choices[0].delta.content is not None:
92
+ yield chunk.choices[0].delta.content
93
+
94
+ except Exception as e:
95
+ logger.exception(f"Cerebras streaming failed: {str(e)}")
96
+ raise ProcessingError(f"Cerebras streaming failed: {str(e)}")
97
+
98
+ def process(self, input_data: CerebrasInput) -> CerebrasOutput:
99
+ """Process the input and return the complete response."""
100
+ try:
101
+ if input_data.stream:
102
+ response_chunks = []
103
+ for chunk in self.process_stream(input_data):
104
+ response_chunks.append(chunk)
105
+ response = "".join(response_chunks)
106
+ usage = {} # Usage stats not available in streaming
107
+ else:
108
+ messages = self._build_messages(input_data)
109
+ response = self.client.chat.completions.create(
110
+ model=input_data.model,
111
+ messages=messages,
112
+ temperature=input_data.temperature,
113
+ max_tokens=input_data.max_tokens,
114
+ )
115
+ usage = (
116
+ response.usage.model_dump() if hasattr(response, "usage") else {}
117
+ )
118
+
119
+ return CerebrasOutput(
120
+ response=response.choices[0].message.content,
121
+ used_model=input_data.model,
122
+ usage=usage,
123
+ )
124
+
125
+ except Exception as e:
126
+ logger.exception(f"Cerebras processing failed: {str(e)}")
127
+ raise ProcessingError(f"Cerebras processing failed: {str(e)}")
@@ -0,0 +1,21 @@
1
+ """Combined integration modules for Airtrain"""
2
+
3
+ from .groq_fireworks_skills import (
4
+ GroqFireworksSkill,
5
+ GroqFireworksInput,
6
+ GroqFireworksOutput
7
+ )
8
+ from .list_models_factory import (
9
+ ListModelsSkillFactory,
10
+ GenericListModelsInput,
11
+ GenericListModelsOutput
12
+ )
13
+
14
+ __all__ = [
15
+ "GroqFireworksSkill",
16
+ "GroqFireworksInput",
17
+ "GroqFireworksOutput",
18
+ "ListModelsSkillFactory",
19
+ "GenericListModelsInput",
20
+ "GenericListModelsOutput"
21
+ ]
@@ -0,0 +1,126 @@
1
+ from typing import Optional, Dict, Any, List
2
+ from pydantic import Field
3
+ import requests
4
+ from groq import Groq
5
+
6
+ from airtrain.core.skills import Skill, ProcessingError
7
+ from airtrain.core.schemas import InputSchema, OutputSchema
8
+ from airtrain.integrations.fireworks.completion_skills import (
9
+ FireworksCompletionSkill,
10
+ FireworksCompletionInput,
11
+ )
12
+
13
+
14
+ class GroqFireworksInput(InputSchema):
15
+ """Schema for combined Groq and Fireworks input"""
16
+
17
+ user_input: str = Field(..., description="User's input text")
18
+ groq_model: str = Field(
19
+ default="mixtral-8x7b-32768", description="Groq model to use"
20
+ )
21
+ fireworks_model: str = Field(
22
+ default="accounts/fireworks/models/deepseek-r1",
23
+ description="Fireworks model to use",
24
+ )
25
+ temperature: float = Field(
26
+ default=0.7, description="Temperature for response generation"
27
+ )
28
+ max_tokens: int = Field(default=131072, description="Maximum tokens in response")
29
+
30
+
31
+ class GroqFireworksOutput(OutputSchema):
32
+ """Schema for combined Groq and Fireworks output"""
33
+
34
+ combined_response: str
35
+ groq_response: str
36
+ fireworks_response: str
37
+ used_models: Dict[str, str]
38
+ usage: Dict[str, Dict[str, int]]
39
+
40
+
41
+ class GroqFireworksSkill(Skill[GroqFireworksInput, GroqFireworksOutput]):
42
+ """Skill combining Groq and Fireworks responses"""
43
+
44
+ input_schema = GroqFireworksInput
45
+ output_schema = GroqFireworksOutput
46
+
47
+ def __init__(
48
+ self,
49
+ groq_api_key: Optional[str] = None,
50
+ fireworks_skill: Optional[FireworksCompletionSkill] = None,
51
+ ):
52
+ """Initialize the skill with optional API keys"""
53
+ super().__init__()
54
+ self.groq_client = Groq(api_key=groq_api_key)
55
+ self.fireworks_skill = fireworks_skill or FireworksCompletionSkill()
56
+
57
+ def _get_groq_response(self, input_data: GroqFireworksInput) -> Dict[str, Any]:
58
+ """Get response from Groq"""
59
+ try:
60
+ completion = self.groq_client.chat.completions.create(
61
+ model=input_data.groq_model,
62
+ messages=[{"role": "user", "content": input_data.user_input}],
63
+ temperature=input_data.temperature,
64
+ max_tokens=input_data.max_tokens,
65
+ )
66
+ return {
67
+ "response": completion.choices[0].message.content,
68
+ "usage": completion.usage.model_dump(),
69
+ }
70
+ except Exception as e:
71
+ raise ProcessingError(f"Groq request failed: {str(e)}")
72
+
73
+ def _get_fireworks_response(
74
+ self, groq_response: str, input_data: GroqFireworksInput
75
+ ) -> Dict[str, Any]:
76
+ """Get response from Fireworks"""
77
+ try:
78
+ formatted_prompt = (
79
+ f"<USER>{input_data.user_input}</USER>\n<ASSISTANT>{groq_response}"
80
+ )
81
+
82
+ fireworks_input = FireworksCompletionInput(
83
+ prompt=formatted_prompt,
84
+ model=input_data.fireworks_model,
85
+ temperature=input_data.temperature,
86
+ max_tokens=input_data.max_tokens,
87
+ )
88
+
89
+ result = self.fireworks_skill.process(fireworks_input)
90
+ return {"response": result.response, "usage": result.usage}
91
+ except Exception as e:
92
+ raise ProcessingError(f"Fireworks request failed: {str(e)}")
93
+
94
+ def process(self, input_data: GroqFireworksInput) -> GroqFireworksOutput:
95
+ """Process the input using both Groq and Fireworks"""
96
+ try:
97
+ # Get Groq response
98
+ groq_result = self._get_groq_response(input_data)
99
+
100
+ # Get Fireworks response
101
+ fireworks_result = self._get_fireworks_response(
102
+ groq_result["response"], input_data
103
+ )
104
+
105
+ # Combine responses in the required format
106
+ combined_response = (
107
+ f"<USER>{input_data.user_input}</USER>\n"
108
+ f"<ASSISTANT>{groq_result['response']} {fireworks_result['response']}"
109
+ )
110
+
111
+ return GroqFireworksOutput(
112
+ combined_response=combined_response,
113
+ groq_response=groq_result["response"],
114
+ fireworks_response=fireworks_result["response"],
115
+ used_models={
116
+ "groq": input_data.groq_model,
117
+ "fireworks": input_data.fireworks_model,
118
+ },
119
+ usage={
120
+ "groq": groq_result["usage"],
121
+ "fireworks": fireworks_result["usage"],
122
+ },
123
+ )
124
+
125
+ except Exception as e:
126
+ raise ProcessingError(f"Combined processing failed: {str(e)}")
@@ -0,0 +1,210 @@
1
+ from typing import Optional, Dict, Any, List
2
+ from pydantic import Field
3
+
4
+ from airtrain.core.skills import Skill, ProcessingError
5
+ from airtrain.core.schemas import InputSchema, OutputSchema
6
+ from airtrain.core.credentials import BaseCredentials
7
+
8
+ # Import existing list models skills
9
+ from airtrain.integrations.openai.list_models import OpenAIListModelsSkill
10
+ from airtrain.integrations.anthropic.list_models import AnthropicListModelsSkill
11
+ from airtrain.integrations.together.list_models import TogetherListModelsSkill
12
+ from airtrain.integrations.fireworks.list_models import FireworksListModelsSkill
13
+
14
+ # Import credentials
15
+ from airtrain.integrations.groq.credentials import GroqCredentials
16
+ from airtrain.integrations.cerebras.credentials import CerebrasCredentials
17
+ from airtrain.integrations.sambanova.credentials import SambanovaCredentials
18
+ from airtrain.integrations.perplexity.credentials import PerplexityCredentials
19
+
20
+ # Import Perplexity list models
21
+ from airtrain.integrations.perplexity.list_models import PerplexityListModelsSkill
22
+
23
+
24
+ # Generic list models input schema
25
+ class GenericListModelsInput(InputSchema):
26
+ """Generic schema for listing models from any provider"""
27
+
28
+ api_models_only: bool = Field(
29
+ default=False,
30
+ description=(
31
+ "If True, fetch models from the API only. If False, use local config."
32
+ ),
33
+ )
34
+
35
+ class Config:
36
+ arbitrary_types_allowed = True
37
+ extra = "allow"
38
+
39
+
40
+ # Generic list models output schema
41
+ class GenericListModelsOutput(OutputSchema):
42
+ """Generic schema for list models output from any provider"""
43
+
44
+ models: List[Dict[str, Any]] = Field(
45
+ default_factory=list, description="List of models"
46
+ )
47
+ provider: str = Field(..., description="Provider name")
48
+
49
+
50
+ # Base class for stub implementations
51
+ class BaseListModelsSkill(Skill[GenericListModelsInput, GenericListModelsOutput]):
52
+ """Base skill for listing models"""
53
+
54
+ input_schema = GenericListModelsInput
55
+ output_schema = GenericListModelsOutput
56
+
57
+ def __init__(self, provider: str, credentials: Optional[BaseCredentials] = None):
58
+ """Initialize the skill with provider name and optional credentials"""
59
+ super().__init__()
60
+ self.provider = provider
61
+ self.credentials = credentials
62
+
63
+ def get_models(self) -> List[Dict[str, Any]]:
64
+ """Return list of models. To be implemented by subclasses."""
65
+ raise NotImplementedError("Subclasses must implement get_models()")
66
+
67
+ def process(self, input_data: GenericListModelsInput) -> GenericListModelsOutput:
68
+ """Process the input and return a list of models."""
69
+ try:
70
+ models = self.get_models()
71
+ return GenericListModelsOutput(models=models, provider=self.provider)
72
+ except Exception as e:
73
+ raise ProcessingError(f"Failed to list {self.provider} models: {str(e)}")
74
+
75
+
76
+ # Groq implementation
77
+ class GroqListModelsSkill(BaseListModelsSkill):
78
+ """Skill for listing Groq models"""
79
+
80
+ def __init__(self, credentials: Optional[GroqCredentials] = None):
81
+ """Initialize the skill with optional credentials"""
82
+ super().__init__(provider="groq", credentials=credentials)
83
+
84
+ def get_models(self) -> List[Dict[str, Any]]:
85
+ """Return list of Groq models."""
86
+ # Default Groq models from trmx_agent config
87
+ models = [
88
+ {
89
+ "id": "llama-3.3-70b-versatile",
90
+ "display_name": "Llama 3.3 70B Versatile (Tool Use)",
91
+ },
92
+ {
93
+ "id": "llama-3.1-8b-instant",
94
+ "display_name": "Llama 3.1 8B Instant (Tool Use)",
95
+ },
96
+ {
97
+ "id": "mixtral-8x7b-32768",
98
+ "display_name": "Mixtral 8x7B (32K) (Tool Use)",
99
+ },
100
+ {"id": "gemma2-9b-it", "display_name": "Gemma 2 9B IT (Tool Use)"},
101
+ {"id": "qwen-qwq-32b", "display_name": "Qwen QWQ 32B (Tool Use)"},
102
+ {
103
+ "id": "qwen-2.5-coder-32b",
104
+ "display_name": "Qwen 2.5 Coder 32B (Tool Use)",
105
+ },
106
+ {"id": "qwen-2.5-32b", "display_name": "Qwen 2.5 32B (Tool Use)"},
107
+ {
108
+ "id": "deepseek-r1-distill-qwen-32b",
109
+ "display_name": "DeepSeek R1 Distill Qwen 32B (Tool Use)",
110
+ },
111
+ {
112
+ "id": "deepseek-r1-distill-llama-70b",
113
+ "display_name": "DeepSeek R1 Distill Llama 70B (Tool Use)",
114
+ },
115
+ ]
116
+ return models
117
+
118
+
119
+ # Cerebras implementation
120
+ class CerebrasListModelsSkill(BaseListModelsSkill):
121
+ """Skill for listing Cerebras models"""
122
+
123
+ def __init__(self, credentials: Optional[CerebrasCredentials] = None):
124
+ """Initialize the skill with optional credentials"""
125
+ super().__init__(provider="cerebras", credentials=credentials)
126
+
127
+ def get_models(self) -> List[Dict[str, Any]]:
128
+ """Return list of Cerebras models."""
129
+ # Default Cerebras models from trmx_agent config
130
+ models = [
131
+ {
132
+ "id": "cerebras/Cerebras-GPT-13B-v0.1",
133
+ "display_name": "Cerebras GPT 13B v0.1",
134
+ },
135
+ {
136
+ "id": "cerebras/Cerebras-GPT-111M-v0.9",
137
+ "display_name": "Cerebras GPT 111M v0.9",
138
+ },
139
+ {
140
+ "id": "cerebras/Cerebras-GPT-590M-v0.7",
141
+ "display_name": "Cerebras GPT 590M v0.7",
142
+ },
143
+ ]
144
+ return models
145
+
146
+
147
+ # Sambanova implementation
148
+ class SambanovaListModelsSkill(BaseListModelsSkill):
149
+ """Skill for listing Sambanova models"""
150
+
151
+ def __init__(self, credentials: Optional[SambanovaCredentials] = None):
152
+ """Initialize the skill with optional credentials"""
153
+ super().__init__(provider="sambanova", credentials=credentials)
154
+
155
+ def get_models(self) -> List[Dict[str, Any]]:
156
+ """Return list of Sambanova models."""
157
+ # Limited Sambanova model information
158
+ models = [
159
+ {"id": "sambanova/samba-1", "display_name": "Samba-1"},
160
+ {"id": "sambanova/samba-2", "display_name": "Samba-2"},
161
+ ]
162
+ return models
163
+
164
+
165
+ # Factory class
166
+ class ListModelsSkillFactory:
167
+ """Factory for creating list models skills for different providers"""
168
+
169
+ # Map provider names to their corresponding list models skills
170
+ _PROVIDER_MAP = {
171
+ "openai": OpenAIListModelsSkill,
172
+ "anthropic": AnthropicListModelsSkill,
173
+ "together": TogetherListModelsSkill,
174
+ "fireworks": FireworksListModelsSkill,
175
+ "groq": GroqListModelsSkill,
176
+ "cerebras": CerebrasListModelsSkill,
177
+ "sambanova": SambanovaListModelsSkill,
178
+ "perplexity": PerplexityListModelsSkill,
179
+ }
180
+
181
+ @classmethod
182
+ def get_skill(cls, provider: str, credentials=None):
183
+ """Return a list models skill for the specified provider
184
+
185
+ Args:
186
+ provider (str): The provider name (case-insensitive)
187
+ credentials: Optional credentials for the provider
188
+
189
+ Returns:
190
+ A ListModelsSkill instance for the specified provider
191
+
192
+ Raises:
193
+ ValueError: If the provider is not supported
194
+ """
195
+ provider = provider.lower()
196
+
197
+ if provider not in cls._PROVIDER_MAP:
198
+ supported = ", ".join(cls.get_supported_providers())
199
+ raise ValueError(
200
+ f"Unsupported provider: {provider}. "
201
+ f"Supported providers are: {supported}"
202
+ )
203
+
204
+ skill_class = cls._PROVIDER_MAP[provider]
205
+ return skill_class(credentials=credentials)
206
+
207
+ @classmethod
208
+ def get_supported_providers(cls):
209
+ """Return a list of supported provider names"""
210
+ return list(cls._PROVIDER_MAP.keys())
@@ -0,0 +1,21 @@
1
+ """Fireworks AI integration module"""
2
+
3
+ from .credentials import FireworksCredentials
4
+ from .skills import FireworksChatSkill, FireworksInput, FireworksOutput
5
+ from .list_models import (
6
+ FireworksListModelsSkill,
7
+ FireworksListModelsInput,
8
+ FireworksListModelsOutput,
9
+ )
10
+ from .models import FireworksModel
11
+
12
+ __all__ = [
13
+ "FireworksCredentials",
14
+ "FireworksChatSkill",
15
+ "FireworksInput",
16
+ "FireworksOutput",
17
+ "FireworksListModelsSkill",
18
+ "FireworksListModelsInput",
19
+ "FireworksListModelsOutput",
20
+ "FireworksModel",
21
+ ]