airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. airtrain/__init__.py +148 -2
  2. airtrain/__main__.py +4 -0
  3. airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
  4. airtrain/agents/__init__.py +45 -0
  5. airtrain/agents/example_agent.py +348 -0
  6. airtrain/agents/groq_agent.py +289 -0
  7. airtrain/agents/memory.py +663 -0
  8. airtrain/agents/registry.py +465 -0
  9. airtrain/builder/__init__.py +3 -0
  10. airtrain/builder/agent_builder.py +122 -0
  11. airtrain/cli/__init__.py +0 -0
  12. airtrain/cli/builder.py +23 -0
  13. airtrain/cli/main.py +120 -0
  14. airtrain/contrib/__init__.py +29 -0
  15. airtrain/contrib/travel/__init__.py +35 -0
  16. airtrain/contrib/travel/agents.py +243 -0
  17. airtrain/contrib/travel/models.py +59 -0
  18. airtrain/core/__init__.py +7 -0
  19. airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
  20. airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
  21. airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
  22. airtrain/core/credentials.py +171 -0
  23. airtrain/core/schemas.py +237 -0
  24. airtrain/core/skills.py +269 -0
  25. airtrain/integrations/__init__.py +74 -0
  26. airtrain/integrations/anthropic/__init__.py +33 -0
  27. airtrain/integrations/anthropic/credentials.py +32 -0
  28. airtrain/integrations/anthropic/list_models.py +110 -0
  29. airtrain/integrations/anthropic/models_config.py +100 -0
  30. airtrain/integrations/anthropic/skills.py +155 -0
  31. airtrain/integrations/aws/__init__.py +6 -0
  32. airtrain/integrations/aws/credentials.py +36 -0
  33. airtrain/integrations/aws/skills.py +98 -0
  34. airtrain/integrations/cerebras/__init__.py +6 -0
  35. airtrain/integrations/cerebras/credentials.py +19 -0
  36. airtrain/integrations/cerebras/skills.py +127 -0
  37. airtrain/integrations/combined/__init__.py +21 -0
  38. airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  39. airtrain/integrations/combined/list_models_factory.py +210 -0
  40. airtrain/integrations/fireworks/__init__.py +21 -0
  41. airtrain/integrations/fireworks/completion_skills.py +147 -0
  42. airtrain/integrations/fireworks/conversation_manager.py +109 -0
  43. airtrain/integrations/fireworks/credentials.py +26 -0
  44. airtrain/integrations/fireworks/list_models.py +128 -0
  45. airtrain/integrations/fireworks/models.py +139 -0
  46. airtrain/integrations/fireworks/requests_skills.py +207 -0
  47. airtrain/integrations/fireworks/skills.py +181 -0
  48. airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
  49. airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
  50. airtrain/integrations/fireworks/structured_skills.py +102 -0
  51. airtrain/integrations/google/__init__.py +7 -0
  52. airtrain/integrations/google/credentials.py +58 -0
  53. airtrain/integrations/google/skills.py +122 -0
  54. airtrain/integrations/groq/__init__.py +23 -0
  55. airtrain/integrations/groq/credentials.py +24 -0
  56. airtrain/integrations/groq/models_config.py +162 -0
  57. airtrain/integrations/groq/skills.py +201 -0
  58. airtrain/integrations/ollama/__init__.py +6 -0
  59. airtrain/integrations/ollama/credentials.py +26 -0
  60. airtrain/integrations/ollama/skills.py +41 -0
  61. airtrain/integrations/openai/__init__.py +37 -0
  62. airtrain/integrations/openai/chinese_assistant.py +42 -0
  63. airtrain/integrations/openai/credentials.py +39 -0
  64. airtrain/integrations/openai/list_models.py +112 -0
  65. airtrain/integrations/openai/models_config.py +224 -0
  66. airtrain/integrations/openai/skills.py +342 -0
  67. airtrain/integrations/perplexity/__init__.py +49 -0
  68. airtrain/integrations/perplexity/credentials.py +43 -0
  69. airtrain/integrations/perplexity/list_models.py +112 -0
  70. airtrain/integrations/perplexity/models_config.py +128 -0
  71. airtrain/integrations/perplexity/skills.py +279 -0
  72. airtrain/integrations/sambanova/__init__.py +6 -0
  73. airtrain/integrations/sambanova/credentials.py +20 -0
  74. airtrain/integrations/sambanova/skills.py +129 -0
  75. airtrain/integrations/search/__init__.py +21 -0
  76. airtrain/integrations/search/exa/__init__.py +23 -0
  77. airtrain/integrations/search/exa/credentials.py +30 -0
  78. airtrain/integrations/search/exa/schemas.py +114 -0
  79. airtrain/integrations/search/exa/skills.py +115 -0
  80. airtrain/integrations/together/__init__.py +33 -0
  81. airtrain/integrations/together/audio_models_config.py +34 -0
  82. airtrain/integrations/together/credentials.py +22 -0
  83. airtrain/integrations/together/embedding_models_config.py +92 -0
  84. airtrain/integrations/together/image_models_config.py +69 -0
  85. airtrain/integrations/together/image_skill.py +143 -0
  86. airtrain/integrations/together/list_models.py +76 -0
  87. airtrain/integrations/together/models.py +95 -0
  88. airtrain/integrations/together/models_config.py +399 -0
  89. airtrain/integrations/together/rerank_models_config.py +43 -0
  90. airtrain/integrations/together/rerank_skill.py +49 -0
  91. airtrain/integrations/together/schemas.py +33 -0
  92. airtrain/integrations/together/skills.py +305 -0
  93. airtrain/integrations/together/vision_models_config.py +49 -0
  94. airtrain/telemetry/__init__.py +38 -0
  95. airtrain/telemetry/service.py +167 -0
  96. airtrain/telemetry/views.py +237 -0
  97. airtrain/tools/__init__.py +45 -0
  98. airtrain/tools/command.py +398 -0
  99. airtrain/tools/filesystem.py +166 -0
  100. airtrain/tools/network.py +111 -0
  101. airtrain/tools/registry.py +320 -0
  102. airtrain/tools/search.py +450 -0
  103. airtrain/tools/testing.py +135 -0
  104. airtrain-0.1.4.dist-info/METADATA +222 -0
  105. airtrain-0.1.4.dist-info/RECORD +108 -0
  106. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
  107. airtrain-0.1.4.dist-info/entry_points.txt +2 -0
  108. airtrain-0.1.2.dist-info/METADATA +0 -106
  109. airtrain-0.1.2.dist-info/RECORD +0 -5
  110. {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,122 @@
1
+ from typing import List, Optional, Dict, Any
2
+ from pydantic import Field
3
+ import google.generativeai as genai
4
+ from loguru import logger
5
+
6
+ from airtrain.core.skills import Skill, ProcessingError
7
+ from airtrain.core.schemas import InputSchema, OutputSchema
8
+ from .credentials import GeminiCredentials
9
+
10
+
11
+ class GoogleGenerationConfig(InputSchema):
12
+ """Schema for Google generation config"""
13
+
14
+ temperature: float = Field(
15
+ default=1.0, description="Temperature for response generation", ge=0, le=1
16
+ )
17
+ top_p: float = Field(
18
+ default=0.95, description="Top p sampling parameter", ge=0, le=1
19
+ )
20
+ top_k: int = Field(default=40, description="Top k sampling parameter")
21
+ max_output_tokens: int = Field(
22
+ default=8192, description="Maximum tokens in response"
23
+ )
24
+ response_mime_type: str = Field(
25
+ default="text/plain", description="Response MIME type"
26
+ )
27
+
28
+
29
+ class GoogleInput(InputSchema):
30
+ """Schema for Google chat input"""
31
+
32
+ user_input: str = Field(..., description="User's input text")
33
+ system_prompt: str = Field(
34
+ default="You are a helpful assistant.",
35
+ description="System prompt to guide the model's behavior",
36
+ )
37
+ conversation_history: List[Dict[str, str | List[Dict[str, str]]]] = Field(
38
+ default_factory=list,
39
+ description="List of conversation messages in Google's format",
40
+ )
41
+ model: str = Field(default="gemini-1.5-flash", description="Google model to use")
42
+ generation_config: GoogleGenerationConfig = Field(
43
+ default_factory=GoogleGenerationConfig, description="Generation configuration"
44
+ )
45
+
46
+
47
+ class GoogleOutput(OutputSchema):
48
+ """Schema for Google chat output"""
49
+
50
+ response: str = Field(..., description="Model's response text")
51
+ used_model: str = Field(..., description="Model used for generation")
52
+ usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
53
+
54
+
55
+ class GoogleChatSkill(Skill[GoogleInput, GoogleOutput]):
56
+ """Skill for Google chat"""
57
+
58
+ input_schema = GoogleInput
59
+ output_schema = GoogleOutput
60
+
61
+ def __init__(self, credentials: Optional[GeminiCredentials] = None):
62
+ super().__init__()
63
+ self.credentials = credentials or GeminiCredentials.from_env()
64
+ genai.configure(api_key=self.credentials.gemini_api_key.get_secret_value())
65
+
66
+ def _convert_history_format(
67
+ self, history: List[Dict[str, str]]
68
+ ) -> List[Dict[str, List[Dict[str, str]]]]:
69
+ """Convert standard history format to Google's format"""
70
+ google_history = []
71
+ for msg in history:
72
+ google_msg = {
73
+ "role": "user" if msg["role"] == "user" else "model",
74
+ "parts": [{"text": msg["content"]}],
75
+ }
76
+ google_history.append(google_msg)
77
+ return google_history
78
+
79
+ def process(self, input_data: GoogleInput) -> GoogleOutput:
80
+ try:
81
+ # Create generation config
82
+ generation_config = {
83
+ "temperature": input_data.generation_config.temperature,
84
+ "top_p": input_data.generation_config.top_p,
85
+ "top_k": input_data.generation_config.top_k,
86
+ "max_output_tokens": input_data.generation_config.max_output_tokens,
87
+ "response_mime_type": input_data.generation_config.response_mime_type,
88
+ }
89
+
90
+ # Initialize model
91
+ model = genai.GenerativeModel(
92
+ model_name=input_data.model,
93
+ generation_config=generation_config,
94
+ system_instruction=input_data.system_prompt,
95
+ )
96
+
97
+ # Convert history format if needed
98
+ history = (
99
+ input_data.conversation_history
100
+ if input_data.conversation_history
101
+ else self._convert_history_format([])
102
+ )
103
+
104
+ # Start chat session
105
+ chat = model.start_chat(history=history)
106
+
107
+ # Send message and get response
108
+ response = chat.send_message(input_data.user_input)
109
+
110
+ return GoogleOutput(
111
+ response=response.text,
112
+ used_model=input_data.model,
113
+ usage={
114
+ "prompt_tokens": 0,
115
+ "completion_tokens": 0,
116
+ "total_tokens": 0,
117
+ }, # Google API doesn't provide usage stats
118
+ )
119
+
120
+ except Exception as e:
121
+ logger.exception(f"Google processing failed: {str(e)}")
122
+ raise ProcessingError(f"Google processing failed: {str(e)}")
@@ -0,0 +1,23 @@
1
+ """Groq integration module"""
2
+
3
+ from .credentials import GroqCredentials
4
+ from .skills import GroqChatSkill
5
+ from .models_config import (
6
+ get_model_config,
7
+ get_default_model,
8
+ supports_tool_use,
9
+ supports_parallel_tool_use,
10
+ supports_json_mode,
11
+ GROQ_MODELS_CONFIG,
12
+ )
13
+
14
+ __all__ = [
15
+ "GroqCredentials",
16
+ "GroqChatSkill",
17
+ "get_model_config",
18
+ "get_default_model",
19
+ "supports_tool_use",
20
+ "supports_parallel_tool_use",
21
+ "supports_json_mode",
22
+ "GROQ_MODELS_CONFIG",
23
+ ]
@@ -0,0 +1,24 @@
1
+ from pydantic import Field, SecretStr
2
+ from airtrain.core.credentials import BaseCredentials, CredentialValidationError
3
+ from groq import Groq
4
+
5
+
6
+ class GroqCredentials(BaseCredentials):
7
+ """Groq API credentials"""
8
+
9
+ groq_api_key: SecretStr = Field(..., description="Groq API key")
10
+
11
+ _required_credentials = {"groq_api_key"}
12
+
13
+ async def validate_credentials(self) -> bool:
14
+ """Validate Groq credentials"""
15
+ try:
16
+ client = Groq(api_key=self.groq_api_key.get_secret_value())
17
+ await client.chat.completions.create(
18
+ messages=[{"role": "user", "content": "Hi"}],
19
+ model="mixtral-8x7b-32768",
20
+ max_tokens=1,
21
+ )
22
+ return True
23
+ except Exception as e:
24
+ raise CredentialValidationError(f"Invalid Groq credentials: {str(e)}")
@@ -0,0 +1,162 @@
1
+ """Configuration of Groq model capabilities."""
2
+
3
+ from typing import Dict, Any
4
+
5
+
6
+ # Model configuration with capabilities for each model
7
+ GROQ_MODELS_CONFIG = {
8
+ "llama-3.3-70b-versatile": {
9
+ "name": "Llama 3.3 70B Versatile",
10
+ "context_window": 128000,
11
+ "max_completion_tokens": 32768,
12
+ "tool_use": True,
13
+ "parallel_tool_use": True,
14
+ "json_mode": True,
15
+ },
16
+ "llama-3.1-8b-instant": {
17
+ "name": "Llama 3.1 8B Instant",
18
+ "context_window": 128000,
19
+ "max_completion_tokens": 8192,
20
+ "tool_use": True,
21
+ "parallel_tool_use": True,
22
+ "json_mode": True,
23
+ },
24
+ "mixtral-8x7b-32768": {
25
+ "name": "Mixtral 8x7B (32K)",
26
+ "context_window": 32768,
27
+ "max_completion_tokens": 8192,
28
+ "tool_use": True,
29
+ "parallel_tool_use": False,
30
+ "json_mode": True,
31
+ },
32
+ "gemma2-9b-it": {
33
+ "name": "Gemma 2 9B IT",
34
+ "context_window": 8192,
35
+ "max_completion_tokens": 4096,
36
+ "tool_use": True,
37
+ "parallel_tool_use": False,
38
+ "json_mode": True,
39
+ },
40
+ "qwen-qwq-32b": {
41
+ "name": "Qwen QWQ 32B",
42
+ "context_window": 128000,
43
+ "max_completion_tokens": 16384,
44
+ "tool_use": True,
45
+ "parallel_tool_use": True,
46
+ "json_mode": True,
47
+ },
48
+ "qwen-2.5-coder-32b": {
49
+ "name": "Qwen 2.5 Coder 32B",
50
+ "context_window": 128000,
51
+ "max_completion_tokens": 16384,
52
+ "tool_use": True,
53
+ "parallel_tool_use": True,
54
+ "json_mode": True,
55
+ },
56
+ "qwen-2.5-32b": {
57
+ "name": "Qwen 2.5 32B",
58
+ "context_window": 128000,
59
+ "max_completion_tokens": 16384,
60
+ "tool_use": True,
61
+ "parallel_tool_use": True,
62
+ "json_mode": True,
63
+ },
64
+ "deepseek-r1-distill-qwen-32b": {
65
+ "name": "DeepSeek R1 Distill Qwen 32B",
66
+ "context_window": 128000,
67
+ "max_completion_tokens": 16384,
68
+ "tool_use": True,
69
+ "parallel_tool_use": True,
70
+ "json_mode": True,
71
+ },
72
+ "deepseek-r1-distill-llama-70b": {
73
+ "name": "DeepSeek R1 Distill Llama 70B",
74
+ "context_window": 128000,
75
+ "max_completion_tokens": 16384,
76
+ "tool_use": True,
77
+ "parallel_tool_use": True,
78
+ "json_mode": True,
79
+ },
80
+ "deepseek-r1-distill-llama-70b-specdec": {
81
+ "name": "DeepSeek R1 Distill Llama 70B SpecDec",
82
+ "context_window": 128000,
83
+ "max_completion_tokens": 16384,
84
+ "tool_use": False,
85
+ "parallel_tool_use": False,
86
+ "json_mode": False,
87
+ },
88
+ "llama3-70b-8192": {
89
+ "name": "Llama 3 70B (8K)",
90
+ "context_window": 8192,
91
+ "max_completion_tokens": 4096,
92
+ "tool_use": False,
93
+ "parallel_tool_use": False,
94
+ "json_mode": False,
95
+ },
96
+ "llama3-8b-8192": {
97
+ "name": "Llama 3 8B (8K)",
98
+ "context_window": 8192,
99
+ "max_completion_tokens": 4096,
100
+ "tool_use": False,
101
+ "parallel_tool_use": False,
102
+ "json_mode": False,
103
+ },
104
+ }
105
+
106
+
107
+ def get_model_config(model_id: str) -> Dict[str, Any]:
108
+ """
109
+ Get the configuration for a specific model.
110
+
111
+ Args:
112
+ model_id: The model ID to get configuration for
113
+
114
+ Returns:
115
+ Dict with model configuration
116
+
117
+ Raises:
118
+ ValueError: If model_id is not found in configuration
119
+ """
120
+ if model_id in GROQ_MODELS_CONFIG:
121
+ return GROQ_MODELS_CONFIG[model_id]
122
+
123
+ # Try to find a match with different format or case
124
+ normalized_id = model_id.lower().replace("-", "").replace("_", "")
125
+ for config_id, config in GROQ_MODELS_CONFIG.items():
126
+ if normalized_id == config_id.lower().replace("-", "").replace("_", ""):
127
+ return config
128
+
129
+ # Default configuration for unknown models
130
+ return {
131
+ "name": model_id,
132
+ "context_window": 4096, # Conservative default
133
+ "max_completion_tokens": 1024, # Conservative default
134
+ "tool_use": False,
135
+ "parallel_tool_use": False,
136
+ "json_mode": False,
137
+ }
138
+
139
+
140
+ def get_default_model() -> str:
141
+ """Get the default model ID for Groq."""
142
+ return "llama-3.3-70b-versatile"
143
+
144
+
145
+ def supports_tool_use(model_id: str) -> bool:
146
+ """Check if a model supports tool use."""
147
+ return get_model_config(model_id).get("tool_use", False)
148
+
149
+
150
+ def supports_parallel_tool_use(model_id: str) -> bool:
151
+ """Check if a model supports parallel tool use."""
152
+ return get_model_config(model_id).get("parallel_tool_use", False)
153
+
154
+
155
+ def supports_json_mode(model_id: str) -> bool:
156
+ """Check if a model supports JSON mode."""
157
+ return get_model_config(model_id).get("json_mode", False)
158
+
159
+
160
+ def get_max_completion_tokens(model_id: str) -> int:
161
+ """Get the maximum number of completion tokens for a model."""
162
+ return get_model_config(model_id).get("max_completion_tokens", 1024)
@@ -0,0 +1,201 @@
1
+ from typing import Generator, Optional, Dict, Any, List, Union
2
+ from pydantic import Field, validator
3
+ from airtrain.core.skills import Skill, ProcessingError
4
+ from airtrain.core.schemas import InputSchema, OutputSchema
5
+ from .credentials import GroqCredentials
6
+ from .models_config import get_max_completion_tokens, get_model_config
7
+ from groq import Groq
8
+
9
+
10
+ class GroqInput(InputSchema):
11
+ """Schema for Groq input"""
12
+
13
+ user_input: str = Field(..., description="User's input text")
14
+ system_prompt: str = Field(
15
+ default="You are a helpful assistant.",
16
+ description=(
17
+ "System prompt to guide the model's behavior"
18
+ ),
19
+ )
20
+ conversation_history: List[Dict[str, str]] = Field(
21
+ default_factory=list,
22
+ description=(
23
+ "List of previous conversation messages in "
24
+ "[{'role': 'user|assistant', 'content': 'message'}] format"
25
+ ),
26
+ )
27
+ model: str = Field(
28
+ default="llama-3.3-70b-versatile",
29
+ description="Groq model to use"
30
+ )
31
+ max_tokens: int = Field(
32
+ default=4096,
33
+ description="Maximum tokens in response"
34
+ )
35
+ temperature: float = Field(
36
+ default=0.7,
37
+ description="Temperature for response generation",
38
+ ge=0,
39
+ le=1
40
+ )
41
+ stream: bool = Field(
42
+ default=False,
43
+ description="Whether to stream the response progressively"
44
+ )
45
+ tools: Optional[List[Dict[str, Any]]] = Field(
46
+ default=None,
47
+ description=(
48
+ "A list of tools the model may use. "
49
+ "Currently only functions supported."
50
+ ),
51
+ )
52
+ tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
53
+ default=None,
54
+ description=(
55
+ "Controls which tool is called by the model. "
56
+ "'none', 'auto', or specific tool."
57
+ ),
58
+ )
59
+
60
+ @validator('max_tokens')
61
+ def validate_max_tokens(cls, v, values):
62
+ """Validate that max_tokens doesn't exceed the model's limit."""
63
+ if 'model' in values:
64
+ model_id = values['model']
65
+ max_limit = get_max_completion_tokens(model_id)
66
+ if v > max_limit:
67
+ return max_limit
68
+ return v
69
+
70
+
71
+ class GroqOutput(OutputSchema):
72
+ """Schema for Groq output"""
73
+
74
+ response: str = Field(..., description="Model's response text")
75
+ used_model: str = Field(..., description="Model used for generation")
76
+ usage: Dict[str, Any] = Field(
77
+ default_factory=dict, description="Usage statistics from the API"
78
+ )
79
+ tool_calls: Optional[List[Dict[str, Any]]] = Field(
80
+ default=None, description="Tool calls generated by the model"
81
+ )
82
+
83
+
84
+ class GroqChatSkill(Skill[GroqInput, GroqOutput]):
85
+ """Skill for Groq chat"""
86
+
87
+ input_schema = GroqInput
88
+ output_schema = GroqOutput
89
+
90
+ def __init__(self, credentials: Optional[GroqCredentials] = None):
91
+ super().__init__()
92
+ self.credentials = credentials or GroqCredentials.from_env()
93
+ self.client = Groq(api_key=self.credentials.groq_api_key.get_secret_value())
94
+
95
+ def _build_messages(self, input_data: GroqInput) -> List[Dict[str, str]]:
96
+ """
97
+ Build messages list from input data including conversation history.
98
+
99
+ Args:
100
+ input_data: The input data containing system prompt, conversation history, and user input
101
+
102
+ Returns:
103
+ List[Dict[str, str]]: List of messages in the format required by Groq
104
+ """
105
+ messages = [{"role": "system", "content": input_data.system_prompt}]
106
+
107
+ # Add conversation history if present
108
+ if input_data.conversation_history:
109
+ messages.extend(input_data.conversation_history)
110
+
111
+ # Add current user input
112
+ messages.append({"role": "user", "content": input_data.user_input})
113
+
114
+ return messages
115
+
116
+ def process_stream(self, input_data: GroqInput) -> Generator[str, None, None]:
117
+ """Process the input and stream the response token by token."""
118
+ try:
119
+ messages = self._build_messages(input_data)
120
+
121
+ stream = self.client.chat.completions.create(
122
+ model=input_data.model,
123
+ messages=messages,
124
+ temperature=input_data.temperature,
125
+ max_tokens=input_data.max_tokens,
126
+ stream=True,
127
+ )
128
+
129
+ for chunk in stream:
130
+ if chunk.choices[0].delta.content is not None:
131
+ yield chunk.choices[0].delta.content
132
+
133
+ except Exception as e:
134
+ raise ProcessingError(f"Groq streaming failed: {str(e)}")
135
+
136
+ def process(self, input_data: GroqInput) -> GroqOutput:
137
+ """Process the input and return the complete response."""
138
+ try:
139
+ if input_data.stream:
140
+ response_chunks = []
141
+ for chunk in self.process_stream(input_data):
142
+ response_chunks.append(chunk)
143
+ response = "".join(response_chunks)
144
+ usage = {} # Usage stats not available in streaming
145
+ tool_calls = None # Tool calls not available in streaming
146
+ else:
147
+ messages = self._build_messages(input_data)
148
+
149
+ # Prepare API call parameters
150
+ api_params = {
151
+ "model": input_data.model,
152
+ "messages": messages,
153
+ "temperature": input_data.temperature,
154
+ "max_tokens": input_data.max_tokens,
155
+ "stream": False,
156
+ }
157
+
158
+ # Add tools and tool_choice if provided
159
+ if input_data.tools:
160
+ api_params["tools"] = input_data.tools
161
+
162
+ if input_data.tool_choice:
163
+ api_params["tool_choice"] = input_data.tool_choice
164
+
165
+ completion = self.client.chat.completions.create(**api_params)
166
+ response = completion.choices[0].message.content or ""
167
+
168
+ # Extract usage information
169
+ usage = {
170
+ "total_tokens": completion.usage.total_tokens,
171
+ "prompt_tokens": completion.usage.prompt_tokens,
172
+ "completion_tokens": completion.usage.completion_tokens,
173
+ }
174
+
175
+ # Check for tool calls in the response
176
+ tool_calls = None
177
+ if (
178
+ hasattr(completion.choices[0].message, "tool_calls")
179
+ and completion.choices[0].message.tool_calls
180
+ ):
181
+ tool_calls = [
182
+ {
183
+ "id": tool_call.id,
184
+ "type": tool_call.type,
185
+ "function": {
186
+ "name": tool_call.function.name,
187
+ "arguments": tool_call.function.arguments
188
+ }
189
+ }
190
+ for tool_call in completion.choices[0].message.tool_calls
191
+ ]
192
+
193
+ return GroqOutput(
194
+ response=response,
195
+ used_model=input_data.model,
196
+ usage=usage,
197
+ tool_calls=tool_calls
198
+ )
199
+
200
+ except Exception as e:
201
+ raise ProcessingError(f"Groq processing failed: {str(e)}")
@@ -0,0 +1,6 @@
1
+ """Ollama integration module"""
2
+
3
+ from .credentials import OllamaCredentials
4
+ from .skills import OllamaChatSkill
5
+
6
+ __all__ = ["OllamaCredentials", "OllamaChatSkill"]
@@ -0,0 +1,26 @@
1
+ from pydantic import Field
2
+ from airtrain.core.credentials import BaseCredentials, CredentialValidationError
3
+ from importlib.util import find_spec
4
+
5
+
6
+ class OllamaCredentials(BaseCredentials):
7
+ """Ollama credentials"""
8
+
9
+ host: str = Field(default="http://localhost:11434", description="Ollama host URL")
10
+ timeout: int = Field(default=30, description="Request timeout in seconds")
11
+
12
+ async def validate_credentials(self) -> bool:
13
+ """Validate Ollama credentials"""
14
+ if find_spec("ollama") is None:
15
+ raise CredentialValidationError(
16
+ "Ollama package is not installed. Please install it using: pip install ollama"
17
+ )
18
+
19
+ try:
20
+ from ollama import Client
21
+
22
+ client = Client(host=self.host)
23
+ await client.list()
24
+ return True
25
+ except Exception as e:
26
+ raise CredentialValidationError(f"Invalid Ollama connection: {str(e)}")
@@ -0,0 +1,41 @@
1
+ from typing import Optional, Dict, Any
2
+ from pydantic import Field
3
+ from airtrain.core.skills import Skill, ProcessingError
4
+ from airtrain.core.schemas import InputSchema, OutputSchema
5
+ from .credentials import OllamaCredentials
6
+
7
+
8
+ class OllamaInput(InputSchema):
9
+ """Schema for Ollama input"""
10
+
11
+ user_input: str = Field(..., description="User's input text")
12
+ system_prompt: str = Field(
13
+ default="You are a helpful assistant.",
14
+ description="System prompt to guide the model's behavior",
15
+ )
16
+ model: str = Field(default="llama2", description="Ollama model to use")
17
+ max_tokens: int = Field(default=131072, description="Maximum tokens in response")
18
+ temperature: float = Field(
19
+ default=0.7, description="Temperature for response generation", ge=0, le=1
20
+ )
21
+
22
+
23
+ class OllamaOutput(OutputSchema):
24
+ """Schema for Ollama output"""
25
+
26
+ response: str = Field(..., description="Model's response text")
27
+ used_model: str = Field(..., description="Model used for generation")
28
+ usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
29
+
30
+
31
+ class OllamaChatSkill(Skill[OllamaInput, OllamaOutput]):
32
+ """Skill for Ollama - Not Implemented"""
33
+
34
+ input_schema = OllamaInput
35
+ output_schema = OllamaOutput
36
+
37
+ def __init__(self, credentials: Optional[OllamaCredentials] = None):
38
+ raise NotImplementedError("OllamaChatSkill is not implemented yet")
39
+
40
+ def process(self, input_data: OllamaInput) -> OllamaOutput:
41
+ raise NotImplementedError("OllamaChatSkill is not implemented yet")
@@ -0,0 +1,37 @@
1
+ """OpenAI API integration."""
2
+
3
+ from .skills import (
4
+ OpenAIChatSkill,
5
+ OpenAIInput,
6
+ OpenAIParserSkill,
7
+ OpenAIOutput,
8
+ OpenAIParserInput,
9
+ OpenAIParserOutput,
10
+ OpenAIEmbeddingsSkill,
11
+ OpenAIEmbeddingsInput,
12
+ OpenAIEmbeddingsOutput,
13
+ )
14
+ from .credentials import OpenAICredentials
15
+ from .list_models import (
16
+ OpenAIListModelsSkill,
17
+ OpenAIListModelsInput,
18
+ OpenAIListModelsOutput,
19
+ OpenAIModel,
20
+ )
21
+
22
+ __all__ = [
23
+ "OpenAIChatSkill",
24
+ "OpenAIInput",
25
+ "OpenAIParserSkill",
26
+ "OpenAIParserInput",
27
+ "OpenAIParserOutput",
28
+ "OpenAICredentials",
29
+ "OpenAIOutput",
30
+ "OpenAIEmbeddingsSkill",
31
+ "OpenAIEmbeddingsInput",
32
+ "OpenAIEmbeddingsOutput",
33
+ "OpenAIListModelsSkill",
34
+ "OpenAIListModelsInput",
35
+ "OpenAIListModelsOutput",
36
+ "OpenAIModel",
37
+ ]
@@ -0,0 +1,42 @@
1
+ from typing import Optional, TypeVar
2
+ from pydantic import Field
3
+ from .skills import OpenAIChatSkill, OpenAIInput, OpenAIOutput
4
+ from .credentials import OpenAICredentials
5
+
6
+ T = TypeVar("T", bound=OpenAIInput)
7
+
8
+
9
+ class ChineseAssistantInput(OpenAIInput):
10
+ """Schema for Chinese Assistant input"""
11
+
12
+ user_input: str = Field(
13
+ ..., description="User's input text (can be in any language)"
14
+ )
15
+ system_prompt: str = Field(
16
+ default="你是一个有帮助的助手。请用中文回答所有问题,即使问题是用其他语言问的。回答要准确、礼貌、专业。",
17
+ description="System prompt in Chinese",
18
+ )
19
+ model: str = Field(default="gpt-4o", description="OpenAI model to use")
20
+ max_tokens: int = Field(default=131072, description="Maximum tokens in response")
21
+ temperature: float = Field(
22
+ default=0.7, description="Temperature for response generation", ge=0, le=1
23
+ )
24
+
25
+
26
+ class ChineseAssistantSkill(OpenAIChatSkill):
27
+ """Skill for Chinese language assistance"""
28
+
29
+ input_schema = ChineseAssistantInput
30
+ output_schema = OpenAIOutput
31
+
32
+ def __init__(self, credentials: Optional[OpenAICredentials] = None):
33
+ super().__init__(credentials)
34
+
35
+ def process(self, input_data: T) -> OpenAIOutput:
36
+ # Add language check to ensure response is in Chinese
37
+ if "你是" not in input_data.system_prompt:
38
+ input_data.system_prompt = (
39
+ "你是一个中文助手。" + input_data.system_prompt + "请用中文回答。"
40
+ )
41
+
42
+ return super().process(input_data)