airtrain 0.1.14__tar.gz → 0.1.18__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. {airtrain-0.1.14/airtrain.egg-info → airtrain-0.1.18}/PKG-INFO +1 -1
  2. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/__init__.py +3 -3
  3. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/core/__pycache__/credentials.cpython-310.pyc +0 -0
  4. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/__init__.py +1 -1
  5. airtrain-0.1.18/airtrain/integrations/anthropic/skills.py +127 -0
  6. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/cerebras/credentials.py +3 -6
  7. airtrain-0.1.18/airtrain/integrations/cerebras/skills.py +95 -0
  8. airtrain-0.1.18/airtrain/integrations/fireworks/conversation_manager.py +109 -0
  9. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/fireworks/skills.py +27 -5
  10. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/google/__init__.py +2 -1
  11. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/google/credentials.py +30 -0
  12. airtrain-0.1.18/airtrain/integrations/google/gemini/conversation_history_test.py +83 -0
  13. airtrain-0.1.18/airtrain/integrations/google/gemini/credentials.py +27 -0
  14. airtrain-0.1.18/airtrain/integrations/google/gemini/skills.py +116 -0
  15. airtrain-0.1.18/airtrain/integrations/google/skills.py +122 -0
  16. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/groq/credentials.py +3 -3
  17. airtrain-0.1.18/airtrain/integrations/groq/skills.py +88 -0
  18. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/sambanova/credentials.py +3 -3
  19. airtrain-0.1.18/airtrain/integrations/sambanova/skills.py +97 -0
  20. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/image_skill.py +5 -33
  21. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/skills.py +52 -4
  22. {airtrain-0.1.14 → airtrain-0.1.18/airtrain.egg-info}/PKG-INFO +1 -1
  23. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain.egg-info/SOURCES.txt +23 -0
  24. airtrain-0.1.18/examples/integrations/anthropic/chat_example.py +42 -0
  25. airtrain-0.1.18/examples/integrations/anthropic/chinese_example.py +62 -0
  26. airtrain-0.1.18/examples/integrations/anthropic/conversation_history_test.py +86 -0
  27. airtrain-0.1.18/examples/integrations/anthropic/vision_example.py +47 -0
  28. airtrain-0.1.18/examples/integrations/cerebras/conversation_history_test.py +84 -0
  29. airtrain-0.1.18/examples/integrations/fireworks/chat_example.py +42 -0
  30. airtrain-0.1.18/examples/integrations/fireworks/conversation_history_test.py +86 -0
  31. airtrain-0.1.18/examples/integrations/fireworks/structured_chat_example.py +43 -0
  32. airtrain-0.1.18/examples/integrations/google/conversation_history_test.py +94 -0
  33. airtrain-0.1.18/examples/integrations/google/gemini/conversation_history_test.py +83 -0
  34. airtrain-0.1.18/examples/integrations/groq/conversation_history_test.py +84 -0
  35. airtrain-0.1.18/examples/integrations/openai/chat_example.py +42 -0
  36. airtrain-0.1.18/examples/integrations/openai/parser_example.py +62 -0
  37. airtrain-0.1.18/examples/integrations/openai/vision_example.py +46 -0
  38. airtrain-0.1.18/examples/integrations/sambanova/conversation_history_test.py +85 -0
  39. airtrain-0.1.18/examples/integrations/together/chat_example.py +42 -0
  40. airtrain-0.1.18/examples/integrations/together/conversation_history_test.py +86 -0
  41. airtrain-0.1.18/examples/integrations/together/image_generation_example.py +58 -0
  42. airtrain-0.1.18/examples/integrations/together/rerank_example.py +59 -0
  43. {airtrain-0.1.14 → airtrain-0.1.18}/requirements.txt +4 -1
  44. {airtrain-0.1.14 → airtrain-0.1.18}/scripts/release.py +28 -3
  45. {airtrain-0.1.14 → airtrain-0.1.18}/setup.py +26 -1
  46. airtrain-0.1.14/airtrain/integrations/anthropic/skills.py +0 -135
  47. airtrain-0.1.14/airtrain/integrations/cerebras/skills.py +0 -41
  48. airtrain-0.1.14/airtrain/integrations/google/skills.py +0 -41
  49. airtrain-0.1.14/airtrain/integrations/groq/skills.py +0 -41
  50. airtrain-0.1.14/airtrain/integrations/sambanova/skills.py +0 -41
  51. {airtrain-0.1.14 → airtrain-0.1.18}/.flake8 +0 -0
  52. {airtrain-0.1.14 → airtrain-0.1.18}/.github/workflows/publish.yml +0 -0
  53. {airtrain-0.1.14 → airtrain-0.1.18}/.gitignore +0 -0
  54. {airtrain-0.1.14 → airtrain-0.1.18}/.mypy.ini +0 -0
  55. {airtrain-0.1.14 → airtrain-0.1.18}/.pre-commit-config.yaml +0 -0
  56. {airtrain-0.1.14 → airtrain-0.1.18}/.vscode/extensions.json +0 -0
  57. {airtrain-0.1.14 → airtrain-0.1.18}/.vscode/launch.json +0 -0
  58. {airtrain-0.1.14 → airtrain-0.1.18}/.vscode/settings.json +0 -0
  59. {airtrain-0.1.14 → airtrain-0.1.18}/EXPERIMENTS/integrations_examples/anthropic_with_image.py +0 -0
  60. {airtrain-0.1.14 → airtrain-0.1.18}/EXPERIMENTS/schema_exps/pydantic_schemas.py +0 -0
  61. {airtrain-0.1.14 → airtrain-0.1.18}/README.md +0 -0
  62. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/agents/travel/agents.py +0 -0
  63. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/agents/travel/models.py +0 -0
  64. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/contrib/__init__.py +0 -0
  65. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/contrib/travel/__init__.py +0 -0
  66. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/contrib/travel/agentlib/verification_agent.py +0 -0
  67. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/contrib/travel/agents.py +0 -0
  68. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/contrib/travel/modellib/verification.py +0 -0
  69. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/contrib/travel/models.py +0 -0
  70. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/core/__init__.py +0 -0
  71. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/core/__pycache__/schemas.cpython-310.pyc +0 -0
  72. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/core/__pycache__/skills.cpython-310.pyc +0 -0
  73. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/core/credentials.py +0 -0
  74. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/core/schemas.py +0 -0
  75. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/core/skills.py +0 -0
  76. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/anthropic/__init__.py +0 -0
  77. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/anthropic/credentials.py +0 -0
  78. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/aws/__init__.py +0 -0
  79. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/aws/credentials.py +0 -0
  80. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/aws/skills.py +0 -0
  81. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/cerebras/__init__.py +0 -0
  82. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/fireworks/__init__.py +0 -0
  83. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/fireworks/credentials.py +0 -0
  84. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/fireworks/models.py +0 -0
  85. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/groq/__init__.py +0 -0
  86. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/ollama/__init__.py +0 -0
  87. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/ollama/credentials.py +0 -0
  88. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/ollama/skills.py +0 -0
  89. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/openai/__init__.py +0 -0
  90. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/openai/chinese_assistant.py +0 -0
  91. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/openai/credentials.py +0 -0
  92. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/openai/models_config.py +0 -0
  93. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/openai/skills.py +0 -0
  94. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/sambanova/__init__.py +0 -0
  95. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/__init__.py +0 -0
  96. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/audio_models_config.py +0 -0
  97. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/credentials.py +0 -0
  98. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/embedding_models_config.py +0 -0
  99. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/image_models_config.py +0 -0
  100. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/models.py +0 -0
  101. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/models_config.py +0 -0
  102. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/rerank_models_config.py +0 -0
  103. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/rerank_skill.py +0 -0
  104. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/schemas.py +0 -0
  105. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain/integrations/together/vision_models_config.py +0 -0
  106. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain.egg-info/dependency_links.txt +0 -0
  107. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain.egg-info/requires.txt +0 -0
  108. {airtrain-0.1.14 → airtrain-0.1.18}/airtrain.egg-info/top_level.txt +0 -0
  109. {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/anthropic_skills_usage.py +0 -0
  110. {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/chinese_anthropic_assistant.py +0 -0
  111. {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/chinese_anthropic_usage.py +0 -0
  112. {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/chinese_assistant_usage.py +0 -0
  113. {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/fireworks_skills_usage.py +0 -0
  114. {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/icon128.png +0 -0
  115. {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/icon16.png +0 -0
  116. {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/image1.jpg +0 -0
  117. {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/image2.jpg +0 -0
  118. {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/openai_skills.py +0 -0
  119. {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/openai_skills_usage.py +0 -0
  120. {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/openai_structured_skills.py +0 -0
  121. {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/together_rerank_skills.py +0 -0
  122. {airtrain-0.1.14 → airtrain-0.1.18}/examples/creating-skills/together_rerank_skills_async.py +0 -0
  123. {airtrain-0.1.14 → airtrain-0.1.18}/examples/credentials_usage.py +0 -0
  124. {airtrain-0.1.14 → airtrain-0.1.18}/examples/images/quantum-circuit.png +0 -0
  125. {airtrain-0.1.14 → airtrain-0.1.18}/examples/schema_usage.py +0 -0
  126. {airtrain-0.1.14 → airtrain-0.1.18}/examples/skill_usage.py +0 -0
  127. {airtrain-0.1.14 → airtrain-0.1.18}/examples/together/image_generation.py +0 -0
  128. {airtrain-0.1.14 → airtrain-0.1.18}/examples/together/image_generation_example.py +0 -0
  129. {airtrain-0.1.14 → airtrain-0.1.18}/examples/travel/verification_agent_usage.py +0 -0
  130. {airtrain-0.1.14 → airtrain-0.1.18}/pyproject.toml +0 -0
  131. {airtrain-0.1.14 → airtrain-0.1.18}/scripts/build.sh +0 -0
  132. {airtrain-0.1.14 → airtrain-0.1.18}/scripts/bump_version.py +0 -0
  133. {airtrain-0.1.14 → airtrain-0.1.18}/scripts/publish.sh +0 -0
  134. {airtrain-0.1.14 → airtrain-0.1.18}/services/firebase_service.py +0 -0
  135. {airtrain-0.1.14 → airtrain-0.1.18}/services/openai_service.py +0 -0
  136. {airtrain-0.1.14 → airtrain-0.1.18}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: airtrain
3
- Version: 0.1.14
3
+ Version: 0.1.18
4
4
  Summary: A platform for building and deploying AI agents with structured skills
5
5
  Home-page: https://github.com/rosaboyle/airtrain.dev
6
6
  Author: Dheeraj Pai
@@ -1,6 +1,6 @@
1
1
  """Airtrain - A platform for building and deploying AI agents with structured skills"""
2
2
 
3
- __version__ = "0.1.14"
3
+ __version__ = "0.1.18"
4
4
 
5
5
  # Core imports
6
6
  from .core.skills import Skill, ProcessingError
@@ -22,7 +22,7 @@ from .integrations.cerebras.credentials import CerebrasCredentials
22
22
  from .integrations.openai.skills import OpenAIChatSkill, OpenAIParserSkill
23
23
  from .integrations.anthropic.skills import AnthropicChatSkill
24
24
  from .integrations.aws.skills import AWSBedrockSkill
25
- from .integrations.google.skills import VertexAISkill
25
+ from .integrations.google.skills import GoogleChatSkill
26
26
  from .integrations.groq.skills import GroqChatSkill
27
27
  from .integrations.together.skills import TogetherAIChatSkill
28
28
  from .integrations.ollama.skills import OllamaChatSkill
@@ -51,7 +51,7 @@ __all__ = [
51
51
  "OpenAIParserSkill",
52
52
  "AnthropicChatSkill",
53
53
  "AWSBedrockSkill",
54
- "VertexAISkill",
54
+ "GoogleChatSkill",
55
55
  "GroqChatSkill",
56
56
  "TogetherAIChatSkill",
57
57
  "OllamaChatSkill",
@@ -15,7 +15,7 @@ from .cerebras.credentials import CerebrasCredentials
15
15
  from .openai.skills import OpenAIChatSkill, OpenAIParserSkill
16
16
  from .anthropic.skills import AnthropicChatSkill
17
17
  from .aws.skills import AWSBedrockSkill
18
- from .google.skills import VertexAISkill
18
+ from .google.skills import GoogleChatSkill
19
19
  from .groq.skills import GroqChatSkill
20
20
  from .together.skills import TogetherAIChatSkill
21
21
  from .ollama.skills import OllamaChatSkill
@@ -0,0 +1,127 @@
1
+ from typing import List, Optional, Dict, Any
2
+ from pydantic import Field
3
+ from anthropic import Anthropic
4
+ import base64
5
+ from pathlib import Path
6
+ from loguru import logger
7
+
8
+ from airtrain.core.skills import Skill, ProcessingError
9
+ from airtrain.core.schemas import InputSchema, OutputSchema
10
+ from .credentials import AnthropicCredentials
11
+
12
+
13
+ class AnthropicInput(InputSchema):
14
+ """Schema for Anthropic chat input"""
15
+
16
+ user_input: str = Field(..., description="User's input text")
17
+ system_prompt: str = Field(
18
+ default="You are a helpful assistant.",
19
+ description="System prompt to guide the model's behavior",
20
+ )
21
+ conversation_history: List[Dict[str, str]] = Field(
22
+ default_factory=list,
23
+ description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
24
+ )
25
+ model: str = Field(
26
+ default="claude-3-opus-20240229", description="Anthropic model to use"
27
+ )
28
+ max_tokens: Optional[int] = Field(
29
+ default=1024, description="Maximum tokens in response"
30
+ )
31
+ temperature: float = Field(
32
+ default=0.7, description="Temperature for response generation", ge=0, le=1
33
+ )
34
+ images: List[Path] = Field(
35
+ default_factory=list,
36
+ description="List of image paths to include in the message",
37
+ )
38
+
39
+
40
+ class AnthropicOutput(OutputSchema):
41
+ """Schema for Anthropic chat output"""
42
+
43
+ response: str = Field(..., description="Model's response text")
44
+ used_model: str = Field(..., description="Model used for generation")
45
+ usage: Dict[str, Any] = Field(
46
+ default_factory=dict, description="Usage statistics from the API"
47
+ )
48
+
49
+
50
+ class AnthropicChatSkill(Skill[AnthropicInput, AnthropicOutput]):
51
+ """Skill for Anthropic chat"""
52
+
53
+ input_schema = AnthropicInput
54
+ output_schema = AnthropicOutput
55
+
56
+ def __init__(self, credentials: Optional[AnthropicCredentials] = None):
57
+ super().__init__()
58
+ self.credentials = credentials or AnthropicCredentials.from_env()
59
+ self.client = Anthropic(
60
+ api_key=self.credentials.anthropic_api_key.get_secret_value()
61
+ )
62
+
63
+ def _build_messages(self, input_data: AnthropicInput) -> List[Dict[str, Any]]:
64
+ """
65
+ Build messages list from input data including conversation history.
66
+
67
+ Args:
68
+ input_data: The input data containing system prompt, conversation history, and user input
69
+
70
+ Returns:
71
+ List[Dict[str, Any]]: List of messages in the format required by Anthropic
72
+ """
73
+ messages = []
74
+
75
+ # Add conversation history if present
76
+ if input_data.conversation_history:
77
+ messages.extend(input_data.conversation_history)
78
+
79
+ # Prepare user message content
80
+ user_message = {"type": "text", "text": input_data.user_input}
81
+
82
+ # Add images if present
83
+ if input_data.images:
84
+ content = []
85
+ for image_path in input_data.images:
86
+ with open(image_path, "rb") as img_file:
87
+ base64_image = base64.b64encode(img_file.read()).decode("utf-8")
88
+ content.append(
89
+ {
90
+ "type": "image",
91
+ "source": {
92
+ "type": "base64",
93
+ "media_type": "image/jpeg",
94
+ "data": base64_image,
95
+ },
96
+ }
97
+ )
98
+ content.append(user_message)
99
+ messages.append({"role": "user", "content": content})
100
+ else:
101
+ messages.append({"role": "user", "content": [user_message]})
102
+
103
+ return messages
104
+
105
+ def process(self, input_data: AnthropicInput) -> AnthropicOutput:
106
+ try:
107
+ # Build messages using the helper method
108
+ messages = self._build_messages(input_data)
109
+
110
+ # Create chat completion with system prompt as a separate parameter
111
+ response = self.client.messages.create(
112
+ model=input_data.model,
113
+ system=input_data.system_prompt, # System prompt passed directly
114
+ messages=messages,
115
+ max_tokens=input_data.max_tokens,
116
+ temperature=input_data.temperature,
117
+ )
118
+
119
+ return AnthropicOutput(
120
+ response=response.content[0].text,
121
+ used_model=input_data.model,
122
+ usage=response.usage.model_dump(),
123
+ )
124
+
125
+ except Exception as e:
126
+ logger.exception(f"Anthropic processing failed: {str(e)}")
127
+ raise ProcessingError(f"Anthropic processing failed: {str(e)}")
@@ -1,16 +1,13 @@
1
- from pydantic import Field, SecretStr, HttpUrl
1
+ from pydantic import Field, SecretStr
2
2
  from airtrain.core.credentials import BaseCredentials, CredentialValidationError
3
- from typing import Optional
4
3
 
5
4
 
6
5
  class CerebrasCredentials(BaseCredentials):
7
6
  """Cerebras credentials"""
8
7
 
9
- api_key: SecretStr = Field(..., description="Cerebras API key")
10
- endpoint_url: HttpUrl = Field(..., description="Cerebras API endpoint")
11
- project_id: Optional[str] = Field(None, description="Cerebras Project ID")
8
+ cerebras_api_key: SecretStr = Field(..., description="Cerebras API key")
12
9
 
13
- _required_credentials = {"api_key", "endpoint_url"}
10
+ _required_credentials = {"cerebras_api_key"}
14
11
 
15
12
  async def validate_credentials(self) -> bool:
16
13
  """Validate Cerebras credentials"""
@@ -0,0 +1,95 @@
1
+ from typing import List, Optional, Dict, Any
2
+ from pydantic import Field
3
+ from cerebras.cloud.sdk import Cerebras
4
+ from loguru import logger
5
+
6
+ from airtrain.core.skills import Skill, ProcessingError
7
+ from airtrain.core.schemas import InputSchema, OutputSchema
8
+ from .credentials import CerebrasCredentials
9
+
10
+
11
+ class CerebrasInput(InputSchema):
12
+ """Schema for Cerebras chat input"""
13
+
14
+ user_input: str = Field(..., description="User's input text")
15
+ system_prompt: str = Field(
16
+ default="You are a helpful assistant.",
17
+ description="System prompt to guide the model's behavior",
18
+ )
19
+ conversation_history: List[Dict[str, str]] = Field(
20
+ default_factory=list,
21
+ description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
22
+ )
23
+ model: str = Field(default="llama3.1-8b", description="Cerebras model to use")
24
+ max_tokens: Optional[int] = Field(
25
+ default=1024, description="Maximum tokens in response"
26
+ )
27
+ temperature: float = Field(
28
+ default=0.7, description="Temperature for response generation", ge=0, le=1
29
+ )
30
+
31
+
32
+ class CerebrasOutput(OutputSchema):
33
+ """Schema for Cerebras chat output"""
34
+
35
+ response: str = Field(..., description="Model's response text")
36
+ used_model: str = Field(..., description="Model used for generation")
37
+ usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
38
+
39
+
40
+ class CerebrasChatSkill(Skill[CerebrasInput, CerebrasOutput]):
41
+ """Skill for Cerebras chat"""
42
+
43
+ input_schema = CerebrasInput
44
+ output_schema = CerebrasOutput
45
+
46
+ def __init__(self, credentials: Optional[CerebrasCredentials] = None):
47
+ super().__init__()
48
+ self.credentials = credentials or CerebrasCredentials.from_env()
49
+ self.client = Cerebras(
50
+ api_key=self.credentials.cerebras_api_key.get_secret_value()
51
+ )
52
+
53
+ def _build_messages(self, input_data: CerebrasInput) -> List[Dict[str, str]]:
54
+ """
55
+ Build messages list from input data including conversation history.
56
+
57
+ Args:
58
+ input_data: The input data containing system prompt, conversation history, and user input
59
+
60
+ Returns:
61
+ List[Dict[str, str]]: List of messages in the format required by Cerebras
62
+ """
63
+ messages = [{"role": "system", "content": input_data.system_prompt}]
64
+
65
+ # Add conversation history if present
66
+ if input_data.conversation_history:
67
+ messages.extend(input_data.conversation_history)
68
+
69
+ # Add current user input
70
+ messages.append({"role": "user", "content": input_data.user_input})
71
+
72
+ return messages
73
+
74
+ def process(self, input_data: CerebrasInput) -> CerebrasOutput:
75
+ try:
76
+ # Build messages using the helper method
77
+ messages = self._build_messages(input_data)
78
+
79
+ # Create chat completion
80
+ response = self.client.chat.completions.create(
81
+ model=input_data.model,
82
+ messages=messages,
83
+ temperature=input_data.temperature,
84
+ max_tokens=input_data.max_tokens,
85
+ )
86
+
87
+ return CerebrasOutput(
88
+ response=response.choices[0].message.content,
89
+ used_model=input_data.model,
90
+ usage=response.usage.model_dump(),
91
+ )
92
+
93
+ except Exception as e:
94
+ logger.exception(f"Cerebras processing failed: {str(e)}")
95
+ raise ProcessingError(f"Cerebras processing failed: {str(e)}")
@@ -0,0 +1,109 @@
1
+ from typing import List, Dict, Optional
2
+ from pydantic import BaseModel, Field
3
+ from .skills import FireworksChatSkill, FireworksInput, FireworksOutput
4
+
5
+ # TODO: Test this thing.
6
+
7
+
8
+ class ConversationState(BaseModel):
9
+ """Model to track conversation state"""
10
+
11
+ messages: List[Dict[str, str]] = Field(
12
+ default_factory=list, description="List of conversation messages"
13
+ )
14
+ system_prompt: str = Field(
15
+ default="You are a helpful assistant.",
16
+ description="System prompt for the conversation",
17
+ )
18
+ model: str = Field(
19
+ default="accounts/fireworks/models/deepseek-r1",
20
+ description="Model being used for the conversation",
21
+ )
22
+ temperature: float = Field(default=0.7, description="Temperature setting")
23
+ max_tokens: Optional[int] = Field(default=None, description="Max tokens setting")
24
+
25
+
26
+ class FireworksConversationManager:
27
+ """Manager for handling conversation state with Fireworks AI"""
28
+
29
+ def __init__(
30
+ self,
31
+ skill: Optional[FireworksChatSkill] = None,
32
+ system_prompt: str = "You are a helpful assistant.",
33
+ model: str = "accounts/fireworks/models/deepseek-r1",
34
+ temperature: float = 0.7,
35
+ max_tokens: Optional[int] = None,
36
+ ):
37
+ """
38
+ Initialize conversation manager.
39
+
40
+ Args:
41
+ skill: FireworksChatSkill instance (creates new one if None)
42
+ system_prompt: Initial system prompt
43
+ model: Model to use
44
+ temperature: Temperature setting
45
+ max_tokens: Max tokens setting
46
+ """
47
+ self.skill = skill or FireworksChatSkill()
48
+ self.state = ConversationState(
49
+ system_prompt=system_prompt,
50
+ model=model,
51
+ temperature=temperature,
52
+ max_tokens=max_tokens,
53
+ )
54
+
55
+ def send_message(self, user_input: str) -> FireworksOutput:
56
+ """
57
+ Send a message and get response while maintaining conversation history.
58
+
59
+ Args:
60
+ user_input: User's message
61
+
62
+ Returns:
63
+ FireworksOutput: Model's response
64
+ """
65
+ # Create input with current conversation state
66
+ input_data = FireworksInput(
67
+ user_input=user_input,
68
+ system_prompt=self.state.system_prompt,
69
+ conversation_history=self.state.messages,
70
+ model=self.state.model,
71
+ temperature=self.state.temperature,
72
+ max_tokens=self.state.max_tokens,
73
+ )
74
+
75
+ # Get response
76
+ result = self.skill.process(input_data)
77
+
78
+ # Update conversation history
79
+ self.state.messages.extend(
80
+ [
81
+ {"role": "user", "content": user_input},
82
+ {"role": "assistant", "content": result.response},
83
+ ]
84
+ )
85
+
86
+ return result
87
+
88
+ def reset_conversation(self) -> None:
89
+ """Reset the conversation history while maintaining other settings"""
90
+ self.state.messages = []
91
+
92
+ def get_conversation_history(self) -> List[Dict[str, str]]:
93
+ """Get the current conversation history"""
94
+ return self.state.messages.copy()
95
+
96
+ def update_system_prompt(self, new_prompt: str) -> None:
97
+ """Update the system prompt for future messages"""
98
+ self.state.system_prompt = new_prompt
99
+
100
+ def save_state(self, file_path: str) -> None:
101
+ """Save conversation state to a file"""
102
+ with open(file_path, "w") as f:
103
+ f.write(self.state.model_dump_json(indent=2))
104
+
105
+ def load_state(self, file_path: str) -> None:
106
+ """Load conversation state from a file"""
107
+ with open(file_path, "r") as f:
108
+ data = f.read()
109
+ self.state = ConversationState.model_validate_json(data)
@@ -17,6 +17,10 @@ class FireworksInput(InputSchema):
17
17
  default="You are a helpful assistant.",
18
18
  description="System prompt to guide the model's behavior",
19
19
  )
20
+ conversation_history: List[Dict[str, str]] = Field(
21
+ default_factory=list,
22
+ description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
23
+ )
20
24
  model: str = Field(
21
25
  default="accounts/fireworks/models/deepseek-r1",
22
26
  description="Fireworks AI model to use",
@@ -52,16 +56,34 @@ class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
52
56
  self.credentials = credentials or FireworksCredentials.from_env()
53
57
  self.base_url = "https://api.fireworks.ai/inference/v1"
54
58
 
59
+ def _build_messages(self, input_data: FireworksInput) -> List[Dict[str, str]]:
60
+ """
61
+ Build messages list from input data including conversation history.
62
+
63
+ Args:
64
+ input_data: The input data containing system prompt, conversation history, and user input
65
+
66
+ Returns:
67
+ List[Dict[str, str]]: List of messages in the format required by Fireworks AI
68
+ """
69
+ messages = [{"role": "system", "content": input_data.system_prompt}]
70
+
71
+ # Add conversation history if present
72
+ if input_data.conversation_history:
73
+ messages.extend(input_data.conversation_history)
74
+
75
+ # Add current user input
76
+ messages.append({"role": "user", "content": input_data.user_input})
77
+
78
+ return messages
79
+
55
80
  def process(self, input_data: FireworksInput) -> FireworksOutput:
56
81
  """Process the input using Fireworks AI API"""
57
82
  try:
58
83
  logger.info(f"Processing request with model {input_data.model}")
59
84
 
60
- # Prepare messages
61
- messages = [
62
- {"role": "system", "content": input_data.system_prompt},
63
- {"role": "user", "content": input_data.user_input},
64
- ]
85
+ # Build messages using the helper method
86
+ messages = self._build_messages(input_data)
65
87
 
66
88
  # Prepare request payload
67
89
  payload = {
@@ -1,6 +1,7 @@
1
1
  """Google Cloud integration module"""
2
2
 
3
3
  from .credentials import GoogleCloudCredentials
4
- from .skills import VertexAISkill
4
+ from .skills import GoogleChatSkill
5
+ # from .skills import VertexAISkill
5
6
 
6
7
  __all__ = ["GoogleCloudCredentials", "VertexAISkill"]
@@ -1,5 +1,8 @@
1
1
  from pydantic import Field, SecretStr
2
2
  from airtrain.core.credentials import BaseCredentials, CredentialValidationError
3
+ import google.genai as genai
4
+ from google.cloud import storage
5
+ import os
3
6
 
4
7
  # from google.cloud import storage
5
8
 
@@ -26,3 +29,30 @@ class GoogleCloudCredentials(BaseCredentials):
26
29
  raise CredentialValidationError(
27
30
  f"Invalid Google Cloud credentials: {str(e)}"
28
31
  )
32
+
33
+
34
+ class GeminiCredentials(BaseCredentials):
35
+ """Gemini API credentials"""
36
+
37
+ gemini_api_key: SecretStr = Field(..., description="Gemini API Key")
38
+
39
+ _required_credentials = {"gemini_api_key"}
40
+
41
+ @classmethod
42
+ def from_env(cls) -> "GeminiCredentials":
43
+ """Create credentials from environment variables"""
44
+ return cls(gemini_api_key=SecretStr(os.environ.get("GEMINI_API_KEY", "")))
45
+
46
+ async def validate_credentials(self) -> bool:
47
+ """Validate Gemini API credentials"""
48
+ try:
49
+ # Configure Gemini with API key
50
+ genai.configure(api_key=self.gemini_api_key.get_secret_value())
51
+
52
+ # Test API call with a simple model
53
+ model = genai.GenerativeModel("gemini-1.5-flash")
54
+ response = model.generate_content("test")
55
+
56
+ return True
57
+ except Exception as e:
58
+ raise CredentialValidationError(f"Invalid Gemini credentials: {str(e)}")
@@ -0,0 +1,83 @@
1
+ import sys
2
+ import os
3
+ from pathlib import Path
4
+ from dotenv import load_dotenv
5
+ from typing import List, Dict
6
+
7
+ load_dotenv()
8
+
9
+ parent_dir = os.path.abspath(
10
+ os.path.join(os.path.abspath(__file__), "..", "..", "..", "..", "..")
11
+ )
12
+ sys.path.append(parent_dir)
13
+
14
+ from airtrain.integrations.google.gemini.skills import (
15
+ Gemini2ChatSkill,
16
+ Gemini2Input,
17
+ Gemini2GenerationConfig,
18
+ )
19
+
20
+
21
+ def run_conversation(
22
+ skill: Gemini2ChatSkill,
23
+ user_input: str,
24
+ system_prompt: str,
25
+ conversation_history: List[Dict[str, str]],
26
+ ) -> Dict[str, str]:
27
+ """Run a single conversation turn and return the assistant's response"""
28
+ generation_config = Gemini2GenerationConfig(
29
+ temperature=1.0,
30
+ top_p=0.95,
31
+ top_k=40,
32
+ max_output_tokens=8192,
33
+ )
34
+
35
+ input_data = Gemini2Input(
36
+ user_input=user_input,
37
+ system_prompt=system_prompt,
38
+ conversation_history=skill._convert_history_format(conversation_history),
39
+ model="gemini-2.0-flash",
40
+ generation_config=generation_config,
41
+ )
42
+
43
+ result = skill.process(input_data)
44
+ return {"role": "assistant", "content": result.response}
45
+
46
+
47
+ def main():
48
+ skill = Gemini2ChatSkill()
49
+ system_prompt = (
50
+ "You are a helpful AI assistant with expertise in cybersecurity and privacy."
51
+ )
52
+ conversation_history = []
53
+
54
+ conversation_turns = [
55
+ "What are the best practices for password security?",
56
+ "How can I protect my personal data online?",
57
+ "What is two-factor authentication?",
58
+ "Can you explain what encryption is?",
59
+ "Can you summarize the key points about cybersecurity we discussed?",
60
+ ]
61
+
62
+ print("\n=== Starting Conversation Test ===\n")
63
+
64
+ for turn_number, user_input in enumerate(conversation_turns, 1):
65
+ print(f"\n--- Turn {turn_number} ---")
66
+ print(f"User: {user_input}\n")
67
+
68
+ assistant_response = run_conversation(
69
+ skill, user_input, system_prompt, conversation_history
70
+ )
71
+
72
+ conversation_history.extend(
73
+ [{"role": "user", "content": user_input}, assistant_response]
74
+ )
75
+
76
+ print(f"Assistant: {assistant_response['content']}\n")
77
+ print(f"Current conversation history length: {len(conversation_history)}")
78
+
79
+ print("\n=== Conversation Test Complete ===")
80
+
81
+
82
+ if __name__ == "__main__":
83
+ main()
@@ -0,0 +1,27 @@
1
+ from pydantic import Field, SecretStr
2
+ from airtrain.core.credentials import BaseCredentials, CredentialValidationError
3
+ import google.generativeai as genai
4
+ import os
5
+
6
+
7
+ class Gemini2Credentials(BaseCredentials):
8
+ """Gemini 2.0 API credentials"""
9
+
10
+ gemini_api_key: SecretStr = Field(..., description="Gemini API Key")
11
+
12
+ _required_credentials = {"gemini_api_key"}
13
+
14
+ @classmethod
15
+ def from_env(cls) -> "Gemini2Credentials":
16
+ """Create credentials from environment variables"""
17
+ return cls(gemini_api_key=SecretStr(os.environ.get("GEMINI_API_KEY", "")))
18
+
19
+ async def validate_credentials(self) -> bool:
20
+ """Validate Gemini API credentials"""
21
+ try:
22
+ genai.configure(api_key=self.gemini_api_key.get_secret_value())
23
+ model = genai.GenerativeModel("gemini-2.0-flash")
24
+ response = model.generate_content("test")
25
+ return True
26
+ except Exception as e:
27
+ raise CredentialValidationError(f"Invalid Gemini 2.0 credentials: {str(e)}")