airtrain 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +148 -2
- airtrain/__main__.py +4 -0
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/agents/__init__.py +45 -0
- airtrain/agents/example_agent.py +348 -0
- airtrain/agents/groq_agent.py +289 -0
- airtrain/agents/memory.py +663 -0
- airtrain/agents/registry.py +465 -0
- airtrain/builder/__init__.py +3 -0
- airtrain/builder/agent_builder.py +122 -0
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/builder.py +23 -0
- airtrain/cli/main.py +120 -0
- airtrain/contrib/__init__.py +29 -0
- airtrain/contrib/travel/__init__.py +35 -0
- airtrain/contrib/travel/agents.py +243 -0
- airtrain/contrib/travel/models.py +59 -0
- airtrain/core/__init__.py +7 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/core/credentials.py +171 -0
- airtrain/core/schemas.py +237 -0
- airtrain/core/skills.py +269 -0
- airtrain/integrations/__init__.py +74 -0
- airtrain/integrations/anthropic/__init__.py +33 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/list_models.py +110 -0
- airtrain/integrations/anthropic/models_config.py +100 -0
- airtrain/integrations/anthropic/skills.py +155 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +19 -0
- airtrain/integrations/cerebras/skills.py +127 -0
- airtrain/integrations/combined/__init__.py +21 -0
- airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
- airtrain/integrations/combined/list_models_factory.py +210 -0
- airtrain/integrations/fireworks/__init__.py +21 -0
- airtrain/integrations/fireworks/completion_skills.py +147 -0
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/credentials.py +26 -0
- airtrain/integrations/fireworks/list_models.py +128 -0
- airtrain/integrations/fireworks/models.py +139 -0
- airtrain/integrations/fireworks/requests_skills.py +207 -0
- airtrain/integrations/fireworks/skills.py +181 -0
- airtrain/integrations/fireworks/structured_completion_skills.py +175 -0
- airtrain/integrations/fireworks/structured_requests_skills.py +291 -0
- airtrain/integrations/fireworks/structured_skills.py +102 -0
- airtrain/integrations/google/__init__.py +7 -0
- airtrain/integrations/google/credentials.py +58 -0
- airtrain/integrations/google/skills.py +122 -0
- airtrain/integrations/groq/__init__.py +23 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/models_config.py +162 -0
- airtrain/integrations/groq/skills.py +201 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +37 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/list_models.py +112 -0
- airtrain/integrations/openai/models_config.py +224 -0
- airtrain/integrations/openai/skills.py +342 -0
- airtrain/integrations/perplexity/__init__.py +49 -0
- airtrain/integrations/perplexity/credentials.py +43 -0
- airtrain/integrations/perplexity/list_models.py +112 -0
- airtrain/integrations/perplexity/models_config.py +128 -0
- airtrain/integrations/perplexity/skills.py +279 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +129 -0
- airtrain/integrations/search/__init__.py +21 -0
- airtrain/integrations/search/exa/__init__.py +23 -0
- airtrain/integrations/search/exa/credentials.py +30 -0
- airtrain/integrations/search/exa/schemas.py +114 -0
- airtrain/integrations/search/exa/skills.py +115 -0
- airtrain/integrations/together/__init__.py +33 -0
- airtrain/integrations/together/audio_models_config.py +34 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/embedding_models_config.py +92 -0
- airtrain/integrations/together/image_models_config.py +69 -0
- airtrain/integrations/together/image_skill.py +143 -0
- airtrain/integrations/together/list_models.py +76 -0
- airtrain/integrations/together/models.py +95 -0
- airtrain/integrations/together/models_config.py +399 -0
- airtrain/integrations/together/rerank_models_config.py +43 -0
- airtrain/integrations/together/rerank_skill.py +49 -0
- airtrain/integrations/together/schemas.py +33 -0
- airtrain/integrations/together/skills.py +305 -0
- airtrain/integrations/together/vision_models_config.py +49 -0
- airtrain/telemetry/__init__.py +38 -0
- airtrain/telemetry/service.py +167 -0
- airtrain/telemetry/views.py +237 -0
- airtrain/tools/__init__.py +45 -0
- airtrain/tools/command.py +398 -0
- airtrain/tools/filesystem.py +166 -0
- airtrain/tools/network.py +111 -0
- airtrain/tools/registry.py +320 -0
- airtrain/tools/search.py +450 -0
- airtrain/tools/testing.py +135 -0
- airtrain-0.1.4.dist-info/METADATA +222 -0
- airtrain-0.1.4.dist-info/RECORD +108 -0
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/WHEEL +1 -1
- airtrain-0.1.4.dist-info/entry_points.txt +2 -0
- airtrain-0.1.2.dist-info/METADATA +0 -106
- airtrain-0.1.2.dist-info/RECORD +0 -5
- {airtrain-0.1.2.dist-info → airtrain-0.1.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,122 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any
|
2
|
+
from pydantic import Field
|
3
|
+
import google.generativeai as genai
|
4
|
+
from loguru import logger
|
5
|
+
|
6
|
+
from airtrain.core.skills import Skill, ProcessingError
|
7
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
8
|
+
from .credentials import GeminiCredentials
|
9
|
+
|
10
|
+
|
11
|
+
class GoogleGenerationConfig(InputSchema):
|
12
|
+
"""Schema for Google generation config"""
|
13
|
+
|
14
|
+
temperature: float = Field(
|
15
|
+
default=1.0, description="Temperature for response generation", ge=0, le=1
|
16
|
+
)
|
17
|
+
top_p: float = Field(
|
18
|
+
default=0.95, description="Top p sampling parameter", ge=0, le=1
|
19
|
+
)
|
20
|
+
top_k: int = Field(default=40, description="Top k sampling parameter")
|
21
|
+
max_output_tokens: int = Field(
|
22
|
+
default=8192, description="Maximum tokens in response"
|
23
|
+
)
|
24
|
+
response_mime_type: str = Field(
|
25
|
+
default="text/plain", description="Response MIME type"
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
class GoogleInput(InputSchema):
|
30
|
+
"""Schema for Google chat input"""
|
31
|
+
|
32
|
+
user_input: str = Field(..., description="User's input text")
|
33
|
+
system_prompt: str = Field(
|
34
|
+
default="You are a helpful assistant.",
|
35
|
+
description="System prompt to guide the model's behavior",
|
36
|
+
)
|
37
|
+
conversation_history: List[Dict[str, str | List[Dict[str, str]]]] = Field(
|
38
|
+
default_factory=list,
|
39
|
+
description="List of conversation messages in Google's format",
|
40
|
+
)
|
41
|
+
model: str = Field(default="gemini-1.5-flash", description="Google model to use")
|
42
|
+
generation_config: GoogleGenerationConfig = Field(
|
43
|
+
default_factory=GoogleGenerationConfig, description="Generation configuration"
|
44
|
+
)
|
45
|
+
|
46
|
+
|
47
|
+
class GoogleOutput(OutputSchema):
|
48
|
+
"""Schema for Google chat output"""
|
49
|
+
|
50
|
+
response: str = Field(..., description="Model's response text")
|
51
|
+
used_model: str = Field(..., description="Model used for generation")
|
52
|
+
usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
|
53
|
+
|
54
|
+
|
55
|
+
class GoogleChatSkill(Skill[GoogleInput, GoogleOutput]):
|
56
|
+
"""Skill for Google chat"""
|
57
|
+
|
58
|
+
input_schema = GoogleInput
|
59
|
+
output_schema = GoogleOutput
|
60
|
+
|
61
|
+
def __init__(self, credentials: Optional[GeminiCredentials] = None):
|
62
|
+
super().__init__()
|
63
|
+
self.credentials = credentials or GeminiCredentials.from_env()
|
64
|
+
genai.configure(api_key=self.credentials.gemini_api_key.get_secret_value())
|
65
|
+
|
66
|
+
def _convert_history_format(
|
67
|
+
self, history: List[Dict[str, str]]
|
68
|
+
) -> List[Dict[str, List[Dict[str, str]]]]:
|
69
|
+
"""Convert standard history format to Google's format"""
|
70
|
+
google_history = []
|
71
|
+
for msg in history:
|
72
|
+
google_msg = {
|
73
|
+
"role": "user" if msg["role"] == "user" else "model",
|
74
|
+
"parts": [{"text": msg["content"]}],
|
75
|
+
}
|
76
|
+
google_history.append(google_msg)
|
77
|
+
return google_history
|
78
|
+
|
79
|
+
def process(self, input_data: GoogleInput) -> GoogleOutput:
|
80
|
+
try:
|
81
|
+
# Create generation config
|
82
|
+
generation_config = {
|
83
|
+
"temperature": input_data.generation_config.temperature,
|
84
|
+
"top_p": input_data.generation_config.top_p,
|
85
|
+
"top_k": input_data.generation_config.top_k,
|
86
|
+
"max_output_tokens": input_data.generation_config.max_output_tokens,
|
87
|
+
"response_mime_type": input_data.generation_config.response_mime_type,
|
88
|
+
}
|
89
|
+
|
90
|
+
# Initialize model
|
91
|
+
model = genai.GenerativeModel(
|
92
|
+
model_name=input_data.model,
|
93
|
+
generation_config=generation_config,
|
94
|
+
system_instruction=input_data.system_prompt,
|
95
|
+
)
|
96
|
+
|
97
|
+
# Convert history format if needed
|
98
|
+
history = (
|
99
|
+
input_data.conversation_history
|
100
|
+
if input_data.conversation_history
|
101
|
+
else self._convert_history_format([])
|
102
|
+
)
|
103
|
+
|
104
|
+
# Start chat session
|
105
|
+
chat = model.start_chat(history=history)
|
106
|
+
|
107
|
+
# Send message and get response
|
108
|
+
response = chat.send_message(input_data.user_input)
|
109
|
+
|
110
|
+
return GoogleOutput(
|
111
|
+
response=response.text,
|
112
|
+
used_model=input_data.model,
|
113
|
+
usage={
|
114
|
+
"prompt_tokens": 0,
|
115
|
+
"completion_tokens": 0,
|
116
|
+
"total_tokens": 0,
|
117
|
+
}, # Google API doesn't provide usage stats
|
118
|
+
)
|
119
|
+
|
120
|
+
except Exception as e:
|
121
|
+
logger.exception(f"Google processing failed: {str(e)}")
|
122
|
+
raise ProcessingError(f"Google processing failed: {str(e)}")
|
@@ -0,0 +1,23 @@
|
|
1
|
+
"""Groq integration module"""
|
2
|
+
|
3
|
+
from .credentials import GroqCredentials
|
4
|
+
from .skills import GroqChatSkill
|
5
|
+
from .models_config import (
|
6
|
+
get_model_config,
|
7
|
+
get_default_model,
|
8
|
+
supports_tool_use,
|
9
|
+
supports_parallel_tool_use,
|
10
|
+
supports_json_mode,
|
11
|
+
GROQ_MODELS_CONFIG,
|
12
|
+
)
|
13
|
+
|
14
|
+
__all__ = [
|
15
|
+
"GroqCredentials",
|
16
|
+
"GroqChatSkill",
|
17
|
+
"get_model_config",
|
18
|
+
"get_default_model",
|
19
|
+
"supports_tool_use",
|
20
|
+
"supports_parallel_tool_use",
|
21
|
+
"supports_json_mode",
|
22
|
+
"GROQ_MODELS_CONFIG",
|
23
|
+
]
|
@@ -0,0 +1,24 @@
|
|
1
|
+
from pydantic import Field, SecretStr
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
from groq import Groq
|
4
|
+
|
5
|
+
|
6
|
+
class GroqCredentials(BaseCredentials):
|
7
|
+
"""Groq API credentials"""
|
8
|
+
|
9
|
+
groq_api_key: SecretStr = Field(..., description="Groq API key")
|
10
|
+
|
11
|
+
_required_credentials = {"groq_api_key"}
|
12
|
+
|
13
|
+
async def validate_credentials(self) -> bool:
|
14
|
+
"""Validate Groq credentials"""
|
15
|
+
try:
|
16
|
+
client = Groq(api_key=self.groq_api_key.get_secret_value())
|
17
|
+
await client.chat.completions.create(
|
18
|
+
messages=[{"role": "user", "content": "Hi"}],
|
19
|
+
model="mixtral-8x7b-32768",
|
20
|
+
max_tokens=1,
|
21
|
+
)
|
22
|
+
return True
|
23
|
+
except Exception as e:
|
24
|
+
raise CredentialValidationError(f"Invalid Groq credentials: {str(e)}")
|
@@ -0,0 +1,162 @@
|
|
1
|
+
"""Configuration of Groq model capabilities."""
|
2
|
+
|
3
|
+
from typing import Dict, Any
|
4
|
+
|
5
|
+
|
6
|
+
# Model configuration with capabilities for each model
|
7
|
+
GROQ_MODELS_CONFIG = {
|
8
|
+
"llama-3.3-70b-versatile": {
|
9
|
+
"name": "Llama 3.3 70B Versatile",
|
10
|
+
"context_window": 128000,
|
11
|
+
"max_completion_tokens": 32768,
|
12
|
+
"tool_use": True,
|
13
|
+
"parallel_tool_use": True,
|
14
|
+
"json_mode": True,
|
15
|
+
},
|
16
|
+
"llama-3.1-8b-instant": {
|
17
|
+
"name": "Llama 3.1 8B Instant",
|
18
|
+
"context_window": 128000,
|
19
|
+
"max_completion_tokens": 8192,
|
20
|
+
"tool_use": True,
|
21
|
+
"parallel_tool_use": True,
|
22
|
+
"json_mode": True,
|
23
|
+
},
|
24
|
+
"mixtral-8x7b-32768": {
|
25
|
+
"name": "Mixtral 8x7B (32K)",
|
26
|
+
"context_window": 32768,
|
27
|
+
"max_completion_tokens": 8192,
|
28
|
+
"tool_use": True,
|
29
|
+
"parallel_tool_use": False,
|
30
|
+
"json_mode": True,
|
31
|
+
},
|
32
|
+
"gemma2-9b-it": {
|
33
|
+
"name": "Gemma 2 9B IT",
|
34
|
+
"context_window": 8192,
|
35
|
+
"max_completion_tokens": 4096,
|
36
|
+
"tool_use": True,
|
37
|
+
"parallel_tool_use": False,
|
38
|
+
"json_mode": True,
|
39
|
+
},
|
40
|
+
"qwen-qwq-32b": {
|
41
|
+
"name": "Qwen QWQ 32B",
|
42
|
+
"context_window": 128000,
|
43
|
+
"max_completion_tokens": 16384,
|
44
|
+
"tool_use": True,
|
45
|
+
"parallel_tool_use": True,
|
46
|
+
"json_mode": True,
|
47
|
+
},
|
48
|
+
"qwen-2.5-coder-32b": {
|
49
|
+
"name": "Qwen 2.5 Coder 32B",
|
50
|
+
"context_window": 128000,
|
51
|
+
"max_completion_tokens": 16384,
|
52
|
+
"tool_use": True,
|
53
|
+
"parallel_tool_use": True,
|
54
|
+
"json_mode": True,
|
55
|
+
},
|
56
|
+
"qwen-2.5-32b": {
|
57
|
+
"name": "Qwen 2.5 32B",
|
58
|
+
"context_window": 128000,
|
59
|
+
"max_completion_tokens": 16384,
|
60
|
+
"tool_use": True,
|
61
|
+
"parallel_tool_use": True,
|
62
|
+
"json_mode": True,
|
63
|
+
},
|
64
|
+
"deepseek-r1-distill-qwen-32b": {
|
65
|
+
"name": "DeepSeek R1 Distill Qwen 32B",
|
66
|
+
"context_window": 128000,
|
67
|
+
"max_completion_tokens": 16384,
|
68
|
+
"tool_use": True,
|
69
|
+
"parallel_tool_use": True,
|
70
|
+
"json_mode": True,
|
71
|
+
},
|
72
|
+
"deepseek-r1-distill-llama-70b": {
|
73
|
+
"name": "DeepSeek R1 Distill Llama 70B",
|
74
|
+
"context_window": 128000,
|
75
|
+
"max_completion_tokens": 16384,
|
76
|
+
"tool_use": True,
|
77
|
+
"parallel_tool_use": True,
|
78
|
+
"json_mode": True,
|
79
|
+
},
|
80
|
+
"deepseek-r1-distill-llama-70b-specdec": {
|
81
|
+
"name": "DeepSeek R1 Distill Llama 70B SpecDec",
|
82
|
+
"context_window": 128000,
|
83
|
+
"max_completion_tokens": 16384,
|
84
|
+
"tool_use": False,
|
85
|
+
"parallel_tool_use": False,
|
86
|
+
"json_mode": False,
|
87
|
+
},
|
88
|
+
"llama3-70b-8192": {
|
89
|
+
"name": "Llama 3 70B (8K)",
|
90
|
+
"context_window": 8192,
|
91
|
+
"max_completion_tokens": 4096,
|
92
|
+
"tool_use": False,
|
93
|
+
"parallel_tool_use": False,
|
94
|
+
"json_mode": False,
|
95
|
+
},
|
96
|
+
"llama3-8b-8192": {
|
97
|
+
"name": "Llama 3 8B (8K)",
|
98
|
+
"context_window": 8192,
|
99
|
+
"max_completion_tokens": 4096,
|
100
|
+
"tool_use": False,
|
101
|
+
"parallel_tool_use": False,
|
102
|
+
"json_mode": False,
|
103
|
+
},
|
104
|
+
}
|
105
|
+
|
106
|
+
|
107
|
+
def get_model_config(model_id: str) -> Dict[str, Any]:
|
108
|
+
"""
|
109
|
+
Get the configuration for a specific model.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
model_id: The model ID to get configuration for
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
Dict with model configuration
|
116
|
+
|
117
|
+
Raises:
|
118
|
+
ValueError: If model_id is not found in configuration
|
119
|
+
"""
|
120
|
+
if model_id in GROQ_MODELS_CONFIG:
|
121
|
+
return GROQ_MODELS_CONFIG[model_id]
|
122
|
+
|
123
|
+
# Try to find a match with different format or case
|
124
|
+
normalized_id = model_id.lower().replace("-", "").replace("_", "")
|
125
|
+
for config_id, config in GROQ_MODELS_CONFIG.items():
|
126
|
+
if normalized_id == config_id.lower().replace("-", "").replace("_", ""):
|
127
|
+
return config
|
128
|
+
|
129
|
+
# Default configuration for unknown models
|
130
|
+
return {
|
131
|
+
"name": model_id,
|
132
|
+
"context_window": 4096, # Conservative default
|
133
|
+
"max_completion_tokens": 1024, # Conservative default
|
134
|
+
"tool_use": False,
|
135
|
+
"parallel_tool_use": False,
|
136
|
+
"json_mode": False,
|
137
|
+
}
|
138
|
+
|
139
|
+
|
140
|
+
def get_default_model() -> str:
|
141
|
+
"""Get the default model ID for Groq."""
|
142
|
+
return "llama-3.3-70b-versatile"
|
143
|
+
|
144
|
+
|
145
|
+
def supports_tool_use(model_id: str) -> bool:
|
146
|
+
"""Check if a model supports tool use."""
|
147
|
+
return get_model_config(model_id).get("tool_use", False)
|
148
|
+
|
149
|
+
|
150
|
+
def supports_parallel_tool_use(model_id: str) -> bool:
|
151
|
+
"""Check if a model supports parallel tool use."""
|
152
|
+
return get_model_config(model_id).get("parallel_tool_use", False)
|
153
|
+
|
154
|
+
|
155
|
+
def supports_json_mode(model_id: str) -> bool:
|
156
|
+
"""Check if a model supports JSON mode."""
|
157
|
+
return get_model_config(model_id).get("json_mode", False)
|
158
|
+
|
159
|
+
|
160
|
+
def get_max_completion_tokens(model_id: str) -> int:
|
161
|
+
"""Get the maximum number of completion tokens for a model."""
|
162
|
+
return get_model_config(model_id).get("max_completion_tokens", 1024)
|
@@ -0,0 +1,201 @@
|
|
1
|
+
from typing import Generator, Optional, Dict, Any, List, Union
|
2
|
+
from pydantic import Field, validator
|
3
|
+
from airtrain.core.skills import Skill, ProcessingError
|
4
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
5
|
+
from .credentials import GroqCredentials
|
6
|
+
from .models_config import get_max_completion_tokens, get_model_config
|
7
|
+
from groq import Groq
|
8
|
+
|
9
|
+
|
10
|
+
class GroqInput(InputSchema):
|
11
|
+
"""Schema for Groq input"""
|
12
|
+
|
13
|
+
user_input: str = Field(..., description="User's input text")
|
14
|
+
system_prompt: str = Field(
|
15
|
+
default="You are a helpful assistant.",
|
16
|
+
description=(
|
17
|
+
"System prompt to guide the model's behavior"
|
18
|
+
),
|
19
|
+
)
|
20
|
+
conversation_history: List[Dict[str, str]] = Field(
|
21
|
+
default_factory=list,
|
22
|
+
description=(
|
23
|
+
"List of previous conversation messages in "
|
24
|
+
"[{'role': 'user|assistant', 'content': 'message'}] format"
|
25
|
+
),
|
26
|
+
)
|
27
|
+
model: str = Field(
|
28
|
+
default="llama-3.3-70b-versatile",
|
29
|
+
description="Groq model to use"
|
30
|
+
)
|
31
|
+
max_tokens: int = Field(
|
32
|
+
default=4096,
|
33
|
+
description="Maximum tokens in response"
|
34
|
+
)
|
35
|
+
temperature: float = Field(
|
36
|
+
default=0.7,
|
37
|
+
description="Temperature for response generation",
|
38
|
+
ge=0,
|
39
|
+
le=1
|
40
|
+
)
|
41
|
+
stream: bool = Field(
|
42
|
+
default=False,
|
43
|
+
description="Whether to stream the response progressively"
|
44
|
+
)
|
45
|
+
tools: Optional[List[Dict[str, Any]]] = Field(
|
46
|
+
default=None,
|
47
|
+
description=(
|
48
|
+
"A list of tools the model may use. "
|
49
|
+
"Currently only functions supported."
|
50
|
+
),
|
51
|
+
)
|
52
|
+
tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(
|
53
|
+
default=None,
|
54
|
+
description=(
|
55
|
+
"Controls which tool is called by the model. "
|
56
|
+
"'none', 'auto', or specific tool."
|
57
|
+
),
|
58
|
+
)
|
59
|
+
|
60
|
+
@validator('max_tokens')
|
61
|
+
def validate_max_tokens(cls, v, values):
|
62
|
+
"""Validate that max_tokens doesn't exceed the model's limit."""
|
63
|
+
if 'model' in values:
|
64
|
+
model_id = values['model']
|
65
|
+
max_limit = get_max_completion_tokens(model_id)
|
66
|
+
if v > max_limit:
|
67
|
+
return max_limit
|
68
|
+
return v
|
69
|
+
|
70
|
+
|
71
|
+
class GroqOutput(OutputSchema):
|
72
|
+
"""Schema for Groq output"""
|
73
|
+
|
74
|
+
response: str = Field(..., description="Model's response text")
|
75
|
+
used_model: str = Field(..., description="Model used for generation")
|
76
|
+
usage: Dict[str, Any] = Field(
|
77
|
+
default_factory=dict, description="Usage statistics from the API"
|
78
|
+
)
|
79
|
+
tool_calls: Optional[List[Dict[str, Any]]] = Field(
|
80
|
+
default=None, description="Tool calls generated by the model"
|
81
|
+
)
|
82
|
+
|
83
|
+
|
84
|
+
class GroqChatSkill(Skill[GroqInput, GroqOutput]):
|
85
|
+
"""Skill for Groq chat"""
|
86
|
+
|
87
|
+
input_schema = GroqInput
|
88
|
+
output_schema = GroqOutput
|
89
|
+
|
90
|
+
def __init__(self, credentials: Optional[GroqCredentials] = None):
|
91
|
+
super().__init__()
|
92
|
+
self.credentials = credentials or GroqCredentials.from_env()
|
93
|
+
self.client = Groq(api_key=self.credentials.groq_api_key.get_secret_value())
|
94
|
+
|
95
|
+
def _build_messages(self, input_data: GroqInput) -> List[Dict[str, str]]:
|
96
|
+
"""
|
97
|
+
Build messages list from input data including conversation history.
|
98
|
+
|
99
|
+
Args:
|
100
|
+
input_data: The input data containing system prompt, conversation history, and user input
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
List[Dict[str, str]]: List of messages in the format required by Groq
|
104
|
+
"""
|
105
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
106
|
+
|
107
|
+
# Add conversation history if present
|
108
|
+
if input_data.conversation_history:
|
109
|
+
messages.extend(input_data.conversation_history)
|
110
|
+
|
111
|
+
# Add current user input
|
112
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
113
|
+
|
114
|
+
return messages
|
115
|
+
|
116
|
+
def process_stream(self, input_data: GroqInput) -> Generator[str, None, None]:
|
117
|
+
"""Process the input and stream the response token by token."""
|
118
|
+
try:
|
119
|
+
messages = self._build_messages(input_data)
|
120
|
+
|
121
|
+
stream = self.client.chat.completions.create(
|
122
|
+
model=input_data.model,
|
123
|
+
messages=messages,
|
124
|
+
temperature=input_data.temperature,
|
125
|
+
max_tokens=input_data.max_tokens,
|
126
|
+
stream=True,
|
127
|
+
)
|
128
|
+
|
129
|
+
for chunk in stream:
|
130
|
+
if chunk.choices[0].delta.content is not None:
|
131
|
+
yield chunk.choices[0].delta.content
|
132
|
+
|
133
|
+
except Exception as e:
|
134
|
+
raise ProcessingError(f"Groq streaming failed: {str(e)}")
|
135
|
+
|
136
|
+
def process(self, input_data: GroqInput) -> GroqOutput:
|
137
|
+
"""Process the input and return the complete response."""
|
138
|
+
try:
|
139
|
+
if input_data.stream:
|
140
|
+
response_chunks = []
|
141
|
+
for chunk in self.process_stream(input_data):
|
142
|
+
response_chunks.append(chunk)
|
143
|
+
response = "".join(response_chunks)
|
144
|
+
usage = {} # Usage stats not available in streaming
|
145
|
+
tool_calls = None # Tool calls not available in streaming
|
146
|
+
else:
|
147
|
+
messages = self._build_messages(input_data)
|
148
|
+
|
149
|
+
# Prepare API call parameters
|
150
|
+
api_params = {
|
151
|
+
"model": input_data.model,
|
152
|
+
"messages": messages,
|
153
|
+
"temperature": input_data.temperature,
|
154
|
+
"max_tokens": input_data.max_tokens,
|
155
|
+
"stream": False,
|
156
|
+
}
|
157
|
+
|
158
|
+
# Add tools and tool_choice if provided
|
159
|
+
if input_data.tools:
|
160
|
+
api_params["tools"] = input_data.tools
|
161
|
+
|
162
|
+
if input_data.tool_choice:
|
163
|
+
api_params["tool_choice"] = input_data.tool_choice
|
164
|
+
|
165
|
+
completion = self.client.chat.completions.create(**api_params)
|
166
|
+
response = completion.choices[0].message.content or ""
|
167
|
+
|
168
|
+
# Extract usage information
|
169
|
+
usage = {
|
170
|
+
"total_tokens": completion.usage.total_tokens,
|
171
|
+
"prompt_tokens": completion.usage.prompt_tokens,
|
172
|
+
"completion_tokens": completion.usage.completion_tokens,
|
173
|
+
}
|
174
|
+
|
175
|
+
# Check for tool calls in the response
|
176
|
+
tool_calls = None
|
177
|
+
if (
|
178
|
+
hasattr(completion.choices[0].message, "tool_calls")
|
179
|
+
and completion.choices[0].message.tool_calls
|
180
|
+
):
|
181
|
+
tool_calls = [
|
182
|
+
{
|
183
|
+
"id": tool_call.id,
|
184
|
+
"type": tool_call.type,
|
185
|
+
"function": {
|
186
|
+
"name": tool_call.function.name,
|
187
|
+
"arguments": tool_call.function.arguments
|
188
|
+
}
|
189
|
+
}
|
190
|
+
for tool_call in completion.choices[0].message.tool_calls
|
191
|
+
]
|
192
|
+
|
193
|
+
return GroqOutput(
|
194
|
+
response=response,
|
195
|
+
used_model=input_data.model,
|
196
|
+
usage=usage,
|
197
|
+
tool_calls=tool_calls
|
198
|
+
)
|
199
|
+
|
200
|
+
except Exception as e:
|
201
|
+
raise ProcessingError(f"Groq processing failed: {str(e)}")
|
@@ -0,0 +1,26 @@
|
|
1
|
+
from pydantic import Field
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
from importlib.util import find_spec
|
4
|
+
|
5
|
+
|
6
|
+
class OllamaCredentials(BaseCredentials):
|
7
|
+
"""Ollama credentials"""
|
8
|
+
|
9
|
+
host: str = Field(default="http://localhost:11434", description="Ollama host URL")
|
10
|
+
timeout: int = Field(default=30, description="Request timeout in seconds")
|
11
|
+
|
12
|
+
async def validate_credentials(self) -> bool:
|
13
|
+
"""Validate Ollama credentials"""
|
14
|
+
if find_spec("ollama") is None:
|
15
|
+
raise CredentialValidationError(
|
16
|
+
"Ollama package is not installed. Please install it using: pip install ollama"
|
17
|
+
)
|
18
|
+
|
19
|
+
try:
|
20
|
+
from ollama import Client
|
21
|
+
|
22
|
+
client = Client(host=self.host)
|
23
|
+
await client.list()
|
24
|
+
return True
|
25
|
+
except Exception as e:
|
26
|
+
raise CredentialValidationError(f"Invalid Ollama connection: {str(e)}")
|
@@ -0,0 +1,41 @@
|
|
1
|
+
from typing import Optional, Dict, Any
|
2
|
+
from pydantic import Field
|
3
|
+
from airtrain.core.skills import Skill, ProcessingError
|
4
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
5
|
+
from .credentials import OllamaCredentials
|
6
|
+
|
7
|
+
|
8
|
+
class OllamaInput(InputSchema):
|
9
|
+
"""Schema for Ollama input"""
|
10
|
+
|
11
|
+
user_input: str = Field(..., description="User's input text")
|
12
|
+
system_prompt: str = Field(
|
13
|
+
default="You are a helpful assistant.",
|
14
|
+
description="System prompt to guide the model's behavior",
|
15
|
+
)
|
16
|
+
model: str = Field(default="llama2", description="Ollama model to use")
|
17
|
+
max_tokens: int = Field(default=131072, description="Maximum tokens in response")
|
18
|
+
temperature: float = Field(
|
19
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
class OllamaOutput(OutputSchema):
|
24
|
+
"""Schema for Ollama output"""
|
25
|
+
|
26
|
+
response: str = Field(..., description="Model's response text")
|
27
|
+
used_model: str = Field(..., description="Model used for generation")
|
28
|
+
usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
|
29
|
+
|
30
|
+
|
31
|
+
class OllamaChatSkill(Skill[OllamaInput, OllamaOutput]):
|
32
|
+
"""Skill for Ollama - Not Implemented"""
|
33
|
+
|
34
|
+
input_schema = OllamaInput
|
35
|
+
output_schema = OllamaOutput
|
36
|
+
|
37
|
+
def __init__(self, credentials: Optional[OllamaCredentials] = None):
|
38
|
+
raise NotImplementedError("OllamaChatSkill is not implemented yet")
|
39
|
+
|
40
|
+
def process(self, input_data: OllamaInput) -> OllamaOutput:
|
41
|
+
raise NotImplementedError("OllamaChatSkill is not implemented yet")
|
@@ -0,0 +1,37 @@
|
|
1
|
+
"""OpenAI API integration."""
|
2
|
+
|
3
|
+
from .skills import (
|
4
|
+
OpenAIChatSkill,
|
5
|
+
OpenAIInput,
|
6
|
+
OpenAIParserSkill,
|
7
|
+
OpenAIOutput,
|
8
|
+
OpenAIParserInput,
|
9
|
+
OpenAIParserOutput,
|
10
|
+
OpenAIEmbeddingsSkill,
|
11
|
+
OpenAIEmbeddingsInput,
|
12
|
+
OpenAIEmbeddingsOutput,
|
13
|
+
)
|
14
|
+
from .credentials import OpenAICredentials
|
15
|
+
from .list_models import (
|
16
|
+
OpenAIListModelsSkill,
|
17
|
+
OpenAIListModelsInput,
|
18
|
+
OpenAIListModelsOutput,
|
19
|
+
OpenAIModel,
|
20
|
+
)
|
21
|
+
|
22
|
+
__all__ = [
|
23
|
+
"OpenAIChatSkill",
|
24
|
+
"OpenAIInput",
|
25
|
+
"OpenAIParserSkill",
|
26
|
+
"OpenAIParserInput",
|
27
|
+
"OpenAIParserOutput",
|
28
|
+
"OpenAICredentials",
|
29
|
+
"OpenAIOutput",
|
30
|
+
"OpenAIEmbeddingsSkill",
|
31
|
+
"OpenAIEmbeddingsInput",
|
32
|
+
"OpenAIEmbeddingsOutput",
|
33
|
+
"OpenAIListModelsSkill",
|
34
|
+
"OpenAIListModelsInput",
|
35
|
+
"OpenAIListModelsOutput",
|
36
|
+
"OpenAIModel",
|
37
|
+
]
|
@@ -0,0 +1,42 @@
|
|
1
|
+
from typing import Optional, TypeVar
|
2
|
+
from pydantic import Field
|
3
|
+
from .skills import OpenAIChatSkill, OpenAIInput, OpenAIOutput
|
4
|
+
from .credentials import OpenAICredentials
|
5
|
+
|
6
|
+
T = TypeVar("T", bound=OpenAIInput)
|
7
|
+
|
8
|
+
|
9
|
+
class ChineseAssistantInput(OpenAIInput):
|
10
|
+
"""Schema for Chinese Assistant input"""
|
11
|
+
|
12
|
+
user_input: str = Field(
|
13
|
+
..., description="User's input text (can be in any language)"
|
14
|
+
)
|
15
|
+
system_prompt: str = Field(
|
16
|
+
default="你是一个有帮助的助手。请用中文回答所有问题,即使问题是用其他语言问的。回答要准确、礼貌、专业。",
|
17
|
+
description="System prompt in Chinese",
|
18
|
+
)
|
19
|
+
model: str = Field(default="gpt-4o", description="OpenAI model to use")
|
20
|
+
max_tokens: int = Field(default=131072, description="Maximum tokens in response")
|
21
|
+
temperature: float = Field(
|
22
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
23
|
+
)
|
24
|
+
|
25
|
+
|
26
|
+
class ChineseAssistantSkill(OpenAIChatSkill):
|
27
|
+
"""Skill for Chinese language assistance"""
|
28
|
+
|
29
|
+
input_schema = ChineseAssistantInput
|
30
|
+
output_schema = OpenAIOutput
|
31
|
+
|
32
|
+
def __init__(self, credentials: Optional[OpenAICredentials] = None):
|
33
|
+
super().__init__(credentials)
|
34
|
+
|
35
|
+
def process(self, input_data: T) -> OpenAIOutput:
|
36
|
+
# Add language check to ensure response is in Chinese
|
37
|
+
if "你是" not in input_data.system_prompt:
|
38
|
+
input_data.system_prompt = (
|
39
|
+
"你是一个中文助手。" + input_data.system_prompt + "请用中文回答。"
|
40
|
+
)
|
41
|
+
|
42
|
+
return super().process(input_data)
|