airtrain 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. airtrain/__init__.py +146 -6
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  19. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  21. airtrain/core/credentials.py +62 -44
  22. airtrain/core/skills.py +102 -0
  23. airtrain/integrations/__init__.py +74 -0
  24. airtrain/integrations/anthropic/__init__.py +33 -0
  25. airtrain/integrations/anthropic/credentials.py +32 -0
  26. airtrain/integrations/anthropic/list_models.py +110 -0
  27. airtrain/integrations/anthropic/models_config.py +100 -0
  28. airtrain/integrations/anthropic/skills.py +155 -0
  29. airtrain/integrations/aws/__init__.py +6 -0
  30. airtrain/integrations/aws/credentials.py +36 -0
  31. airtrain/integrations/aws/skills.py +98 -0
  32. airtrain/integrations/cerebras/__init__.py +6 -0
  33. airtrain/integrations/cerebras/credentials.py +19 -0
  34. airtrain/integrations/cerebras/skills.py +127 -0
  35. airtrain/integrations/combined/__init__.py +21 -0
  36. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  37. airtrain/integrations/combined/list_models_factory.py +210 -0
  38. airtrain/integrations/fireworks/__init__.py +21 -0
  39. airtrain/integrations/fireworks/completion_skills.py +147 -0
  40. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  41. airtrain/integrations/fireworks/credentials.py +26 -0
  42. airtrain/integrations/fireworks/list_models.py +128 -0
  43. airtrain/integrations/fireworks/models.py +139 -0
  44. airtrain/integrations/fireworks/requests_skills.py +207 -0
  45. airtrain/integrations/fireworks/skills.py +181 -0
  46. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  47. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  48. airtrain/integrations/fireworks/structured_skills.py +102 -0
  49. airtrain/integrations/google/__init__.py +7 -0
  50. airtrain/integrations/google/credentials.py +58 -0
  51. airtrain/integrations/google/skills.py +122 -0
  52. airtrain/integrations/groq/__init__.py +23 -0
  53. airtrain/integrations/groq/credentials.py +24 -0
  54. airtrain/integrations/groq/models_config.py +162 -0
  55. airtrain/integrations/groq/skills.py +201 -0
  56. airtrain/integrations/ollama/__init__.py +6 -0
  57. airtrain/integrations/ollama/credentials.py +26 -0
  58. airtrain/integrations/ollama/skills.py +41 -0
  59. airtrain/integrations/openai/__init__.py +37 -0
  60. airtrain/integrations/openai/chinese_assistant.py +42 -0
  61. airtrain/integrations/openai/credentials.py +39 -0
  62. airtrain/integrations/openai/list_models.py +112 -0
  63. airtrain/integrations/openai/models_config.py +224 -0
  64. airtrain/integrations/openai/skills.py +342 -0
  65. airtrain/integrations/perplexity/__init__.py +49 -0
  66. airtrain/integrations/perplexity/credentials.py +43 -0
  67. airtrain/integrations/perplexity/list_models.py +112 -0
  68. airtrain/integrations/perplexity/models_config.py +128 -0
  69. airtrain/integrations/perplexity/skills.py +279 -0
  70. airtrain/integrations/sambanova/__init__.py +6 -0
  71. airtrain/integrations/sambanova/credentials.py +20 -0
  72. airtrain/integrations/sambanova/skills.py +129 -0
  73. airtrain/integrations/search/__init__.py +21 -0
  74. airtrain/integrations/search/exa/__init__.py +23 -0
  75. airtrain/integrations/search/exa/credentials.py +30 -0
  76. airtrain/integrations/search/exa/schemas.py +114 -0
  77. airtrain/integrations/search/exa/skills.py +115 -0
  78. airtrain/integrations/together/__init__.py +33 -0
  79. airtrain/integrations/together/audio_models_config.py +34 -0
  80. airtrain/integrations/together/credentials.py +22 -0
  81. airtrain/integrations/together/embedding_models_config.py +92 -0
  82. airtrain/integrations/together/image_models_config.py +69 -0
  83. airtrain/integrations/together/image_skill.py +143 -0
  84. airtrain/integrations/together/list_models.py +76 -0
  85. airtrain/integrations/together/models.py +95 -0
  86. airtrain/integrations/together/models_config.py +399 -0
  87. airtrain/integrations/together/rerank_models_config.py +43 -0
  88. airtrain/integrations/together/rerank_skill.py +49 -0
  89. airtrain/integrations/together/schemas.py +33 -0
  90. airtrain/integrations/together/skills.py +305 -0
  91. airtrain/integrations/together/vision_models_config.py +49 -0
  92. airtrain/telemetry/__init__.py +38 -0
  93. airtrain/telemetry/service.py +167 -0
  94. airtrain/telemetry/views.py +237 -0
  95. airtrain/tools/__init__.py +45 -0
  96. airtrain/tools/command.py +398 -0
  97. airtrain/tools/filesystem.py +166 -0
  98. airtrain/tools/network.py +111 -0
  99. airtrain/tools/registry.py +320 -0
  100. airtrain/tools/search.py +450 -0
  101. airtrain/tools/testing.py +135 -0
  102. airtrain-0.1.4.dist-info/METADATA +222 -0
  103. airtrain-0.1.4.dist-info/RECORD +108 -0
  104. {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  105. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  106. airtrain-0.1.3.dist-info/METADATA +0 -106
  107. airtrain-0.1.3.dist-info/RECORD +0 -9
  108. {airtrain-0.1.3.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,162 @@
1
+ """Configuration of Groq model capabilities."""
2
+
3
+ from typing import Dict, Any
4
+
5
+
6
+ # Model configuration with capabilities for each model
7
+ GROQ_MODELS_CONFIG = {
8
+ "llama-3.3-70b-versatile": {
9
+ "name": "Llama 3.3 70B Versatile",
10
+ "context_window": 128000,
11
+ "max_completion_tokens": 32768,
12
+ "tool_use": True,
13
+ "parallel_tool_use": True,
14
+ "json_mode": True,
15
+ },
16
+ "llama-3.1-8b-instant": {
17
+ "name": "Llama 3.1 8B Instant",
18
+ "context_window": 128000,
19
+ "max_completion_tokens": 8192,
20
+ "tool_use": True,
21
+ "parallel_tool_use": True,
22
+ "json_mode": True,
23
+ },
24
+ "mixtral-8x7b-32768": {
25
+ "name": "Mixtral 8x7B (32K)",
26
+ "context_window": 32768,
27
+ "max_completion_tokens": 8192,
28
+ "tool_use": True,
29
+ "parallel_tool_use": False,
30
+ "json_mode": True,
31
+ },
32
+ "gemma2-9b-it": {
33
+ "name": "Gemma 2 9B IT",
34
+ "context_window": 8192,
35
+ "max_completion_tokens": 4096,
36
+ "tool_use": True,
37
+ "parallel_tool_use": False,
38
+ "json_mode": True,
39
+ },
40
+ "qwen-qwq-32b": {
41
+ "name": "Qwen QWQ 32B",
42
+ "context_window": 128000,
43
+ "max_completion_tokens": 16384,
44
+ "tool_use": True,
45
+ "parallel_tool_use": True,
46
+ "json_mode": True,
47
+ },
48
+ "qwen-2.5-coder-32b": {
49
+ "name": "Qwen 2.5 Coder 32B",
50
+ "context_window": 128000,
51
+ "max_completion_tokens": 16384,
52
+ "tool_use": True,
53
+ "parallel_tool_use": True,
54
+ "json_mode": True,
55
+ },
56
+ "qwen-2.5-32b": {
57
+ "name": "Qwen 2.5 32B",
58
+ "context_window": 128000,
59
+ "max_completion_tokens": 16384,
60
+ "tool_use": True,
61
+ "parallel_tool_use": True,
62
+ "json_mode": True,
63
+ },
64
+ "deepseek-r1-distill-qwen-32b": {
65
+ "name": "DeepSeek R1 Distill Qwen 32B",
66
+ "context_window": 128000,
67
+ "max_completion_tokens": 16384,
68
+ "tool_use": True,
69
+ "parallel_tool_use": True,
70
+ "json_mode": True,
71
+ },
72
+ "deepseek-r1-distill-llama-70b": {
73
+ "name": "DeepSeek R1 Distill Llama 70B",
74
+ "context_window": 128000,
75
+ "max_completion_tokens": 16384,
76
+ "tool_use": True,
77
+ "parallel_tool_use": True,
78
+ "json_mode": True,
79
+ },
80
+ "deepseek-r1-distill-llama-70b-specdec": {
81
+ "name": "DeepSeek R1 Distill Llama 70B SpecDec",
82
+ "context_window": 128000,
83
+ "max_completion_tokens": 16384,
84
+ "tool_use": False,
85
+ "parallel_tool_use": False,
86
+ "json_mode": False,
87
+ },
88
+ "llama3-70b-8192": {
89
+ "name": "Llama 3 70B (8K)",
90
+ "context_window": 8192,
91
+ "max_completion_tokens": 4096,
92
+ "tool_use": False,
93
+ "parallel_tool_use": False,
94
+ "json_mode": False,
95
+ },
96
+ "llama3-8b-8192": {
97
+ "name": "Llama 3 8B (8K)",
98
+ "context_window": 8192,
99
+ "max_completion_tokens": 4096,
100
+ "tool_use": False,
101
+ "parallel_tool_use": False,
102
+ "json_mode": False,
103
+ },
104
+ }
105
+
106
+
107
+ def get_model_config(model_id: str) -> Dict[str, Any]:
108
+ """
109
+ Get the configuration for a specific model.
110
+
111
+ Args:
112
+ model_id: The model ID to get configuration for
113
+
114
+ Returns:
115
+ Dict with model configuration
116
+
117
+ Raises:
118
+ ValueError: If model_id is not found in configuration
119
+ """
120
+ if model_id in GROQ_MODELS_CONFIG:
121
+ return GROQ_MODELS_CONFIG[model_id]
122
+
123
+ # Try to find a match with different format or case
124
+ normalized_id = model_id.lower().replace("-", "").replace("_", "")
125
+ for config_id, config in GROQ_MODELS_CONFIG.items():
126
+ if normalized_id == config_id.lower().replace("-", "").replace("_", ""):
127
+ return config
128
+
129
+ # Default configuration for unknown models
130
+ return {
131
+ "name": model_id,
132
+ "context_window": 4096, # Conservative default
133
+ "max_completion_tokens": 1024, # Conservative default
134
+ "tool_use": False,
135
+ "parallel_tool_use": False,
136
+ "json_mode": False,
137
+ }
138
+
139
+
140
+ def get_default_model() -> str:
141
+ """Get the default model ID for Groq."""
142
+ return "llama-3.3-70b-versatile"
143
+
144
+
145
+ def supports_tool_use(model_id: str) -> bool:
146
+ """Check if a model supports tool use."""
147
+ return get_model_config(model_id).get("tool_use", False)
148
+
149
+
150
+ def supports_parallel_tool_use(model_id: str) -> bool:
151
+ """Check if a model supports parallel tool use."""
152
+ return get_model_config(model_id).get("parallel_tool_use", False)
153
+
154
+
155
+ def supports_json_mode(model_id: str) -> bool:
156
+ """Check if a model supports JSON mode."""
157
+ return get_model_config(model_id).get("json_mode", False)
158
+
159
+
160
+ def get_max_completion_tokens(model_id: str) -> int:
161
+ """Get the maximum number of completion tokens for a model."""
162
+ return get_model_config(model_id).get("max_completion_tokens", 1024)
@@ -0,0 +1,201 @@
1
+ from typing import Generator, Optional, Dict, Any, List, Union
2
+ from pydantic import Field, validator
3
+ from airtrain.core.skills import Skill, ProcessingError
4
+ from airtrain.core.schemas import InputSchema, OutputSchema
5
+ from .credentials import GroqCredentials
6
+ from .models_config import get_max_completion_tokens, get_model_config
7
+ from groq import Groq
8
+
9
+
10
+ class GroqInput(InputSchema):
11
+ """Schema for Groq input"""
12
+
13
+ user_input: str = Field(..., description="User's input text")
14
+ system_prompt: str = Field(
15
+ default="You are a helpful assistant.",
16
+ description=(
17
+ "System prompt to guide the model's behavior"
18
+ ),
19
+ )
20
+ conversation_history: List[Dict[str, str]] = Field(
21
+ default_factory=list,
22
+ description=(
23
+ "List of previous conversation messages in "
24
+ "[{'role': 'user|assistant', 'content': 'message'}] format"
25
+ ),
26
+ )
27
+ model: str = Field(
28
+ default="llama-3.3-70b-versatile",
29
+ description="Groq model to use"
30
+ )
31
+ max_tokens: int = Field(
32
+ default=4096,
33
+ description="Maximum tokens in response"
34
+ )
35
+ temperature: float = Field(
36
+ default=0.7,
37
+ description="Temperature for response generation",
38
+ ge=0,
39
+ le=1
40
+ )
41
+ stream: bool = Field(
42
+ default=False,
43
+ description="Whether to stream the response progressively"
44
+ )
45
+ tools: Optional[List[Dict[str, Any]]] = Field(
46
+ default=None,
47
+ description=(
48
+ "A list of tools the model may use. "
49
+ "Currently only functions supported."
50
+ ),
51
+ )
52
+ tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
53
+ default=None,
54
+ description=(
55
+ "Controls which tool is called by the model. "
56
+ "'none', 'auto', or specific tool."
57
+ ),
58
+ )
59
+
60
+ @validator('max_tokens')
61
+ def validate_max_tokens(cls, v, values):
62
+ """Validate that max_tokens doesn't exceed the model's limit."""
63
+ if 'model' in values:
64
+ model_id = values['model']
65
+ max_limit = get_max_completion_tokens(model_id)
66
+ if v > max_limit:
67
+ return max_limit
68
+ return v
69
+
70
+
71
+ class GroqOutput(OutputSchema):
72
+ """Schema for Groq output"""
73
+
74
+ response: str = Field(..., description="Model's response text")
75
+ used_model: str = Field(..., description="Model used for generation")
76
+ usage: Dict[str, Any] = Field(
77
+ default_factory=dict, description="Usage statistics from the API"
78
+ )
79
+ tool_calls: Optional[List[Dict[str, Any]]] = Field(
80
+ default=None, description="Tool calls generated by the model"
81
+ )
82
+
83
+
84
+ class GroqChatSkill(Skill[GroqInput, GroqOutput]):
85
+ """Skill for Groq chat"""
86
+
87
+ input_schema = GroqInput
88
+ output_schema = GroqOutput
89
+
90
+ def __init__(self, credentials: Optional[GroqCredentials] = None):
91
+ super().__init__()
92
+ self.credentials = credentials or GroqCredentials.from_env()
93
+ self.client = Groq(api_key=self.credentials.groq_api_key.get_secret_value())
94
+
95
+ def _build_messages(self, input_data: GroqInput) -> List[Dict[str, str]]:
96
+ """
97
+ Build messages list from input data including conversation history.
98
+
99
+ Args:
100
+ input_data: The input data containing system prompt, conversation history, and user input
101
+
102
+ Returns:
103
+ List[Dict[str, str]]: List of messages in the format required by Groq
104
+ """
105
+ messages = [{"role": "system", "content": input_data.system_prompt}]
106
+
107
+ # Add conversation history if present
108
+ if input_data.conversation_history:
109
+ messages.extend(input_data.conversation_history)
110
+
111
+ # Add current user input
112
+ messages.append({"role": "user", "content": input_data.user_input})
113
+
114
+ return messages
115
+
116
+ def process_stream(self, input_data: GroqInput) -> Generator[str, None, None]:
117
+ """Process the input and stream the response token by token."""
118
+ try:
119
+ messages = self._build_messages(input_data)
120
+
121
+ stream = self.client.chat.completions.create(
122
+ model=input_data.model,
123
+ messages=messages,
124
+ temperature=input_data.temperature,
125
+ max_tokens=input_data.max_tokens,
126
+ stream=True,
127
+ )
128
+
129
+ for chunk in stream:
130
+ if chunk.choices[0].delta.content is not None:
131
+ yield chunk.choices[0].delta.content
132
+
133
+ except Exception as e:
134
+ raise ProcessingError(f"Groq streaming failed: {str(e)}")
135
+
136
+ def process(self, input_data: GroqInput) -> GroqOutput:
137
+ """Process the input and return the complete response."""
138
+ try:
139
+ if input_data.stream:
140
+ response_chunks = []
141
+ for chunk in self.process_stream(input_data):
142
+ response_chunks.append(chunk)
143
+ response = "".join(response_chunks)
144
+ usage = {} # Usage stats not available in streaming
145
+ tool_calls = None # Tool calls not available in streaming
146
+ else:
147
+ messages = self._build_messages(input_data)
148
+
149
+ # Prepare API call parameters
150
+ api_params = {
151
+ "model": input_data.model,
152
+ "messages": messages,
153
+ "temperature": input_data.temperature,
154
+ "max_tokens": input_data.max_tokens,
155
+ "stream": False,
156
+ }
157
+
158
+ # Add tools and tool_choice if provided
159
+ if input_data.tools:
160
+ api_params["tools"] = input_data.tools
161
+
162
+ if input_data.tool_choice:
163
+ api_params["tool_choice"] = input_data.tool_choice
164
+
165
+ completion = self.client.chat.completions.create(**api_params)
166
+ response = completion.choices[0].message.content or ""
167
+
168
+ # Extract usage information
169
+ usage = {
170
+ "total_tokens": completion.usage.total_tokens,
171
+ "prompt_tokens": completion.usage.prompt_tokens,
172
+ "completion_tokens": completion.usage.completion_tokens,
173
+ }
174
+
175
+ # Check for tool calls in the response
176
+ tool_calls = None
177
+ if (
178
+ hasattr(completion.choices[0].message, "tool_calls")
179
+ and completion.choices[0].message.tool_calls
180
+ ):
181
+ tool_calls = [
182
+ {
183
+ "id": tool_call.id,
184
+ "type": tool_call.type,
185
+ "function": {
186
+ "name": tool_call.function.name,
187
+ "arguments": tool_call.function.arguments
188
+ }
189
+ }
190
+ for tool_call in completion.choices[0].message.tool_calls
191
+ ]
192
+
193
+ return GroqOutput(
194
+ response=response,
195
+ used_model=input_data.model,
196
+ usage=usage,
197
+ tool_calls=tool_calls
198
+ )
199
+
200
+ except Exception as e:
201
+ raise ProcessingError(f"Groq processing failed: {str(e)}")
@@ -0,0 +1,6 @@
1
+ """Ollama integration module"""
2
+
3
+ from .credentials import OllamaCredentials
4
+ from .skills import OllamaChatSkill
5
+
6
+ __all__ = ["OllamaCredentials", "OllamaChatSkill"]
@@ -0,0 +1,26 @@
1
+ from pydantic import Field
2
+ from airtrain.core.credentials import BaseCredentials, CredentialValidationError
3
+ from importlib.util import find_spec
4
+
5
+
6
+ class OllamaCredentials(BaseCredentials):
7
+ """Ollama credentials"""
8
+
9
+ host: str = Field(default="http://localhost:11434", description="Ollama host URL")
10
+ timeout: int = Field(default=30, description="Request timeout in seconds")
11
+
12
+ async def validate_credentials(self) -> bool:
13
+ """Validate Ollama credentials"""
14
+ if find_spec("ollama") is None:
15
+ raise CredentialValidationError(
16
+ "Ollama package is not installed. Please install it using: pip install ollama"
17
+ )
18
+
19
+ try:
20
+ from ollama import Client
21
+
22
+ client = Client(host=self.host)
23
+ await client.list()
24
+ return True
25
+ except Exception as e:
26
+ raise CredentialValidationError(f"Invalid Ollama connection: {str(e)}")
@@ -0,0 +1,41 @@
1
+ from typing import Optional, Dict, Any
2
+ from pydantic import Field
3
+ from airtrain.core.skills import Skill, ProcessingError
4
+ from airtrain.core.schemas import InputSchema, OutputSchema
5
+ from .credentials import OllamaCredentials
6
+
7
+
8
+ class OllamaInput(InputSchema):
9
+ """Schema for Ollama input"""
10
+
11
+ user_input: str = Field(..., description="User's input text")
12
+ system_prompt: str = Field(
13
+ default="You are a helpful assistant.",
14
+ description="System prompt to guide the model's behavior",
15
+ )
16
+ model: str = Field(default="llama2", description="Ollama model to use")
17
+ max_tokens: int = Field(default=131072, description="Maximum tokens in response")
18
+ temperature: float = Field(
19
+ default=0.7, description="Temperature for response generation", ge=0, le=1
20
+ )
21
+
22
+
23
+ class OllamaOutput(OutputSchema):
24
+ """Schema for Ollama output"""
25
+
26
+ response: str = Field(..., description="Model's response text")
27
+ used_model: str = Field(..., description="Model used for generation")
28
+ usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
29
+
30
+
31
+ class OllamaChatSkill(Skill[OllamaInput, OllamaOutput]):
32
+ """Skill for Ollama - Not Implemented"""
33
+
34
+ input_schema = OllamaInput
35
+ output_schema = OllamaOutput
36
+
37
+ def __init__(self, credentials: Optional[OllamaCredentials] = None):
38
+ raise NotImplementedError("OllamaChatSkill is not implemented yet")
39
+
40
+ def process(self, input_data: OllamaInput) -> OllamaOutput:
41
+ raise NotImplementedError("OllamaChatSkill is not implemented yet")
@@ -0,0 +1,37 @@
1
+ """OpenAI API integration."""
2
+
3
+ from .skills import (
4
+ OpenAIChatSkill,
5
+ OpenAIInput,
6
+ OpenAIParserSkill,
7
+ OpenAIOutput,
8
+ OpenAIParserInput,
9
+ OpenAIParserOutput,
10
+ OpenAIEmbeddingsSkill,
11
+ OpenAIEmbeddingsInput,
12
+ OpenAIEmbeddingsOutput,
13
+ )
14
+ from .credentials import OpenAICredentials
15
+ from .list_models import (
16
+ OpenAIListModelsSkill,
17
+ OpenAIListModelsInput,
18
+ OpenAIListModelsOutput,
19
+ OpenAIModel,
20
+ )
21
+
22
+ __all__ = [
23
+ "OpenAIChatSkill",
24
+ "OpenAIInput",
25
+ "OpenAIParserSkill",
26
+ "OpenAIParserInput",
27
+ "OpenAIParserOutput",
28
+ "OpenAICredentials",
29
+ "OpenAIOutput",
30
+ "OpenAIEmbeddingsSkill",
31
+ "OpenAIEmbeddingsInput",
32
+ "OpenAIEmbeddingsOutput",
33
+ "OpenAIListModelsSkill",
34
+ "OpenAIListModelsInput",
35
+ "OpenAIListModelsOutput",
36
+ "OpenAIModel",
37
+ ]
@@ -0,0 +1,42 @@
1
+ from typing import Optional, TypeVar
2
+ from pydantic import Field
3
+ from .skills import OpenAIChatSkill, OpenAIInput, OpenAIOutput
4
+ from .credentials import OpenAICredentials
5
+
6
+ T = TypeVar("T", bound=OpenAIInput)
7
+
8
+
9
+ class ChineseAssistantInput(OpenAIInput):
10
+ """Schema for Chinese Assistant input"""
11
+
12
+ user_input: str = Field(
13
+ ..., description="User's input text (can be in any language)"
14
+ )
15
+ system_prompt: str = Field(
16
+ default="你是一个有帮助的助手。请用中文回答所有问题,即使问题是用其他语言问的。回答要准确、礼貌、专业。",
17
+ description="System prompt in Chinese",
18
+ )
19
+ model: str = Field(default="gpt-4o", description="OpenAI model to use")
20
+ max_tokens: int = Field(default=131072, description="Maximum tokens in response")
21
+ temperature: float = Field(
22
+ default=0.7, description="Temperature for response generation", ge=0, le=1
23
+ )
24
+
25
+
26
+ class ChineseAssistantSkill(OpenAIChatSkill):
27
+ """Skill for Chinese language assistance"""
28
+
29
+ input_schema = ChineseAssistantInput
30
+ output_schema = OpenAIOutput
31
+
32
+ def __init__(self, credentials: Optional[OpenAICredentials] = None):
33
+ super().__init__(credentials)
34
+
35
+ def process(self, input_data: T) -> OpenAIOutput:
36
+ # Add language check to ensure response is in Chinese
37
+ if "你是" not in input_data.system_prompt:
38
+ input_data.system_prompt = (
39
+ "你是一个中文助手。" + input_data.system_prompt + "请用中文回答。"
40
+ )
41
+
42
+ return super().process(input_data)
@@ -0,0 +1,39 @@
1
+ from datetime import datetime, timedelta
2
+ from typing import Optional
3
+ from pydantic import Field, SecretStr, validator
4
+ from openai import OpenAI
5
+
6
+ from airtrain.core.credentials import BaseCredentials, CredentialValidationError
7
+
8
+
9
+ class OpenAICredentials(BaseCredentials):
10
+ """OpenAI API credentials with enhanced validation"""
11
+
12
+ openai_api_key: SecretStr = Field(..., description="OpenAI API key")
13
+ openai_organization_id: Optional[str] = Field(
14
+ None, description="OpenAI organization ID", pattern="^org-[A-Za-z0-9]{24}$"
15
+ )
16
+
17
+ _required_credentials = {"openai_api_key"}
18
+
19
+ @validator("openai_api_key")
20
+ def validate_api_key_format(cls, v: SecretStr) -> SecretStr:
21
+ key = v.get_secret_value()
22
+ if not key.startswith("sk-"):
23
+ raise ValueError("OpenAI API key must start with 'sk-'")
24
+ if len(key) < 40:
25
+ raise ValueError("OpenAI API key appears to be too short")
26
+ return v
27
+
28
+ async def validate_credentials(self) -> bool:
29
+ """Validate credentials by making a test API call"""
30
+ try:
31
+ client = OpenAI(
32
+ api_key=self.openai_api_key.get_secret_value(),
33
+ organization=self.openai_organization_id,
34
+ )
35
+ # Make minimal API call to validate
36
+ await client.models.list(limit=1)
37
+ return True
38
+ except Exception as e:
39
+ raise CredentialValidationError(f"Invalid OpenAI credentials: {str(e)}")
@@ -0,0 +1,112 @@
1
+ from typing import Optional, List, Dict, Any
2
+ from pydantic import Field
3
+
4
+ from airtrain.core.skills import Skill, ProcessingError
5
+ from airtrain.core.schemas import InputSchema, OutputSchema
6
+ from .credentials import OpenAICredentials
7
+ from .models_config import OPENAI_MODELS, OpenAIModelConfig
8
+
9
+
10
+ class OpenAIModel:
11
+ """Class to represent an OpenAI model."""
12
+
13
+ def __init__(self, model_id: str, config: OpenAIModelConfig):
14
+ """Initialize the OpenAI model."""
15
+ self.id = model_id
16
+ self.display_name = config.display_name
17
+ self.base_model = config.base_model
18
+ self.input_price = config.input_price
19
+ self.cached_input_price = config.cached_input_price
20
+ self.output_price = config.output_price
21
+
22
+ def dict(self, exclude_none=False):
23
+ """Convert the model to a dictionary."""
24
+ result = {
25
+ "id": self.id,
26
+ "display_name": self.display_name,
27
+ "base_model": self.base_model,
28
+ "input_price": float(self.input_price),
29
+ "output_price": float(self.output_price),
30
+ }
31
+ if self.cached_input_price is not None:
32
+ result["cached_input_price"] = float(self.cached_input_price)
33
+ elif not exclude_none:
34
+ result["cached_input_price"] = None
35
+ return result
36
+
37
+
38
+ class OpenAIListModelsInput(InputSchema):
39
+ """Schema for OpenAI list models input"""
40
+
41
+ api_models_only: bool = Field(
42
+ default=False,
43
+ description=(
44
+ "If True, fetch models from the API only. If False, use local config."
45
+ )
46
+ )
47
+
48
+
49
+ class OpenAIListModelsOutput(OutputSchema):
50
+ """Schema for OpenAI list models output"""
51
+
52
+ models: List[Dict[str, Any]] = Field(
53
+ default_factory=list,
54
+ description="List of OpenAI models"
55
+ )
56
+
57
+
58
+ class OpenAIListModelsSkill(Skill[OpenAIListModelsInput, OpenAIListModelsOutput]):
59
+ """Skill for listing OpenAI models"""
60
+
61
+ input_schema = OpenAIListModelsInput
62
+ output_schema = OpenAIListModelsOutput
63
+
64
+ def __init__(self, credentials: Optional[OpenAICredentials] = None):
65
+ """Initialize the skill with optional credentials"""
66
+ super().__init__()
67
+ self.credentials = credentials
68
+
69
+ def process(
70
+ self, input_data: OpenAIListModelsInput
71
+ ) -> OpenAIListModelsOutput:
72
+ """Process the input and return a list of models."""
73
+ try:
74
+ models = []
75
+
76
+ if input_data.api_models_only:
77
+ # Fetch models from OpenAI API - requires credentials
78
+ if not self.credentials:
79
+ raise ProcessingError(
80
+ "OpenAI credentials required for API models"
81
+ )
82
+
83
+ from openai import OpenAI
84
+ client = OpenAI(
85
+ api_key=self.credentials.openai_api_key.get_secret_value(),
86
+ organization=self.credentials.openai_organization_id,
87
+ )
88
+
89
+ # Make API call to get models
90
+ response = client.models.list()
91
+
92
+ # Convert response to our format
93
+ for model in response.data:
94
+ models.append({
95
+ "id": model.id,
96
+ "display_name": model.id, # API doesn't provide display_name
97
+ "base_model": model.id, # API doesn't provide base_model
98
+ "created": model.created,
99
+ "owned_by": model.owned_by,
100
+ # Pricing info not available from API
101
+ })
102
+ else:
103
+ # Use local model config - no credentials needed
104
+ for model_id, config in OPENAI_MODELS.items():
105
+ model = OpenAIModel(model_id, config)
106
+ models.append(model.dict())
107
+
108
+ # Return the output
109
+ return OpenAIListModelsOutput(models=models)
110
+
111
+ except Exception as e:
112
+ raise ProcessingError(f"Failed to list OpenAI models: {str(e)}")