airtrain 0.1.30__tar.gz → 0.1.31__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. {airtrain-0.1.30/airtrain.egg-info → airtrain-0.1.31}/PKG-INFO +24 -39
  2. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/__init__.py +1 -1
  3. airtrain-0.1.31/airtrain/cli/__init__.py +0 -0
  4. airtrain-0.1.31/airtrain/cli/main.py +53 -0
  5. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/anthropic/skills.py +36 -8
  6. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/cerebras/skills.py +38 -6
  7. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/fireworks/requests_skills.py +56 -1
  8. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/groq/skills.py +47 -10
  9. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/openai/skills.py +48 -6
  10. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/sambanova/skills.py +38 -6
  11. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/together/skills.py +39 -8
  12. {airtrain-0.1.30 → airtrain-0.1.31/airtrain.egg-info}/PKG-INFO +24 -39
  13. airtrain-0.1.31/airtrain.egg-info/SOURCES.txt +73 -0
  14. airtrain-0.1.31/airtrain.egg-info/entry_points.txt +2 -0
  15. airtrain-0.1.31/airtrain.egg-info/requires.txt +32 -0
  16. {airtrain-0.1.30 → airtrain-0.1.31}/changelog.md +6 -2
  17. airtrain-0.1.31/pyproject.toml +87 -0
  18. {airtrain-0.1.30 → airtrain-0.1.31}/requirements.txt +2 -1
  19. {airtrain-0.1.30 → airtrain-0.1.31}/setup.py +28 -1
  20. airtrain-0.1.30/.flake8 +0 -14
  21. airtrain-0.1.30/.github/workflows/publish.yml +0 -26
  22. airtrain-0.1.30/.gitignore +0 -183
  23. airtrain-0.1.30/.mypy.ini +0 -7
  24. airtrain-0.1.30/.pre-commit-config.yaml +0 -29
  25. airtrain-0.1.30/.vscode/extensions.json +0 -7
  26. airtrain-0.1.30/.vscode/launch.json +0 -27
  27. airtrain-0.1.30/.vscode/settings.json +0 -25
  28. airtrain-0.1.30/EXPERIMENTS/integrations_examples/anthropic_with_image.py +0 -43
  29. airtrain-0.1.30/EXPERIMENTS/schema_exps/pydantic_schemas.py +0 -37
  30. airtrain-0.1.30/airtrain/changelog.md +0 -16
  31. airtrain-0.1.30/airtrain/contrib/travel/agentlib/verification_agent.py +0 -96
  32. airtrain-0.1.30/airtrain/contrib/travel/agents.py +0 -243
  33. airtrain-0.1.30/airtrain/contrib/travel/modellib/verification.py +0 -32
  34. airtrain-0.1.30/airtrain/contrib/travel/models.py +0 -59
  35. airtrain-0.1.30/airtrain/core/__pycache__/credentials.cpython-310.pyc +0 -0
  36. airtrain-0.1.30/airtrain/core/__pycache__/schemas.cpython-310.pyc +0 -0
  37. airtrain-0.1.30/airtrain/core/__pycache__/skills.cpython-310.pyc +0 -0
  38. airtrain-0.1.30/airtrain/integrations/combined/groq_fireworks_skills.py +0 -126
  39. airtrain-0.1.30/airtrain/integrations/google/gemini/conversation_history_test.py +0 -83
  40. airtrain-0.1.30/airtrain/integrations/google/gemini/credentials.py +0 -27
  41. airtrain-0.1.30/airtrain/integrations/google/gemini/skills.py +0 -116
  42. airtrain-0.1.30/airtrain.egg-info/SOURCES.txt +0 -147
  43. airtrain-0.1.30/airtrain.egg-info/requires.txt +0 -11
  44. airtrain-0.1.30/examples/creating-skills/anthropic_skills_usage.py +0 -56
  45. airtrain-0.1.30/examples/creating-skills/chinese_anthropic_assistant.py +0 -56
  46. airtrain-0.1.30/examples/creating-skills/chinese_anthropic_usage.py +0 -60
  47. airtrain-0.1.30/examples/creating-skills/chinese_assistant_usage.py +0 -45
  48. airtrain-0.1.30/examples/creating-skills/fireworks_skills_usage.py +0 -69
  49. airtrain-0.1.30/examples/creating-skills/icon128.png +0 -0
  50. airtrain-0.1.30/examples/creating-skills/icon16.png +0 -0
  51. airtrain-0.1.30/examples/creating-skills/image1.jpg +0 -0
  52. airtrain-0.1.30/examples/creating-skills/image2.jpg +0 -0
  53. airtrain-0.1.30/examples/creating-skills/openai_skills.py +0 -192
  54. airtrain-0.1.30/examples/creating-skills/openai_skills_usage.py +0 -175
  55. airtrain-0.1.30/examples/creating-skills/openai_structured_skills.py +0 -206
  56. airtrain-0.1.30/examples/creating-skills/together_rerank_skills.py +0 -58
  57. airtrain-0.1.30/examples/creating-skills/together_rerank_skills_async.py +0 -1
  58. airtrain-0.1.30/examples/credentials_usage.py +0 -47
  59. airtrain-0.1.30/examples/images/quantum-circuit.png +0 -0
  60. airtrain-0.1.30/examples/integrations/anthropic/chat_example.py +0 -42
  61. airtrain-0.1.30/examples/integrations/anthropic/chinese_example.py +0 -62
  62. airtrain-0.1.30/examples/integrations/anthropic/conversation_history_test.py +0 -86
  63. airtrain-0.1.30/examples/integrations/anthropic/vision_example.py +0 -47
  64. airtrain-0.1.30/examples/integrations/cerebras/conversation_history_test.py +0 -84
  65. airtrain-0.1.30/examples/integrations/combined/groq_fireworks_example.py +0 -113
  66. airtrain-0.1.30/examples/integrations/fireworks/chat_example.py +0 -42
  67. airtrain-0.1.30/examples/integrations/fireworks/completion_example.py +0 -81
  68. airtrain-0.1.30/examples/integrations/fireworks/conversation_history_test.py +0 -86
  69. airtrain-0.1.30/examples/integrations/fireworks/parser_example.py +0 -62
  70. airtrain-0.1.30/examples/integrations/fireworks/requests_example.py +0 -83
  71. airtrain-0.1.30/examples/integrations/fireworks/streaming_chat_example.py +0 -65
  72. airtrain-0.1.30/examples/integrations/fireworks/structured_chat_example.py +0 -43
  73. airtrain-0.1.30/examples/integrations/fireworks/structured_completion_example.py +0 -115
  74. airtrain-0.1.30/examples/integrations/fireworks/structured_conversation_example.py +0 -112
  75. airtrain-0.1.30/examples/integrations/fireworks/structured_requests_example.py +0 -106
  76. airtrain-0.1.30/examples/integrations/google/conversation_history_test.py +0 -94
  77. airtrain-0.1.30/examples/integrations/google/gemini/conversation_history_test.py +0 -83
  78. airtrain-0.1.30/examples/integrations/groq/conversation_history_test.py +0 -84
  79. airtrain-0.1.30/examples/integrations/openai/chat_example.py +0 -42
  80. airtrain-0.1.30/examples/integrations/openai/parser_example.py +0 -62
  81. airtrain-0.1.30/examples/integrations/openai/streaming_chat_example.py +0 -65
  82. airtrain-0.1.30/examples/integrations/openai/vision_example.py +0 -46
  83. airtrain-0.1.30/examples/integrations/sambanova/conversation_history_test.py +0 -85
  84. airtrain-0.1.30/examples/integrations/together/chat_example.py +0 -42
  85. airtrain-0.1.30/examples/integrations/together/conversation_history_test.py +0 -86
  86. airtrain-0.1.30/examples/integrations/together/image_generation_example.py +0 -58
  87. airtrain-0.1.30/examples/integrations/together/rerank_example.py +0 -59
  88. airtrain-0.1.30/examples/schema_usage.py +0 -77
  89. airtrain-0.1.30/examples/skill_usage.py +0 -83
  90. airtrain-0.1.30/examples/together/image_generation.py +0 -64
  91. airtrain-0.1.30/examples/together/image_generation_example.py +0 -81
  92. airtrain-0.1.30/examples/travel/verification_agent_usage.py +0 -104
  93. airtrain-0.1.30/pyproject.toml +0 -11
  94. airtrain-0.1.30/scripts/build.sh +0 -10
  95. airtrain-0.1.30/scripts/bump_version.py +0 -55
  96. airtrain-0.1.30/scripts/publish.sh +0 -10
  97. airtrain-0.1.30/scripts/release.py +0 -90
  98. airtrain-0.1.30/services/firebase_service.py +0 -181
  99. airtrain-0.1.30/services/openai_service.py +0 -366
  100. {airtrain-0.1.30 → airtrain-0.1.31}/MANIFEST.in +0 -0
  101. {airtrain-0.1.30 → airtrain-0.1.31}/README.md +0 -0
  102. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/contrib/__init__.py +0 -0
  103. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/contrib/travel/__init__.py +0 -0
  104. {airtrain-0.1.30/airtrain/agents → airtrain-0.1.31/airtrain/contrib}/travel/agents.py +0 -0
  105. {airtrain-0.1.30/airtrain/agents → airtrain-0.1.31/airtrain/contrib}/travel/models.py +0 -0
  106. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/core/__init__.py +0 -0
  107. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/core/credentials.py +0 -0
  108. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/core/schemas.py +0 -0
  109. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/core/skills.py +0 -0
  110. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/__init__.py +0 -0
  111. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/anthropic/__init__.py +0 -0
  112. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/anthropic/credentials.py +0 -0
  113. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/aws/__init__.py +0 -0
  114. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/aws/credentials.py +0 -0
  115. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/aws/skills.py +0 -0
  116. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/cerebras/__init__.py +0 -0
  117. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/cerebras/credentials.py +0 -0
  118. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/fireworks/__init__.py +0 -0
  119. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/fireworks/completion_skills.py +0 -0
  120. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/fireworks/conversation_manager.py +0 -0
  121. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/fireworks/credentials.py +0 -0
  122. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/fireworks/models.py +0 -0
  123. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/fireworks/skills.py +0 -0
  124. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/fireworks/structured_completion_skills.py +0 -0
  125. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/fireworks/structured_requests_skills.py +0 -0
  126. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/fireworks/structured_skills.py +0 -0
  127. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/google/__init__.py +0 -0
  128. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/google/credentials.py +0 -0
  129. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/google/skills.py +0 -0
  130. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/groq/__init__.py +0 -0
  131. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/groq/credentials.py +0 -0
  132. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/ollama/__init__.py +0 -0
  133. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/ollama/credentials.py +0 -0
  134. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/ollama/skills.py +0 -0
  135. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/openai/__init__.py +0 -0
  136. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/openai/chinese_assistant.py +0 -0
  137. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/openai/credentials.py +0 -0
  138. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/openai/models_config.py +0 -0
  139. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/sambanova/__init__.py +0 -0
  140. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/sambanova/credentials.py +0 -0
  141. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/together/__init__.py +0 -0
  142. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/together/audio_models_config.py +0 -0
  143. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/together/credentials.py +0 -0
  144. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/together/embedding_models_config.py +0 -0
  145. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/together/image_models_config.py +0 -0
  146. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/together/image_skill.py +0 -0
  147. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/together/models.py +0 -0
  148. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/together/models_config.py +0 -0
  149. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/together/rerank_models_config.py +0 -0
  150. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/together/rerank_skill.py +0 -0
  151. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/together/schemas.py +0 -0
  152. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain/integrations/together/vision_models_config.py +0 -0
  153. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain.egg-info/dependency_links.txt +0 -0
  154. {airtrain-0.1.30 → airtrain-0.1.31}/airtrain.egg-info/top_level.txt +0 -0
  155. {airtrain-0.1.30 → airtrain-0.1.31}/setup.cfg +0 -0
@@ -1,10 +1,12 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: airtrain
3
- Version: 0.1.30
3
+ Version: 0.1.31
4
4
  Summary: A platform for building and deploying AI agents with structured skills
5
5
  Home-page: https://github.com/rosaboyle/airtrain.dev
6
6
  Author: Dheeraj Pai
7
- Author-email: helloworldcmu@gmail.com
7
+ Author-email: Dheeraj Pai <helloworldcmu@gmail.com>
8
+ Project-URL: Homepage, https://github.com/rosaboyle/airtrain.dev
9
+ Project-URL: Documentation, https://docs.airtrain.dev/
8
10
  Classifier: Development Status :: 3 - Alpha
9
11
  Classifier: Intended Audience :: Developers
10
12
  Classifier: License :: OSI Approved :: MIT License
@@ -26,15 +28,29 @@ Requires-Dist: boto3>=1.36.6
26
28
  Requires-Dist: together>=1.3.13
27
29
  Requires-Dist: anthropic>=0.45.0
28
30
  Requires-Dist: groq>=0.15.0
31
+ Requires-Dist: cerebras-cloud-sdk>=1.19.0
32
+ Requires-Dist: google-genai>=1.0.0
33
+ Requires-Dist: fireworks-ai>=0.15.12
34
+ Requires-Dist: google-generativeai>=0.8.4
35
+ Requires-Dist: click>=8.0.0
36
+ Requires-Dist: rich>=13.3.1
37
+ Requires-Dist: prompt-toolkit>=3.0.36
38
+ Requires-Dist: colorama>=0.4.6
39
+ Requires-Dist: typer>=0.9.0
40
+ Provides-Extra: dev
41
+ Requires-Dist: black>=24.10.0; extra == "dev"
42
+ Requires-Dist: flake8>=7.1.1; extra == "dev"
43
+ Requires-Dist: isort>=5.13.0; extra == "dev"
44
+ Requires-Dist: mypy>=1.9.0; extra == "dev"
45
+ Requires-Dist: pytest>=7.0.0; extra == "dev"
46
+ Requires-Dist: twine>=4.0.0; extra == "dev"
47
+ Requires-Dist: build>=0.10.0; extra == "dev"
48
+ Requires-Dist: types-PyYAML>=6.0; extra == "dev"
49
+ Requires-Dist: types-requests>=2.31.0; extra == "dev"
50
+ Requires-Dist: types-Markdown>=3.5.0; extra == "dev"
29
51
  Dynamic: author
30
- Dynamic: author-email
31
- Dynamic: classifier
32
- Dynamic: description
33
- Dynamic: description-content-type
34
52
  Dynamic: home-page
35
- Dynamic: requires-dist
36
53
  Dynamic: requires-python
37
- Dynamic: summary
38
54
 
39
55
  # Airtrain
40
56
 
@@ -167,34 +183,3 @@ Contributions are welcome! Please feel free to submit a Pull Request.
167
183
  ## License
168
184
 
169
185
  This project is licensed under the MIT License - see the LICENSE file for details.
170
-
171
- ## Changelog
172
-
173
-
174
- ## 0.1.29
175
-
176
- - Fixed some issues with th strctured output response example.
177
-
178
- ## 0.1.28
179
-
180
- - Bug fix: reasoning to Fireworks structured output.
181
- - Added reasoning to Fireworks structured output.
182
-
183
- ## 0.1.27
184
-
185
- - Added structured completion skills for Fireworks AI
186
- - Added Completion skills for Fireworks AI.
187
- - Added Combination skill for Groq and Fireworks AI.
188
- - Add completion streaming.
189
- - Added strcutured output streaming for Fireworks AI.
190
-
191
- ## 0.1.23
192
-
193
- - Added conversation support for Deepseek, Togehter AI, Fireworks AI, Gemini, Groq, Cerebras and Sambanova.
194
- - Added Change Log
195
-
196
-
197
- ## Notes
198
-
199
- The changelog format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
200
- and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
@@ -1,6 +1,6 @@
1
1
  """Airtrain - A platform for building and deploying AI agents with structured skills"""
2
2
 
3
- __version__ = "0.1.30"
3
+ __version__ = "0.1.33"
4
4
 
5
5
  # Core imports
6
6
  from .core.skills import Skill, ProcessingError
File without changes
@@ -0,0 +1,53 @@
1
+ import click
2
+ from airtrain.integrations.openai.skills import OpenAIChatSkill, OpenAIInput
3
+ import os
4
+ from dotenv import load_dotenv
5
+
6
+ load_dotenv()
7
+
8
+
9
+ def initialize_chat():
10
+ return OpenAIChatSkill()
11
+
12
+
13
+ @click.group()
14
+ def cli():
15
+ """Airtrain CLI - Your AI Agent Building Assistant"""
16
+ pass
17
+
18
+
19
+ @cli.command()
20
+ def chat():
21
+ """Start an interactive chat session with Airtrain"""
22
+ skill = initialize_chat()
23
+ click.echo("Welcome to Airtrain! I'm here to help you build your AI Agent.")
24
+ click.echo("Type 'exit' to end the conversation.\n")
25
+
26
+ while True:
27
+ user_input = click.prompt("You", type=str)
28
+
29
+ if user_input.lower() == "exit":
30
+ click.echo("\nGoodbye! Have a great day!")
31
+ break
32
+
33
+ try:
34
+ input_data = OpenAIInput(
35
+ user_input=user_input,
36
+ system_prompt="You are an AI assistant that helps users build their own AI agents. Be helpful and provide clear explanations.",
37
+ model="gpt-4o",
38
+ temperature=0.7,
39
+ )
40
+
41
+ result = skill.process(input_data)
42
+ click.echo(f"\nAirtrain: {result.response}\n")
43
+
44
+ except Exception as e:
45
+ click.echo(f"\nError: {str(e)}\n")
46
+
47
+
48
+ def main():
49
+ cli()
50
+
51
+
52
+ if __name__ == "__main__":
53
+ main()
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Dict, Any
1
+ from typing import List, Optional, Dict, Any, Generator
2
2
  from pydantic import Field
3
3
  from anthropic import Anthropic
4
4
  import base64
@@ -35,6 +35,9 @@ class AnthropicInput(InputSchema):
35
35
  default_factory=list,
36
36
  description="List of image paths to include in the message",
37
37
  )
38
+ stream: bool = Field(
39
+ default=False, description="Whether to stream the response progressively"
40
+ )
38
41
 
39
42
 
40
43
  class AnthropicOutput(OutputSchema):
@@ -102,24 +105,49 @@ class AnthropicChatSkill(Skill[AnthropicInput, AnthropicOutput]):
102
105
 
103
106
  return messages
104
107
 
105
- def process(self, input_data: AnthropicInput) -> AnthropicOutput:
108
+ def process_stream(self, input_data: AnthropicInput) -> Generator[str, None, None]:
109
+ """Process the input and stream the response token by token."""
106
110
  try:
107
- # Build messages using the helper method
108
111
  messages = self._build_messages(input_data)
109
112
 
110
- # Create chat completion with system prompt as a separate parameter
111
- response = self.client.messages.create(
113
+ with self.client.beta.messages.stream(
112
114
  model=input_data.model,
113
- system=input_data.system_prompt, # System prompt passed directly
115
+ system=input_data.system_prompt,
114
116
  messages=messages,
115
117
  max_tokens=input_data.max_tokens,
116
118
  temperature=input_data.temperature,
117
- )
119
+ ) as stream:
120
+ for chunk in stream.text_stream:
121
+ yield chunk
122
+
123
+ except Exception as e:
124
+ logger.exception(f"Anthropic streaming failed: {str(e)}")
125
+ raise ProcessingError(f"Anthropic streaming failed: {str(e)}")
126
+
127
+ def process(self, input_data: AnthropicInput) -> AnthropicOutput:
128
+ """Process the input and return the complete response."""
129
+ try:
130
+ if input_data.stream:
131
+ response_chunks = []
132
+ for chunk in self.process_stream(input_data):
133
+ response_chunks.append(chunk)
134
+ response = "".join(response_chunks)
135
+ usage = {} # Usage stats not available in streaming
136
+ else:
137
+ messages = self._build_messages(input_data)
138
+ response = self.client.messages.create(
139
+ model=input_data.model,
140
+ system=input_data.system_prompt,
141
+ messages=messages,
142
+ max_tokens=input_data.max_tokens,
143
+ temperature=input_data.temperature,
144
+ )
145
+ usage = response.usage.model_dump() if response.usage else {}
118
146
 
119
147
  return AnthropicOutput(
120
148
  response=response.content[0].text,
121
149
  used_model=input_data.model,
122
- usage=response.usage.model_dump(),
150
+ usage=usage,
123
151
  )
124
152
 
125
153
  except Exception as e:
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Dict, Any
1
+ from typing import List, Optional, Dict, Any, Generator
2
2
  from pydantic import Field
3
3
  from cerebras.cloud.sdk import Cerebras
4
4
  from loguru import logger
@@ -27,6 +27,9 @@ class CerebrasInput(InputSchema):
27
27
  temperature: float = Field(
28
28
  default=0.7, description="Temperature for response generation", ge=0, le=1
29
29
  )
30
+ stream: bool = Field(
31
+ default=False, description="Whether to stream the response progressively"
32
+ )
30
33
 
31
34
 
32
35
  class CerebrasOutput(OutputSchema):
@@ -71,23 +74,52 @@ class CerebrasChatSkill(Skill[CerebrasInput, CerebrasOutput]):
71
74
 
72
75
  return messages
73
76
 
74
- def process(self, input_data: CerebrasInput) -> CerebrasOutput:
77
+ def process_stream(self, input_data: CerebrasInput) -> Generator[str, None, None]:
78
+ """Process the input and stream the response token by token."""
75
79
  try:
76
- # Build messages using the helper method
77
80
  messages = self._build_messages(input_data)
78
81
 
79
- # Create chat completion
80
- response = self.client.chat.completions.create(
82
+ stream = self.client.chat.completions.create(
81
83
  model=input_data.model,
82
84
  messages=messages,
83
85
  temperature=input_data.temperature,
84
86
  max_tokens=input_data.max_tokens,
87
+ stream=True,
85
88
  )
86
89
 
90
+ for chunk in stream:
91
+ if chunk.choices[0].delta.content is not None:
92
+ yield chunk.choices[0].delta.content
93
+
94
+ except Exception as e:
95
+ logger.exception(f"Cerebras streaming failed: {str(e)}")
96
+ raise ProcessingError(f"Cerebras streaming failed: {str(e)}")
97
+
98
+ def process(self, input_data: CerebrasInput) -> CerebrasOutput:
99
+ """Process the input and return the complete response."""
100
+ try:
101
+ if input_data.stream:
102
+ response_chunks = []
103
+ for chunk in self.process_stream(input_data):
104
+ response_chunks.append(chunk)
105
+ response = "".join(response_chunks)
106
+ usage = {} # Usage stats not available in streaming
107
+ else:
108
+ messages = self._build_messages(input_data)
109
+ response = self.client.chat.completions.create(
110
+ model=input_data.model,
111
+ messages=messages,
112
+ temperature=input_data.temperature,
113
+ max_tokens=input_data.max_tokens,
114
+ )
115
+ usage = (
116
+ response.usage.model_dump() if hasattr(response, "usage") else {}
117
+ )
118
+
87
119
  return CerebrasOutput(
88
120
  response=response.choices[0].message.content,
89
121
  used_model=input_data.model,
90
- usage=response.usage.model_dump(),
122
+ usage=usage,
91
123
  )
92
124
 
93
125
  except Exception as e:
@@ -1,8 +1,9 @@
1
- from typing import List, Optional, Dict, Any, Generator
1
+ from typing import List, Optional, Dict, Any, Generator, AsyncGenerator
2
2
  from pydantic import Field
3
3
  import requests
4
4
  import json
5
5
  from loguru import logger
6
+ import aiohttp
6
7
 
7
8
  from airtrain.core.skills import Skill, ProcessingError
8
9
  from airtrain.core.schemas import InputSchema, OutputSchema
@@ -69,6 +70,11 @@ class FireworksRequestSkill(Skill[FireworksRequestInput, FireworksRequestOutput]
69
70
  "Content-Type": "application/json",
70
71
  "Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
71
72
  }
73
+ self.stream_headers = {
74
+ "Accept": "text/event-stream",
75
+ "Content-Type": "application/json",
76
+ "Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
77
+ }
72
78
 
73
79
  def _build_messages(
74
80
  self, input_data: FireworksRequestInput
@@ -150,3 +156,52 @@ class FireworksRequestSkill(Skill[FireworksRequestInput, FireworksRequestOutput]
150
156
 
151
157
  except Exception as e:
152
158
  raise ProcessingError(f"Fireworks request failed: {str(e)}")
159
+
160
+ async def process_async(
161
+ self, input_data: FireworksRequestInput
162
+ ) -> FireworksRequestOutput:
163
+ """Async version of process method using aiohttp"""
164
+ try:
165
+ async with aiohttp.ClientSession() as session:
166
+ payload = self._build_payload(input_data)
167
+ async with session.post(
168
+ self.BASE_URL, headers=self.headers, json=payload
169
+ ) as response:
170
+ response.raise_for_status()
171
+ data = await response.json()
172
+
173
+ return FireworksRequestOutput(
174
+ response=data["choices"][0]["message"]["content"],
175
+ used_model=input_data.model,
176
+ usage=data.get("usage", {}),
177
+ )
178
+
179
+ except Exception as e:
180
+ raise ProcessingError(f"Async Fireworks request failed: {str(e)}")
181
+
182
+ async def process_stream_async(
183
+ self, input_data: FireworksRequestInput
184
+ ) -> AsyncGenerator[str, None]:
185
+ """Async version of stream processor using aiohttp"""
186
+ try:
187
+ async with aiohttp.ClientSession() as session:
188
+ payload = self._build_payload(input_data)
189
+ async with session.post(
190
+ self.BASE_URL, headers=self.stream_headers, json=payload
191
+ ) as response:
192
+ response.raise_for_status()
193
+
194
+ async for line in response.content:
195
+ if line.startswith(b"data: "):
196
+ chunk = json.loads(line[6:].strip())
197
+ if "choices" in chunk:
198
+ content = (
199
+ chunk["choices"][0]
200
+ .get("delta", {})
201
+ .get("content", "")
202
+ )
203
+ if content:
204
+ yield content
205
+
206
+ except Exception as e:
207
+ raise ProcessingError(f"Async Fireworks streaming failed: {str(e)}")
@@ -1,4 +1,4 @@
1
- from typing import Optional, Dict, Any, List
1
+ from typing import Generator, Optional, Dict, Any, List
2
2
  from pydantic import Field
3
3
  from airtrain.core.skills import Skill, ProcessingError
4
4
  from airtrain.core.schemas import InputSchema, OutputSchema
@@ -18,11 +18,16 @@ class GroqInput(InputSchema):
18
18
  default_factory=list,
19
19
  description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
20
20
  )
21
- model: str = Field(default="mixtral-8x7b", description="Groq model to use")
21
+ model: str = Field(
22
+ default="deepseek-r1-distill-llama-70b-specdec", description="Groq model to use"
23
+ )
22
24
  max_tokens: int = Field(default=1024, description="Maximum tokens in response")
23
25
  temperature: float = Field(
24
26
  default=0.7, description="Temperature for response generation", ge=0, le=1
25
27
  )
28
+ stream: bool = Field(
29
+ default=False, description="Whether to stream the response progressively"
30
+ )
26
31
 
27
32
 
28
33
  class GroqOutput(OutputSchema):
@@ -30,7 +35,9 @@ class GroqOutput(OutputSchema):
30
35
 
31
36
  response: str = Field(..., description="Model's response text")
32
37
  used_model: str = Field(..., description="Model used for generation")
33
- usage: Dict[str, Any] = Field(default_factory=dict, description="Usage statistics")
38
+ usage: Dict[str, Any] = Field(
39
+ default_factory=dict, description="Usage statistics from the API"
40
+ )
34
41
 
35
42
 
36
43
  class GroqChatSkill(Skill[GroqInput, GroqOutput]):
@@ -65,23 +72,53 @@ class GroqChatSkill(Skill[GroqInput, GroqOutput]):
65
72
 
66
73
  return messages
67
74
 
68
- def process(self, input_data: GroqInput) -> GroqOutput:
75
+ def process_stream(self, input_data: GroqInput) -> Generator[str, None, None]:
76
+ """Process the input and stream the response token by token."""
69
77
  try:
70
- # Build messages using the helper method
71
78
  messages = self._build_messages(input_data)
72
79
 
73
- # Create chat completion
74
- response = self.client.chat.completions.create(
80
+ stream = self.client.chat.completions.create(
75
81
  model=input_data.model,
76
82
  messages=messages,
77
83
  temperature=input_data.temperature,
78
84
  max_tokens=input_data.max_tokens,
85
+ stream=True,
79
86
  )
80
87
 
88
+ for chunk in stream:
89
+ if chunk.choices[0].delta.content is not None:
90
+ yield chunk.choices[0].delta.content
91
+
92
+ except Exception as e:
93
+ raise ProcessingError(f"Groq streaming failed: {str(e)}")
94
+
95
+ def process(self, input_data: GroqInput) -> GroqOutput:
96
+ """Process the input and return the complete response."""
97
+ try:
98
+ if input_data.stream:
99
+ response_chunks = []
100
+ for chunk in self.process_stream(input_data):
101
+ response_chunks.append(chunk)
102
+ response = "".join(response_chunks)
103
+ usage = {} # Usage stats not available in streaming
104
+ else:
105
+ messages = self._build_messages(input_data)
106
+ completion = self.client.chat.completions.create(
107
+ model=input_data.model,
108
+ messages=messages,
109
+ temperature=input_data.temperature,
110
+ max_tokens=input_data.max_tokens,
111
+ stream=False,
112
+ )
113
+ response = completion.choices[0].message.content
114
+ usage = {
115
+ "total_tokens": completion.usage.total_tokens,
116
+ "prompt_tokens": completion.usage.prompt_tokens,
117
+ "completion_tokens": completion.usage.completion_tokens,
118
+ }
119
+
81
120
  return GroqOutput(
82
- response=response.choices[0].message.content,
83
- used_model=input_data.model,
84
- usage=response.usage.model_dump(),
121
+ response=response, used_model=input_data.model, usage=usage
85
122
  )
86
123
 
87
124
  except Exception as e:
@@ -1,9 +1,6 @@
1
- from typing import List, Optional, Dict, Any, TypeVar, Type, Generator
1
+ from typing import AsyncGenerator, List, Optional, Dict, TypeVar, Type, Generator
2
2
  from pydantic import Field, BaseModel
3
- from openai import OpenAI
4
- import base64
5
- from pathlib import Path
6
- from loguru import logger
3
+ from openai import OpenAI, AsyncOpenAI
7
4
  from openai.types.chat import ChatCompletionChunk
8
5
 
9
6
  from airtrain.core.skills import Skill, ProcessingError
@@ -48,7 +45,7 @@ class OpenAIOutput(OutputSchema):
48
45
 
49
46
 
50
47
  class OpenAIChatSkill(Skill[OpenAIInput, OpenAIOutput]):
51
- """Skill for interacting with OpenAI models"""
48
+ """Skill for interacting with OpenAI models with async support"""
52
49
 
53
50
  input_schema = OpenAIInput
54
51
  output_schema = OpenAIOutput
@@ -61,6 +58,10 @@ class OpenAIChatSkill(Skill[OpenAIInput, OpenAIOutput]):
61
58
  api_key=self.credentials.openai_api_key.get_secret_value(),
62
59
  organization=self.credentials.openai_organization_id,
63
60
  )
61
+ self.async_client = AsyncOpenAI(
62
+ api_key=self.credentials.openai_api_key.get_secret_value(),
63
+ organization=self.credentials.openai_organization_id,
64
+ )
64
65
 
65
66
  def _build_messages(self, input_data: OpenAIInput) -> List[Dict[str, str]]:
66
67
  """Build messages list from input data including conversation history."""
@@ -126,6 +127,47 @@ class OpenAIChatSkill(Skill[OpenAIInput, OpenAIOutput]):
126
127
  except Exception as e:
127
128
  raise ProcessingError(f"OpenAI chat failed: {str(e)}")
128
129
 
130
+ async def process_async(self, input_data: OpenAIInput) -> OpenAIOutput:
131
+ """Async version of process method"""
132
+ try:
133
+ messages = self._build_messages(input_data)
134
+ completion = await self.async_client.chat.completions.create(
135
+ model=input_data.model,
136
+ messages=messages,
137
+ temperature=input_data.temperature,
138
+ max_tokens=input_data.max_tokens,
139
+ )
140
+ return OpenAIOutput(
141
+ response=completion.choices[0].message.content,
142
+ used_model=completion.model,
143
+ usage={
144
+ "total_tokens": completion.usage.total_tokens,
145
+ "prompt_tokens": completion.usage.prompt_tokens,
146
+ "completion_tokens": completion.usage.completion_tokens,
147
+ },
148
+ )
149
+ except Exception as e:
150
+ raise ProcessingError(f"OpenAI async chat failed: {str(e)}")
151
+
152
+ async def process_stream_async(
153
+ self, input_data: OpenAIInput
154
+ ) -> AsyncGenerator[str, None]:
155
+ """Async version of stream processor"""
156
+ try:
157
+ messages = self._build_messages(input_data)
158
+ stream = await self.async_client.chat.completions.create(
159
+ model=input_data.model,
160
+ messages=messages,
161
+ temperature=input_data.temperature,
162
+ max_tokens=input_data.max_tokens,
163
+ stream=True,
164
+ )
165
+ async for chunk in stream:
166
+ if chunk.choices[0].delta.content is not None:
167
+ yield chunk.choices[0].delta.content
168
+ except Exception as e:
169
+ raise ProcessingError(f"OpenAI async streaming failed: {str(e)}")
170
+
129
171
 
130
172
  ResponseT = TypeVar("ResponseT", bound=BaseModel)
131
173
 
@@ -1,4 +1,4 @@
1
- from typing import Optional, Dict, Any, List
1
+ from typing import Optional, Dict, Any, List, Generator
2
2
  from pydantic import Field
3
3
  from airtrain.core.skills import Skill, ProcessingError
4
4
  from airtrain.core.schemas import InputSchema, OutputSchema
@@ -28,6 +28,9 @@ class SambanovaInput(InputSchema):
28
28
  top_p: float = Field(
29
29
  default=0.1, description="Top p sampling parameter", ge=0, le=1
30
30
  )
31
+ stream: bool = Field(
32
+ default=False, description="Whether to stream the response progressively"
33
+ )
31
34
 
32
35
 
33
36
  class SambanovaOutput(OutputSchema):
@@ -73,24 +76,53 @@ class SambanovaChatSkill(Skill[SambanovaInput, SambanovaOutput]):
73
76
 
74
77
  return messages
75
78
 
76
- def process(self, input_data: SambanovaInput) -> SambanovaOutput:
79
+ def process_stream(self, input_data: SambanovaInput) -> Generator[str, None, None]:
80
+ """Process the input and stream the response token by token."""
77
81
  try:
78
- # Build messages using the helper method
79
82
  messages = self._build_messages(input_data)
80
83
 
81
- # Create chat completion
82
- response = self.client.chat.completions.create(
84
+ stream = self.client.chat.completions.create(
83
85
  model=input_data.model,
84
86
  messages=messages,
85
87
  temperature=input_data.temperature,
86
88
  max_tokens=input_data.max_tokens,
87
89
  top_p=input_data.top_p,
90
+ stream=True,
88
91
  )
89
92
 
93
+ for chunk in stream:
94
+ if chunk.choices[0].delta.content is not None:
95
+ yield chunk.choices[0].delta.content
96
+
97
+ except Exception as e:
98
+ raise ProcessingError(f"Sambanova streaming failed: {str(e)}")
99
+
100
+ def process(self, input_data: SambanovaInput) -> SambanovaOutput:
101
+ """Process the input and return the complete response."""
102
+ try:
103
+ if input_data.stream:
104
+ response_chunks = []
105
+ for chunk in self.process_stream(input_data):
106
+ response_chunks.append(chunk)
107
+ response = "".join(response_chunks)
108
+ usage = {} # Usage stats not available in streaming
109
+ else:
110
+ messages = self._build_messages(input_data)
111
+ response = self.client.chat.completions.create(
112
+ model=input_data.model,
113
+ messages=messages,
114
+ temperature=input_data.temperature,
115
+ max_tokens=input_data.max_tokens,
116
+ top_p=input_data.top_p,
117
+ )
118
+ usage = (
119
+ response.usage.model_dump() if hasattr(response, "usage") else {}
120
+ )
121
+
90
122
  return SambanovaOutput(
91
123
  response=response.choices[0].message.content,
92
124
  used_model=input_data.model,
93
- usage=response.usage.model_dump(),
125
+ usage=usage,
94
126
  )
95
127
 
96
128
  except Exception as e: