airtrain 0.1.14__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +3 -3
- airtrain/integrations/__init__.py +1 -1
- airtrain/integrations/anthropic/skills.py +59 -67
- airtrain/integrations/cerebras/credentials.py +3 -6
- airtrain/integrations/cerebras/skills.py +62 -8
- airtrain/integrations/fireworks/conversation_manager.py +109 -0
- airtrain/integrations/fireworks/skills.py +27 -5
- airtrain/integrations/google/__init__.py +2 -1
- airtrain/integrations/google/credentials.py +30 -0
- airtrain/integrations/google/skills.py +99 -18
- airtrain/integrations/groq/credentials.py +3 -3
- airtrain/integrations/groq/skills.py +51 -4
- airtrain/integrations/sambanova/credentials.py +3 -3
- airtrain/integrations/sambanova/skills.py +61 -5
- airtrain/integrations/together/image_skill.py +5 -33
- airtrain/integrations/together/skills.py +52 -4
- {airtrain-0.1.14.dist-info → airtrain-0.1.17.dist-info}/METADATA +1 -1
- {airtrain-0.1.14.dist-info → airtrain-0.1.17.dist-info}/RECORD +20 -19
- {airtrain-0.1.14.dist-info → airtrain-0.1.17.dist-info}/WHEEL +0 -0
- {airtrain-0.1.14.dist-info → airtrain-0.1.17.dist-info}/top_level.txt +0 -0
airtrain/__init__.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
"""Airtrain - A platform for building and deploying AI agents with structured skills"""
|
2
2
|
|
3
|
-
__version__ = "0.1.
|
3
|
+
__version__ = "0.1.17"
|
4
4
|
|
5
5
|
# Core imports
|
6
6
|
from .core.skills import Skill, ProcessingError
|
@@ -22,7 +22,7 @@ from .integrations.cerebras.credentials import CerebrasCredentials
|
|
22
22
|
from .integrations.openai.skills import OpenAIChatSkill, OpenAIParserSkill
|
23
23
|
from .integrations.anthropic.skills import AnthropicChatSkill
|
24
24
|
from .integrations.aws.skills import AWSBedrockSkill
|
25
|
-
from .integrations.google.skills import
|
25
|
+
from .integrations.google.skills import GoogleChatSkill
|
26
26
|
from .integrations.groq.skills import GroqChatSkill
|
27
27
|
from .integrations.together.skills import TogetherAIChatSkill
|
28
28
|
from .integrations.ollama.skills import OllamaChatSkill
|
@@ -51,7 +51,7 @@ __all__ = [
|
|
51
51
|
"OpenAIParserSkill",
|
52
52
|
"AnthropicChatSkill",
|
53
53
|
"AWSBedrockSkill",
|
54
|
-
"
|
54
|
+
"GoogleChatSkill",
|
55
55
|
"GroqChatSkill",
|
56
56
|
"TogetherAIChatSkill",
|
57
57
|
"OllamaChatSkill",
|
@@ -15,7 +15,7 @@ from .cerebras.credentials import CerebrasCredentials
|
|
15
15
|
from .openai.skills import OpenAIChatSkill, OpenAIParserSkill
|
16
16
|
from .anthropic.skills import AnthropicChatSkill
|
17
17
|
from .aws.skills import AWSBedrockSkill
|
18
|
-
from .google.skills import
|
18
|
+
from .google.skills import GoogleChatSkill
|
19
19
|
from .groq.skills import GroqChatSkill
|
20
20
|
from .together.skills import TogetherAIChatSkill
|
21
21
|
from .ollama.skills import OllamaChatSkill
|
@@ -18,16 +18,22 @@ class AnthropicInput(InputSchema):
|
|
18
18
|
default="You are a helpful assistant.",
|
19
19
|
description="System prompt to guide the model's behavior",
|
20
20
|
)
|
21
|
+
conversation_history: List[Dict[str, str]] = Field(
|
22
|
+
default_factory=list,
|
23
|
+
description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
|
24
|
+
)
|
21
25
|
model: str = Field(
|
22
26
|
default="claude-3-opus-20240229", description="Anthropic model to use"
|
23
27
|
)
|
24
|
-
max_tokens: int = Field(
|
28
|
+
max_tokens: Optional[int] = Field(
|
29
|
+
default=1024, description="Maximum tokens in response"
|
30
|
+
)
|
25
31
|
temperature: float = Field(
|
26
32
|
default=0.7, description="Temperature for response generation", ge=0, le=1
|
27
33
|
)
|
28
|
-
images:
|
29
|
-
|
30
|
-
description="
|
34
|
+
images: List[Path] = Field(
|
35
|
+
default_factory=list,
|
36
|
+
description="List of image paths to include in the message",
|
31
37
|
)
|
32
38
|
|
33
39
|
|
@@ -42,94 +48,80 @@ class AnthropicOutput(OutputSchema):
|
|
42
48
|
|
43
49
|
|
44
50
|
class AnthropicChatSkill(Skill[AnthropicInput, AnthropicOutput]):
|
45
|
-
"""Skill for
|
51
|
+
"""Skill for Anthropic chat"""
|
46
52
|
|
47
53
|
input_schema = AnthropicInput
|
48
54
|
output_schema = AnthropicOutput
|
49
55
|
|
50
56
|
def __init__(self, credentials: Optional[AnthropicCredentials] = None):
|
51
|
-
"""Initialize the skill with optional credentials"""
|
52
57
|
super().__init__()
|
53
58
|
self.credentials = credentials or AnthropicCredentials.from_env()
|
54
59
|
self.client = Anthropic(
|
55
60
|
api_key=self.credentials.anthropic_api_key.get_secret_value()
|
56
61
|
)
|
57
62
|
|
58
|
-
def
|
59
|
-
"""
|
60
|
-
|
61
|
-
if not image_path.exists():
|
62
|
-
raise FileNotFoundError(f"Image file not found: {image_path}")
|
63
|
-
|
64
|
-
with open(image_path, "rb") as img_file:
|
65
|
-
encoded = base64.b64encode(img_file.read()).decode()
|
66
|
-
return {
|
67
|
-
"type": "image",
|
68
|
-
"source": {
|
69
|
-
"type": "base64",
|
70
|
-
"media_type": f"image/{image_path.suffix[1:]}",
|
71
|
-
"data": encoded,
|
72
|
-
},
|
73
|
-
}
|
74
|
-
except Exception as e:
|
75
|
-
logger.error(f"Failed to encode image {image_path}: {str(e)}")
|
76
|
-
raise ProcessingError(f"Image encoding failed: {str(e)}")
|
63
|
+
def _build_messages(self, input_data: AnthropicInput) -> List[Dict[str, Any]]:
|
64
|
+
"""
|
65
|
+
Build messages list from input data including conversation history.
|
77
66
|
|
78
|
-
|
79
|
-
|
80
|
-
try:
|
81
|
-
logger.info(f"Processing request with model {input_data.model}")
|
67
|
+
Args:
|
68
|
+
input_data: The input data containing system prompt, conversation history, and user input
|
82
69
|
|
83
|
-
|
84
|
-
|
70
|
+
Returns:
|
71
|
+
List[Dict[str, Any]]: List of messages in the format required by Anthropic
|
72
|
+
"""
|
73
|
+
messages = []
|
74
|
+
|
75
|
+
# Add conversation history if present
|
76
|
+
if input_data.conversation_history:
|
77
|
+
messages.extend(input_data.conversation_history)
|
85
78
|
|
86
|
-
|
87
|
-
|
79
|
+
# Prepare user message content
|
80
|
+
user_message = {"type": "text", "text": input_data.user_input}
|
88
81
|
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
82
|
+
# Add images if present
|
83
|
+
if input_data.images:
|
84
|
+
content = []
|
85
|
+
for image_path in input_data.images:
|
86
|
+
with open(image_path, "rb") as img_file:
|
87
|
+
base64_image = base64.b64encode(img_file.read()).decode("utf-8")
|
88
|
+
content.append(
|
89
|
+
{
|
90
|
+
"type": "image",
|
91
|
+
"source": {
|
92
|
+
"type": "base64",
|
93
|
+
"media_type": "image/jpeg",
|
94
|
+
"data": base64_image,
|
95
|
+
},
|
96
|
+
}
|
97
|
+
)
|
98
|
+
content.append(user_message)
|
99
|
+
messages.append({"role": "user", "content": content})
|
100
|
+
else:
|
101
|
+
messages.append({"role": "user", "content": [user_message]})
|
102
|
+
|
103
|
+
return messages
|
104
|
+
|
105
|
+
def process(self, input_data: AnthropicInput) -> AnthropicOutput:
|
106
|
+
try:
|
107
|
+
# Build messages using the helper method
|
108
|
+
messages = self._build_messages(input_data)
|
94
109
|
|
95
|
-
# Create
|
110
|
+
# Create chat completion with system prompt as a separate parameter
|
96
111
|
response = self.client.messages.create(
|
97
112
|
model=input_data.model,
|
113
|
+
system=input_data.system_prompt, # System prompt passed directly
|
114
|
+
messages=messages,
|
98
115
|
max_tokens=input_data.max_tokens,
|
99
116
|
temperature=input_data.temperature,
|
100
|
-
system=input_data.system_prompt,
|
101
|
-
messages=[{"role": "user", "content": content}],
|
102
117
|
)
|
103
118
|
|
104
|
-
# Validate response content
|
105
|
-
if not response.content:
|
106
|
-
logger.error("Empty response received from Anthropic API")
|
107
|
-
raise ProcessingError("Empty response received from Anthropic API")
|
108
|
-
|
109
|
-
if not isinstance(response.content, list) or not response.content:
|
110
|
-
logger.error("Invalid response format from Anthropic API")
|
111
|
-
raise ProcessingError("Invalid response format from Anthropic API")
|
112
|
-
|
113
|
-
first_content = response.content[0]
|
114
|
-
if not hasattr(first_content, "text"):
|
115
|
-
logger.error("Response content does not contain text")
|
116
|
-
raise ProcessingError("Response content does not contain text")
|
117
|
-
|
118
|
-
logger.success("Successfully processed Anthropic request")
|
119
|
-
|
120
|
-
# Create output
|
121
119
|
return AnthropicOutput(
|
122
|
-
response=
|
123
|
-
used_model=
|
124
|
-
usage=
|
125
|
-
"input_tokens": response.usage.input_tokens,
|
126
|
-
"output_tokens": response.usage.output_tokens,
|
127
|
-
},
|
120
|
+
response=response.content[0].text,
|
121
|
+
used_model=input_data.model,
|
122
|
+
usage=response.usage.model_dump(),
|
128
123
|
)
|
129
124
|
|
130
|
-
except ProcessingError:
|
131
|
-
# Re-raise ProcessingError without modification
|
132
|
-
raise
|
133
125
|
except Exception as e:
|
134
126
|
logger.exception(f"Anthropic processing failed: {str(e)}")
|
135
127
|
raise ProcessingError(f"Anthropic processing failed: {str(e)}")
|
@@ -1,16 +1,13 @@
|
|
1
|
-
from pydantic import Field, SecretStr
|
1
|
+
from pydantic import Field, SecretStr
|
2
2
|
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
-
from typing import Optional
|
4
3
|
|
5
4
|
|
6
5
|
class CerebrasCredentials(BaseCredentials):
|
7
6
|
"""Cerebras credentials"""
|
8
7
|
|
9
|
-
|
10
|
-
endpoint_url: HttpUrl = Field(..., description="Cerebras API endpoint")
|
11
|
-
project_id: Optional[str] = Field(None, description="Cerebras Project ID")
|
8
|
+
cerebras_api_key: SecretStr = Field(..., description="Cerebras API key")
|
12
9
|
|
13
|
-
_required_credentials = {"
|
10
|
+
_required_credentials = {"cerebras_api_key"}
|
14
11
|
|
15
12
|
async def validate_credentials(self) -> bool:
|
16
13
|
"""Validate Cerebras credentials"""
|
@@ -1,27 +1,36 @@
|
|
1
|
-
from typing import Optional, Dict, Any
|
1
|
+
from typing import List, Optional, Dict, Any
|
2
2
|
from pydantic import Field
|
3
|
+
from cerebras.cloud.sdk import Cerebras
|
4
|
+
from loguru import logger
|
5
|
+
|
3
6
|
from airtrain.core.skills import Skill, ProcessingError
|
4
7
|
from airtrain.core.schemas import InputSchema, OutputSchema
|
5
8
|
from .credentials import CerebrasCredentials
|
6
9
|
|
7
10
|
|
8
11
|
class CerebrasInput(InputSchema):
|
9
|
-
"""Schema for Cerebras input"""
|
12
|
+
"""Schema for Cerebras chat input"""
|
10
13
|
|
11
14
|
user_input: str = Field(..., description="User's input text")
|
12
15
|
system_prompt: str = Field(
|
13
16
|
default="You are a helpful assistant.",
|
14
17
|
description="System prompt to guide the model's behavior",
|
15
18
|
)
|
16
|
-
|
17
|
-
|
19
|
+
conversation_history: List[Dict[str, str]] = Field(
|
20
|
+
default_factory=list,
|
21
|
+
description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
|
22
|
+
)
|
23
|
+
model: str = Field(default="llama3.1-8b", description="Cerebras model to use")
|
24
|
+
max_tokens: Optional[int] = Field(
|
25
|
+
default=1024, description="Maximum tokens in response"
|
26
|
+
)
|
18
27
|
temperature: float = Field(
|
19
28
|
default=0.7, description="Temperature for response generation", ge=0, le=1
|
20
29
|
)
|
21
30
|
|
22
31
|
|
23
32
|
class CerebrasOutput(OutputSchema):
|
24
|
-
"""Schema for Cerebras output"""
|
33
|
+
"""Schema for Cerebras chat output"""
|
25
34
|
|
26
35
|
response: str = Field(..., description="Model's response text")
|
27
36
|
used_model: str = Field(..., description="Model used for generation")
|
@@ -29,13 +38,58 @@ class CerebrasOutput(OutputSchema):
|
|
29
38
|
|
30
39
|
|
31
40
|
class CerebrasChatSkill(Skill[CerebrasInput, CerebrasOutput]):
|
32
|
-
"""Skill for Cerebras
|
41
|
+
"""Skill for Cerebras chat"""
|
33
42
|
|
34
43
|
input_schema = CerebrasInput
|
35
44
|
output_schema = CerebrasOutput
|
36
45
|
|
37
46
|
def __init__(self, credentials: Optional[CerebrasCredentials] = None):
|
38
|
-
|
47
|
+
super().__init__()
|
48
|
+
self.credentials = credentials or CerebrasCredentials.from_env()
|
49
|
+
self.client = Cerebras(
|
50
|
+
api_key=self.credentials.cerebras_api_key.get_secret_value()
|
51
|
+
)
|
52
|
+
|
53
|
+
def _build_messages(self, input_data: CerebrasInput) -> List[Dict[str, str]]:
|
54
|
+
"""
|
55
|
+
Build messages list from input data including conversation history.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
input_data: The input data containing system prompt, conversation history, and user input
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
List[Dict[str, str]]: List of messages in the format required by Cerebras
|
62
|
+
"""
|
63
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
64
|
+
|
65
|
+
# Add conversation history if present
|
66
|
+
if input_data.conversation_history:
|
67
|
+
messages.extend(input_data.conversation_history)
|
68
|
+
|
69
|
+
# Add current user input
|
70
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
71
|
+
|
72
|
+
return messages
|
39
73
|
|
40
74
|
def process(self, input_data: CerebrasInput) -> CerebrasOutput:
|
41
|
-
|
75
|
+
try:
|
76
|
+
# Build messages using the helper method
|
77
|
+
messages = self._build_messages(input_data)
|
78
|
+
|
79
|
+
# Create chat completion
|
80
|
+
response = self.client.chat.completions.create(
|
81
|
+
model=input_data.model,
|
82
|
+
messages=messages,
|
83
|
+
temperature=input_data.temperature,
|
84
|
+
max_tokens=input_data.max_tokens,
|
85
|
+
)
|
86
|
+
|
87
|
+
return CerebrasOutput(
|
88
|
+
response=response.choices[0].message.content,
|
89
|
+
used_model=input_data.model,
|
90
|
+
usage=response.usage.model_dump(),
|
91
|
+
)
|
92
|
+
|
93
|
+
except Exception as e:
|
94
|
+
logger.exception(f"Cerebras processing failed: {str(e)}")
|
95
|
+
raise ProcessingError(f"Cerebras processing failed: {str(e)}")
|
@@ -0,0 +1,109 @@
|
|
1
|
+
from typing import List, Dict, Optional
|
2
|
+
from pydantic import BaseModel, Field
|
3
|
+
from .skills import FireworksChatSkill, FireworksInput, FireworksOutput
|
4
|
+
|
5
|
+
# TODO: Test this thing.
|
6
|
+
|
7
|
+
|
8
|
+
class ConversationState(BaseModel):
|
9
|
+
"""Model to track conversation state"""
|
10
|
+
|
11
|
+
messages: List[Dict[str, str]] = Field(
|
12
|
+
default_factory=list, description="List of conversation messages"
|
13
|
+
)
|
14
|
+
system_prompt: str = Field(
|
15
|
+
default="You are a helpful assistant.",
|
16
|
+
description="System prompt for the conversation",
|
17
|
+
)
|
18
|
+
model: str = Field(
|
19
|
+
default="accounts/fireworks/models/deepseek-r1",
|
20
|
+
description="Model being used for the conversation",
|
21
|
+
)
|
22
|
+
temperature: float = Field(default=0.7, description="Temperature setting")
|
23
|
+
max_tokens: Optional[int] = Field(default=None, description="Max tokens setting")
|
24
|
+
|
25
|
+
|
26
|
+
class FireworksConversationManager:
|
27
|
+
"""Manager for handling conversation state with Fireworks AI"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
skill: Optional[FireworksChatSkill] = None,
|
32
|
+
system_prompt: str = "You are a helpful assistant.",
|
33
|
+
model: str = "accounts/fireworks/models/deepseek-r1",
|
34
|
+
temperature: float = 0.7,
|
35
|
+
max_tokens: Optional[int] = None,
|
36
|
+
):
|
37
|
+
"""
|
38
|
+
Initialize conversation manager.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
skill: FireworksChatSkill instance (creates new one if None)
|
42
|
+
system_prompt: Initial system prompt
|
43
|
+
model: Model to use
|
44
|
+
temperature: Temperature setting
|
45
|
+
max_tokens: Max tokens setting
|
46
|
+
"""
|
47
|
+
self.skill = skill or FireworksChatSkill()
|
48
|
+
self.state = ConversationState(
|
49
|
+
system_prompt=system_prompt,
|
50
|
+
model=model,
|
51
|
+
temperature=temperature,
|
52
|
+
max_tokens=max_tokens,
|
53
|
+
)
|
54
|
+
|
55
|
+
def send_message(self, user_input: str) -> FireworksOutput:
|
56
|
+
"""
|
57
|
+
Send a message and get response while maintaining conversation history.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
user_input: User's message
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
FireworksOutput: Model's response
|
64
|
+
"""
|
65
|
+
# Create input with current conversation state
|
66
|
+
input_data = FireworksInput(
|
67
|
+
user_input=user_input,
|
68
|
+
system_prompt=self.state.system_prompt,
|
69
|
+
conversation_history=self.state.messages,
|
70
|
+
model=self.state.model,
|
71
|
+
temperature=self.state.temperature,
|
72
|
+
max_tokens=self.state.max_tokens,
|
73
|
+
)
|
74
|
+
|
75
|
+
# Get response
|
76
|
+
result = self.skill.process(input_data)
|
77
|
+
|
78
|
+
# Update conversation history
|
79
|
+
self.state.messages.extend(
|
80
|
+
[
|
81
|
+
{"role": "user", "content": user_input},
|
82
|
+
{"role": "assistant", "content": result.response},
|
83
|
+
]
|
84
|
+
)
|
85
|
+
|
86
|
+
return result
|
87
|
+
|
88
|
+
def reset_conversation(self) -> None:
|
89
|
+
"""Reset the conversation history while maintaining other settings"""
|
90
|
+
self.state.messages = []
|
91
|
+
|
92
|
+
def get_conversation_history(self) -> List[Dict[str, str]]:
|
93
|
+
"""Get the current conversation history"""
|
94
|
+
return self.state.messages.copy()
|
95
|
+
|
96
|
+
def update_system_prompt(self, new_prompt: str) -> None:
|
97
|
+
"""Update the system prompt for future messages"""
|
98
|
+
self.state.system_prompt = new_prompt
|
99
|
+
|
100
|
+
def save_state(self, file_path: str) -> None:
|
101
|
+
"""Save conversation state to a file"""
|
102
|
+
with open(file_path, "w") as f:
|
103
|
+
f.write(self.state.model_dump_json(indent=2))
|
104
|
+
|
105
|
+
def load_state(self, file_path: str) -> None:
|
106
|
+
"""Load conversation state from a file"""
|
107
|
+
with open(file_path, "r") as f:
|
108
|
+
data = f.read()
|
109
|
+
self.state = ConversationState.model_validate_json(data)
|
@@ -17,6 +17,10 @@ class FireworksInput(InputSchema):
|
|
17
17
|
default="You are a helpful assistant.",
|
18
18
|
description="System prompt to guide the model's behavior",
|
19
19
|
)
|
20
|
+
conversation_history: List[Dict[str, str]] = Field(
|
21
|
+
default_factory=list,
|
22
|
+
description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
|
23
|
+
)
|
20
24
|
model: str = Field(
|
21
25
|
default="accounts/fireworks/models/deepseek-r1",
|
22
26
|
description="Fireworks AI model to use",
|
@@ -52,16 +56,34 @@ class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
|
|
52
56
|
self.credentials = credentials or FireworksCredentials.from_env()
|
53
57
|
self.base_url = "https://api.fireworks.ai/inference/v1"
|
54
58
|
|
59
|
+
def _build_messages(self, input_data: FireworksInput) -> List[Dict[str, str]]:
|
60
|
+
"""
|
61
|
+
Build messages list from input data including conversation history.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
input_data: The input data containing system prompt, conversation history, and user input
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
List[Dict[str, str]]: List of messages in the format required by Fireworks AI
|
68
|
+
"""
|
69
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
70
|
+
|
71
|
+
# Add conversation history if present
|
72
|
+
if input_data.conversation_history:
|
73
|
+
messages.extend(input_data.conversation_history)
|
74
|
+
|
75
|
+
# Add current user input
|
76
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
77
|
+
|
78
|
+
return messages
|
79
|
+
|
55
80
|
def process(self, input_data: FireworksInput) -> FireworksOutput:
|
56
81
|
"""Process the input using Fireworks AI API"""
|
57
82
|
try:
|
58
83
|
logger.info(f"Processing request with model {input_data.model}")
|
59
84
|
|
60
|
-
#
|
61
|
-
messages =
|
62
|
-
{"role": "system", "content": input_data.system_prompt},
|
63
|
-
{"role": "user", "content": input_data.user_input},
|
64
|
-
]
|
85
|
+
# Build messages using the helper method
|
86
|
+
messages = self._build_messages(input_data)
|
65
87
|
|
66
88
|
# Prepare request payload
|
67
89
|
payload = {
|
@@ -1,5 +1,8 @@
|
|
1
1
|
from pydantic import Field, SecretStr
|
2
2
|
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
import google.genai as genai
|
4
|
+
from google.cloud import storage
|
5
|
+
import os
|
3
6
|
|
4
7
|
# from google.cloud import storage
|
5
8
|
|
@@ -26,3 +29,30 @@ class GoogleCloudCredentials(BaseCredentials):
|
|
26
29
|
raise CredentialValidationError(
|
27
30
|
f"Invalid Google Cloud credentials: {str(e)}"
|
28
31
|
)
|
32
|
+
|
33
|
+
|
34
|
+
class GeminiCredentials(BaseCredentials):
|
35
|
+
"""Gemini API credentials"""
|
36
|
+
|
37
|
+
gemini_api_key: SecretStr = Field(..., description="Gemini API Key")
|
38
|
+
|
39
|
+
_required_credentials = {"gemini_api_key"}
|
40
|
+
|
41
|
+
@classmethod
|
42
|
+
def from_env(cls) -> "GeminiCredentials":
|
43
|
+
"""Create credentials from environment variables"""
|
44
|
+
return cls(gemini_api_key=SecretStr(os.environ.get("GEMINI_API_KEY", "")))
|
45
|
+
|
46
|
+
async def validate_credentials(self) -> bool:
|
47
|
+
"""Validate Gemini API credentials"""
|
48
|
+
try:
|
49
|
+
# Configure Gemini with API key
|
50
|
+
genai.configure(api_key=self.gemini_api_key.get_secret_value())
|
51
|
+
|
52
|
+
# Test API call with a simple model
|
53
|
+
model = genai.GenerativeModel("gemini-1.5-flash")
|
54
|
+
response = model.generate_content("test")
|
55
|
+
|
56
|
+
return True
|
57
|
+
except Exception as e:
|
58
|
+
raise CredentialValidationError(f"Invalid Gemini credentials: {str(e)}")
|
@@ -1,41 +1,122 @@
|
|
1
|
-
from typing import Optional, Dict, Any
|
1
|
+
from typing import List, Optional, Dict, Any
|
2
2
|
from pydantic import Field
|
3
|
+
import google.generativeai as genai
|
4
|
+
from loguru import logger
|
5
|
+
|
3
6
|
from airtrain.core.skills import Skill, ProcessingError
|
4
7
|
from airtrain.core.schemas import InputSchema, OutputSchema
|
5
|
-
from .credentials import
|
8
|
+
from .credentials import GeminiCredentials
|
9
|
+
|
10
|
+
|
11
|
+
class GoogleGenerationConfig(InputSchema):
|
12
|
+
"""Schema for Google generation config"""
|
13
|
+
|
14
|
+
temperature: float = Field(
|
15
|
+
default=1.0, description="Temperature for response generation", ge=0, le=1
|
16
|
+
)
|
17
|
+
top_p: float = Field(
|
18
|
+
default=0.95, description="Top p sampling parameter", ge=0, le=1
|
19
|
+
)
|
20
|
+
top_k: int = Field(default=40, description="Top k sampling parameter")
|
21
|
+
max_output_tokens: int = Field(
|
22
|
+
default=8192, description="Maximum tokens in response"
|
23
|
+
)
|
24
|
+
response_mime_type: str = Field(
|
25
|
+
default="text/plain", description="Response MIME type"
|
26
|
+
)
|
6
27
|
|
7
28
|
|
8
|
-
class
|
9
|
-
"""Schema for Google
|
29
|
+
class GoogleInput(InputSchema):
|
30
|
+
"""Schema for Google chat input"""
|
10
31
|
|
11
32
|
user_input: str = Field(..., description="User's input text")
|
12
33
|
system_prompt: str = Field(
|
13
34
|
default="You are a helpful assistant.",
|
14
35
|
description="System prompt to guide the model's behavior",
|
15
36
|
)
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
37
|
+
conversation_history: List[Dict[str, str | List[Dict[str, str]]]] = Field(
|
38
|
+
default_factory=list,
|
39
|
+
description="List of conversation messages in Google's format",
|
40
|
+
)
|
41
|
+
model: str = Field(default="gemini-1.5-flash", description="Google model to use")
|
42
|
+
generation_config: GoogleGenerationConfig = Field(
|
43
|
+
default_factory=GoogleGenerationConfig, description="Generation configuration"
|
20
44
|
)
|
21
45
|
|
22
46
|
|
23
|
-
class
|
24
|
-
"""Schema for
|
47
|
+
class GoogleOutput(OutputSchema):
|
48
|
+
"""Schema for Google chat output"""
|
25
49
|
|
26
50
|
response: str = Field(..., description="Model's response text")
|
27
51
|
used_model: str = Field(..., description="Model used for generation")
|
28
52
|
usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
|
29
53
|
|
30
54
|
|
31
|
-
class
|
32
|
-
"""Skill for Google
|
55
|
+
class GoogleChatSkill(Skill[GoogleInput, GoogleOutput]):
|
56
|
+
"""Skill for Google chat"""
|
57
|
+
|
58
|
+
input_schema = GoogleInput
|
59
|
+
output_schema = GoogleOutput
|
60
|
+
|
61
|
+
def __init__(self, credentials: Optional[GeminiCredentials] = None):
|
62
|
+
super().__init__()
|
63
|
+
self.credentials = credentials or GeminiCredentials.from_env()
|
64
|
+
genai.configure(api_key=self.credentials.gemini_api_key.get_secret_value())
|
65
|
+
|
66
|
+
def _convert_history_format(
|
67
|
+
self, history: List[Dict[str, str]]
|
68
|
+
) -> List[Dict[str, List[Dict[str, str]]]]:
|
69
|
+
"""Convert standard history format to Google's format"""
|
70
|
+
google_history = []
|
71
|
+
for msg in history:
|
72
|
+
google_msg = {
|
73
|
+
"role": "user" if msg["role"] == "user" else "model",
|
74
|
+
"parts": [{"text": msg["content"]}],
|
75
|
+
}
|
76
|
+
google_history.append(google_msg)
|
77
|
+
return google_history
|
78
|
+
|
79
|
+
def process(self, input_data: GoogleInput) -> GoogleOutput:
|
80
|
+
try:
|
81
|
+
# Create generation config
|
82
|
+
generation_config = {
|
83
|
+
"temperature": input_data.generation_config.temperature,
|
84
|
+
"top_p": input_data.generation_config.top_p,
|
85
|
+
"top_k": input_data.generation_config.top_k,
|
86
|
+
"max_output_tokens": input_data.generation_config.max_output_tokens,
|
87
|
+
"response_mime_type": input_data.generation_config.response_mime_type,
|
88
|
+
}
|
89
|
+
|
90
|
+
# Initialize model
|
91
|
+
model = genai.GenerativeModel(
|
92
|
+
model_name=input_data.model,
|
93
|
+
generation_config=generation_config,
|
94
|
+
system_instruction=input_data.system_prompt,
|
95
|
+
)
|
96
|
+
|
97
|
+
# Convert history format if needed
|
98
|
+
history = (
|
99
|
+
input_data.conversation_history
|
100
|
+
if input_data.conversation_history
|
101
|
+
else self._convert_history_format([])
|
102
|
+
)
|
103
|
+
|
104
|
+
# Start chat session
|
105
|
+
chat = model.start_chat(history=history)
|
33
106
|
|
34
|
-
|
35
|
-
|
107
|
+
# Send message and get response
|
108
|
+
response = chat.send_message(input_data.user_input)
|
36
109
|
|
37
|
-
|
38
|
-
|
110
|
+
return GoogleOutput(
|
111
|
+
response=response.text,
|
112
|
+
used_model=input_data.model,
|
113
|
+
usage={
|
114
|
+
"prompt_tokens": 0,
|
115
|
+
"completion_tokens": 0,
|
116
|
+
"total_tokens": 0,
|
117
|
+
}, # Google API doesn't provide usage stats
|
118
|
+
)
|
39
119
|
|
40
|
-
|
41
|
-
|
120
|
+
except Exception as e:
|
121
|
+
logger.exception(f"Google processing failed: {str(e)}")
|
122
|
+
raise ProcessingError(f"Google processing failed: {str(e)}")
|
@@ -6,14 +6,14 @@ from groq import Groq
|
|
6
6
|
class GroqCredentials(BaseCredentials):
|
7
7
|
"""Groq API credentials"""
|
8
8
|
|
9
|
-
|
9
|
+
groq_api_key: SecretStr = Field(..., description="Groq API key")
|
10
10
|
|
11
|
-
_required_credentials = {"
|
11
|
+
_required_credentials = {"groq_api_key"}
|
12
12
|
|
13
13
|
async def validate_credentials(self) -> bool:
|
14
14
|
"""Validate Groq credentials"""
|
15
15
|
try:
|
16
|
-
client = Groq(api_key=self.
|
16
|
+
client = Groq(api_key=self.groq_api_key.get_secret_value())
|
17
17
|
await client.chat.completions.create(
|
18
18
|
messages=[{"role": "user", "content": "Hi"}],
|
19
19
|
model="mixtral-8x7b-32768",
|
@@ -1,8 +1,9 @@
|
|
1
|
-
from typing import Optional, Dict, Any
|
1
|
+
from typing import Optional, Dict, Any, List
|
2
2
|
from pydantic import Field
|
3
3
|
from airtrain.core.skills import Skill, ProcessingError
|
4
4
|
from airtrain.core.schemas import InputSchema, OutputSchema
|
5
5
|
from .credentials import GroqCredentials
|
6
|
+
from groq import Groq
|
6
7
|
|
7
8
|
|
8
9
|
class GroqInput(InputSchema):
|
@@ -13,6 +14,10 @@ class GroqInput(InputSchema):
|
|
13
14
|
default="You are a helpful assistant.",
|
14
15
|
description="System prompt to guide the model's behavior",
|
15
16
|
)
|
17
|
+
conversation_history: List[Dict[str, str]] = Field(
|
18
|
+
default_factory=list,
|
19
|
+
description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
|
20
|
+
)
|
16
21
|
model: str = Field(default="mixtral-8x7b", description="Groq model to use")
|
17
22
|
max_tokens: int = Field(default=1024, description="Maximum tokens in response")
|
18
23
|
temperature: float = Field(
|
@@ -29,13 +34,55 @@ class GroqOutput(OutputSchema):
|
|
29
34
|
|
30
35
|
|
31
36
|
class GroqChatSkill(Skill[GroqInput, GroqOutput]):
|
32
|
-
"""Skill for Groq
|
37
|
+
"""Skill for Groq chat"""
|
33
38
|
|
34
39
|
input_schema = GroqInput
|
35
40
|
output_schema = GroqOutput
|
36
41
|
|
37
42
|
def __init__(self, credentials: Optional[GroqCredentials] = None):
|
38
|
-
|
43
|
+
super().__init__()
|
44
|
+
self.credentials = credentials or GroqCredentials.from_env()
|
45
|
+
self.client = Groq(api_key=self.credentials.groq_api_key.get_secret_value())
|
46
|
+
|
47
|
+
def _build_messages(self, input_data: GroqInput) -> List[Dict[str, str]]:
|
48
|
+
"""
|
49
|
+
Build messages list from input data including conversation history.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
input_data: The input data containing system prompt, conversation history, and user input
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
List[Dict[str, str]]: List of messages in the format required by Groq
|
56
|
+
"""
|
57
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
58
|
+
|
59
|
+
# Add conversation history if present
|
60
|
+
if input_data.conversation_history:
|
61
|
+
messages.extend(input_data.conversation_history)
|
62
|
+
|
63
|
+
# Add current user input
|
64
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
65
|
+
|
66
|
+
return messages
|
39
67
|
|
40
68
|
def process(self, input_data: GroqInput) -> GroqOutput:
|
41
|
-
|
69
|
+
try:
|
70
|
+
# Build messages using the helper method
|
71
|
+
messages = self._build_messages(input_data)
|
72
|
+
|
73
|
+
# Create chat completion
|
74
|
+
response = self.client.chat.completions.create(
|
75
|
+
model=input_data.model,
|
76
|
+
messages=messages,
|
77
|
+
temperature=input_data.temperature,
|
78
|
+
max_tokens=input_data.max_tokens,
|
79
|
+
)
|
80
|
+
|
81
|
+
return GroqOutput(
|
82
|
+
response=response.choices[0].message.content,
|
83
|
+
used_model=input_data.model,
|
84
|
+
usage=response.usage.model_dump(),
|
85
|
+
)
|
86
|
+
|
87
|
+
except Exception as e:
|
88
|
+
raise ProcessingError(f"Groq processing failed: {str(e)}")
|
@@ -5,10 +5,10 @@ from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
|
5
5
|
class SambanovaCredentials(BaseCredentials):
|
6
6
|
"""SambaNova credentials"""
|
7
7
|
|
8
|
-
|
9
|
-
|
8
|
+
sambanova_api_key: SecretStr = Field(..., description="SambaNova API key")
|
9
|
+
sambanova_endpoint_url: HttpUrl = Field(..., description="SambaNova API endpoint")
|
10
10
|
|
11
|
-
_required_credentials = {"
|
11
|
+
_required_credentials = {"sambanova_api_key", "sambanova_endpoint_url"}
|
12
12
|
|
13
13
|
async def validate_credentials(self) -> bool:
|
14
14
|
"""Validate SambaNova credentials"""
|
@@ -1,8 +1,9 @@
|
|
1
|
-
from typing import Optional, Dict, Any
|
1
|
+
from typing import Optional, Dict, Any, List
|
2
2
|
from pydantic import Field
|
3
3
|
from airtrain.core.skills import Skill, ProcessingError
|
4
4
|
from airtrain.core.schemas import InputSchema, OutputSchema
|
5
5
|
from .credentials import SambanovaCredentials
|
6
|
+
import openai
|
6
7
|
|
7
8
|
|
8
9
|
class SambanovaInput(InputSchema):
|
@@ -13,11 +14,20 @@ class SambanovaInput(InputSchema):
|
|
13
14
|
default="You are a helpful assistant.",
|
14
15
|
description="System prompt to guide the model's behavior",
|
15
16
|
)
|
16
|
-
|
17
|
+
conversation_history: List[Dict[str, str]] = Field(
|
18
|
+
default_factory=list,
|
19
|
+
description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
|
20
|
+
)
|
21
|
+
model: str = Field(
|
22
|
+
default="DeepSeek-R1-Distill-Llama-70B", description="Sambanova model to use"
|
23
|
+
)
|
17
24
|
max_tokens: int = Field(default=1024, description="Maximum tokens in response")
|
18
25
|
temperature: float = Field(
|
19
26
|
default=0.7, description="Temperature for response generation", ge=0, le=1
|
20
27
|
)
|
28
|
+
top_p: float = Field(
|
29
|
+
default=0.1, description="Top p sampling parameter", ge=0, le=1
|
30
|
+
)
|
21
31
|
|
22
32
|
|
23
33
|
class SambanovaOutput(OutputSchema):
|
@@ -29,13 +39,59 @@ class SambanovaOutput(OutputSchema):
|
|
29
39
|
|
30
40
|
|
31
41
|
class SambanovaChatSkill(Skill[SambanovaInput, SambanovaOutput]):
|
32
|
-
"""Skill for Sambanova
|
42
|
+
"""Skill for Sambanova chat"""
|
33
43
|
|
34
44
|
input_schema = SambanovaInput
|
35
45
|
output_schema = SambanovaOutput
|
36
46
|
|
37
47
|
def __init__(self, credentials: Optional[SambanovaCredentials] = None):
|
38
|
-
|
48
|
+
super().__init__()
|
49
|
+
self.credentials = credentials or SambanovaCredentials.from_env()
|
50
|
+
self.client = openai.OpenAI(
|
51
|
+
api_key=self.credentials.sambanova_api_key.get_secret_value(),
|
52
|
+
base_url="https://api.sambanova.ai/v1",
|
53
|
+
)
|
54
|
+
|
55
|
+
def _build_messages(self, input_data: SambanovaInput) -> List[Dict[str, str]]:
|
56
|
+
"""
|
57
|
+
Build messages list from input data including conversation history.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
input_data: The input data containing system prompt, conversation history, and user input
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
List[Dict[str, str]]: List of messages in the format required by Sambanova
|
64
|
+
"""
|
65
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
66
|
+
|
67
|
+
# Add conversation history if present
|
68
|
+
if input_data.conversation_history:
|
69
|
+
messages.extend(input_data.conversation_history)
|
70
|
+
|
71
|
+
# Add current user input
|
72
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
73
|
+
|
74
|
+
return messages
|
39
75
|
|
40
76
|
def process(self, input_data: SambanovaInput) -> SambanovaOutput:
|
41
|
-
|
77
|
+
try:
|
78
|
+
# Build messages using the helper method
|
79
|
+
messages = self._build_messages(input_data)
|
80
|
+
|
81
|
+
# Create chat completion
|
82
|
+
response = self.client.chat.completions.create(
|
83
|
+
model=input_data.model,
|
84
|
+
messages=messages,
|
85
|
+
temperature=input_data.temperature,
|
86
|
+
max_tokens=input_data.max_tokens,
|
87
|
+
top_p=input_data.top_p,
|
88
|
+
)
|
89
|
+
|
90
|
+
return SambanovaOutput(
|
91
|
+
response=response.choices[0].message.content,
|
92
|
+
used_model=input_data.model,
|
93
|
+
usage=response.usage.model_dump(),
|
94
|
+
)
|
95
|
+
|
96
|
+
except Exception as e:
|
97
|
+
raise ProcessingError(f"Sambanova processing failed: {str(e)}")
|
@@ -34,12 +34,12 @@ class TogetherAIImageInput(InputSchema):
|
|
34
34
|
class GeneratedImage(OutputSchema):
|
35
35
|
"""Individual generated image data"""
|
36
36
|
|
37
|
-
b64_json: str = Field(
|
37
|
+
b64_json: Optional[str] = Field(None, description="Base64 encoded image data")
|
38
|
+
url: str = Field(..., description="URL of the generated image")
|
38
39
|
seed: Optional[int] = Field(None, description="Seed used for this image")
|
39
40
|
finish_reason: Optional[str] = Field(
|
40
41
|
None, description="Reason for finishing generation"
|
41
42
|
)
|
42
|
-
url: Optional[str] = Field(None, description="URL of the generated image")
|
43
43
|
|
44
44
|
|
45
45
|
class TogetherAIImageOutput(OutputSchema):
|
@@ -87,47 +87,19 @@ class TogetherAIImageSkill(Skill[TogetherAIImageInput, TogetherAIImageOutput]):
|
|
87
87
|
# Calculate total time
|
88
88
|
total_time = time.time() - start_time
|
89
89
|
|
90
|
-
# Debug print the response structure
|
91
|
-
print(f"Response type: {type(response)}")
|
92
|
-
print(f"Response data type: {type(response.data)}")
|
93
|
-
if response.data:
|
94
|
-
print(f"First image type: {type(response.data[0])}")
|
95
|
-
print(f"First image attributes: {dir(response.data[0])}")
|
96
|
-
|
97
90
|
# Convert response to our output format
|
98
91
|
generated_images = []
|
99
92
|
for img in response.data:
|
100
|
-
|
101
|
-
b64_data = None
|
102
|
-
for attr in ["b64_json", "image", "base64", "data"]:
|
103
|
-
if hasattr(img, attr):
|
104
|
-
b64_data = getattr(img, attr)
|
105
|
-
if b64_data:
|
106
|
-
break
|
107
|
-
|
108
|
-
if not b64_data:
|
109
|
-
# If no direct attribute found, try accessing as dictionary
|
110
|
-
try:
|
111
|
-
if hasattr(img, "__dict__"):
|
112
|
-
img_dict = img.__dict__
|
113
|
-
for key in ["b64_json", "image", "base64", "data"]:
|
114
|
-
if key in img_dict and img_dict[key]:
|
115
|
-
b64_data = img_dict[key]
|
116
|
-
break
|
117
|
-
except:
|
118
|
-
pass
|
119
|
-
|
120
|
-
if not b64_data:
|
93
|
+
if not hasattr(img, "url"):
|
121
94
|
raise ProcessingError(
|
122
|
-
f"No
|
95
|
+
f"No URL found in API response. Response structure: {dir(img)}"
|
123
96
|
)
|
124
97
|
|
125
98
|
generated_images.append(
|
126
99
|
GeneratedImage(
|
127
|
-
|
100
|
+
url=img.url,
|
128
101
|
seed=getattr(img, "seed", None),
|
129
102
|
finish_reason=getattr(img, "finish_reason", None),
|
130
|
-
url=getattr(img, "url", None),
|
131
103
|
)
|
132
104
|
)
|
133
105
|
|
@@ -18,8 +18,12 @@ class TogetherAIInput(InputSchema):
|
|
18
18
|
default="You are a helpful assistant.",
|
19
19
|
description="System prompt to guide the model's behavior",
|
20
20
|
)
|
21
|
+
conversation_history: List[Dict[str, str]] = Field(
|
22
|
+
default_factory=list,
|
23
|
+
description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
|
24
|
+
)
|
21
25
|
model: str = Field(
|
22
|
-
default="
|
26
|
+
default="deepseek-ai/DeepSeek-R1", description="Together AI model to use"
|
23
27
|
)
|
24
28
|
max_tokens: int = Field(default=1024, description="Maximum tokens in response")
|
25
29
|
temperature: float = Field(
|
@@ -36,16 +40,60 @@ class TogetherAIOutput(OutputSchema):
|
|
36
40
|
|
37
41
|
|
38
42
|
class TogetherAIChatSkill(Skill[TogetherAIInput, TogetherAIOutput]):
|
39
|
-
"""Skill for Together AI
|
43
|
+
"""Skill for Together AI chat"""
|
40
44
|
|
41
45
|
input_schema = TogetherAIInput
|
42
46
|
output_schema = TogetherAIOutput
|
43
47
|
|
44
48
|
def __init__(self, credentials: Optional[TogetherAICredentials] = None):
|
45
|
-
|
49
|
+
super().__init__()
|
50
|
+
self.credentials = credentials or TogetherAICredentials.from_env()
|
51
|
+
self.client = Together(
|
52
|
+
api_key=self.credentials.together_api_key.get_secret_value()
|
53
|
+
)
|
54
|
+
|
55
|
+
def _build_messages(self, input_data: TogetherAIInput) -> List[Dict[str, str]]:
|
56
|
+
"""
|
57
|
+
Build messages list from input data including conversation history.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
input_data: The input data containing system prompt, conversation history, and user input
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
List[Dict[str, str]]: List of messages in the format required by Together AI
|
64
|
+
"""
|
65
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
66
|
+
|
67
|
+
# Add conversation history if present
|
68
|
+
if input_data.conversation_history:
|
69
|
+
messages.extend(input_data.conversation_history)
|
70
|
+
|
71
|
+
# Add current user input
|
72
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
73
|
+
|
74
|
+
return messages
|
46
75
|
|
47
76
|
def process(self, input_data: TogetherAIInput) -> TogetherAIOutput:
|
48
|
-
|
77
|
+
try:
|
78
|
+
# Build messages using the helper method
|
79
|
+
messages = self._build_messages(input_data)
|
80
|
+
|
81
|
+
# Create chat completion
|
82
|
+
response = self.client.chat.completions.create(
|
83
|
+
model=input_data.model,
|
84
|
+
messages=messages,
|
85
|
+
max_tokens=input_data.max_tokens,
|
86
|
+
temperature=input_data.temperature,
|
87
|
+
)
|
88
|
+
|
89
|
+
return TogetherAIOutput(
|
90
|
+
response=response.choices[0].message.content,
|
91
|
+
used_model=input_data.model,
|
92
|
+
usage=response.usage.model_dump(),
|
93
|
+
)
|
94
|
+
|
95
|
+
except Exception as e:
|
96
|
+
raise ProcessingError(f"Together AI processing failed: {str(e)}")
|
49
97
|
|
50
98
|
|
51
99
|
class TogetherAIImageSkill(Skill[TogetherAIImageInput, TogetherAIImageOutput]):
|
@@ -1,4 +1,4 @@
|
|
1
|
-
airtrain/__init__.py,sha256=
|
1
|
+
airtrain/__init__.py,sha256=xaeCzvEPJ7ZQSTLIn1ZHb_bqstM__aElFJ6dQnRNfuA,2099
|
2
2
|
airtrain/contrib/__init__.py,sha256=pG-7mJ0pBMqp3Q86mIF9bo1PqoBOVSGlnEK1yY1U1ok,641
|
3
3
|
airtrain/contrib/travel/__init__.py,sha256=clmBodw4nkTA-DsgjVGcXfJGPaWxIpCZDtdO-8RzL0M,811
|
4
4
|
airtrain/contrib/travel/agents.py,sha256=tpQtZ0WUiXBuhvZtc2JlEam5TuR5l-Tndi14YyImDBM,8975
|
@@ -7,26 +7,27 @@ airtrain/core/__init__.py,sha256=9h7iKwTzZocCPc9bU6j8bA02BokteWIOcO1uaqGMcrk,254
|
|
7
7
|
airtrain/core/credentials.py,sha256=PgQotrQc46J5djidKnkK1znUv3fyNkUFDO-m2Kn_Gzo,4006
|
8
8
|
airtrain/core/schemas.py,sha256=MMXrDviC4gRea_QaPpbjgO--B_UKxnD7YrxqZOLJZZU,7003
|
9
9
|
airtrain/core/skills.py,sha256=LljalzeSHK5eQPTAOEAYc5D8Qn1kVSfiz9WgziTD5UM,4688
|
10
|
-
airtrain/integrations/__init__.py,sha256
|
10
|
+
airtrain/integrations/__init__.py,sha256=-3Vz2bqAUNvVHEZxFGUv5BfzJZsO_7MRyLifbHsweEE,1488
|
11
11
|
airtrain/integrations/anthropic/__init__.py,sha256=qwlWLDh1rEVizYFbW8430z-f1SxHio7_Gaw5cCTUtoo,274
|
12
12
|
airtrain/integrations/anthropic/credentials.py,sha256=hlTSw9HX66kYNaeQUtn0JjdZQBMNkzzFOJOoLOOzvcY,1246
|
13
|
-
airtrain/integrations/anthropic/skills.py,sha256=
|
13
|
+
airtrain/integrations/anthropic/skills.py,sha256=3yr9-X_O_sw7Z57qiqi9A7uuo6EREmXqHFG7MAwcxDw,4645
|
14
14
|
airtrain/integrations/aws/__init__.py,sha256=3x7v2NxpAfI-U-YgwQeH5PtsmUrNLPMfLyUGFLiBjbs,155
|
15
15
|
airtrain/integrations/aws/credentials.py,sha256=nN-daKAl7qOb_VdRpsThG8gN5GeSUkx-ji5E_gF_vYw,1444
|
16
16
|
airtrain/integrations/aws/skills.py,sha256=TQiMXeXRRcJ14fe8Xi7Uk20iS6_INbcznuLGtMorcKY,3870
|
17
17
|
airtrain/integrations/cerebras/__init__.py,sha256=zAD-qV38OzHhMCz1z-NvjjqcYEhURbm8RWTOKHNqbew,174
|
18
|
-
airtrain/integrations/cerebras/credentials.py,sha256=
|
19
|
-
airtrain/integrations/cerebras/skills.py,sha256=
|
18
|
+
airtrain/integrations/cerebras/credentials.py,sha256=KDEH4r8FGT68L9p34MLZWK65wq_a703pqIF3ODaSbts,694
|
19
|
+
airtrain/integrations/cerebras/skills.py,sha256=Ksggq_s5wHWlf_xQIOO8MFoNTYV0cm9SHZ7GESOd2YE,3527
|
20
20
|
airtrain/integrations/fireworks/__init__.py,sha256=9pJvP0u1FJbNtB0oHa09mHVJLctELf_c27LOYyDk2ZI,271
|
21
|
+
airtrain/integrations/fireworks/conversation_manager.py,sha256=m6VEHijqpYEYawkKhuHtb8RQxw4kxGWFWdbSK6zGuro,3704
|
21
22
|
airtrain/integrations/fireworks/credentials.py,sha256=UpcwR9V5Hbk5sJbjFDJDbHMRqc90IQSqAvrtJCOvwEo,524
|
22
23
|
airtrain/integrations/fireworks/models.py,sha256=F-MddbLCLAsTjwRr1l6IpJxOegyY4pD7jN9ySPiypSo,593
|
23
|
-
airtrain/integrations/fireworks/skills.py,sha256
|
24
|
-
airtrain/integrations/google/__init__.py,sha256=
|
25
|
-
airtrain/integrations/google/credentials.py,sha256=
|
26
|
-
airtrain/integrations/google/skills.py,sha256=
|
24
|
+
airtrain/integrations/fireworks/skills.py,sha256=-6zfe5eooygDYL1cyrk4PUT2TLEmAT2ZRfH2cr42DaU,4960
|
25
|
+
airtrain/integrations/google/__init__.py,sha256=ElwgcXfbg_gGMm6zbkMXCQPFKZUb-yTJk986o19A7Cs,214
|
26
|
+
airtrain/integrations/google/credentials.py,sha256=KSvWNqW8Mjr4MkysRvUqlrOSGdShNIe5u2OPO6vRrWY,2047
|
27
|
+
airtrain/integrations/google/skills.py,sha256=ytsoksCY4qbfRO9Brnxhc2694fAj0ytnHX20SXS_FOM,4547
|
27
28
|
airtrain/integrations/groq/__init__.py,sha256=B_X2fXbsJfFD6GquKeVCsEJjwd9Ygbq1uEHlV4Jy7YE,154
|
28
|
-
airtrain/integrations/groq/credentials.py,sha256=
|
29
|
-
airtrain/integrations/groq/skills.py,sha256=
|
29
|
+
airtrain/integrations/groq/credentials.py,sha256=bdTHykcIeaQ7td8KZlQBPfEFAkvJuxk2f_cbTLPD_I4,844
|
30
|
+
airtrain/integrations/groq/skills.py,sha256=oXgNLV8qRrYEY2VotKcaESQrQZAXJKzWLaHpVAp88VM,3269
|
30
31
|
airtrain/integrations/ollama/__init__.py,sha256=zMHBsGzViVrvxAeJmfq6r-ZfSE6Dy5QcKLhe4d5fEcM,164
|
31
32
|
airtrain/integrations/ollama/credentials.py,sha256=D7O4kUAb_VHs5s1ncUN9Ezhu5PvLfgj3RifAkB9sEZk,940
|
32
33
|
airtrain/integrations/ollama/skills.py,sha256=M_Un8D5VJ5XtPEq9IClzqV3jCPBoFTSm2ve6EO8W2JU,1556
|
@@ -36,22 +37,22 @@ airtrain/integrations/openai/credentials.py,sha256=NfRyp1QgEtgm8cxt2-BOLq-6d0X-P
|
|
36
37
|
airtrain/integrations/openai/models_config.py,sha256=bzosqqpDy2AJxu2vGdk2H4voqEGlv7LORR6fpJLhNic,3962
|
37
38
|
airtrain/integrations/openai/skills.py,sha256=Olg9-6f_p2XgkVwwcB9tvjAMApmM2EK81i8LP4qVVvs,7676
|
38
39
|
airtrain/integrations/sambanova/__init__.py,sha256=dp_263iOckM_J9pOEvyqpf3FrejD6-_x33r0edMCTe0,179
|
39
|
-
airtrain/integrations/sambanova/credentials.py,sha256=
|
40
|
-
airtrain/integrations/sambanova/skills.py,sha256=
|
40
|
+
airtrain/integrations/sambanova/credentials.py,sha256=JyN8sbMCoXuXAjim46aI3LTicBijoemS7Ao0rn4yBJU,824
|
41
|
+
airtrain/integrations/sambanova/skills.py,sha256=SDFY-ZzhOEIxQgTkQJzX9gN7UDqqnCBJdK7I2JydIoY,3625
|
41
42
|
airtrain/integrations/together/__init__.py,sha256=we4KXn_pUs6Dxo3QcB-t40BSRraQFdKg2nXw7yi2FjM,185
|
42
43
|
airtrain/integrations/together/audio_models_config.py,sha256=GtqfmKR1vJ5x4B3kScvEO3x4exvzwNP78vcGVTk_fBE,1004
|
43
44
|
airtrain/integrations/together/credentials.py,sha256=cYNhyIwgsxm8LfiFfT-omBvgV3mUP6SZeRSukyzzDlI,747
|
44
45
|
airtrain/integrations/together/embedding_models_config.py,sha256=F0ISAXCG_Pcnf-ojkvZwIXacXD8LaU8hQmGHCFzmlds,2927
|
45
46
|
airtrain/integrations/together/image_models_config.py,sha256=JlCozrphI9zE4uYpGfj4DCWSN6GZGyr84Tb1HmjNQ28,2455
|
46
|
-
airtrain/integrations/together/image_skill.py,sha256=
|
47
|
+
airtrain/integrations/together/image_skill.py,sha256=wQ8wSzfL-QHpM_esYGLNXf8ciOPPsz-QJw6zSrxZT68,5214
|
47
48
|
airtrain/integrations/together/models.py,sha256=ZW5xfEN9fU18aeltb-sB2O-Bnu5sLkDPZqvUtxgoH-U,2112
|
48
49
|
airtrain/integrations/together/models_config.py,sha256=XMKp0Oq1nWWnMMdNAZxkFXmJaURwWrwLE18kFXsMsRw,8829
|
49
50
|
airtrain/integrations/together/rerank_models_config.py,sha256=coCg0IOG2tU4L2uc2uPtPdoBwGjSc_zQwxENwdDuwHE,1188
|
50
51
|
airtrain/integrations/together/rerank_skill.py,sha256=gjH24hLWCweWKPyyfKZMG3K_g9gWzm80WgiJNjkA9eg,1894
|
51
52
|
airtrain/integrations/together/schemas.py,sha256=pBMrbX67oxPCr-sg4K8_Xqu1DWbaC4uLCloVSascROg,1210
|
52
|
-
airtrain/integrations/together/skills.py,sha256=
|
53
|
+
airtrain/integrations/together/skills.py,sha256=mUoHc2r5TYQi5iGzwz2aDuUeROGq7teCtNrOlNApef4,6276
|
53
54
|
airtrain/integrations/together/vision_models_config.py,sha256=m28HwYDk2Kup_J-a1FtynIa2ZVcbl37kltfoHnK8zxs,1544
|
54
|
-
airtrain-0.1.
|
55
|
-
airtrain-0.1.
|
56
|
-
airtrain-0.1.
|
57
|
-
airtrain-0.1.
|
55
|
+
airtrain-0.1.17.dist-info/METADATA,sha256=zJfPgsD-aCTdXLEhv--wIbX-ejhGLg7C-xc72ynNNOY,4536
|
56
|
+
airtrain-0.1.17.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
57
|
+
airtrain-0.1.17.dist-info/top_level.txt,sha256=cFWW1vY6VMCb3AGVdz6jBDpZ65xxBRSqlsPyySxTkxY,9
|
58
|
+
airtrain-0.1.17.dist-info/RECORD,,
|
File without changes
|
File without changes
|