airtrain 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +1 -1
- airtrain/integrations/__init__.py +22 -1
- airtrain/integrations/anthropic/__init__.py +11 -0
- airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain/integrations/anthropic/skills.py +135 -0
- airtrain/integrations/aws/__init__.py +6 -0
- airtrain/integrations/aws/credentials.py +36 -0
- airtrain/integrations/aws/skills.py +98 -0
- airtrain/integrations/cerebras/__init__.py +6 -0
- airtrain/integrations/cerebras/credentials.py +22 -0
- airtrain/integrations/cerebras/skills.py +41 -0
- airtrain/integrations/google/__init__.py +6 -0
- airtrain/integrations/google/credentials.py +27 -0
- airtrain/integrations/google/skills.py +41 -0
- airtrain/integrations/groq/__init__.py +6 -0
- airtrain/integrations/groq/credentials.py +24 -0
- airtrain/integrations/groq/skills.py +41 -0
- airtrain/integrations/ollama/__init__.py +6 -0
- airtrain/integrations/ollama/credentials.py +26 -0
- airtrain/integrations/ollama/skills.py +41 -0
- airtrain/integrations/openai/__init__.py +19 -0
- airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain/integrations/openai/credentials.py +39 -0
- airtrain/integrations/openai/skills.py +208 -0
- airtrain/integrations/sambanova/__init__.py +6 -0
- airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain/integrations/sambanova/skills.py +41 -0
- airtrain/integrations/together/__init__.py +6 -0
- airtrain/integrations/together/credentials.py +22 -0
- airtrain/integrations/together/skills.py +43 -0
- {airtrain-0.1.6.dist-info → airtrain-0.1.8.dist-info}/METADATA +1 -1
- airtrain-0.1.8.dist-info/RECORD +38 -0
- airtrain-0.1.6.dist-info/RECORD +0 -10
- {airtrain-0.1.6.dist-info → airtrain-0.1.8.dist-info}/WHEEL +0 -0
- {airtrain-0.1.6.dist-info → airtrain-0.1.8.dist-info}/top_level.txt +0 -0
airtrain/__init__.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
"""Airtrain integrations package"""
|
2
2
|
|
3
|
+
# Credentials imports
|
3
4
|
from .openai.credentials import OpenAICredentials
|
4
5
|
from .aws.credentials import AWSCredentials
|
5
6
|
from .google.credentials import GoogleCloudCredentials
|
@@ -10,17 +11,37 @@ from .ollama.credentials import OllamaCredentials
|
|
10
11
|
from .sambanova.credentials import SambanovaCredentials
|
11
12
|
from .cerebras.credentials import CerebrasCredentials
|
12
13
|
|
14
|
+
# Skills imports
|
15
|
+
from .openai.skills import OpenAIChatSkill, OpenAIParserSkill
|
13
16
|
from .anthropic.skills import AnthropicChatSkill
|
17
|
+
from .aws.skills import AWSBedrockSkill
|
18
|
+
from .google.skills import VertexAISkill
|
19
|
+
from .groq.skills import GroqChatSkill
|
20
|
+
from .together.skills import TogetherAIChatSkill
|
21
|
+
from .ollama.skills import OllamaChatSkill
|
22
|
+
from .sambanova.skills import SambanovaChatSkill
|
23
|
+
from .cerebras.skills import CerebrasChatSkill
|
14
24
|
|
15
25
|
__all__ = [
|
26
|
+
# Credentials
|
16
27
|
"OpenAICredentials",
|
17
28
|
"AWSCredentials",
|
18
29
|
"GoogleCloudCredentials",
|
19
30
|
"AnthropicCredentials",
|
20
|
-
"AnthropicChatSkill",
|
21
31
|
"GroqCredentials",
|
22
32
|
"TogetherAICredentials",
|
23
33
|
"OllamaCredentials",
|
24
34
|
"SambanovaCredentials",
|
25
35
|
"CerebrasCredentials",
|
36
|
+
# Skills
|
37
|
+
"OpenAIChatSkill",
|
38
|
+
"OpenAIParserSkill",
|
39
|
+
"AnthropicChatSkill",
|
40
|
+
"AWSBedrockSkill",
|
41
|
+
"VertexAISkill",
|
42
|
+
"GroqChatSkill",
|
43
|
+
"TogetherAIChatSkill",
|
44
|
+
"OllamaChatSkill",
|
45
|
+
"SambanovaChatSkill",
|
46
|
+
"CerebrasChatSkill",
|
26
47
|
]
|
@@ -0,0 +1,11 @@
|
|
1
|
+
"""Anthropic integration for Airtrain"""
|
2
|
+
|
3
|
+
from .credentials import AnthropicCredentials
|
4
|
+
from .skills import AnthropicChatSkill, AnthropicInput, AnthropicOutput
|
5
|
+
|
6
|
+
__all__ = [
|
7
|
+
"AnthropicCredentials",
|
8
|
+
"AnthropicChatSkill",
|
9
|
+
"AnthropicInput",
|
10
|
+
"AnthropicOutput",
|
11
|
+
]
|
@@ -0,0 +1,32 @@
|
|
1
|
+
from pydantic import Field, SecretStr, validator
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
from anthropic import Anthropic
|
4
|
+
|
5
|
+
|
6
|
+
class AnthropicCredentials(BaseCredentials):
|
7
|
+
"""Anthropic API credentials"""
|
8
|
+
|
9
|
+
anthropic_api_key: SecretStr = Field(..., description="Anthropic API key")
|
10
|
+
version: str = Field(default="2023-06-01", description="API Version")
|
11
|
+
|
12
|
+
_required_credentials = {"anthropic_api_key"}
|
13
|
+
|
14
|
+
@validator("anthropic_api_key")
|
15
|
+
def validate_api_key_format(cls, v: SecretStr) -> SecretStr:
|
16
|
+
key = v.get_secret_value()
|
17
|
+
if not key.startswith("sk-ant-"):
|
18
|
+
raise ValueError("Anthropic API key must start with 'sk-ant-'")
|
19
|
+
return v
|
20
|
+
|
21
|
+
async def validate_credentials(self) -> bool:
|
22
|
+
"""Validate Anthropic credentials"""
|
23
|
+
try:
|
24
|
+
client = Anthropic(api_key=self.anthropic_api_key.get_secret_value())
|
25
|
+
client.messages.create(
|
26
|
+
model="claude-3-opus-20240229",
|
27
|
+
max_tokens=1,
|
28
|
+
messages=[{"role": "user", "content": "Hi"}],
|
29
|
+
)
|
30
|
+
return True
|
31
|
+
except Exception as e:
|
32
|
+
raise CredentialValidationError(f"Invalid Anthropic credentials: {str(e)}")
|
@@ -0,0 +1,135 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any
|
2
|
+
from pydantic import Field
|
3
|
+
from anthropic import Anthropic
|
4
|
+
import base64
|
5
|
+
from pathlib import Path
|
6
|
+
from loguru import logger
|
7
|
+
|
8
|
+
from airtrain.core.skills import Skill, ProcessingError
|
9
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
10
|
+
from .credentials import AnthropicCredentials
|
11
|
+
|
12
|
+
|
13
|
+
class AnthropicInput(InputSchema):
|
14
|
+
"""Schema for Anthropic chat input"""
|
15
|
+
|
16
|
+
user_input: str = Field(..., description="User's input text")
|
17
|
+
system_prompt: str = Field(
|
18
|
+
default="You are a helpful assistant.",
|
19
|
+
description="System prompt to guide the model's behavior",
|
20
|
+
)
|
21
|
+
model: str = Field(
|
22
|
+
default="claude-3-opus-20240229", description="Anthropic model to use"
|
23
|
+
)
|
24
|
+
max_tokens: int = Field(default=1024, description="Maximum tokens in response")
|
25
|
+
temperature: float = Field(
|
26
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
27
|
+
)
|
28
|
+
images: Optional[List[Path]] = Field(
|
29
|
+
default=None,
|
30
|
+
description="Optional list of image paths to include in the message",
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
class AnthropicOutput(OutputSchema):
|
35
|
+
"""Schema for Anthropic chat output"""
|
36
|
+
|
37
|
+
response: str = Field(..., description="Model's response text")
|
38
|
+
used_model: str = Field(..., description="Model used for generation")
|
39
|
+
usage: Dict[str, Any] = Field(
|
40
|
+
default_factory=dict, description="Usage statistics from the API"
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
class AnthropicChatSkill(Skill[AnthropicInput, AnthropicOutput]):
|
45
|
+
"""Skill for interacting with Anthropic's Claude models"""
|
46
|
+
|
47
|
+
input_schema = AnthropicInput
|
48
|
+
output_schema = AnthropicOutput
|
49
|
+
|
50
|
+
def __init__(self, credentials: Optional[AnthropicCredentials] = None):
|
51
|
+
"""Initialize the skill with optional credentials"""
|
52
|
+
super().__init__()
|
53
|
+
self.credentials = credentials or AnthropicCredentials.from_env()
|
54
|
+
self.client = Anthropic(
|
55
|
+
api_key=self.credentials.anthropic_api_key.get_secret_value()
|
56
|
+
)
|
57
|
+
|
58
|
+
def _encode_image(self, image_path: Path) -> Dict[str, Any]:
|
59
|
+
"""Convert image to base64 for API consumption"""
|
60
|
+
try:
|
61
|
+
if not image_path.exists():
|
62
|
+
raise FileNotFoundError(f"Image file not found: {image_path}")
|
63
|
+
|
64
|
+
with open(image_path, "rb") as img_file:
|
65
|
+
encoded = base64.b64encode(img_file.read()).decode()
|
66
|
+
return {
|
67
|
+
"type": "image",
|
68
|
+
"source": {
|
69
|
+
"type": "base64",
|
70
|
+
"media_type": f"image/{image_path.suffix[1:]}",
|
71
|
+
"data": encoded,
|
72
|
+
},
|
73
|
+
}
|
74
|
+
except Exception as e:
|
75
|
+
logger.error(f"Failed to encode image {image_path}: {str(e)}")
|
76
|
+
raise ProcessingError(f"Image encoding failed: {str(e)}")
|
77
|
+
|
78
|
+
def process(self, input_data: AnthropicInput) -> AnthropicOutput:
|
79
|
+
"""Process the input using Anthropic's API"""
|
80
|
+
try:
|
81
|
+
logger.info(f"Processing request with model {input_data.model}")
|
82
|
+
|
83
|
+
# Prepare message content
|
84
|
+
content = []
|
85
|
+
|
86
|
+
# Add text content
|
87
|
+
content.append({"type": "text", "text": input_data.user_input})
|
88
|
+
|
89
|
+
# Add images if provided
|
90
|
+
if input_data.images:
|
91
|
+
logger.debug(f"Processing {len(input_data.images)} images")
|
92
|
+
for image_path in input_data.images:
|
93
|
+
content.append(self._encode_image(image_path))
|
94
|
+
|
95
|
+
# Create message
|
96
|
+
response = self.client.messages.create(
|
97
|
+
model=input_data.model,
|
98
|
+
max_tokens=input_data.max_tokens,
|
99
|
+
temperature=input_data.temperature,
|
100
|
+
system=input_data.system_prompt,
|
101
|
+
messages=[{"role": "user", "content": content}],
|
102
|
+
)
|
103
|
+
|
104
|
+
# Validate response content
|
105
|
+
if not response.content:
|
106
|
+
logger.error("Empty response received from Anthropic API")
|
107
|
+
raise ProcessingError("Empty response received from Anthropic API")
|
108
|
+
|
109
|
+
if not isinstance(response.content, list) or not response.content:
|
110
|
+
logger.error("Invalid response format from Anthropic API")
|
111
|
+
raise ProcessingError("Invalid response format from Anthropic API")
|
112
|
+
|
113
|
+
first_content = response.content[0]
|
114
|
+
if not hasattr(first_content, "text"):
|
115
|
+
logger.error("Response content does not contain text")
|
116
|
+
raise ProcessingError("Response content does not contain text")
|
117
|
+
|
118
|
+
logger.success("Successfully processed Anthropic request")
|
119
|
+
|
120
|
+
# Create output
|
121
|
+
return AnthropicOutput(
|
122
|
+
response=first_content.text,
|
123
|
+
used_model=response.model,
|
124
|
+
usage={
|
125
|
+
"input_tokens": response.usage.input_tokens,
|
126
|
+
"output_tokens": response.usage.output_tokens,
|
127
|
+
},
|
128
|
+
)
|
129
|
+
|
130
|
+
except ProcessingError:
|
131
|
+
# Re-raise ProcessingError without modification
|
132
|
+
raise
|
133
|
+
except Exception as e:
|
134
|
+
logger.exception(f"Anthropic processing failed: {str(e)}")
|
135
|
+
raise ProcessingError(f"Anthropic processing failed: {str(e)}")
|
@@ -0,0 +1,36 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
from pydantic import Field, SecretStr
|
3
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
4
|
+
import boto3
|
5
|
+
|
6
|
+
|
7
|
+
class AWSCredentials(BaseCredentials):
|
8
|
+
"""AWS credentials"""
|
9
|
+
|
10
|
+
aws_access_key_id: SecretStr = Field(..., description="AWS Access Key ID")
|
11
|
+
aws_secret_access_key: SecretStr = Field(..., description="AWS Secret Access Key")
|
12
|
+
aws_region: str = Field(default="us-east-1", description="AWS Region")
|
13
|
+
aws_session_token: Optional[SecretStr] = Field(
|
14
|
+
None, description="AWS Session Token"
|
15
|
+
)
|
16
|
+
|
17
|
+
_required_credentials = {"aws_access_key_id", "aws_secret_access_key"}
|
18
|
+
|
19
|
+
async def validate_credentials(self) -> bool:
|
20
|
+
"""Validate AWS credentials by making a test API call"""
|
21
|
+
try:
|
22
|
+
session = boto3.Session(
|
23
|
+
aws_access_key_id=self.aws_access_key_id.get_secret_value(),
|
24
|
+
aws_secret_access_key=self.aws_secret_access_key.get_secret_value(),
|
25
|
+
aws_session_token=(
|
26
|
+
self.aws_session_token.get_secret_value()
|
27
|
+
if self.aws_session_token
|
28
|
+
else None
|
29
|
+
),
|
30
|
+
region_name=self.aws_region,
|
31
|
+
)
|
32
|
+
sts = session.client("sts")
|
33
|
+
sts.get_caller_identity()
|
34
|
+
return True
|
35
|
+
except Exception as e:
|
36
|
+
raise CredentialValidationError(f"Invalid AWS credentials: {str(e)}")
|
@@ -0,0 +1,98 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any
|
2
|
+
from pydantic import Field
|
3
|
+
import boto3
|
4
|
+
from pathlib import Path
|
5
|
+
from loguru import logger
|
6
|
+
|
7
|
+
from airtrain.core.skills import Skill, ProcessingError
|
8
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
9
|
+
from .credentials import AWSCredentials
|
10
|
+
|
11
|
+
|
12
|
+
class AWSBedrockInput(InputSchema):
|
13
|
+
"""Schema for AWS Bedrock chat input"""
|
14
|
+
|
15
|
+
user_input: str = Field(..., description="User's input text")
|
16
|
+
system_prompt: str = Field(
|
17
|
+
default="You are a helpful assistant.",
|
18
|
+
description="System prompt to guide the model's behavior",
|
19
|
+
)
|
20
|
+
model: str = Field(
|
21
|
+
default="anthropic.claude-3-sonnet-20240229-v1:0",
|
22
|
+
description="AWS Bedrock model to use",
|
23
|
+
)
|
24
|
+
max_tokens: int = Field(default=1024, description="Maximum tokens in response")
|
25
|
+
temperature: float = Field(
|
26
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
27
|
+
)
|
28
|
+
images: Optional[List[Path]] = Field(
|
29
|
+
default=None,
|
30
|
+
description="Optional list of image paths to include in the message",
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
class AWSBedrockOutput(OutputSchema):
|
35
|
+
"""Schema for AWS Bedrock chat output"""
|
36
|
+
|
37
|
+
response: str = Field(..., description="Model's response text")
|
38
|
+
used_model: str = Field(..., description="Model used for generation")
|
39
|
+
usage: Dict[str, Any] = Field(
|
40
|
+
default_factory=dict, description="Usage statistics from the API"
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
class AWSBedrockSkill(Skill[AWSBedrockInput, AWSBedrockOutput]):
|
45
|
+
"""Skill for interacting with AWS Bedrock models"""
|
46
|
+
|
47
|
+
input_schema = AWSBedrockInput
|
48
|
+
output_schema = AWSBedrockOutput
|
49
|
+
|
50
|
+
def __init__(self, credentials: Optional[AWSCredentials] = None):
|
51
|
+
"""Initialize the skill with optional credentials"""
|
52
|
+
super().__init__()
|
53
|
+
self.credentials = credentials or AWSCredentials.from_env()
|
54
|
+
self.client = boto3.client(
|
55
|
+
"bedrock-runtime",
|
56
|
+
aws_access_key_id=self.credentials.aws_access_key_id.get_secret_value(),
|
57
|
+
aws_secret_access_key=self.credentials.aws_secret_access_key.get_secret_value(),
|
58
|
+
region_name=self.credentials.aws_region,
|
59
|
+
)
|
60
|
+
|
61
|
+
def process(self, input_data: AWSBedrockInput) -> AWSBedrockOutput:
|
62
|
+
"""Process the input using AWS Bedrock API"""
|
63
|
+
try:
|
64
|
+
logger.info(f"Processing request with model {input_data.model}")
|
65
|
+
|
66
|
+
# Prepare request body based on model provider
|
67
|
+
if "anthropic" in input_data.model:
|
68
|
+
request_body = {
|
69
|
+
"anthropic_version": "bedrock-2023-05-31",
|
70
|
+
"max_tokens": input_data.max_tokens,
|
71
|
+
"temperature": input_data.temperature,
|
72
|
+
"system": input_data.system_prompt,
|
73
|
+
"messages": [{"role": "user", "content": input_data.user_input}],
|
74
|
+
}
|
75
|
+
else:
|
76
|
+
raise ProcessingError(f"Unsupported model: {input_data.model}")
|
77
|
+
|
78
|
+
response = self.client.invoke_model(
|
79
|
+
modelId=input_data.model, body=request_body
|
80
|
+
)
|
81
|
+
|
82
|
+
# Parse response based on model provider
|
83
|
+
if "anthropic" in input_data.model:
|
84
|
+
response_data = response["body"]["completion"]
|
85
|
+
usage = {
|
86
|
+
"input_tokens": response["body"]["usage"]["input_tokens"],
|
87
|
+
"output_tokens": response["body"]["usage"]["output_tokens"],
|
88
|
+
}
|
89
|
+
else:
|
90
|
+
raise ProcessingError(f"Unsupported model response: {input_data.model}")
|
91
|
+
|
92
|
+
return AWSBedrockOutput(
|
93
|
+
response=response_data, used_model=input_data.model, usage=usage
|
94
|
+
)
|
95
|
+
|
96
|
+
except Exception as e:
|
97
|
+
logger.exception(f"AWS Bedrock processing failed: {str(e)}")
|
98
|
+
raise ProcessingError(f"AWS Bedrock processing failed: {str(e)}")
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from pydantic import Field, SecretStr, HttpUrl
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
|
6
|
+
class CerebrasCredentials(BaseCredentials):
|
7
|
+
"""Cerebras credentials"""
|
8
|
+
|
9
|
+
api_key: SecretStr = Field(..., description="Cerebras API key")
|
10
|
+
endpoint_url: HttpUrl = Field(..., description="Cerebras API endpoint")
|
11
|
+
project_id: Optional[str] = Field(None, description="Cerebras Project ID")
|
12
|
+
|
13
|
+
_required_credentials = {"api_key", "endpoint_url"}
|
14
|
+
|
15
|
+
async def validate_credentials(self) -> bool:
|
16
|
+
"""Validate Cerebras credentials"""
|
17
|
+
try:
|
18
|
+
# Implement Cerebras-specific validation
|
19
|
+
# This would depend on their API client implementation
|
20
|
+
return True
|
21
|
+
except Exception as e:
|
22
|
+
raise CredentialValidationError(f"Invalid Cerebras credentials: {str(e)}")
|
@@ -0,0 +1,41 @@
|
|
1
|
+
from typing import Optional, Dict, Any
|
2
|
+
from pydantic import Field
|
3
|
+
from airtrain.core.skills import Skill, ProcessingError
|
4
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
5
|
+
from .credentials import CerebrasCredentials
|
6
|
+
|
7
|
+
|
8
|
+
class CerebrasInput(InputSchema):
|
9
|
+
"""Schema for Cerebras input"""
|
10
|
+
|
11
|
+
user_input: str = Field(..., description="User's input text")
|
12
|
+
system_prompt: str = Field(
|
13
|
+
default="You are a helpful assistant.",
|
14
|
+
description="System prompt to guide the model's behavior",
|
15
|
+
)
|
16
|
+
model: str = Field(default="cerebras-gpt", description="Cerebras model to use")
|
17
|
+
max_tokens: int = Field(default=1024, description="Maximum tokens in response")
|
18
|
+
temperature: float = Field(
|
19
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
class CerebrasOutput(OutputSchema):
|
24
|
+
"""Schema for Cerebras output"""
|
25
|
+
|
26
|
+
response: str = Field(..., description="Model's response text")
|
27
|
+
used_model: str = Field(..., description="Model used for generation")
|
28
|
+
usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
|
29
|
+
|
30
|
+
|
31
|
+
class CerebrasChatSkill(Skill[CerebrasInput, CerebrasOutput]):
|
32
|
+
"""Skill for Cerebras - Not Implemented"""
|
33
|
+
|
34
|
+
input_schema = CerebrasInput
|
35
|
+
output_schema = CerebrasOutput
|
36
|
+
|
37
|
+
def __init__(self, credentials: Optional[CerebrasCredentials] = None):
|
38
|
+
raise NotImplementedError("CerebrasChatSkill is not implemented yet")
|
39
|
+
|
40
|
+
def process(self, input_data: CerebrasInput) -> CerebrasOutput:
|
41
|
+
raise NotImplementedError("CerebrasChatSkill is not implemented yet")
|
@@ -0,0 +1,27 @@
|
|
1
|
+
from pydantic import Field, SecretStr
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
from google.cloud import storage
|
4
|
+
|
5
|
+
|
6
|
+
class GoogleCloudCredentials(BaseCredentials):
|
7
|
+
"""Google Cloud credentials"""
|
8
|
+
|
9
|
+
project_id: str = Field(..., description="Google Cloud Project ID")
|
10
|
+
service_account_key: SecretStr = Field(..., description="Service Account Key JSON")
|
11
|
+
|
12
|
+
_required_credentials = {"project_id", "service_account_key"}
|
13
|
+
|
14
|
+
async def validate_credentials(self) -> bool:
|
15
|
+
"""Validate Google Cloud credentials"""
|
16
|
+
try:
|
17
|
+
# Initialize with service account key
|
18
|
+
storage_client = storage.Client.from_service_account_info(
|
19
|
+
self.service_account_key.get_secret_value()
|
20
|
+
)
|
21
|
+
# Test API call
|
22
|
+
storage_client.list_buckets(max_results=1)
|
23
|
+
return True
|
24
|
+
except Exception as e:
|
25
|
+
raise CredentialValidationError(
|
26
|
+
f"Invalid Google Cloud credentials: {str(e)}"
|
27
|
+
)
|
@@ -0,0 +1,41 @@
|
|
1
|
+
from typing import Optional, Dict, Any
|
2
|
+
from pydantic import Field
|
3
|
+
from airtrain.core.skills import Skill, ProcessingError
|
4
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
5
|
+
from .credentials import GoogleCloudCredentials
|
6
|
+
|
7
|
+
|
8
|
+
class VertexAIInput(InputSchema):
|
9
|
+
"""Schema for Google Vertex AI input"""
|
10
|
+
|
11
|
+
user_input: str = Field(..., description="User's input text")
|
12
|
+
system_prompt: str = Field(
|
13
|
+
default="You are a helpful assistant.",
|
14
|
+
description="System prompt to guide the model's behavior",
|
15
|
+
)
|
16
|
+
model: str = Field(default="text-bison", description="Vertex AI model to use")
|
17
|
+
max_tokens: int = Field(default=1024, description="Maximum tokens in response")
|
18
|
+
temperature: float = Field(
|
19
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
class VertexAIOutput(OutputSchema):
|
24
|
+
"""Schema for Vertex AI output"""
|
25
|
+
|
26
|
+
response: str = Field(..., description="Model's response text")
|
27
|
+
used_model: str = Field(..., description="Model used for generation")
|
28
|
+
usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
|
29
|
+
|
30
|
+
|
31
|
+
class VertexAISkill(Skill[VertexAIInput, VertexAIOutput]):
|
32
|
+
"""Skill for Google Vertex AI - Not Implemented"""
|
33
|
+
|
34
|
+
input_schema = VertexAIInput
|
35
|
+
output_schema = VertexAIOutput
|
36
|
+
|
37
|
+
def __init__(self, credentials: Optional[GoogleCloudCredentials] = None):
|
38
|
+
raise NotImplementedError("VertexAISkill is not implemented yet")
|
39
|
+
|
40
|
+
def process(self, input_data: VertexAIInput) -> VertexAIOutput:
|
41
|
+
raise NotImplementedError("VertexAISkill is not implemented yet")
|
@@ -0,0 +1,24 @@
|
|
1
|
+
from pydantic import Field, SecretStr
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
from groq import Groq
|
4
|
+
|
5
|
+
|
6
|
+
class GroqCredentials(BaseCredentials):
|
7
|
+
"""Groq API credentials"""
|
8
|
+
|
9
|
+
api_key: SecretStr = Field(..., description="Groq API key")
|
10
|
+
|
11
|
+
_required_credentials = {"api_key"}
|
12
|
+
|
13
|
+
async def validate_credentials(self) -> bool:
|
14
|
+
"""Validate Groq credentials"""
|
15
|
+
try:
|
16
|
+
client = Groq(api_key=self.api_key.get_secret_value())
|
17
|
+
await client.chat.completions.create(
|
18
|
+
messages=[{"role": "user", "content": "Hi"}],
|
19
|
+
model="mixtral-8x7b-32768",
|
20
|
+
max_tokens=1,
|
21
|
+
)
|
22
|
+
return True
|
23
|
+
except Exception as e:
|
24
|
+
raise CredentialValidationError(f"Invalid Groq credentials: {str(e)}")
|
@@ -0,0 +1,41 @@
|
|
1
|
+
from typing import Optional, Dict, Any
|
2
|
+
from pydantic import Field
|
3
|
+
from airtrain.core.skills import Skill, ProcessingError
|
4
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
5
|
+
from .credentials import GroqCredentials
|
6
|
+
|
7
|
+
|
8
|
+
class GroqInput(InputSchema):
|
9
|
+
"""Schema for Groq input"""
|
10
|
+
|
11
|
+
user_input: str = Field(..., description="User's input text")
|
12
|
+
system_prompt: str = Field(
|
13
|
+
default="You are a helpful assistant.",
|
14
|
+
description="System prompt to guide the model's behavior",
|
15
|
+
)
|
16
|
+
model: str = Field(default="mixtral-8x7b", description="Groq model to use")
|
17
|
+
max_tokens: int = Field(default=1024, description="Maximum tokens in response")
|
18
|
+
temperature: float = Field(
|
19
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
class GroqOutput(OutputSchema):
|
24
|
+
"""Schema for Groq output"""
|
25
|
+
|
26
|
+
response: str = Field(..., description="Model's response text")
|
27
|
+
used_model: str = Field(..., description="Model used for generation")
|
28
|
+
usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
|
29
|
+
|
30
|
+
|
31
|
+
class GroqChatSkill(Skill[GroqInput, GroqOutput]):
|
32
|
+
"""Skill for Groq - Not Implemented"""
|
33
|
+
|
34
|
+
input_schema = GroqInput
|
35
|
+
output_schema = GroqOutput
|
36
|
+
|
37
|
+
def __init__(self, credentials: Optional[GroqCredentials] = None):
|
38
|
+
raise NotImplementedError("GroqChatSkill is not implemented yet")
|
39
|
+
|
40
|
+
def process(self, input_data: GroqInput) -> GroqOutput:
|
41
|
+
raise NotImplementedError("GroqChatSkill is not implemented yet")
|
@@ -0,0 +1,26 @@
|
|
1
|
+
from pydantic import Field
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
from importlib.util import find_spec
|
4
|
+
|
5
|
+
|
6
|
+
class OllamaCredentials(BaseCredentials):
|
7
|
+
"""Ollama credentials"""
|
8
|
+
|
9
|
+
host: str = Field(default="http://localhost:11434", description="Ollama host URL")
|
10
|
+
timeout: int = Field(default=30, description="Request timeout in seconds")
|
11
|
+
|
12
|
+
async def validate_credentials(self) -> bool:
|
13
|
+
"""Validate Ollama credentials"""
|
14
|
+
if find_spec("ollama") is None:
|
15
|
+
raise CredentialValidationError(
|
16
|
+
"Ollama package is not installed. Please install it using: pip install ollama"
|
17
|
+
)
|
18
|
+
|
19
|
+
try:
|
20
|
+
from ollama import Client
|
21
|
+
|
22
|
+
client = Client(host=self.host)
|
23
|
+
await client.list()
|
24
|
+
return True
|
25
|
+
except Exception as e:
|
26
|
+
raise CredentialValidationError(f"Invalid Ollama connection: {str(e)}")
|
@@ -0,0 +1,41 @@
|
|
1
|
+
from typing import Optional, Dict, Any
|
2
|
+
from pydantic import Field
|
3
|
+
from airtrain.core.skills import Skill, ProcessingError
|
4
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
5
|
+
from .credentials import OllamaCredentials
|
6
|
+
|
7
|
+
|
8
|
+
class OllamaInput(InputSchema):
|
9
|
+
"""Schema for Ollama input"""
|
10
|
+
|
11
|
+
user_input: str = Field(..., description="User's input text")
|
12
|
+
system_prompt: str = Field(
|
13
|
+
default="You are a helpful assistant.",
|
14
|
+
description="System prompt to guide the model's behavior",
|
15
|
+
)
|
16
|
+
model: str = Field(default="llama2", description="Ollama model to use")
|
17
|
+
max_tokens: int = Field(default=1024, description="Maximum tokens in response")
|
18
|
+
temperature: float = Field(
|
19
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
class OllamaOutput(OutputSchema):
|
24
|
+
"""Schema for Ollama output"""
|
25
|
+
|
26
|
+
response: str = Field(..., description="Model's response text")
|
27
|
+
used_model: str = Field(..., description="Model used for generation")
|
28
|
+
usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
|
29
|
+
|
30
|
+
|
31
|
+
class OllamaChatSkill(Skill[OllamaInput, OllamaOutput]):
|
32
|
+
"""Skill for Ollama - Not Implemented"""
|
33
|
+
|
34
|
+
input_schema = OllamaInput
|
35
|
+
output_schema = OllamaOutput
|
36
|
+
|
37
|
+
def __init__(self, credentials: Optional[OllamaCredentials] = None):
|
38
|
+
raise NotImplementedError("OllamaChatSkill is not implemented yet")
|
39
|
+
|
40
|
+
def process(self, input_data: OllamaInput) -> OllamaOutput:
|
41
|
+
raise NotImplementedError("OllamaChatSkill is not implemented yet")
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from .skills import (
|
2
|
+
OpenAIChatSkill,
|
3
|
+
OpenAIInput,
|
4
|
+
OpenAIParserSkill,
|
5
|
+
OpenAIOutput,
|
6
|
+
OpenAIParserInput,
|
7
|
+
OpenAIParserOutput,
|
8
|
+
)
|
9
|
+
from .credentials import OpenAICredentials
|
10
|
+
|
11
|
+
__all__ = [
|
12
|
+
"OpenAIChatSkill",
|
13
|
+
"OpenAIInput",
|
14
|
+
"OpenAIParserSkill",
|
15
|
+
"OpenAIParserInput",
|
16
|
+
"OpenAIParserOutput",
|
17
|
+
"OpenAICredentials",
|
18
|
+
"OpenAIOutput",
|
19
|
+
]
|
@@ -0,0 +1,42 @@
|
|
1
|
+
from typing import Optional, TypeVar
|
2
|
+
from pydantic import Field
|
3
|
+
from .skills import OpenAIChatSkill, OpenAIInput, OpenAIOutput
|
4
|
+
from .credentials import OpenAICredentials
|
5
|
+
|
6
|
+
T = TypeVar("T", bound=OpenAIInput)
|
7
|
+
|
8
|
+
|
9
|
+
class ChineseAssistantInput(OpenAIInput):
|
10
|
+
"""Schema for Chinese Assistant input"""
|
11
|
+
|
12
|
+
user_input: str = Field(
|
13
|
+
..., description="User's input text (can be in any language)"
|
14
|
+
)
|
15
|
+
system_prompt: str = Field(
|
16
|
+
default="你是一个有帮助的助手。请用中文回答所有问题,即使问题是用其他语言问的。回答要准确、礼貌、专业。",
|
17
|
+
description="System prompt in Chinese",
|
18
|
+
)
|
19
|
+
model: str = Field(default="gpt-4o", description="OpenAI model to use")
|
20
|
+
max_tokens: int = Field(default=8096, description="Maximum tokens in response")
|
21
|
+
temperature: float = Field(
|
22
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
23
|
+
)
|
24
|
+
|
25
|
+
|
26
|
+
class ChineseAssistantSkill(OpenAIChatSkill):
|
27
|
+
"""Skill for Chinese language assistance"""
|
28
|
+
|
29
|
+
input_schema = ChineseAssistantInput
|
30
|
+
output_schema = OpenAIOutput
|
31
|
+
|
32
|
+
def __init__(self, credentials: Optional[OpenAICredentials] = None):
|
33
|
+
super().__init__(credentials)
|
34
|
+
|
35
|
+
def process(self, input_data: T) -> OpenAIOutput:
|
36
|
+
# Add language check to ensure response is in Chinese
|
37
|
+
if "你是" not in input_data.system_prompt:
|
38
|
+
input_data.system_prompt = (
|
39
|
+
"你是一个中文助手。" + input_data.system_prompt + "请用中文回答。"
|
40
|
+
)
|
41
|
+
|
42
|
+
return super().process(input_data)
|
@@ -0,0 +1,39 @@
|
|
1
|
+
from datetime import datetime, timedelta
|
2
|
+
from typing import Optional
|
3
|
+
from pydantic import Field, SecretStr, validator
|
4
|
+
from openai import OpenAI
|
5
|
+
|
6
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
7
|
+
|
8
|
+
|
9
|
+
class OpenAICredentials(BaseCredentials):
|
10
|
+
"""OpenAI API credentials with enhanced validation"""
|
11
|
+
|
12
|
+
openai_api_key: SecretStr = Field(..., description="OpenAI API key")
|
13
|
+
openai_organization_id: Optional[str] = Field(
|
14
|
+
None, description="OpenAI organization ID", pattern="^org-[A-Za-z0-9]{24}$"
|
15
|
+
)
|
16
|
+
|
17
|
+
_required_credentials = {"openai_api_key"}
|
18
|
+
|
19
|
+
@validator("openai_api_key")
|
20
|
+
def validate_api_key_format(cls, v: SecretStr) -> SecretStr:
|
21
|
+
key = v.get_secret_value()
|
22
|
+
if not key.startswith("sk-"):
|
23
|
+
raise ValueError("OpenAI API key must start with 'sk-'")
|
24
|
+
if len(key) < 40:
|
25
|
+
raise ValueError("OpenAI API key appears to be too short")
|
26
|
+
return v
|
27
|
+
|
28
|
+
async def validate_credentials(self) -> bool:
|
29
|
+
"""Validate credentials by making a test API call"""
|
30
|
+
try:
|
31
|
+
client = OpenAI(
|
32
|
+
api_key=self.openai_api_key.get_secret_value(),
|
33
|
+
organization=self.openai_organization_id,
|
34
|
+
)
|
35
|
+
# Make minimal API call to validate
|
36
|
+
await client.models.list(limit=1)
|
37
|
+
return True
|
38
|
+
except Exception as e:
|
39
|
+
raise CredentialValidationError(f"Invalid OpenAI credentials: {str(e)}")
|
@@ -0,0 +1,208 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any, TypeVar, Type
|
2
|
+
from pydantic import Field, BaseModel
|
3
|
+
from openai import OpenAI
|
4
|
+
import base64
|
5
|
+
from pathlib import Path
|
6
|
+
from loguru import logger
|
7
|
+
|
8
|
+
from airtrain.core.skills import Skill, ProcessingError
|
9
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
10
|
+
from .credentials import OpenAICredentials
|
11
|
+
|
12
|
+
|
13
|
+
class OpenAIInput(InputSchema):
|
14
|
+
"""Schema for OpenAI chat input"""
|
15
|
+
|
16
|
+
user_input: str = Field(..., description="User's input text")
|
17
|
+
system_prompt: str = Field(
|
18
|
+
default="You are a helpful assistant.",
|
19
|
+
description="System prompt to guide the model's behavior",
|
20
|
+
)
|
21
|
+
model: str = Field(default="gpt-4o", description="OpenAI model to use")
|
22
|
+
max_tokens: int = Field(default=8192, description="Maximum tokens in response")
|
23
|
+
temperature: float = Field(
|
24
|
+
default=0.2, description="Temperature for response generation", ge=0, le=1
|
25
|
+
)
|
26
|
+
images: Optional[List[Path]] = Field(
|
27
|
+
default=None,
|
28
|
+
description="Optional list of image paths to include in the message",
|
29
|
+
)
|
30
|
+
functions: Optional[List[Dict[str, Any]]] = Field(
|
31
|
+
default=None,
|
32
|
+
description="Optional function definitions for function calling",
|
33
|
+
)
|
34
|
+
function_call: Optional[str] = Field(
|
35
|
+
default=None,
|
36
|
+
description="Controls function calling behavior",
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
class OpenAIOutput(OutputSchema):
|
41
|
+
"""Schema for OpenAI chat output"""
|
42
|
+
|
43
|
+
response: str = Field(..., description="Model's response text")
|
44
|
+
used_model: str = Field(..., description="Model used for generation")
|
45
|
+
usage: Dict[str, Any] = Field(
|
46
|
+
default_factory=dict, description="Usage statistics from the API"
|
47
|
+
)
|
48
|
+
function_call: Optional[Dict[str, Any]] = Field(
|
49
|
+
default=None, description="Function call information if applicable"
|
50
|
+
)
|
51
|
+
|
52
|
+
|
53
|
+
class OpenAIChatSkill(Skill[OpenAIInput, OpenAIOutput]):
|
54
|
+
"""Skill for interacting with OpenAI's models"""
|
55
|
+
|
56
|
+
input_schema = OpenAIInput
|
57
|
+
output_schema = OpenAIOutput
|
58
|
+
|
59
|
+
def __init__(self, credentials: Optional[OpenAICredentials] = None):
|
60
|
+
"""Initialize the skill with optional credentials"""
|
61
|
+
super().__init__()
|
62
|
+
self.credentials = credentials or OpenAICredentials.from_env()
|
63
|
+
self.client = OpenAI(
|
64
|
+
api_key=self.credentials.openai_api_key.get_secret_value(),
|
65
|
+
organization=self.credentials.openai_organization_id,
|
66
|
+
)
|
67
|
+
|
68
|
+
def _encode_image(self, image_path: Path) -> Dict[str, Any]:
|
69
|
+
"""Convert image to base64 for API consumption"""
|
70
|
+
try:
|
71
|
+
if not image_path.exists():
|
72
|
+
raise FileNotFoundError(f"Image file not found: {image_path}")
|
73
|
+
|
74
|
+
with open(image_path, "rb") as img_file:
|
75
|
+
encoded = base64.b64encode(img_file.read()).decode()
|
76
|
+
return {
|
77
|
+
"type": "image_url",
|
78
|
+
"image_url": {"url": f"data:image/jpeg;base64,{encoded}"},
|
79
|
+
}
|
80
|
+
except Exception as e:
|
81
|
+
logger.error(f"Failed to encode image {image_path}: {str(e)}")
|
82
|
+
raise ProcessingError(f"Image encoding failed: {str(e)}")
|
83
|
+
|
84
|
+
def process(self, input_data: OpenAIInput) -> OpenAIOutput:
|
85
|
+
"""Process the input using OpenAI's API"""
|
86
|
+
try:
|
87
|
+
logger.info(f"Processing request with model {input_data.model}")
|
88
|
+
|
89
|
+
# Prepare message content
|
90
|
+
content = []
|
91
|
+
|
92
|
+
# Add text content
|
93
|
+
content.append({"type": "text", "text": input_data.user_input})
|
94
|
+
|
95
|
+
# Add images if provided
|
96
|
+
if input_data.images:
|
97
|
+
logger.debug(f"Processing {len(input_data.images)} images")
|
98
|
+
for image_path in input_data.images:
|
99
|
+
content.append(self._encode_image(image_path))
|
100
|
+
|
101
|
+
# Prepare messages
|
102
|
+
messages = [
|
103
|
+
{"role": "system", "content": input_data.system_prompt},
|
104
|
+
{"role": "user", "content": content},
|
105
|
+
]
|
106
|
+
|
107
|
+
# Create completion parameters
|
108
|
+
params = {
|
109
|
+
"model": input_data.model,
|
110
|
+
"messages": messages,
|
111
|
+
"temperature": input_data.temperature,
|
112
|
+
"max_tokens": input_data.max_tokens,
|
113
|
+
}
|
114
|
+
|
115
|
+
# Add function calling if provided
|
116
|
+
if input_data.functions:
|
117
|
+
params["functions"] = input_data.functions
|
118
|
+
params["function_call"] = input_data.function_call
|
119
|
+
|
120
|
+
# Create chat completion
|
121
|
+
response = self.client.chat.completions.create(**params)
|
122
|
+
|
123
|
+
# Extract function call if present
|
124
|
+
function_call = None
|
125
|
+
if response.choices[0].message.function_call:
|
126
|
+
function_call = response.choices[0].message.function_call.model_dump()
|
127
|
+
|
128
|
+
logger.success("Successfully processed OpenAI request")
|
129
|
+
|
130
|
+
return OpenAIOutput(
|
131
|
+
response=response.choices[0].message.content or "",
|
132
|
+
used_model=response.model,
|
133
|
+
usage={
|
134
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
135
|
+
"completion_tokens": response.usage.completion_tokens,
|
136
|
+
"total_tokens": response.usage.total_tokens,
|
137
|
+
},
|
138
|
+
function_call=function_call,
|
139
|
+
)
|
140
|
+
|
141
|
+
except Exception as e:
|
142
|
+
logger.exception(f"OpenAI processing failed: {str(e)}")
|
143
|
+
raise ProcessingError(f"OpenAI processing failed: {str(e)}")
|
144
|
+
|
145
|
+
|
146
|
+
ResponseT = TypeVar("ResponseT", bound=BaseModel)
|
147
|
+
|
148
|
+
|
149
|
+
class OpenAIParserInput(InputSchema):
|
150
|
+
"""Schema for OpenAI structured output input"""
|
151
|
+
|
152
|
+
user_input: str
|
153
|
+
system_prompt: str = "You are a helpful assistant that provides structured data."
|
154
|
+
model: str = "gpt-4o"
|
155
|
+
temperature: float = 0.7
|
156
|
+
max_tokens: Optional[int] = None
|
157
|
+
response_model: Type[ResponseT]
|
158
|
+
|
159
|
+
class Config:
|
160
|
+
arbitrary_types_allowed = True
|
161
|
+
|
162
|
+
|
163
|
+
class OpenAIParserOutput(OutputSchema):
|
164
|
+
"""Schema for OpenAI structured output"""
|
165
|
+
|
166
|
+
parsed_response: BaseModel
|
167
|
+
used_model: str
|
168
|
+
tokens_used: int
|
169
|
+
|
170
|
+
|
171
|
+
class OpenAIParserSkill(Skill[OpenAIParserInput, OpenAIParserOutput]):
|
172
|
+
"""Skill for getting structured responses from OpenAI"""
|
173
|
+
|
174
|
+
input_schema = OpenAIParserInput
|
175
|
+
output_schema = OpenAIParserOutput
|
176
|
+
|
177
|
+
def __init__(self, credentials: Optional[OpenAICredentials] = None):
|
178
|
+
"""Initialize the skill with optional credentials"""
|
179
|
+
super().__init__()
|
180
|
+
self.credentials = credentials or OpenAICredentials.from_env()
|
181
|
+
self.client = OpenAI(
|
182
|
+
api_key=self.credentials.openai_api_key.get_secret_value(),
|
183
|
+
organization=self.credentials.openai_organization_id,
|
184
|
+
)
|
185
|
+
|
186
|
+
def process(self, input_data: OpenAIParserInput) -> OpenAIParserOutput:
|
187
|
+
try:
|
188
|
+
# Use parse method instead of create
|
189
|
+
completion = self.client.beta.chat.completions.parse(
|
190
|
+
model=input_data.model,
|
191
|
+
messages=[
|
192
|
+
{"role": "system", "content": input_data.system_prompt},
|
193
|
+
{"role": "user", "content": input_data.user_input},
|
194
|
+
],
|
195
|
+
response_format=input_data.response_model,
|
196
|
+
)
|
197
|
+
|
198
|
+
if completion.choices[0].message.parsed is None:
|
199
|
+
raise ProcessingError("Failed to parse response")
|
200
|
+
|
201
|
+
return OpenAIParserOutput(
|
202
|
+
parsed_response=completion.choices[0].message.parsed,
|
203
|
+
used_model=completion.model,
|
204
|
+
tokens_used=completion.usage.total_tokens,
|
205
|
+
)
|
206
|
+
|
207
|
+
except Exception as e:
|
208
|
+
raise ProcessingError(f"OpenAI parsing failed: {str(e)}")
|
@@ -0,0 +1,20 @@
|
|
1
|
+
from pydantic import Field, SecretStr, HttpUrl
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
|
4
|
+
|
5
|
+
class SambanovaCredentials(BaseCredentials):
|
6
|
+
"""SambaNova credentials"""
|
7
|
+
|
8
|
+
api_key: SecretStr = Field(..., description="SambaNova API key")
|
9
|
+
endpoint_url: HttpUrl = Field(..., description="SambaNova API endpoint")
|
10
|
+
|
11
|
+
_required_credentials = {"api_key", "endpoint_url"}
|
12
|
+
|
13
|
+
async def validate_credentials(self) -> bool:
|
14
|
+
"""Validate SambaNova credentials"""
|
15
|
+
try:
|
16
|
+
# Implement SambaNova-specific validation
|
17
|
+
# This would depend on their API client implementation
|
18
|
+
return True
|
19
|
+
except Exception as e:
|
20
|
+
raise CredentialValidationError(f"Invalid SambaNova credentials: {str(e)}")
|
@@ -0,0 +1,41 @@
|
|
1
|
+
from typing import Optional, Dict, Any
|
2
|
+
from pydantic import Field
|
3
|
+
from airtrain.core.skills import Skill, ProcessingError
|
4
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
5
|
+
from .credentials import SambanovaCredentials
|
6
|
+
|
7
|
+
|
8
|
+
class SambanovaInput(InputSchema):
|
9
|
+
"""Schema for Sambanova input"""
|
10
|
+
|
11
|
+
user_input: str = Field(..., description="User's input text")
|
12
|
+
system_prompt: str = Field(
|
13
|
+
default="You are a helpful assistant.",
|
14
|
+
description="System prompt to guide the model's behavior",
|
15
|
+
)
|
16
|
+
model: str = Field(default="sambanova-llm", description="Sambanova model to use")
|
17
|
+
max_tokens: int = Field(default=1024, description="Maximum tokens in response")
|
18
|
+
temperature: float = Field(
|
19
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
class SambanovaOutput(OutputSchema):
|
24
|
+
"""Schema for Sambanova output"""
|
25
|
+
|
26
|
+
response: str = Field(..., description="Model's response text")
|
27
|
+
used_model: str = Field(..., description="Model used for generation")
|
28
|
+
usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
|
29
|
+
|
30
|
+
|
31
|
+
class SambanovaChatSkill(Skill[SambanovaInput, SambanovaOutput]):
|
32
|
+
"""Skill for Sambanova - Not Implemented"""
|
33
|
+
|
34
|
+
input_schema = SambanovaInput
|
35
|
+
output_schema = SambanovaOutput
|
36
|
+
|
37
|
+
def __init__(self, credentials: Optional[SambanovaCredentials] = None):
|
38
|
+
raise NotImplementedError("SambanovaChatSkill is not implemented yet")
|
39
|
+
|
40
|
+
def process(self, input_data: SambanovaInput) -> SambanovaOutput:
|
41
|
+
raise NotImplementedError("SambanovaChatSkill is not implemented yet")
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from pydantic import Field, SecretStr
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
import together
|
4
|
+
|
5
|
+
|
6
|
+
class TogetherAICredentials(BaseCredentials):
|
7
|
+
"""Together AI credentials"""
|
8
|
+
|
9
|
+
api_key: SecretStr = Field(..., description="Together AI API key")
|
10
|
+
|
11
|
+
_required_credentials = {"api_key"}
|
12
|
+
|
13
|
+
async def validate_credentials(self) -> bool:
|
14
|
+
"""Validate Together AI credentials"""
|
15
|
+
try:
|
16
|
+
together.api_key = self.api_key.get_secret_value()
|
17
|
+
await together.Models.list()
|
18
|
+
return True
|
19
|
+
except Exception as e:
|
20
|
+
raise CredentialValidationError(
|
21
|
+
f"Invalid Together AI credentials: {str(e)}"
|
22
|
+
)
|
@@ -0,0 +1,43 @@
|
|
1
|
+
from typing import Optional, Dict, Any
|
2
|
+
from pydantic import Field
|
3
|
+
from airtrain.core.skills import Skill, ProcessingError
|
4
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
5
|
+
from .credentials import TogetherAICredentials
|
6
|
+
|
7
|
+
|
8
|
+
class TogetherAIInput(InputSchema):
|
9
|
+
"""Schema for Together AI input"""
|
10
|
+
|
11
|
+
user_input: str = Field(..., description="User's input text")
|
12
|
+
system_prompt: str = Field(
|
13
|
+
default="You are a helpful assistant.",
|
14
|
+
description="System prompt to guide the model's behavior",
|
15
|
+
)
|
16
|
+
model: str = Field(
|
17
|
+
default="togethercomputer/llama-2-70b", description="Together AI model to use"
|
18
|
+
)
|
19
|
+
max_tokens: int = Field(default=1024, description="Maximum tokens in response")
|
20
|
+
temperature: float = Field(
|
21
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
class TogetherAIOutput(OutputSchema):
|
26
|
+
"""Schema for Together AI output"""
|
27
|
+
|
28
|
+
response: str = Field(..., description="Model's response text")
|
29
|
+
used_model: str = Field(..., description="Model used for generation")
|
30
|
+
usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
|
31
|
+
|
32
|
+
|
33
|
+
class TogetherAIChatSkill(Skill[TogetherAIInput, TogetherAIOutput]):
|
34
|
+
"""Skill for Together AI - Not Implemented"""
|
35
|
+
|
36
|
+
input_schema = TogetherAIInput
|
37
|
+
output_schema = TogetherAIOutput
|
38
|
+
|
39
|
+
def __init__(self, credentials: Optional[TogetherAICredentials] = None):
|
40
|
+
raise NotImplementedError("TogetherAIChatSkill is not implemented yet")
|
41
|
+
|
42
|
+
def process(self, input_data: TogetherAIInput) -> TogetherAIOutput:
|
43
|
+
raise NotImplementedError("TogetherAIChatSkill is not implemented yet")
|
@@ -0,0 +1,38 @@
|
|
1
|
+
airtrain/__init__.py,sha256=VKqqTKi_5O7itXPoDiYPbAAgnxMd5opQeJHsRPyUNTM,312
|
2
|
+
airtrain/core/__init__.py,sha256=9h7iKwTzZocCPc9bU6j8bA02BokteWIOcO1uaqGMcrk,254
|
3
|
+
airtrain/core/credentials.py,sha256=PgQotrQc46J5djidKnkK1znUv3fyNkUFDO-m2Kn_Gzo,4006
|
4
|
+
airtrain/core/schemas.py,sha256=MMXrDviC4gRea_QaPpbjgO--B_UKxnD7YrxqZOLJZZU,7003
|
5
|
+
airtrain/core/skills.py,sha256=LljalzeSHK5eQPTAOEAYc5D8Qn1kVSfiz9WgziTD5UM,4688
|
6
|
+
airtrain/integrations/__init__.py,sha256=Y-nxbWlPGy0Txt6qkRzLKQGapfAfL_q9W6U3ISTSg_I,1486
|
7
|
+
airtrain/integrations/anthropic/__init__.py,sha256=qwlWLDh1rEVizYFbW8430z-f1SxHio7_Gaw5cCTUtoo,274
|
8
|
+
airtrain/integrations/anthropic/credentials.py,sha256=hlTSw9HX66kYNaeQUtn0JjdZQBMNkzzFOJOoLOOzvcY,1246
|
9
|
+
airtrain/integrations/anthropic/skills.py,sha256=sT7dBYPVCsICYjgBjUlyyP84A8h9OkbgkslvKHk3Tjs,5273
|
10
|
+
airtrain/integrations/aws/__init__.py,sha256=3x7v2NxpAfI-U-YgwQeH5PtsmUrNLPMfLyUGFLiBjbs,155
|
11
|
+
airtrain/integrations/aws/credentials.py,sha256=nN-daKAl7qOb_VdRpsThG8gN5GeSUkx-ji5E_gF_vYw,1444
|
12
|
+
airtrain/integrations/aws/skills.py,sha256=TQiMXeXRRcJ14fe8Xi7Uk20iS6_INbcznuLGtMorcKY,3870
|
13
|
+
airtrain/integrations/cerebras/__init__.py,sha256=zAD-qV38OzHhMCz1z-NvjjqcYEhURbm8RWTOKHNqbew,174
|
14
|
+
airtrain/integrations/cerebras/credentials.py,sha256=IFkn8LxMAaOpvEWXDpb94VQGtqcDxQ7rZHKH-tX4Nuw,884
|
15
|
+
airtrain/integrations/cerebras/skills.py,sha256=O9vwFzvv_tUOwFOVE8CszAQEac711eVYVUj_8dVMTpc,1596
|
16
|
+
airtrain/integrations/google/__init__.py,sha256=INZFNOcNebz3m-Ggk07ZjmX0kNHIbTe_St9gBlZBki8,176
|
17
|
+
airtrain/integrations/google/credentials.py,sha256=yyl-MWl06wr4SWvcvJGSpJ3hGTz21ByrRSr_3np5cbU,1030
|
18
|
+
airtrain/integrations/google/skills.py,sha256=uwmgetl5Ien7fLOA5HIZdqoL6AZnexFDyzfsrGuJ1RU,1606
|
19
|
+
airtrain/integrations/groq/__init__.py,sha256=B_X2fXbsJfFD6GquKeVCsEJjwd9Ygbq1uEHlV4Jy7YE,154
|
20
|
+
airtrain/integrations/groq/credentials.py,sha256=A8-VIyoZTkHFQb-O-lmu-UrgaLZ3hfWfzzigkYteESk,829
|
21
|
+
airtrain/integrations/groq/skills.py,sha256=Qy6SBAb19SzOFuqgcLyzdyRBp4D7jKqsEeJ6UTDaqMM,1528
|
22
|
+
airtrain/integrations/ollama/__init__.py,sha256=zMHBsGzViVrvxAeJmfq6r-ZfSE6Dy5QcKLhe4d5fEcM,164
|
23
|
+
airtrain/integrations/ollama/credentials.py,sha256=D7O4kUAb_VHs5s1ncUN9Ezhu5PvLfgj3RifAkB9sEZk,940
|
24
|
+
airtrain/integrations/ollama/skills.py,sha256=M_Un8D5VJ5XtPEq9IClzqV3jCPBoFTSm2ve6EO8W2JU,1556
|
25
|
+
airtrain/integrations/openai/__init__.py,sha256=K-NY2_T1T6SEOgkpbUA55cWvK2nr2NOJgLCqmmtaCno,371
|
26
|
+
airtrain/integrations/openai/chinese_assistant.py,sha256=MMhv4NBOoEQ0O22ZZtP255rd5ajHC9l6FPWIjpqxBOA,1581
|
27
|
+
airtrain/integrations/openai/credentials.py,sha256=NfRyp1QgEtgm8cxt2-BOLq-6d0X-Pcm80NnfHM8p0FY,1470
|
28
|
+
airtrain/integrations/openai/skills.py,sha256=Olg9-6f_p2XgkVwwcB9tvjAMApmM2EK81i8LP4qVVvs,7676
|
29
|
+
airtrain/integrations/sambanova/__init__.py,sha256=dp_263iOckM_J9pOEvyqpf3FrejD6-_x33r0edMCTe0,179
|
30
|
+
airtrain/integrations/sambanova/credentials.py,sha256=U36RAEIPNuwo-vTrt3U9kkkj2GfdqSclA1ttOYHxS-w,784
|
31
|
+
airtrain/integrations/sambanova/skills.py,sha256=Po1ur_QFwzVIugbkk2mt73WdXDz_Gr9ASlUc9Y12Kok,1614
|
32
|
+
airtrain/integrations/together/__init__.py,sha256=we4KXn_pUs6Dxo3QcB-t40BSRraQFdKg2nXw7yi2FjM,185
|
33
|
+
airtrain/integrations/together/credentials.py,sha256=y5M6ZQrfYJLJbClxEasq4HaVyZM0l5lFshwVP6jq2E4,720
|
34
|
+
airtrain/integrations/together/skills.py,sha256=YMOULyk2TX32rCjhxK29e4ehn8iIzMXpg3xmdYtuyQQ,1664
|
35
|
+
airtrain-0.1.8.dist-info/METADATA,sha256=o8NbBrkTDmyZBRZJQ2fxti7ffsVjD8U1F2R-w0aSbpw,4380
|
36
|
+
airtrain-0.1.8.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
37
|
+
airtrain-0.1.8.dist-info/top_level.txt,sha256=cFWW1vY6VMCb3AGVdz6jBDpZ65xxBRSqlsPyySxTkxY,9
|
38
|
+
airtrain-0.1.8.dist-info/RECORD,,
|
airtrain-0.1.6.dist-info/RECORD
DELETED
@@ -1,10 +0,0 @@
|
|
1
|
-
airtrain/__init__.py,sha256=t0n2IItXdHzaLZf_1zSUmSf3buF-8Y4FJol1LAzSclk,312
|
2
|
-
airtrain/core/__init__.py,sha256=9h7iKwTzZocCPc9bU6j8bA02BokteWIOcO1uaqGMcrk,254
|
3
|
-
airtrain/core/credentials.py,sha256=PgQotrQc46J5djidKnkK1znUv3fyNkUFDO-m2Kn_Gzo,4006
|
4
|
-
airtrain/core/schemas.py,sha256=MMXrDviC4gRea_QaPpbjgO--B_UKxnD7YrxqZOLJZZU,7003
|
5
|
-
airtrain/core/skills.py,sha256=LljalzeSHK5eQPTAOEAYc5D8Qn1kVSfiz9WgziTD5UM,4688
|
6
|
-
airtrain/integrations/__init__.py,sha256=PRKI_A-KE307C4lpXgFAsZA2oFtTl1kt_4CrRUF2rpU,832
|
7
|
-
airtrain-0.1.6.dist-info/METADATA,sha256=C0hcwg_Am0cUqrv8vKfNJ_AeMdBfGhxrSEmzDbjJA7o,4380
|
8
|
-
airtrain-0.1.6.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
9
|
-
airtrain-0.1.6.dist-info/top_level.txt,sha256=cFWW1vY6VMCb3AGVdz6jBDpZ65xxBRSqlsPyySxTkxY,9
|
10
|
-
airtrain-0.1.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|