airtrain 0.1.14__tar.gz → 0.1.18__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {airtrain-0.1.14/airtrain.egg-info → airtrain-0.1.18}/PKG-INFO +1 -1
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/__init__.py +3 -3
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/core/__pycache__/credentials.cpython-310.pyc +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/__init__.py +1 -1
- airtrain-0.1.18/airtrain/integrations/anthropic/skills.py +127 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/cerebras/credentials.py +3 -6
- airtrain-0.1.18/airtrain/integrations/cerebras/skills.py +95 -0
- airtrain-0.1.18/airtrain/integrations/fireworks/conversation_manager.py +109 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/fireworks/skills.py +27 -5
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/google/__init__.py +2 -1
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/google/credentials.py +30 -0
- airtrain-0.1.18/airtrain/integrations/google/gemini/conversation_history_test.py +83 -0
- airtrain-0.1.18/airtrain/integrations/google/gemini/credentials.py +27 -0
- airtrain-0.1.18/airtrain/integrations/google/gemini/skills.py +116 -0
- airtrain-0.1.18/airtrain/integrations/google/skills.py +122 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/groq/credentials.py +3 -3
- airtrain-0.1.18/airtrain/integrations/groq/skills.py +88 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/sambanova/credentials.py +3 -3
- airtrain-0.1.18/airtrain/integrations/sambanova/skills.py +97 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/image_skill.py +5 -33
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/skills.py +52 -4
- {airtrain-0.1.14 → airtrain-0.1.18/airtrain.egg-info}/PKG-INFO +1 -1
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain.egg-info/SOURCES.txt +23 -0
- airtrain-0.1.18/examples/integrations/anthropic/chat_example.py +42 -0
- airtrain-0.1.18/examples/integrations/anthropic/chinese_example.py +62 -0
- airtrain-0.1.18/examples/integrations/anthropic/conversation_history_test.py +86 -0
- airtrain-0.1.18/examples/integrations/anthropic/vision_example.py +47 -0
- airtrain-0.1.18/examples/integrations/cerebras/conversation_history_test.py +84 -0
- airtrain-0.1.18/examples/integrations/fireworks/chat_example.py +42 -0
- airtrain-0.1.18/examples/integrations/fireworks/conversation_history_test.py +86 -0
- airtrain-0.1.18/examples/integrations/fireworks/structured_chat_example.py +43 -0
- airtrain-0.1.18/examples/integrations/google/conversation_history_test.py +94 -0
- airtrain-0.1.18/examples/integrations/google/gemini/conversation_history_test.py +83 -0
- airtrain-0.1.18/examples/integrations/groq/conversation_history_test.py +84 -0
- airtrain-0.1.18/examples/integrations/openai/chat_example.py +42 -0
- airtrain-0.1.18/examples/integrations/openai/parser_example.py +62 -0
- airtrain-0.1.18/examples/integrations/openai/vision_example.py +46 -0
- airtrain-0.1.18/examples/integrations/sambanova/conversation_history_test.py +85 -0
- airtrain-0.1.18/examples/integrations/together/chat_example.py +42 -0
- airtrain-0.1.18/examples/integrations/together/conversation_history_test.py +86 -0
- airtrain-0.1.18/examples/integrations/together/image_generation_example.py +58 -0
- airtrain-0.1.18/examples/integrations/together/rerank_example.py +59 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/requirements.txt +4 -1
- {airtrain-0.1.14 → airtrain-0.1.18}/scripts/release.py +28 -3
- {airtrain-0.1.14 → airtrain-0.1.18}/setup.py +26 -1
- airtrain-0.1.14/airtrain/integrations/anthropic/skills.py +0 -135
- airtrain-0.1.14/airtrain/integrations/cerebras/skills.py +0 -41
- airtrain-0.1.14/airtrain/integrations/google/skills.py +0 -41
- airtrain-0.1.14/airtrain/integrations/groq/skills.py +0 -41
- airtrain-0.1.14/airtrain/integrations/sambanova/skills.py +0 -41
- {airtrain-0.1.14 → airtrain-0.1.18}/.flake8 +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/.github/workflows/publish.yml +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/.gitignore +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/.mypy.ini +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/.pre-commit-config.yaml +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/.vscode/extensions.json +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/.vscode/launch.json +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/.vscode/settings.json +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/EXPERIMENTS/integrations_examples/anthropic_with_image.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/EXPERIMENTS/schema_exps/pydantic_schemas.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/README.md +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/agents/travel/agents.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/agents/travel/models.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/contrib/__init__.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/contrib/travel/__init__.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/contrib/travel/agentlib/verification_agent.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/contrib/travel/agents.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/contrib/travel/modellib/verification.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/contrib/travel/models.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/core/__init__.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/core/__pycache__/schemas.cpython-310.pyc +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/core/__pycache__/skills.cpython-310.pyc +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/core/credentials.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/core/schemas.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/core/skills.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/anthropic/__init__.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/anthropic/credentials.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/aws/__init__.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/aws/credentials.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/aws/skills.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/cerebras/__init__.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/fireworks/__init__.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/fireworks/credentials.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/fireworks/models.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/groq/__init__.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/ollama/__init__.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/ollama/credentials.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/ollama/skills.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/openai/__init__.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/openai/chinese_assistant.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/openai/credentials.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/openai/models_config.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/openai/skills.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/sambanova/__init__.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/__init__.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/audio_models_config.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/credentials.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/embedding_models_config.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/image_models_config.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/models.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/models_config.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/rerank_models_config.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/rerank_skill.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/schemas.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/vision_models_config.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain.egg-info/dependency_links.txt +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain.egg-info/requires.txt +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/airtrain.egg-info/top_level.txt +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/anthropic_skills_usage.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/chinese_anthropic_assistant.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/chinese_anthropic_usage.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/chinese_assistant_usage.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/fireworks_skills_usage.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/icon128.png +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/icon16.png +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/image1.jpg +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/image2.jpg +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/openai_skills.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/openai_skills_usage.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/openai_structured_skills.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/together_rerank_skills.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/together_rerank_skills_async.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/examples/credentials_usage.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/examples/images/quantum-circuit.png +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/examples/schema_usage.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/examples/skill_usage.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/examples/together/image_generation.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/examples/together/image_generation_example.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/examples/travel/verification_agent_usage.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/pyproject.toml +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/scripts/build.sh +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/scripts/bump_version.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/scripts/publish.sh +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/services/firebase_service.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/services/openai_service.py +0 -0
- {airtrain-0.1.14 → airtrain-0.1.18}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
"""Airtrain - A platform for building and deploying AI agents with structured skills"""
|
2
2
|
|
3
|
-
__version__ = "0.1.
|
3
|
+
__version__ = "0.1.18"
|
4
4
|
|
5
5
|
# Core imports
|
6
6
|
from .core.skills import Skill, ProcessingError
|
@@ -22,7 +22,7 @@ from .integrations.cerebras.credentials import CerebrasCredentials
|
|
22
22
|
from .integrations.openai.skills import OpenAIChatSkill, OpenAIParserSkill
|
23
23
|
from .integrations.anthropic.skills import AnthropicChatSkill
|
24
24
|
from .integrations.aws.skills import AWSBedrockSkill
|
25
|
-
from .integrations.google.skills import
|
25
|
+
from .integrations.google.skills import GoogleChatSkill
|
26
26
|
from .integrations.groq.skills import GroqChatSkill
|
27
27
|
from .integrations.together.skills import TogetherAIChatSkill
|
28
28
|
from .integrations.ollama.skills import OllamaChatSkill
|
@@ -51,7 +51,7 @@ __all__ = [
|
|
51
51
|
"OpenAIParserSkill",
|
52
52
|
"AnthropicChatSkill",
|
53
53
|
"AWSBedrockSkill",
|
54
|
-
"
|
54
|
+
"GoogleChatSkill",
|
55
55
|
"GroqChatSkill",
|
56
56
|
"TogetherAIChatSkill",
|
57
57
|
"OllamaChatSkill",
|
Binary file
|
@@ -15,7 +15,7 @@ from .cerebras.credentials import CerebrasCredentials
|
|
15
15
|
from .openai.skills import OpenAIChatSkill, OpenAIParserSkill
|
16
16
|
from .anthropic.skills import AnthropicChatSkill
|
17
17
|
from .aws.skills import AWSBedrockSkill
|
18
|
-
from .google.skills import
|
18
|
+
from .google.skills import GoogleChatSkill
|
19
19
|
from .groq.skills import GroqChatSkill
|
20
20
|
from .together.skills import TogetherAIChatSkill
|
21
21
|
from .ollama.skills import OllamaChatSkill
|
@@ -0,0 +1,127 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any
|
2
|
+
from pydantic import Field
|
3
|
+
from anthropic import Anthropic
|
4
|
+
import base64
|
5
|
+
from pathlib import Path
|
6
|
+
from loguru import logger
|
7
|
+
|
8
|
+
from airtrain.core.skills import Skill, ProcessingError
|
9
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
10
|
+
from .credentials import AnthropicCredentials
|
11
|
+
|
12
|
+
|
13
|
+
class AnthropicInput(InputSchema):
|
14
|
+
"""Schema for Anthropic chat input"""
|
15
|
+
|
16
|
+
user_input: str = Field(..., description="User's input text")
|
17
|
+
system_prompt: str = Field(
|
18
|
+
default="You are a helpful assistant.",
|
19
|
+
description="System prompt to guide the model's behavior",
|
20
|
+
)
|
21
|
+
conversation_history: List[Dict[str, str]] = Field(
|
22
|
+
default_factory=list,
|
23
|
+
description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
|
24
|
+
)
|
25
|
+
model: str = Field(
|
26
|
+
default="claude-3-opus-20240229", description="Anthropic model to use"
|
27
|
+
)
|
28
|
+
max_tokens: Optional[int] = Field(
|
29
|
+
default=1024, description="Maximum tokens in response"
|
30
|
+
)
|
31
|
+
temperature: float = Field(
|
32
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
33
|
+
)
|
34
|
+
images: List[Path] = Field(
|
35
|
+
default_factory=list,
|
36
|
+
description="List of image paths to include in the message",
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
class AnthropicOutput(OutputSchema):
|
41
|
+
"""Schema for Anthropic chat output"""
|
42
|
+
|
43
|
+
response: str = Field(..., description="Model's response text")
|
44
|
+
used_model: str = Field(..., description="Model used for generation")
|
45
|
+
usage: Dict[str, Any] = Field(
|
46
|
+
default_factory=dict, description="Usage statistics from the API"
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
class AnthropicChatSkill(Skill[AnthropicInput, AnthropicOutput]):
|
51
|
+
"""Skill for Anthropic chat"""
|
52
|
+
|
53
|
+
input_schema = AnthropicInput
|
54
|
+
output_schema = AnthropicOutput
|
55
|
+
|
56
|
+
def __init__(self, credentials: Optional[AnthropicCredentials] = None):
|
57
|
+
super().__init__()
|
58
|
+
self.credentials = credentials or AnthropicCredentials.from_env()
|
59
|
+
self.client = Anthropic(
|
60
|
+
api_key=self.credentials.anthropic_api_key.get_secret_value()
|
61
|
+
)
|
62
|
+
|
63
|
+
def _build_messages(self, input_data: AnthropicInput) -> List[Dict[str, Any]]:
|
64
|
+
"""
|
65
|
+
Build messages list from input data including conversation history.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
input_data: The input data containing system prompt, conversation history, and user input
|
69
|
+
|
70
|
+
Returns:
|
71
|
+
List[Dict[str, Any]]: List of messages in the format required by Anthropic
|
72
|
+
"""
|
73
|
+
messages = []
|
74
|
+
|
75
|
+
# Add conversation history if present
|
76
|
+
if input_data.conversation_history:
|
77
|
+
messages.extend(input_data.conversation_history)
|
78
|
+
|
79
|
+
# Prepare user message content
|
80
|
+
user_message = {"type": "text", "text": input_data.user_input}
|
81
|
+
|
82
|
+
# Add images if present
|
83
|
+
if input_data.images:
|
84
|
+
content = []
|
85
|
+
for image_path in input_data.images:
|
86
|
+
with open(image_path, "rb") as img_file:
|
87
|
+
base64_image = base64.b64encode(img_file.read()).decode("utf-8")
|
88
|
+
content.append(
|
89
|
+
{
|
90
|
+
"type": "image",
|
91
|
+
"source": {
|
92
|
+
"type": "base64",
|
93
|
+
"media_type": "image/jpeg",
|
94
|
+
"data": base64_image,
|
95
|
+
},
|
96
|
+
}
|
97
|
+
)
|
98
|
+
content.append(user_message)
|
99
|
+
messages.append({"role": "user", "content": content})
|
100
|
+
else:
|
101
|
+
messages.append({"role": "user", "content": [user_message]})
|
102
|
+
|
103
|
+
return messages
|
104
|
+
|
105
|
+
def process(self, input_data: AnthropicInput) -> AnthropicOutput:
|
106
|
+
try:
|
107
|
+
# Build messages using the helper method
|
108
|
+
messages = self._build_messages(input_data)
|
109
|
+
|
110
|
+
# Create chat completion with system prompt as a separate parameter
|
111
|
+
response = self.client.messages.create(
|
112
|
+
model=input_data.model,
|
113
|
+
system=input_data.system_prompt, # System prompt passed directly
|
114
|
+
messages=messages,
|
115
|
+
max_tokens=input_data.max_tokens,
|
116
|
+
temperature=input_data.temperature,
|
117
|
+
)
|
118
|
+
|
119
|
+
return AnthropicOutput(
|
120
|
+
response=response.content[0].text,
|
121
|
+
used_model=input_data.model,
|
122
|
+
usage=response.usage.model_dump(),
|
123
|
+
)
|
124
|
+
|
125
|
+
except Exception as e:
|
126
|
+
logger.exception(f"Anthropic processing failed: {str(e)}")
|
127
|
+
raise ProcessingError(f"Anthropic processing failed: {str(e)}")
|
@@ -1,16 +1,13 @@
|
|
1
|
-
from pydantic import Field, SecretStr
|
1
|
+
from pydantic import Field, SecretStr
|
2
2
|
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
-
from typing import Optional
|
4
3
|
|
5
4
|
|
6
5
|
class CerebrasCredentials(BaseCredentials):
|
7
6
|
"""Cerebras credentials"""
|
8
7
|
|
9
|
-
|
10
|
-
endpoint_url: HttpUrl = Field(..., description="Cerebras API endpoint")
|
11
|
-
project_id: Optional[str] = Field(None, description="Cerebras Project ID")
|
8
|
+
cerebras_api_key: SecretStr = Field(..., description="Cerebras API key")
|
12
9
|
|
13
|
-
_required_credentials = {"
|
10
|
+
_required_credentials = {"cerebras_api_key"}
|
14
11
|
|
15
12
|
async def validate_credentials(self) -> bool:
|
16
13
|
"""Validate Cerebras credentials"""
|
@@ -0,0 +1,95 @@
|
|
1
|
+
from typing import List, Optional, Dict, Any
|
2
|
+
from pydantic import Field
|
3
|
+
from cerebras.cloud.sdk import Cerebras
|
4
|
+
from loguru import logger
|
5
|
+
|
6
|
+
from airtrain.core.skills import Skill, ProcessingError
|
7
|
+
from airtrain.core.schemas import InputSchema, OutputSchema
|
8
|
+
from .credentials import CerebrasCredentials
|
9
|
+
|
10
|
+
|
11
|
+
class CerebrasInput(InputSchema):
|
12
|
+
"""Schema for Cerebras chat input"""
|
13
|
+
|
14
|
+
user_input: str = Field(..., description="User's input text")
|
15
|
+
system_prompt: str = Field(
|
16
|
+
default="You are a helpful assistant.",
|
17
|
+
description="System prompt to guide the model's behavior",
|
18
|
+
)
|
19
|
+
conversation_history: List[Dict[str, str]] = Field(
|
20
|
+
default_factory=list,
|
21
|
+
description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
|
22
|
+
)
|
23
|
+
model: str = Field(default="llama3.1-8b", description="Cerebras model to use")
|
24
|
+
max_tokens: Optional[int] = Field(
|
25
|
+
default=1024, description="Maximum tokens in response"
|
26
|
+
)
|
27
|
+
temperature: float = Field(
|
28
|
+
default=0.7, description="Temperature for response generation", ge=0, le=1
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
class CerebrasOutput(OutputSchema):
|
33
|
+
"""Schema for Cerebras chat output"""
|
34
|
+
|
35
|
+
response: str = Field(..., description="Model's response text")
|
36
|
+
used_model: str = Field(..., description="Model used for generation")
|
37
|
+
usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
|
38
|
+
|
39
|
+
|
40
|
+
class CerebrasChatSkill(Skill[CerebrasInput, CerebrasOutput]):
|
41
|
+
"""Skill for Cerebras chat"""
|
42
|
+
|
43
|
+
input_schema = CerebrasInput
|
44
|
+
output_schema = CerebrasOutput
|
45
|
+
|
46
|
+
def __init__(self, credentials: Optional[CerebrasCredentials] = None):
|
47
|
+
super().__init__()
|
48
|
+
self.credentials = credentials or CerebrasCredentials.from_env()
|
49
|
+
self.client = Cerebras(
|
50
|
+
api_key=self.credentials.cerebras_api_key.get_secret_value()
|
51
|
+
)
|
52
|
+
|
53
|
+
def _build_messages(self, input_data: CerebrasInput) -> List[Dict[str, str]]:
|
54
|
+
"""
|
55
|
+
Build messages list from input data including conversation history.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
input_data: The input data containing system prompt, conversation history, and user input
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
List[Dict[str, str]]: List of messages in the format required by Cerebras
|
62
|
+
"""
|
63
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
64
|
+
|
65
|
+
# Add conversation history if present
|
66
|
+
if input_data.conversation_history:
|
67
|
+
messages.extend(input_data.conversation_history)
|
68
|
+
|
69
|
+
# Add current user input
|
70
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
71
|
+
|
72
|
+
return messages
|
73
|
+
|
74
|
+
def process(self, input_data: CerebrasInput) -> CerebrasOutput:
|
75
|
+
try:
|
76
|
+
# Build messages using the helper method
|
77
|
+
messages = self._build_messages(input_data)
|
78
|
+
|
79
|
+
# Create chat completion
|
80
|
+
response = self.client.chat.completions.create(
|
81
|
+
model=input_data.model,
|
82
|
+
messages=messages,
|
83
|
+
temperature=input_data.temperature,
|
84
|
+
max_tokens=input_data.max_tokens,
|
85
|
+
)
|
86
|
+
|
87
|
+
return CerebrasOutput(
|
88
|
+
response=response.choices[0].message.content,
|
89
|
+
used_model=input_data.model,
|
90
|
+
usage=response.usage.model_dump(),
|
91
|
+
)
|
92
|
+
|
93
|
+
except Exception as e:
|
94
|
+
logger.exception(f"Cerebras processing failed: {str(e)}")
|
95
|
+
raise ProcessingError(f"Cerebras processing failed: {str(e)}")
|
@@ -0,0 +1,109 @@
|
|
1
|
+
from typing import List, Dict, Optional
|
2
|
+
from pydantic import BaseModel, Field
|
3
|
+
from .skills import FireworksChatSkill, FireworksInput, FireworksOutput
|
4
|
+
|
5
|
+
# TODO: Test this thing.
|
6
|
+
|
7
|
+
|
8
|
+
class ConversationState(BaseModel):
|
9
|
+
"""Model to track conversation state"""
|
10
|
+
|
11
|
+
messages: List[Dict[str, str]] = Field(
|
12
|
+
default_factory=list, description="List of conversation messages"
|
13
|
+
)
|
14
|
+
system_prompt: str = Field(
|
15
|
+
default="You are a helpful assistant.",
|
16
|
+
description="System prompt for the conversation",
|
17
|
+
)
|
18
|
+
model: str = Field(
|
19
|
+
default="accounts/fireworks/models/deepseek-r1",
|
20
|
+
description="Model being used for the conversation",
|
21
|
+
)
|
22
|
+
temperature: float = Field(default=0.7, description="Temperature setting")
|
23
|
+
max_tokens: Optional[int] = Field(default=None, description="Max tokens setting")
|
24
|
+
|
25
|
+
|
26
|
+
class FireworksConversationManager:
|
27
|
+
"""Manager for handling conversation state with Fireworks AI"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
skill: Optional[FireworksChatSkill] = None,
|
32
|
+
system_prompt: str = "You are a helpful assistant.",
|
33
|
+
model: str = "accounts/fireworks/models/deepseek-r1",
|
34
|
+
temperature: float = 0.7,
|
35
|
+
max_tokens: Optional[int] = None,
|
36
|
+
):
|
37
|
+
"""
|
38
|
+
Initialize conversation manager.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
skill: FireworksChatSkill instance (creates new one if None)
|
42
|
+
system_prompt: Initial system prompt
|
43
|
+
model: Model to use
|
44
|
+
temperature: Temperature setting
|
45
|
+
max_tokens: Max tokens setting
|
46
|
+
"""
|
47
|
+
self.skill = skill or FireworksChatSkill()
|
48
|
+
self.state = ConversationState(
|
49
|
+
system_prompt=system_prompt,
|
50
|
+
model=model,
|
51
|
+
temperature=temperature,
|
52
|
+
max_tokens=max_tokens,
|
53
|
+
)
|
54
|
+
|
55
|
+
def send_message(self, user_input: str) -> FireworksOutput:
|
56
|
+
"""
|
57
|
+
Send a message and get response while maintaining conversation history.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
user_input: User's message
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
FireworksOutput: Model's response
|
64
|
+
"""
|
65
|
+
# Create input with current conversation state
|
66
|
+
input_data = FireworksInput(
|
67
|
+
user_input=user_input,
|
68
|
+
system_prompt=self.state.system_prompt,
|
69
|
+
conversation_history=self.state.messages,
|
70
|
+
model=self.state.model,
|
71
|
+
temperature=self.state.temperature,
|
72
|
+
max_tokens=self.state.max_tokens,
|
73
|
+
)
|
74
|
+
|
75
|
+
# Get response
|
76
|
+
result = self.skill.process(input_data)
|
77
|
+
|
78
|
+
# Update conversation history
|
79
|
+
self.state.messages.extend(
|
80
|
+
[
|
81
|
+
{"role": "user", "content": user_input},
|
82
|
+
{"role": "assistant", "content": result.response},
|
83
|
+
]
|
84
|
+
)
|
85
|
+
|
86
|
+
return result
|
87
|
+
|
88
|
+
def reset_conversation(self) -> None:
|
89
|
+
"""Reset the conversation history while maintaining other settings"""
|
90
|
+
self.state.messages = []
|
91
|
+
|
92
|
+
def get_conversation_history(self) -> List[Dict[str, str]]:
|
93
|
+
"""Get the current conversation history"""
|
94
|
+
return self.state.messages.copy()
|
95
|
+
|
96
|
+
def update_system_prompt(self, new_prompt: str) -> None:
|
97
|
+
"""Update the system prompt for future messages"""
|
98
|
+
self.state.system_prompt = new_prompt
|
99
|
+
|
100
|
+
def save_state(self, file_path: str) -> None:
|
101
|
+
"""Save conversation state to a file"""
|
102
|
+
with open(file_path, "w") as f:
|
103
|
+
f.write(self.state.model_dump_json(indent=2))
|
104
|
+
|
105
|
+
def load_state(self, file_path: str) -> None:
|
106
|
+
"""Load conversation state from a file"""
|
107
|
+
with open(file_path, "r") as f:
|
108
|
+
data = f.read()
|
109
|
+
self.state = ConversationState.model_validate_json(data)
|
@@ -17,6 +17,10 @@ class FireworksInput(InputSchema):
|
|
17
17
|
default="You are a helpful assistant.",
|
18
18
|
description="System prompt to guide the model's behavior",
|
19
19
|
)
|
20
|
+
conversation_history: List[Dict[str, str]] = Field(
|
21
|
+
default_factory=list,
|
22
|
+
description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
|
23
|
+
)
|
20
24
|
model: str = Field(
|
21
25
|
default="accounts/fireworks/models/deepseek-r1",
|
22
26
|
description="Fireworks AI model to use",
|
@@ -52,16 +56,34 @@ class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
|
|
52
56
|
self.credentials = credentials or FireworksCredentials.from_env()
|
53
57
|
self.base_url = "https://api.fireworks.ai/inference/v1"
|
54
58
|
|
59
|
+
def _build_messages(self, input_data: FireworksInput) -> List[Dict[str, str]]:
|
60
|
+
"""
|
61
|
+
Build messages list from input data including conversation history.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
input_data: The input data containing system prompt, conversation history, and user input
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
List[Dict[str, str]]: List of messages in the format required by Fireworks AI
|
68
|
+
"""
|
69
|
+
messages = [{"role": "system", "content": input_data.system_prompt}]
|
70
|
+
|
71
|
+
# Add conversation history if present
|
72
|
+
if input_data.conversation_history:
|
73
|
+
messages.extend(input_data.conversation_history)
|
74
|
+
|
75
|
+
# Add current user input
|
76
|
+
messages.append({"role": "user", "content": input_data.user_input})
|
77
|
+
|
78
|
+
return messages
|
79
|
+
|
55
80
|
def process(self, input_data: FireworksInput) -> FireworksOutput:
|
56
81
|
"""Process the input using Fireworks AI API"""
|
57
82
|
try:
|
58
83
|
logger.info(f"Processing request with model {input_data.model}")
|
59
84
|
|
60
|
-
#
|
61
|
-
messages =
|
62
|
-
{"role": "system", "content": input_data.system_prompt},
|
63
|
-
{"role": "user", "content": input_data.user_input},
|
64
|
-
]
|
85
|
+
# Build messages using the helper method
|
86
|
+
messages = self._build_messages(input_data)
|
65
87
|
|
66
88
|
# Prepare request payload
|
67
89
|
payload = {
|
@@ -1,5 +1,8 @@
|
|
1
1
|
from pydantic import Field, SecretStr
|
2
2
|
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
import google.genai as genai
|
4
|
+
from google.cloud import storage
|
5
|
+
import os
|
3
6
|
|
4
7
|
# from google.cloud import storage
|
5
8
|
|
@@ -26,3 +29,30 @@ class GoogleCloudCredentials(BaseCredentials):
|
|
26
29
|
raise CredentialValidationError(
|
27
30
|
f"Invalid Google Cloud credentials: {str(e)}"
|
28
31
|
)
|
32
|
+
|
33
|
+
|
34
|
+
class GeminiCredentials(BaseCredentials):
|
35
|
+
"""Gemini API credentials"""
|
36
|
+
|
37
|
+
gemini_api_key: SecretStr = Field(..., description="Gemini API Key")
|
38
|
+
|
39
|
+
_required_credentials = {"gemini_api_key"}
|
40
|
+
|
41
|
+
@classmethod
|
42
|
+
def from_env(cls) -> "GeminiCredentials":
|
43
|
+
"""Create credentials from environment variables"""
|
44
|
+
return cls(gemini_api_key=SecretStr(os.environ.get("GEMINI_API_KEY", "")))
|
45
|
+
|
46
|
+
async def validate_credentials(self) -> bool:
|
47
|
+
"""Validate Gemini API credentials"""
|
48
|
+
try:
|
49
|
+
# Configure Gemini with API key
|
50
|
+
genai.configure(api_key=self.gemini_api_key.get_secret_value())
|
51
|
+
|
52
|
+
# Test API call with a simple model
|
53
|
+
model = genai.GenerativeModel("gemini-1.5-flash")
|
54
|
+
response = model.generate_content("test")
|
55
|
+
|
56
|
+
return True
|
57
|
+
except Exception as e:
|
58
|
+
raise CredentialValidationError(f"Invalid Gemini credentials: {str(e)}")
|
@@ -0,0 +1,83 @@
|
|
1
|
+
import sys
|
2
|
+
import os
|
3
|
+
from pathlib import Path
|
4
|
+
from dotenv import load_dotenv
|
5
|
+
from typing import List, Dict
|
6
|
+
|
7
|
+
load_dotenv()
|
8
|
+
|
9
|
+
parent_dir = os.path.abspath(
|
10
|
+
os.path.join(os.path.abspath(__file__), "..", "..", "..", "..", "..")
|
11
|
+
)
|
12
|
+
sys.path.append(parent_dir)
|
13
|
+
|
14
|
+
from airtrain.integrations.google.gemini.skills import (
|
15
|
+
Gemini2ChatSkill,
|
16
|
+
Gemini2Input,
|
17
|
+
Gemini2GenerationConfig,
|
18
|
+
)
|
19
|
+
|
20
|
+
|
21
|
+
def run_conversation(
|
22
|
+
skill: Gemini2ChatSkill,
|
23
|
+
user_input: str,
|
24
|
+
system_prompt: str,
|
25
|
+
conversation_history: List[Dict[str, str]],
|
26
|
+
) -> Dict[str, str]:
|
27
|
+
"""Run a single conversation turn and return the assistant's response"""
|
28
|
+
generation_config = Gemini2GenerationConfig(
|
29
|
+
temperature=1.0,
|
30
|
+
top_p=0.95,
|
31
|
+
top_k=40,
|
32
|
+
max_output_tokens=8192,
|
33
|
+
)
|
34
|
+
|
35
|
+
input_data = Gemini2Input(
|
36
|
+
user_input=user_input,
|
37
|
+
system_prompt=system_prompt,
|
38
|
+
conversation_history=skill._convert_history_format(conversation_history),
|
39
|
+
model="gemini-2.0-flash",
|
40
|
+
generation_config=generation_config,
|
41
|
+
)
|
42
|
+
|
43
|
+
result = skill.process(input_data)
|
44
|
+
return {"role": "assistant", "content": result.response}
|
45
|
+
|
46
|
+
|
47
|
+
def main():
|
48
|
+
skill = Gemini2ChatSkill()
|
49
|
+
system_prompt = (
|
50
|
+
"You are a helpful AI assistant with expertise in cybersecurity and privacy."
|
51
|
+
)
|
52
|
+
conversation_history = []
|
53
|
+
|
54
|
+
conversation_turns = [
|
55
|
+
"What are the best practices for password security?",
|
56
|
+
"How can I protect my personal data online?",
|
57
|
+
"What is two-factor authentication?",
|
58
|
+
"Can you explain what encryption is?",
|
59
|
+
"Can you summarize the key points about cybersecurity we discussed?",
|
60
|
+
]
|
61
|
+
|
62
|
+
print("\n=== Starting Conversation Test ===\n")
|
63
|
+
|
64
|
+
for turn_number, user_input in enumerate(conversation_turns, 1):
|
65
|
+
print(f"\n--- Turn {turn_number} ---")
|
66
|
+
print(f"User: {user_input}\n")
|
67
|
+
|
68
|
+
assistant_response = run_conversation(
|
69
|
+
skill, user_input, system_prompt, conversation_history
|
70
|
+
)
|
71
|
+
|
72
|
+
conversation_history.extend(
|
73
|
+
[{"role": "user", "content": user_input}, assistant_response]
|
74
|
+
)
|
75
|
+
|
76
|
+
print(f"Assistant: {assistant_response['content']}\n")
|
77
|
+
print(f"Current conversation history length: {len(conversation_history)}")
|
78
|
+
|
79
|
+
print("\n=== Conversation Test Complete ===")
|
80
|
+
|
81
|
+
|
82
|
+
if __name__ == "__main__":
|
83
|
+
main()
|
@@ -0,0 +1,27 @@
|
|
1
|
+
from pydantic import Field, SecretStr
|
2
|
+
from airtrain.core.credentials import BaseCredentials, CredentialValidationError
|
3
|
+
import google.generativeai as genai
|
4
|
+
import os
|
5
|
+
|
6
|
+
|
7
|
+
class Gemini2Credentials(BaseCredentials):
|
8
|
+
"""Gemini 2.0 API credentials"""
|
9
|
+
|
10
|
+
gemini_api_key: SecretStr = Field(..., description="Gemini API Key")
|
11
|
+
|
12
|
+
_required_credentials = {"gemini_api_key"}
|
13
|
+
|
14
|
+
@classmethod
|
15
|
+
def from_env(cls) -> "Gemini2Credentials":
|
16
|
+
"""Create credentials from environment variables"""
|
17
|
+
return cls(gemini_api_key=SecretStr(os.environ.get("GEMINI_API_KEY", "")))
|
18
|
+
|
19
|
+
async def validate_credentials(self) -> bool:
|
20
|
+
"""Validate Gemini API credentials"""
|
21
|
+
try:
|
22
|
+
genai.configure(api_key=self.gemini_api_key.get_secret_value())
|
23
|
+
model = genai.GenerativeModel("gemini-2.0-flash")
|
24
|
+
response = model.generate_content("test")
|
25
|
+
return True
|
26
|
+
except Exception as e:
|
27
|
+
raise CredentialValidationError(f"Invalid Gemini 2.0 credentials: {str(e)}")
|