airtrain 0.1.3__tar.gz → 0.1.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain-0.1.5/.gitignore +183 -0
- airtrain-0.1.5/EXPERIMENTS/integrations_examples/anthropic_with_image.py +43 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/PKG-INFO +1 -1
- {airtrain-0.1.3 → airtrain-0.1.5}/airtrain/__init__.py +1 -1
- airtrain-0.1.5/airtrain/core/__pycache__/credentials.cpython-310.pyc +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/airtrain/core/credentials.py +3 -31
- airtrain-0.1.5/airtrain/integrations/__init__.py +26 -0
- airtrain-0.1.5/airtrain/integrations/anthropic/credentials.py +32 -0
- airtrain-0.1.5/airtrain/integrations/anthropic/skills.py +135 -0
- airtrain-0.1.5/airtrain/integrations/aws/credentials.py +36 -0
- airtrain-0.1.5/airtrain/integrations/cerebras/credentials.py +22 -0
- airtrain-0.1.5/airtrain/integrations/google/credentials.py +27 -0
- airtrain-0.1.5/airtrain/integrations/groq/credentials.py +24 -0
- airtrain-0.1.5/airtrain/integrations/ollama/credentials.py +26 -0
- airtrain-0.1.5/airtrain/integrations/openai/chinese_assistant.py +42 -0
- airtrain-0.1.5/airtrain/integrations/openai/credentials.py +39 -0
- airtrain-0.1.5/airtrain/integrations/openai/skills.py +208 -0
- airtrain-0.1.5/airtrain/integrations/sambanova/credentials.py +20 -0
- airtrain-0.1.5/airtrain/integrations/together/credentials.py +22 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/airtrain.egg-info/PKG-INFO +1 -1
- airtrain-0.1.5/airtrain.egg-info/SOURCES.txt +60 -0
- airtrain-0.1.5/examples/creating-skills/anthropic_skills_usage.py +56 -0
- airtrain-0.1.5/examples/creating-skills/chinese_anthropic_assistant.py +44 -0
- airtrain-0.1.5/examples/creating-skills/chinese_anthropic_usage.py +60 -0
- airtrain-0.1.5/examples/creating-skills/chinese_assistant_usage.py +45 -0
- airtrain-0.1.5/examples/creating-skills/icon128.png +0 -0
- airtrain-0.1.5/examples/creating-skills/icon16.png +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/examples/creating-skills/openai_skills.py +6 -6
- airtrain-0.1.5/examples/creating-skills/openai_skills_usage.py +175 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/examples/credentials_usage.py +0 -1
- airtrain-0.1.5/examples/images/quantum-circuit.png +0 -0
- airtrain-0.1.5/scripts/release.py +60 -0
- airtrain-0.1.3/airtrain/core/__pycache__/credentials.cpython-310.pyc +0 -0
- airtrain-0.1.3/airtrain.egg-info/SOURCES.txt +0 -36
- {airtrain-0.1.3 → airtrain-0.1.5}/.flake8 +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/.github/workflows/publish.yml +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/.mypy.ini +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/.pre-commit-config.yaml +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/.vscode/extensions.json +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/.vscode/launch.json +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/.vscode/settings.json +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/EXPERIMENTS/schema_exps/pydantic_schemas.py +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/README.md +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/airtrain/core/__init__.py +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/airtrain/core/__pycache__/schemas.cpython-310.pyc +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/airtrain/core/__pycache__/skills.cpython-310.pyc +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/airtrain/core/schemas.py +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/airtrain/core/skills.py +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/airtrain.egg-info/dependency_links.txt +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/airtrain.egg-info/requires.txt +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/airtrain.egg-info/top_level.txt +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/examples/creating-skills/image1.jpg +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/examples/creating-skills/image2.jpg +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/examples/creating-skills/openai_structured_skills.py +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/examples/schema_usage.py +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/examples/skill_usage.py +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/pyproject.toml +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/scripts/build.sh +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/scripts/bump_version.py +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/scripts/publish.sh +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/services/firebase_service.py +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/services/openai_service.py +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/setup.cfg +0 -0
- {airtrain-0.1.3 → airtrain-0.1.5}/setup.py +0 -0
@@ -0,0 +1,183 @@
|
|
1
|
+
package
|
2
|
+
.env
|
3
|
+
.mypy_cache
|
4
|
+
firebaseadmin.json
|
5
|
+
**pyc=
|
6
|
+
50mb_test3.bin
|
7
|
+
*bin
|
8
|
+
**bin
|
9
|
+
token.pickle
|
10
|
+
temp_workspace
|
11
|
+
temp*
|
12
|
+
|
13
|
+
# Byte-compiled / optimized / DLL files
|
14
|
+
__pycache__/
|
15
|
+
*.py[cod]
|
16
|
+
*$py.class
|
17
|
+
|
18
|
+
# C extensions
|
19
|
+
*.so
|
20
|
+
|
21
|
+
# Distribution / packaging
|
22
|
+
.Python
|
23
|
+
build/
|
24
|
+
develop-eggs/
|
25
|
+
dist/
|
26
|
+
downloads/
|
27
|
+
eggs/
|
28
|
+
.eggs/
|
29
|
+
lib/
|
30
|
+
lib64/
|
31
|
+
parts/
|
32
|
+
sdist/
|
33
|
+
var/
|
34
|
+
wheels/
|
35
|
+
share/python-wheels/
|
36
|
+
*.egg-info/
|
37
|
+
.installed.cfg
|
38
|
+
*.egg
|
39
|
+
MANIFEST
|
40
|
+
|
41
|
+
# PyInstaller
|
42
|
+
# Usually these files are written by a python script from a template
|
43
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
44
|
+
*.manifest
|
45
|
+
*.spec
|
46
|
+
|
47
|
+
# Installer logs
|
48
|
+
pip-log.txt
|
49
|
+
pip-delete-this-directory.txt
|
50
|
+
|
51
|
+
# Unit test / coverage reports
|
52
|
+
htmlcov/
|
53
|
+
.tox/
|
54
|
+
.nox/
|
55
|
+
.coverage
|
56
|
+
.coverage.*
|
57
|
+
.cache
|
58
|
+
nosetests.xml
|
59
|
+
coverage.xml
|
60
|
+
*.cover
|
61
|
+
*.py,cover
|
62
|
+
.hypothesis/
|
63
|
+
.pytest_cache/
|
64
|
+
cover/
|
65
|
+
|
66
|
+
# Translations
|
67
|
+
*.mo
|
68
|
+
*.pot
|
69
|
+
|
70
|
+
# Django stuff:
|
71
|
+
*.log
|
72
|
+
local_settings.py
|
73
|
+
db.sqlite3
|
74
|
+
db.sqlite3-journal
|
75
|
+
|
76
|
+
# Flask stuff:
|
77
|
+
instance/
|
78
|
+
.webassets-cache
|
79
|
+
|
80
|
+
# Scrapy stuff:
|
81
|
+
.scrapy
|
82
|
+
|
83
|
+
# Sphinx documentation
|
84
|
+
docs/_build/
|
85
|
+
|
86
|
+
# PyBuilder
|
87
|
+
.pybuilder/
|
88
|
+
target/
|
89
|
+
|
90
|
+
# Jupyter Notebook
|
91
|
+
.ipynb_checkpoints
|
92
|
+
|
93
|
+
# IPython
|
94
|
+
profile_default/
|
95
|
+
ipython_config.py
|
96
|
+
|
97
|
+
# pyenv
|
98
|
+
# For a library or package, you might want to ignore these files since the code is
|
99
|
+
# intended to run in multiple environments; otherwise, check them in:
|
100
|
+
# .python-version
|
101
|
+
|
102
|
+
# pipenv
|
103
|
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
104
|
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
105
|
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
106
|
+
# install all needed dependencies.
|
107
|
+
#Pipfile.lock
|
108
|
+
|
109
|
+
# UV
|
110
|
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
111
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
112
|
+
# commonly ignored for libraries.
|
113
|
+
#uv.lock
|
114
|
+
|
115
|
+
# poetry
|
116
|
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
117
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
118
|
+
# commonly ignored for libraries.
|
119
|
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
120
|
+
#poetry.lock
|
121
|
+
|
122
|
+
# pdm
|
123
|
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
124
|
+
#pdm.lock
|
125
|
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
126
|
+
# in version control.
|
127
|
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
128
|
+
.pdm.toml
|
129
|
+
.pdm-python
|
130
|
+
.pdm-build/
|
131
|
+
|
132
|
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
133
|
+
__pypackages__/
|
134
|
+
|
135
|
+
# Celery stuff
|
136
|
+
celerybeat-schedule
|
137
|
+
celerybeat.pid
|
138
|
+
|
139
|
+
# SageMath parsed files
|
140
|
+
*.sage.py
|
141
|
+
|
142
|
+
# Environments
|
143
|
+
.env
|
144
|
+
.venv
|
145
|
+
env/
|
146
|
+
venv/
|
147
|
+
ENV/
|
148
|
+
env.bak/
|
149
|
+
venv.bak/
|
150
|
+
|
151
|
+
# Spyder project settings
|
152
|
+
.spyderproject
|
153
|
+
.spyproject
|
154
|
+
|
155
|
+
# Rope project settings
|
156
|
+
.ropeproject
|
157
|
+
|
158
|
+
# mkdocs documentation
|
159
|
+
/site
|
160
|
+
|
161
|
+
# mypy
|
162
|
+
.mypy_cache/
|
163
|
+
.dmypy.json
|
164
|
+
dmypy.json
|
165
|
+
|
166
|
+
# Pyre type checker
|
167
|
+
.pyre/
|
168
|
+
|
169
|
+
# pytype static type analyzer
|
170
|
+
.pytype/
|
171
|
+
|
172
|
+
# Cython debug symbols
|
173
|
+
cython_debug/
|
174
|
+
|
175
|
+
# PyCharm
|
176
|
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
177
|
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
178
|
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
179
|
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
180
|
+
#.idea/
|
181
|
+
|
182
|
+
# PyPI configuration file
|
183
|
+
.pypirc
|
@@ -0,0 +1,43 @@
|
|
1
|
+
import anthropic
|
2
|
+
|
3
|
+
client = anthropic.Anthropic(
|
4
|
+
# defaults to os.environ.get("ANTHROPIC_API_KEY")
|
5
|
+
api_key="my_api_key",
|
6
|
+
)
|
7
|
+
|
8
|
+
# Replace placeholders like {{PR_DESCRIPTION}} with real values,
|
9
|
+
# because the SDK does not support variables.
|
10
|
+
message = client.messages.create(
|
11
|
+
model="claude-3-5-sonnet-20241022",
|
12
|
+
max_tokens=8192,
|
13
|
+
temperature=0,
|
14
|
+
system="You are an experienced software engineer tasked with reviewing a GitHub Pull Request (PR). Your goal is to analyze the code quality and suggest improvements. Follow these steps carefully:\n\n1. Review the PR description:\n<PR_DESCRIPTION>\n{{PR_DESCRIPTION}}\n</PR_DESCRIPTION>\n\n2. Examine the code changes:\n<CODE_CHANGES>\n{{CODE_CHANGES}}\n</CODE_CHANGES>\n\n3. Consider any existing comments:\n<EXISTING_COMMENTS>\n{{EXISTING_COMMENTS}}\n</EXISTING_COMMENTS>\n\n4. Analyze the code quality:\n a. Check for adherence to coding standards and best practices\n b. Evaluate code readability and maintainability\n c. Assess performance implications\n d. Look for potential bugs or edge cases\n e. Consider security implications\n\n5. Suggest improvements:\n a. Identify areas where the code can be optimized or simplified\n b. Propose alternative approaches if applicable\n c. Recommend additional tests or error handling if needed\n\n6. Format your response as follows:\n <code_review>\n <quality_analysis>\n Provide a detailed analysis of the code quality, addressing points 4a-4e.\n </quality_analysis>\n\n <improvement_suggestions>\n List your suggestions for improvement, addressing points 5a-5c. Number each suggestion.\n </improvement_suggestions>\n\n <summary>\n Provide a brief summary of your overall assessment and key recommendations.\n </summary>\n </code_review>\n\nRemember to be constructive and specific in your feedback. Use code snippets or pseudocode to illustrate your suggestions when appropriate. If you need clarification on any part of the code or PR description, state your assumptions clearly.\n\nDo not comment on aspects unrelated to code quality or potential improvements. Focus solely on the technical aspects of the code changes presented.",
|
15
|
+
messages=[
|
16
|
+
{
|
17
|
+
"role": "user",
|
18
|
+
"content": [
|
19
|
+
{
|
20
|
+
"type": "text",
|
21
|
+
"text": "\nAnalyze the above examples and give me some updates. Analyze this image as well.\n\nOne more image is this. Can you image this as well.",
|
22
|
+
},
|
23
|
+
{
|
24
|
+
"type": "image",
|
25
|
+
"source": {
|
26
|
+
"type": "base64",
|
27
|
+
"media_type": "image/jpeg",
|
28
|
+
"data": "<base64_encoded_image>",
|
29
|
+
},
|
30
|
+
},
|
31
|
+
{
|
32
|
+
"type": "image",
|
33
|
+
"source": {
|
34
|
+
"type": "base64",
|
35
|
+
"media_type": "image/jpeg",
|
36
|
+
"data": "<base64_encoded_image>",
|
37
|
+
},
|
38
|
+
},
|
39
|
+
],
|
40
|
+
}
|
41
|
+
],
|
42
|
+
)
|
43
|
+
print(message.content)
|
Binary file
|
@@ -5,7 +5,7 @@ from pathlib import Path
|
|
5
5
|
from abc import ABC, abstractmethod
|
6
6
|
import dotenv
|
7
7
|
from pydantic import BaseModel, Field, SecretStr
|
8
|
-
import yaml
|
8
|
+
import yaml
|
9
9
|
|
10
10
|
|
11
11
|
class CredentialError(Exception):
|
@@ -100,7 +100,7 @@ class BaseCredentials(BaseModel):
|
|
100
100
|
else:
|
101
101
|
raise ValueError(f"Unsupported file format: {file_path.suffix}")
|
102
102
|
|
103
|
-
def validate_credentials(self) ->
|
103
|
+
async def validate_credentials(self) -> bool:
|
104
104
|
"""Validate that all required credentials are present"""
|
105
105
|
missing = []
|
106
106
|
for field_name in self._required_credentials:
|
@@ -114,6 +114,7 @@ class BaseCredentials(BaseModel):
|
|
114
114
|
raise CredentialValidationError(
|
115
115
|
f"Missing required credentials: {', '.join(missing)}"
|
116
116
|
)
|
117
|
+
return True
|
117
118
|
|
118
119
|
def clear_from_env(self) -> None:
|
119
120
|
"""Remove credentials from environment variables"""
|
@@ -122,32 +123,3 @@ class BaseCredentials(BaseModel):
|
|
122
123
|
if env_key in os.environ:
|
123
124
|
del os.environ[env_key]
|
124
125
|
self._loaded = False
|
125
|
-
|
126
|
-
|
127
|
-
class OpenAICredentials(BaseCredentials):
|
128
|
-
"""OpenAI API credentials"""
|
129
|
-
|
130
|
-
api_key: SecretStr = Field(..., description="OpenAI API key")
|
131
|
-
organization_id: Optional[str] = Field(None, description="OpenAI organization ID")
|
132
|
-
|
133
|
-
_required_credentials = {"api_key"}
|
134
|
-
|
135
|
-
|
136
|
-
class AWSCredentials(BaseCredentials):
|
137
|
-
"""AWS credentials"""
|
138
|
-
|
139
|
-
aws_access_key_id: SecretStr
|
140
|
-
aws_secret_access_key: SecretStr
|
141
|
-
aws_region: str = "us-east-1"
|
142
|
-
aws_session_token: Optional[SecretStr] = None
|
143
|
-
|
144
|
-
_required_credentials = {"aws_access_key_id", "aws_secret_access_key"}
|
145
|
-
|
146
|
-
|
147
|
-
class GoogleCloudCredentials(BaseCredentials):
|
148
|
-
"""Google Cloud credentials"""
|
149
|
-
|
150
|
-
project_id: str
|
151
|
-
service_account_key: SecretStr
|
152
|
-
|
153
|
-
_required_credentials = {"project_id", "service_account_key"}
|
@@ -0,0 +1,26 @@
|
|
1
|
+
"""Airtrain integrations package"""
|
2
|
+
|
3
|
+
from .openai.credentials import OpenAICredentials
|
4
|
+
from .aws.credentials import AWSCredentials
|
5
|
+
from .google.credentials import GoogleCloudCredentials
|
6
|
+
from .anthropic.credentials import AnthropicCredentials
|
7
|
+
from .groq.credentials import GroqCredentials
|
8
|
+
from .together.credentials import TogetherAICredentials
|
9
|
+
from .ollama.credentials import OllamaCredentials
|
10
|
+
from .sambanova.credentials import SambanovaCredentials
|
11
|
+
from .cerebras.credentials import CerebrasCredentials
|
12
|
+
|
13
|
+
from .anthropic.skills import AnthropicChatSkill
|
14
|
+
|
15
|
+
__all__ = [
|
16
|
+
"OpenAICredentials",
|
17
|
+
"AWSCredentials",
|
18
|
+
"GoogleCloudCredentials",
|
19
|
+
"AnthropicCredentials",
|
20
|
+
"AnthropicChatSkill",
|
21
|
+
"GroqCredentials",
|
22
|
+
"TogetherAICredentials",
|
23
|
+
"OllamaCredentials",
|
24
|
+
"SambanovaCredentials",
|
25
|
+
"CerebrasCredentials",
|
26
|
+
]
|
@@ -0,0 +1,32 @@
|
|
1
|
+
from pydantic import Field, SecretStr, validator
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
from anthropic import Anthropic
|
4
|
+
|
5
|
+
|
6
|
+
class AnthropicCredentials(BaseCredentials):
|
7
|
+
"""Anthropic API credentials"""
|
8
|
+
|
9
|
+
anthropic_api_key: SecretStr = Field(..., description="Anthropic API key")
|
10
|
+
version: str = Field(default="2023-06-01", description="API Version")
|
11
|
+
|
12
|
+
_required_credentials = {"anthropic_api_key"}
|
13
|
+
|
14
|
+
@validator("anthropic_api_key")
|
15
|
+
def validate_api_key_format(cls, v: SecretStr) -> SecretStr:
|
16
|
+
key = v.get_secret_value()
|
17
|
+
if not key.startswith("sk-ant-"):
|
18
|
+
raise ValueError("Anthropic API key must start with 'sk-ant-'")
|
19
|
+
return v
|
20
|
+
|
21
|
+
async def validate_credentials(self) -> bool:
|
22
|
+
"""Validate Anthropic credentials"""
|
23
|
+
try:
|
24
|
+
client = Anthropic(api_key=self.anthropic_api_key.get_secret_value())
|
25
|
+
client.messages.create(
|
26
|
+
model="claude-3-opus-20240229",
|
27
|
+
max_tokens=1,
|
28
|
+
messages=[{"role": "user", "content": "Hi"}],
|
29
|
+
)
|
30
|
+
return True
|
31
|
+
except Exception as e:
|
32
|
+
raise CredentialValidationError(f"Invalid Anthropic credentials: {str(e)}")
|
@@ -0,0 +1,135 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any
|
2
|
+
from pydantic import Field
|
3
|
+
from anthropic import Anthropic
|
4
|
+
import base64
|
5
|
+
from pathlib import Path
|
6
|
+
from loguru import logger
|
7
|
+
|
8
|
+
from airtrain.core.skills import Skill, ProcessingError
|
9
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
10
|
+
from .credentials import AnthropicCredentials
|
11
|
+
|
12
|
+
|
13
|
+
class AnthropicInput(InputSchema):
|
14
|
+
"""Schema for Anthropic chat input"""
|
15
|
+
|
16
|
+
user_input: str = Field(..., description="User's input text")
|
17
|
+
system_prompt: str = Field(
|
18
|
+
default="You are a helpful assistant.",
|
19
|
+
description="System prompt to guide the model's behavior",
|
20
|
+
)
|
21
|
+
model: str = Field(
|
22
|
+
default="claude-3-opus-20240229", description="Anthropic model to use"
|
23
|
+
)
|
24
|
+
max_tokens: int = Field(default=1024, description="Maximum tokens in response")
|
25
|
+
temperature: float = Field(
|
26
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
27
|
+
)
|
28
|
+
images: Optional[List[Path]] = Field(
|
29
|
+
default=None,
|
30
|
+
description="Optional list of image paths to include in the message",
|
31
|
+
)
|
32
|
+
|
33
|
+
|
34
|
+
class AnthropicOutput(OutputSchema):
|
35
|
+
"""Schema for Anthropic chat output"""
|
36
|
+
|
37
|
+
response: str = Field(..., description="Model's response text")
|
38
|
+
used_model: str = Field(..., description="Model used for generation")
|
39
|
+
usage: Dict[str, Any] = Field(
|
40
|
+
default_factory=dict, description="Usage statistics from the API"
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
class AnthropicChatSkill(Skill[AnthropicInput, AnthropicOutput]):
|
45
|
+
"""Skill for interacting with Anthropic's Claude models"""
|
46
|
+
|
47
|
+
input_schema = AnthropicInput
|
48
|
+
output_schema = AnthropicOutput
|
49
|
+
|
50
|
+
def __init__(self, credentials: Optional[AnthropicCredentials] = None):
|
51
|
+
"""Initialize the skill with optional credentials"""
|
52
|
+
super().__init__()
|
53
|
+
self.credentials = credentials or AnthropicCredentials.from_env()
|
54
|
+
self.client = Anthropic(
|
55
|
+
api_key=self.credentials.anthropic_api_key.get_secret_value()
|
56
|
+
)
|
57
|
+
|
58
|
+
def _encode_image(self, image_path: Path) -> Dict[str, Any]:
|
59
|
+
"""Convert image to base64 for API consumption"""
|
60
|
+
try:
|
61
|
+
if not image_path.exists():
|
62
|
+
raise FileNotFoundError(f"Image file not found: {image_path}")
|
63
|
+
|
64
|
+
with open(image_path, "rb") as img_file:
|
65
|
+
encoded = base64.b64encode(img_file.read()).decode()
|
66
|
+
return {
|
67
|
+
"type": "image",
|
68
|
+
"source": {
|
69
|
+
"type": "base64",
|
70
|
+
"media_type": f"image/{image_path.suffix[1:]}",
|
71
|
+
"data": encoded,
|
72
|
+
},
|
73
|
+
}
|
74
|
+
except Exception as e:
|
75
|
+
logger.error(f"Failed to encode image {image_path}: {str(e)}")
|
76
|
+
raise ProcessingError(f"Image encoding failed: {str(e)}")
|
77
|
+
|
78
|
+
def process(self, input_data: AnthropicInput) -> AnthropicOutput:
|
79
|
+
"""Process the input using Anthropic's API"""
|
80
|
+
try:
|
81
|
+
logger.info(f"Processing request with model {input_data.model}")
|
82
|
+
|
83
|
+
# Prepare message content
|
84
|
+
content = []
|
85
|
+
|
86
|
+
# Add text content
|
87
|
+
content.append({"type": "text", "text": input_data.user_input})
|
88
|
+
|
89
|
+
# Add images if provided
|
90
|
+
if input_data.images:
|
91
|
+
logger.debug(f"Processing {len(input_data.images)} images")
|
92
|
+
for image_path in input_data.images:
|
93
|
+
content.append(self._encode_image(image_path))
|
94
|
+
|
95
|
+
# Create message
|
96
|
+
response = self.client.messages.create(
|
97
|
+
model=input_data.model,
|
98
|
+
max_tokens=input_data.max_tokens,
|
99
|
+
temperature=input_data.temperature,
|
100
|
+
system=input_data.system_prompt,
|
101
|
+
messages=[{"role": "user", "content": content}],
|
102
|
+
)
|
103
|
+
|
104
|
+
# Validate response content
|
105
|
+
if not response.content:
|
106
|
+
logger.error("Empty response received from Anthropic API")
|
107
|
+
raise ProcessingError("Empty response received from Anthropic API")
|
108
|
+
|
109
|
+
if not isinstance(response.content, list) or not response.content:
|
110
|
+
logger.error("Invalid response format from Anthropic API")
|
111
|
+
raise ProcessingError("Invalid response format from Anthropic API")
|
112
|
+
|
113
|
+
first_content = response.content[0]
|
114
|
+
if not hasattr(first_content, "text"):
|
115
|
+
logger.error("Response content does not contain text")
|
116
|
+
raise ProcessingError("Response content does not contain text")
|
117
|
+
|
118
|
+
logger.success("Successfully processed Anthropic request")
|
119
|
+
|
120
|
+
# Create output
|
121
|
+
return AnthropicOutput(
|
122
|
+
response=first_content.text,
|
123
|
+
used_model=response.model,
|
124
|
+
usage={
|
125
|
+
"input_tokens": response.usage.input_tokens,
|
126
|
+
"output_tokens": response.usage.output_tokens,
|
127
|
+
},
|
128
|
+
)
|
129
|
+
|
130
|
+
except ProcessingError:
|
131
|
+
# Re-raise ProcessingError without modification
|
132
|
+
raise
|
133
|
+
except Exception as e:
|
134
|
+
logger.exception(f"Anthropic processing failed: {str(e)}")
|
135
|
+
raise ProcessingError(f"Anthropic processing failed: {str(e)}")
|
@@ -0,0 +1,36 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
from pydantic import Field, SecretStr
|
3
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
4
|
+
import boto3
|
5
|
+
|
6
|
+
|
7
|
+
class AWSCredentials(BaseCredentials):
|
8
|
+
"""AWS credentials"""
|
9
|
+
|
10
|
+
aws_access_key_id: SecretStr = Field(..., description="AWS Access Key ID")
|
11
|
+
aws_secret_access_key: SecretStr = Field(..., description="AWS Secret Access Key")
|
12
|
+
aws_region: str = Field(default="us-east-1", description="AWS Region")
|
13
|
+
aws_session_token: Optional[SecretStr] = Field(
|
14
|
+
None, description="AWS Session Token"
|
15
|
+
)
|
16
|
+
|
17
|
+
_required_credentials = {"aws_access_key_id", "aws_secret_access_key"}
|
18
|
+
|
19
|
+
async def validate_credentials(self) -> bool:
|
20
|
+
"""Validate AWS credentials by making a test API call"""
|
21
|
+
try:
|
22
|
+
session = boto3.Session(
|
23
|
+
aws_access_key_id=self.aws_access_key_id.get_secret_value(),
|
24
|
+
aws_secret_access_key=self.aws_secret_access_key.get_secret_value(),
|
25
|
+
aws_session_token=(
|
26
|
+
self.aws_session_token.get_secret_value()
|
27
|
+
if self.aws_session_token
|
28
|
+
else None
|
29
|
+
),
|
30
|
+
region_name=self.aws_region,
|
31
|
+
)
|
32
|
+
sts = session.client("sts")
|
33
|
+
sts.get_caller_identity()
|
34
|
+
return True
|
35
|
+
except Exception as e:
|
36
|
+
raise CredentialValidationError(f"Invalid AWS credentials: {str(e)}")
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from pydantic import Field, SecretStr, HttpUrl
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
|
6
|
+
class CerebrasCredentials(BaseCredentials):
|
7
|
+
"""Cerebras credentials"""
|
8
|
+
|
9
|
+
api_key: SecretStr = Field(..., description="Cerebras API key")
|
10
|
+
endpoint_url: HttpUrl = Field(..., description="Cerebras API endpoint")
|
11
|
+
project_id: Optional[str] = Field(None, description="Cerebras Project ID")
|
12
|
+
|
13
|
+
_required_credentials = {"api_key", "endpoint_url"}
|
14
|
+
|
15
|
+
async def validate_credentials(self) -> bool:
|
16
|
+
"""Validate Cerebras credentials"""
|
17
|
+
try:
|
18
|
+
# Implement Cerebras-specific validation
|
19
|
+
# This would depend on their API client implementation
|
20
|
+
return True
|
21
|
+
except Exception as e:
|
22
|
+
raise CredentialValidationError(f"Invalid Cerebras credentials: {str(e)}")
|
@@ -0,0 +1,27 @@
|
|
1
|
+
from pydantic import Field, SecretStr
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
from google.cloud import storage
|
4
|
+
|
5
|
+
|
6
|
+
class GoogleCloudCredentials(BaseCredentials):
|
7
|
+
"""Google Cloud credentials"""
|
8
|
+
|
9
|
+
project_id: str = Field(..., description="Google Cloud Project ID")
|
10
|
+
service_account_key: SecretStr = Field(..., description="Service Account Key JSON")
|
11
|
+
|
12
|
+
_required_credentials = {"project_id", "service_account_key"}
|
13
|
+
|
14
|
+
async def validate_credentials(self) -> bool:
|
15
|
+
"""Validate Google Cloud credentials"""
|
16
|
+
try:
|
17
|
+
# Initialize with service account key
|
18
|
+
storage_client = storage.Client.from_service_account_info(
|
19
|
+
self.service_account_key.get_secret_value()
|
20
|
+
)
|
21
|
+
# Test API call
|
22
|
+
storage_client.list_buckets(max_results=1)
|
23
|
+
return True
|
24
|
+
except Exception as e:
|
25
|
+
raise CredentialValidationError(
|
26
|
+
f"Invalid Google Cloud credentials: {str(e)}"
|
27
|
+
)
|
@@ -0,0 +1,24 @@
|
|
1
|
+
from pydantic import Field, SecretStr
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
from groq import Groq
|
4
|
+
|
5
|
+
|
6
|
+
class GroqCredentials(BaseCredentials):
|
7
|
+
"""Groq API credentials"""
|
8
|
+
|
9
|
+
api_key: SecretStr = Field(..., description="Groq API key")
|
10
|
+
|
11
|
+
_required_credentials = {"api_key"}
|
12
|
+
|
13
|
+
async def validate_credentials(self) -> bool:
|
14
|
+
"""Validate Groq credentials"""
|
15
|
+
try:
|
16
|
+
client = Groq(api_key=self.api_key.get_secret_value())
|
17
|
+
await client.chat.completions.create(
|
18
|
+
messages=[{"role": "user", "content": "Hi"}],
|
19
|
+
model="mixtral-8x7b-32768",
|
20
|
+
max_tokens=1,
|
21
|
+
)
|
22
|
+
return True
|
23
|
+
except Exception as e:
|
24
|
+
raise CredentialValidationError(f"Invalid Groq credentials: {str(e)}")
|
@@ -0,0 +1,26 @@
|
|
1
|
+
from pydantic import Field
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
from importlib.util import find_spec
|
4
|
+
|
5
|
+
|
6
|
+
class OllamaCredentials(BaseCredentials):
|
7
|
+
"""Ollama credentials"""
|
8
|
+
|
9
|
+
host: str = Field(default="http://localhost:11434", description="Ollama host URL")
|
10
|
+
timeout: int = Field(default=30, description="Request timeout in seconds")
|
11
|
+
|
12
|
+
async def validate_credentials(self) -> bool:
|
13
|
+
"""Validate Ollama credentials"""
|
14
|
+
if find_spec("ollama") is None:
|
15
|
+
raise CredentialValidationError(
|
16
|
+
"Ollama package is not installed. Please install it using: pip install ollama"
|
17
|
+
)
|
18
|
+
|
19
|
+
try:
|
20
|
+
from ollama import Client
|
21
|
+
|
22
|
+
client = Client(host=self.host)
|
23
|
+
await client.list()
|
24
|
+
return True
|
25
|
+
except Exception as e:
|
26
|
+
raise CredentialValidationError(f"Invalid Ollama connection: {str(e)}")
|
@@ -0,0 +1,42 @@
|
|
1
|
+
from typing import Optional, TypeVar
|
2
|
+
from pydantic import Field
|
3
|
+
from .skills import OpenAIChatSkill, OpenAIInput, OpenAIOutput
|
4
|
+
from .credentials import OpenAICredentials
|
5
|
+
|
6
|
+
T = TypeVar("T", bound=OpenAIInput)
|
7
|
+
|
8
|
+
|
9
|
+
class ChineseAssistantInput(OpenAIInput):
|
10
|
+
"""Schema for Chinese Assistant input"""
|
11
|
+
|
12
|
+
user_input: str = Field(
|
13
|
+
..., description="User's input text (can be in any language)"
|
14
|
+
)
|
15
|
+
system_prompt: str = Field(
|
16
|
+
default="你是一个有帮助的助手。请用中文回答所有问题,即使问题是用其他语言问的。回答要准确、礼貌、专业。",
|
17
|
+
description="System prompt in Chinese",
|
18
|
+
)
|
19
|
+
model: str = Field(default="gpt-4o", description="OpenAI model to use")
|
20
|
+
max_tokens: int = Field(default=8096, description="Maximum tokens in response")
|
21
|
+
temperature: float = Field(
|
22
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
23
|
+
)
|
24
|
+
|
25
|
+
|
26
|
+
class ChineseAssistantSkill(OpenAIChatSkill):
|
27
|
+
"""Skill for Chinese language assistance"""
|
28
|
+
|
29
|
+
input_schema = ChineseAssistantInput
|
30
|
+
output_schema = OpenAIOutput
|
31
|
+
|
32
|
+
def __init__(self, credentials: Optional[OpenAICredentials] = None):
|
33
|
+
super().__init__(credentials)
|
34
|
+
|
35
|
+
def process(self, input_data: T) -> OpenAIOutput:
|
36
|
+
# Add language check to ensure response is in Chinese
|
37
|
+
if "你是" not in input_data.system_prompt:
|
38
|
+
input_data.system_prompt = (
|
39
|
+
"你是一个中文助手。" + input_data.system_prompt + "请用中文回答。"
|
40
|
+
)
|
41
|
+
|
42
|
+
return super().process(input_data)
|