airtrain 0.1.25__tar.gz → 0.1.27__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (151) hide show
  1. {airtrain-0.1.25/airtrain.egg-info → airtrain-0.1.27}/PKG-INFO +10 -2
  2. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/__init__.py +1 -1
  3. airtrain-0.1.27/airtrain/integrations/combined/groq_fireworks_skills.py +126 -0
  4. airtrain-0.1.27/airtrain/integrations/fireworks/completion_skills.py +147 -0
  5. airtrain-0.1.27/airtrain/integrations/fireworks/requests_skills.py +152 -0
  6. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/fireworks/skills.py +53 -52
  7. airtrain-0.1.27/airtrain/integrations/fireworks/structured_completion_skills.py +169 -0
  8. airtrain-0.1.27/airtrain/integrations/fireworks/structured_requests_skills.py +189 -0
  9. airtrain-0.1.27/airtrain/integrations/fireworks/structured_skills.py +102 -0
  10. airtrain-0.1.27/airtrain/integrations/openai/skills.py +192 -0
  11. {airtrain-0.1.25 → airtrain-0.1.27/airtrain.egg-info}/PKG-INFO +10 -2
  12. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain.egg-info/SOURCES.txt +16 -0
  13. airtrain-0.1.27/changelog.md +24 -0
  14. airtrain-0.1.27/examples/creating-skills/openai_structured_skills.py +206 -0
  15. airtrain-0.1.27/examples/integrations/combined/groq_fireworks_example.py +113 -0
  16. airtrain-0.1.27/examples/integrations/fireworks/completion_example.py +81 -0
  17. airtrain-0.1.27/examples/integrations/fireworks/parser_example.py +62 -0
  18. airtrain-0.1.27/examples/integrations/fireworks/requests_example.py +83 -0
  19. airtrain-0.1.27/examples/integrations/fireworks/streaming_chat_example.py +65 -0
  20. airtrain-0.1.27/examples/integrations/fireworks/structured_completion_example.py +115 -0
  21. airtrain-0.1.27/examples/integrations/fireworks/structured_conversation_example.py +112 -0
  22. airtrain-0.1.27/examples/integrations/fireworks/structured_requests_example.py +96 -0
  23. airtrain-0.1.27/examples/integrations/openai/streaming_chat_example.py +65 -0
  24. {airtrain-0.1.25 → airtrain-0.1.27}/scripts/release.py +1 -3
  25. airtrain-0.1.25/airtrain/integrations/openai/skills.py +0 -208
  26. airtrain-0.1.25/examples/creating-skills/openai_structured_skills.py +0 -144
  27. {airtrain-0.1.25 → airtrain-0.1.27}/.flake8 +0 -0
  28. {airtrain-0.1.25 → airtrain-0.1.27}/.github/workflows/publish.yml +0 -0
  29. {airtrain-0.1.25 → airtrain-0.1.27}/.gitignore +0 -0
  30. {airtrain-0.1.25 → airtrain-0.1.27}/.mypy.ini +0 -0
  31. {airtrain-0.1.25 → airtrain-0.1.27}/.pre-commit-config.yaml +0 -0
  32. {airtrain-0.1.25 → airtrain-0.1.27}/.vscode/extensions.json +0 -0
  33. {airtrain-0.1.25 → airtrain-0.1.27}/.vscode/launch.json +0 -0
  34. {airtrain-0.1.25 → airtrain-0.1.27}/.vscode/settings.json +0 -0
  35. {airtrain-0.1.25 → airtrain-0.1.27}/EXPERIMENTS/integrations_examples/anthropic_with_image.py +0 -0
  36. {airtrain-0.1.25 → airtrain-0.1.27}/EXPERIMENTS/schema_exps/pydantic_schemas.py +0 -0
  37. {airtrain-0.1.25 → airtrain-0.1.27}/MANIFEST.in +0 -0
  38. {airtrain-0.1.25 → airtrain-0.1.27}/README.md +0 -0
  39. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/agents/travel/agents.py +0 -0
  40. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/agents/travel/models.py +0 -0
  41. {airtrain-0.1.25 → airtrain-0.1.27/airtrain}/changelog.md +0 -0
  42. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/contrib/__init__.py +0 -0
  43. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/contrib/travel/__init__.py +0 -0
  44. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/contrib/travel/agentlib/verification_agent.py +0 -0
  45. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/contrib/travel/agents.py +0 -0
  46. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/contrib/travel/modellib/verification.py +0 -0
  47. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/contrib/travel/models.py +0 -0
  48. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/core/__init__.py +0 -0
  49. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/core/__pycache__/credentials.cpython-310.pyc +0 -0
  50. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/core/__pycache__/schemas.cpython-310.pyc +0 -0
  51. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/core/__pycache__/skills.cpython-310.pyc +0 -0
  52. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/core/credentials.py +0 -0
  53. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/core/schemas.py +0 -0
  54. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/core/skills.py +0 -0
  55. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/__init__.py +0 -0
  56. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/anthropic/__init__.py +0 -0
  57. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/anthropic/credentials.py +0 -0
  58. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/anthropic/skills.py +0 -0
  59. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/aws/__init__.py +0 -0
  60. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/aws/credentials.py +0 -0
  61. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/aws/skills.py +0 -0
  62. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/cerebras/__init__.py +0 -0
  63. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/cerebras/credentials.py +0 -0
  64. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/cerebras/skills.py +0 -0
  65. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/fireworks/__init__.py +0 -0
  66. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/fireworks/conversation_manager.py +0 -0
  67. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/fireworks/credentials.py +0 -0
  68. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/fireworks/models.py +0 -0
  69. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/google/__init__.py +0 -0
  70. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/google/credentials.py +0 -0
  71. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/google/gemini/conversation_history_test.py +0 -0
  72. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/google/gemini/credentials.py +0 -0
  73. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/google/gemini/skills.py +0 -0
  74. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/google/skills.py +0 -0
  75. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/groq/__init__.py +0 -0
  76. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/groq/credentials.py +0 -0
  77. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/groq/skills.py +0 -0
  78. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/ollama/__init__.py +0 -0
  79. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/ollama/credentials.py +0 -0
  80. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/ollama/skills.py +0 -0
  81. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/openai/__init__.py +0 -0
  82. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/openai/chinese_assistant.py +0 -0
  83. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/openai/credentials.py +0 -0
  84. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/openai/models_config.py +0 -0
  85. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/sambanova/__init__.py +0 -0
  86. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/sambanova/credentials.py +0 -0
  87. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/sambanova/skills.py +0 -0
  88. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/together/__init__.py +0 -0
  89. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/together/audio_models_config.py +0 -0
  90. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/together/credentials.py +0 -0
  91. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/together/embedding_models_config.py +0 -0
  92. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/together/image_models_config.py +0 -0
  93. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/together/image_skill.py +0 -0
  94. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/together/models.py +0 -0
  95. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/together/models_config.py +0 -0
  96. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/together/rerank_models_config.py +0 -0
  97. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/together/rerank_skill.py +0 -0
  98. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/together/schemas.py +0 -0
  99. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/together/skills.py +0 -0
  100. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain/integrations/together/vision_models_config.py +0 -0
  101. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain.egg-info/dependency_links.txt +0 -0
  102. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain.egg-info/requires.txt +0 -0
  103. {airtrain-0.1.25 → airtrain-0.1.27}/airtrain.egg-info/top_level.txt +0 -0
  104. {airtrain-0.1.25 → airtrain-0.1.27}/examples/creating-skills/anthropic_skills_usage.py +0 -0
  105. {airtrain-0.1.25 → airtrain-0.1.27}/examples/creating-skills/chinese_anthropic_assistant.py +0 -0
  106. {airtrain-0.1.25 → airtrain-0.1.27}/examples/creating-skills/chinese_anthropic_usage.py +0 -0
  107. {airtrain-0.1.25 → airtrain-0.1.27}/examples/creating-skills/chinese_assistant_usage.py +0 -0
  108. {airtrain-0.1.25 → airtrain-0.1.27}/examples/creating-skills/fireworks_skills_usage.py +0 -0
  109. {airtrain-0.1.25 → airtrain-0.1.27}/examples/creating-skills/icon128.png +0 -0
  110. {airtrain-0.1.25 → airtrain-0.1.27}/examples/creating-skills/icon16.png +0 -0
  111. {airtrain-0.1.25 → airtrain-0.1.27}/examples/creating-skills/image1.jpg +0 -0
  112. {airtrain-0.1.25 → airtrain-0.1.27}/examples/creating-skills/image2.jpg +0 -0
  113. {airtrain-0.1.25 → airtrain-0.1.27}/examples/creating-skills/openai_skills.py +0 -0
  114. {airtrain-0.1.25 → airtrain-0.1.27}/examples/creating-skills/openai_skills_usage.py +0 -0
  115. {airtrain-0.1.25 → airtrain-0.1.27}/examples/creating-skills/together_rerank_skills.py +0 -0
  116. {airtrain-0.1.25 → airtrain-0.1.27}/examples/creating-skills/together_rerank_skills_async.py +0 -0
  117. {airtrain-0.1.25 → airtrain-0.1.27}/examples/credentials_usage.py +0 -0
  118. {airtrain-0.1.25 → airtrain-0.1.27}/examples/images/quantum-circuit.png +0 -0
  119. {airtrain-0.1.25 → airtrain-0.1.27}/examples/integrations/anthropic/chat_example.py +0 -0
  120. {airtrain-0.1.25 → airtrain-0.1.27}/examples/integrations/anthropic/chinese_example.py +0 -0
  121. {airtrain-0.1.25 → airtrain-0.1.27}/examples/integrations/anthropic/conversation_history_test.py +0 -0
  122. {airtrain-0.1.25 → airtrain-0.1.27}/examples/integrations/anthropic/vision_example.py +0 -0
  123. {airtrain-0.1.25 → airtrain-0.1.27}/examples/integrations/cerebras/conversation_history_test.py +0 -0
  124. {airtrain-0.1.25 → airtrain-0.1.27}/examples/integrations/fireworks/chat_example.py +0 -0
  125. {airtrain-0.1.25 → airtrain-0.1.27}/examples/integrations/fireworks/conversation_history_test.py +0 -0
  126. {airtrain-0.1.25 → airtrain-0.1.27}/examples/integrations/fireworks/structured_chat_example.py +0 -0
  127. {airtrain-0.1.25 → airtrain-0.1.27}/examples/integrations/google/conversation_history_test.py +0 -0
  128. {airtrain-0.1.25 → airtrain-0.1.27}/examples/integrations/google/gemini/conversation_history_test.py +0 -0
  129. {airtrain-0.1.25 → airtrain-0.1.27}/examples/integrations/groq/conversation_history_test.py +0 -0
  130. {airtrain-0.1.25 → airtrain-0.1.27}/examples/integrations/openai/chat_example.py +0 -0
  131. {airtrain-0.1.25 → airtrain-0.1.27}/examples/integrations/openai/parser_example.py +0 -0
  132. {airtrain-0.1.25 → airtrain-0.1.27}/examples/integrations/openai/vision_example.py +0 -0
  133. {airtrain-0.1.25 → airtrain-0.1.27}/examples/integrations/sambanova/conversation_history_test.py +0 -0
  134. {airtrain-0.1.25 → airtrain-0.1.27}/examples/integrations/together/chat_example.py +0 -0
  135. {airtrain-0.1.25 → airtrain-0.1.27}/examples/integrations/together/conversation_history_test.py +0 -0
  136. {airtrain-0.1.25 → airtrain-0.1.27}/examples/integrations/together/image_generation_example.py +0 -0
  137. {airtrain-0.1.25 → airtrain-0.1.27}/examples/integrations/together/rerank_example.py +0 -0
  138. {airtrain-0.1.25 → airtrain-0.1.27}/examples/schema_usage.py +0 -0
  139. {airtrain-0.1.25 → airtrain-0.1.27}/examples/skill_usage.py +0 -0
  140. {airtrain-0.1.25 → airtrain-0.1.27}/examples/together/image_generation.py +0 -0
  141. {airtrain-0.1.25 → airtrain-0.1.27}/examples/together/image_generation_example.py +0 -0
  142. {airtrain-0.1.25 → airtrain-0.1.27}/examples/travel/verification_agent_usage.py +0 -0
  143. {airtrain-0.1.25 → airtrain-0.1.27}/pyproject.toml +0 -0
  144. {airtrain-0.1.25 → airtrain-0.1.27}/requirements.txt +0 -0
  145. {airtrain-0.1.25 → airtrain-0.1.27}/scripts/build.sh +0 -0
  146. {airtrain-0.1.25 → airtrain-0.1.27}/scripts/bump_version.py +0 -0
  147. {airtrain-0.1.25 → airtrain-0.1.27}/scripts/publish.sh +0 -0
  148. {airtrain-0.1.25 → airtrain-0.1.27}/services/firebase_service.py +0 -0
  149. {airtrain-0.1.25 → airtrain-0.1.27}/services/openai_service.py +0 -0
  150. {airtrain-0.1.25 → airtrain-0.1.27}/setup.cfg +0 -0
  151. {airtrain-0.1.25 → airtrain-0.1.27}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: airtrain
3
- Version: 0.1.25
3
+ Version: 0.1.27
4
4
  Summary: A platform for building and deploying AI agents with structured skills
5
5
  Home-page: https://github.com/rosaboyle/airtrain.dev
6
6
  Author: Dheeraj Pai
@@ -171,7 +171,15 @@ This project is licensed under the MIT License - see the LICENSE file for detail
171
171
  ## Changelog
172
172
 
173
173
 
174
- ## 0.1.15
174
+ ## 0.1.27
175
+
176
+ - Added structured completion skills for Fireworks AI
177
+ - Added Completion skills for Fireworks AI.
178
+ - Added Combination skill for Groq and Fireworks AI.
179
+ - Add completion streaming.
180
+ - Added strcutured output streaming for Fireworks AI.
181
+
182
+ ## 0.1.23
175
183
 
176
184
  - Added conversation support for Deepseek, Togehter AI, Fireworks AI, Gemini, Groq, Cerebras and Sambanova.
177
185
  - Added Change Log
@@ -1,6 +1,6 @@
1
1
  """Airtrain - A platform for building and deploying AI agents with structured skills"""
2
2
 
3
- __version__ = "0.1.25"
3
+ __version__ = "0.1.27"
4
4
 
5
5
  # Core imports
6
6
  from .core.skills import Skill, ProcessingError
@@ -0,0 +1,126 @@
1
+ from typing import Optional, Dict, Any, List
2
+ from pydantic import Field
3
+ import requests
4
+ from groq import Groq
5
+
6
+ from airtrain.core.skills import Skill, ProcessingError
7
+ from airtrain.core.schemas import InputSchema, OutputSchema
8
+ from airtrain.integrations.fireworks.completion_skills import (
9
+ FireworksCompletionSkill,
10
+ FireworksCompletionInput,
11
+ )
12
+
13
+
14
+ class GroqFireworksInput(InputSchema):
15
+ """Schema for combined Groq and Fireworks input"""
16
+
17
+ user_input: str = Field(..., description="User's input text")
18
+ groq_model: str = Field(
19
+ default="mixtral-8x7b-32768", description="Groq model to use"
20
+ )
21
+ fireworks_model: str = Field(
22
+ default="accounts/fireworks/models/deepseek-r1",
23
+ description="Fireworks model to use",
24
+ )
25
+ temperature: float = Field(
26
+ default=0.7, description="Temperature for response generation"
27
+ )
28
+ max_tokens: int = Field(default=4096, description="Maximum tokens in response")
29
+
30
+
31
+ class GroqFireworksOutput(OutputSchema):
32
+ """Schema for combined Groq and Fireworks output"""
33
+
34
+ combined_response: str
35
+ groq_response: str
36
+ fireworks_response: str
37
+ used_models: Dict[str, str]
38
+ usage: Dict[str, Dict[str, int]]
39
+
40
+
41
+ class GroqFireworksSkill(Skill[GroqFireworksInput, GroqFireworksOutput]):
42
+ """Skill combining Groq and Fireworks responses"""
43
+
44
+ input_schema = GroqFireworksInput
45
+ output_schema = GroqFireworksOutput
46
+
47
+ def __init__(
48
+ self,
49
+ groq_api_key: Optional[str] = None,
50
+ fireworks_skill: Optional[FireworksCompletionSkill] = None,
51
+ ):
52
+ """Initialize the skill with optional API keys"""
53
+ super().__init__()
54
+ self.groq_client = Groq(api_key=groq_api_key)
55
+ self.fireworks_skill = fireworks_skill or FireworksCompletionSkill()
56
+
57
+ def _get_groq_response(self, input_data: GroqFireworksInput) -> Dict[str, Any]:
58
+ """Get response from Groq"""
59
+ try:
60
+ completion = self.groq_client.chat.completions.create(
61
+ model=input_data.groq_model,
62
+ messages=[{"role": "user", "content": input_data.user_input}],
63
+ temperature=input_data.temperature,
64
+ max_tokens=input_data.max_tokens,
65
+ )
66
+ return {
67
+ "response": completion.choices[0].message.content,
68
+ "usage": completion.usage.model_dump(),
69
+ }
70
+ except Exception as e:
71
+ raise ProcessingError(f"Groq request failed: {str(e)}")
72
+
73
+ def _get_fireworks_response(
74
+ self, groq_response: str, input_data: GroqFireworksInput
75
+ ) -> Dict[str, Any]:
76
+ """Get response from Fireworks"""
77
+ try:
78
+ formatted_prompt = (
79
+ f"<USER>{input_data.user_input}</USER>\n<ASSISTANT>{groq_response}"
80
+ )
81
+
82
+ fireworks_input = FireworksCompletionInput(
83
+ prompt=formatted_prompt,
84
+ model=input_data.fireworks_model,
85
+ temperature=input_data.temperature,
86
+ max_tokens=input_data.max_tokens,
87
+ )
88
+
89
+ result = self.fireworks_skill.process(fireworks_input)
90
+ return {"response": result.response, "usage": result.usage}
91
+ except Exception as e:
92
+ raise ProcessingError(f"Fireworks request failed: {str(e)}")
93
+
94
+ def process(self, input_data: GroqFireworksInput) -> GroqFireworksOutput:
95
+ """Process the input using both Groq and Fireworks"""
96
+ try:
97
+ # Get Groq response
98
+ groq_result = self._get_groq_response(input_data)
99
+
100
+ # Get Fireworks response
101
+ fireworks_result = self._get_fireworks_response(
102
+ groq_result["response"], input_data
103
+ )
104
+
105
+ # Combine responses in the required format
106
+ combined_response = (
107
+ f"<USER>{input_data.user_input}</USER>\n"
108
+ f"<ASSISTANT>{groq_result['response']} {fireworks_result['response']}"
109
+ )
110
+
111
+ return GroqFireworksOutput(
112
+ combined_response=combined_response,
113
+ groq_response=groq_result["response"],
114
+ fireworks_response=fireworks_result["response"],
115
+ used_models={
116
+ "groq": input_data.groq_model,
117
+ "fireworks": input_data.fireworks_model,
118
+ },
119
+ usage={
120
+ "groq": groq_result["usage"],
121
+ "fireworks": fireworks_result["usage"],
122
+ },
123
+ )
124
+
125
+ except Exception as e:
126
+ raise ProcessingError(f"Combined processing failed: {str(e)}")
@@ -0,0 +1,147 @@
1
+ from typing import List, Optional, Dict, Any, Generator, Union
2
+ from pydantic import Field
3
+ import requests
4
+ import json
5
+ from loguru import logger
6
+
7
+ from airtrain.core.skills import Skill, ProcessingError
8
+ from airtrain.core.schemas import InputSchema, OutputSchema
9
+ from .credentials import FireworksCredentials
10
+
11
+
12
+ class FireworksCompletionInput(InputSchema):
13
+ """Schema for Fireworks AI completion input using requests"""
14
+
15
+ prompt: str = Field(..., description="Input prompt for completion")
16
+ model: str = Field(
17
+ default="accounts/fireworks/models/deepseek-r1",
18
+ description="Fireworks AI model to use",
19
+ )
20
+ max_tokens: int = Field(default=4096, description="Maximum tokens in response")
21
+ temperature: float = Field(
22
+ default=0.7, description="Temperature for response generation", ge=0, le=1
23
+ )
24
+ top_p: float = Field(
25
+ default=1.0, description="Top p sampling parameter", ge=0, le=1
26
+ )
27
+ top_k: int = Field(default=50, description="Top k sampling parameter", ge=0)
28
+ presence_penalty: float = Field(
29
+ default=0.0, description="Presence penalty", ge=-2.0, le=2.0
30
+ )
31
+ frequency_penalty: float = Field(
32
+ default=0.0, description="Frequency penalty", ge=-2.0, le=2.0
33
+ )
34
+ repetition_penalty: float = Field(
35
+ default=1.0, description="Repetition penalty", ge=0.0
36
+ )
37
+ stop: Optional[Union[str, List[str]]] = Field(
38
+ default=None, description="Stop sequences"
39
+ )
40
+ echo: bool = Field(default=False, description="Echo the prompt in the response")
41
+ stream: bool = Field(default=False, description="Whether to stream the response")
42
+
43
+
44
+ class FireworksCompletionOutput(OutputSchema):
45
+ """Schema for Fireworks AI completion output"""
46
+
47
+ response: str
48
+ used_model: str
49
+ usage: Dict[str, int]
50
+
51
+
52
+ class FireworksCompletionSkill(
53
+ Skill[FireworksCompletionInput, FireworksCompletionOutput]
54
+ ):
55
+ """Skill for text completion using Fireworks AI"""
56
+
57
+ input_schema = FireworksCompletionInput
58
+ output_schema = FireworksCompletionOutput
59
+ BASE_URL = "https://api.fireworks.ai/inference/v1/completions"
60
+
61
+ def __init__(self, credentials: Optional[FireworksCredentials] = None):
62
+ """Initialize the skill with optional credentials"""
63
+ super().__init__()
64
+ self.credentials = credentials or FireworksCredentials.from_env()
65
+ self.headers = {
66
+ "Accept": "application/json",
67
+ "Content-Type": "application/json",
68
+ "Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
69
+ }
70
+
71
+ def _build_payload(self, input_data: FireworksCompletionInput) -> Dict[str, Any]:
72
+ """Build the request payload."""
73
+ payload = {
74
+ "model": input_data.model,
75
+ "prompt": input_data.prompt,
76
+ "max_tokens": input_data.max_tokens,
77
+ "temperature": input_data.temperature,
78
+ "top_p": input_data.top_p,
79
+ "top_k": input_data.top_k,
80
+ "presence_penalty": input_data.presence_penalty,
81
+ "frequency_penalty": input_data.frequency_penalty,
82
+ "repetition_penalty": input_data.repetition_penalty,
83
+ "echo": input_data.echo,
84
+ "stream": input_data.stream,
85
+ }
86
+
87
+ if input_data.stop:
88
+ payload["stop"] = input_data.stop
89
+
90
+ return payload
91
+
92
+ def process_stream(
93
+ self, input_data: FireworksCompletionInput
94
+ ) -> Generator[str, None, None]:
95
+ """Process the input and stream the response."""
96
+ try:
97
+ payload = self._build_payload(input_data)
98
+ response = requests.post(
99
+ self.BASE_URL,
100
+ headers=self.headers,
101
+ data=json.dumps(payload),
102
+ stream=True,
103
+ )
104
+ response.raise_for_status()
105
+
106
+ for line in response.iter_lines():
107
+ if line:
108
+ try:
109
+ data = json.loads(line.decode("utf-8").removeprefix("data: "))
110
+ if data.get("choices") and data["choices"][0].get("text"):
111
+ yield data["choices"][0]["text"]
112
+ except json.JSONDecodeError:
113
+ continue
114
+
115
+ except Exception as e:
116
+ raise ProcessingError(f"Fireworks completion streaming failed: {str(e)}")
117
+
118
+ def process(
119
+ self, input_data: FireworksCompletionInput
120
+ ) -> FireworksCompletionOutput:
121
+ """Process the input and return completion response."""
122
+ try:
123
+ if input_data.stream:
124
+ # For streaming, collect the entire response
125
+ response_chunks = []
126
+ for chunk in self.process_stream(input_data):
127
+ response_chunks.append(chunk)
128
+ response_text = "".join(response_chunks)
129
+ usage = {} # Usage stats not available in streaming mode
130
+ else:
131
+ # For non-streaming, use regular request
132
+ payload = self._build_payload(input_data)
133
+ response = requests.post(
134
+ self.BASE_URL, headers=self.headers, data=json.dumps(payload)
135
+ )
136
+ response.raise_for_status()
137
+ data = response.json()
138
+
139
+ response_text = data["choices"][0]["text"]
140
+ usage = data["usage"]
141
+
142
+ return FireworksCompletionOutput(
143
+ response=response_text, used_model=input_data.model, usage=usage
144
+ )
145
+
146
+ except Exception as e:
147
+ raise ProcessingError(f"Fireworks completion failed: {str(e)}")
@@ -0,0 +1,152 @@
1
+ from typing import List, Optional, Dict, Any, Generator
2
+ from pydantic import Field
3
+ import requests
4
+ import json
5
+ from loguru import logger
6
+
7
+ from airtrain.core.skills import Skill, ProcessingError
8
+ from airtrain.core.schemas import InputSchema, OutputSchema
9
+ from .credentials import FireworksCredentials
10
+
11
+
12
+ class FireworksRequestInput(InputSchema):
13
+ """Schema for Fireworks AI chat input using requests"""
14
+
15
+ user_input: str = Field(..., description="User's input text")
16
+ system_prompt: str = Field(
17
+ default="You are a helpful assistant.",
18
+ description="System prompt to guide the model's behavior",
19
+ )
20
+ conversation_history: List[Dict[str, str]] = Field(
21
+ default_factory=list,
22
+ description="List of previous conversation messages",
23
+ )
24
+ model: str = Field(
25
+ default="accounts/fireworks/models/deepseek-r1",
26
+ description="Fireworks AI model to use",
27
+ )
28
+ temperature: float = Field(
29
+ default=0.7, description="Temperature for response generation", ge=0, le=1
30
+ )
31
+ max_tokens: int = Field(default=4096, description="Maximum tokens in response")
32
+ top_p: float = Field(
33
+ default=1.0, description="Top p sampling parameter", ge=0, le=1
34
+ )
35
+ top_k: int = Field(default=40, description="Top k sampling parameter", ge=0)
36
+ presence_penalty: float = Field(
37
+ default=0.0, description="Presence penalty", ge=-2.0, le=2.0
38
+ )
39
+ frequency_penalty: float = Field(
40
+ default=0.0, description="Frequency penalty", ge=-2.0, le=2.0
41
+ )
42
+ stream: bool = Field(
43
+ default=False,
44
+ description="Whether to stream the response",
45
+ )
46
+
47
+
48
+ class FireworksRequestOutput(OutputSchema):
49
+ """Schema for Fireworks AI chat output"""
50
+
51
+ response: str
52
+ used_model: str
53
+ usage: Dict[str, int]
54
+
55
+
56
+ class FireworksRequestSkill(Skill[FireworksRequestInput, FireworksRequestOutput]):
57
+ """Skill for interacting with Fireworks AI models using requests"""
58
+
59
+ input_schema = FireworksRequestInput
60
+ output_schema = FireworksRequestOutput
61
+ BASE_URL = "https://api.fireworks.ai/inference/v1/chat/completions"
62
+
63
+ def __init__(self, credentials: Optional[FireworksCredentials] = None):
64
+ """Initialize the skill with optional credentials"""
65
+ super().__init__()
66
+ self.credentials = credentials or FireworksCredentials.from_env()
67
+ self.headers = {
68
+ "Accept": "application/json",
69
+ "Content-Type": "application/json",
70
+ "Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
71
+ }
72
+
73
+ def _build_messages(
74
+ self, input_data: FireworksRequestInput
75
+ ) -> List[Dict[str, str]]:
76
+ """Build messages list from input data including conversation history."""
77
+ messages = [{"role": "system", "content": input_data.system_prompt}]
78
+
79
+ if input_data.conversation_history:
80
+ messages.extend(input_data.conversation_history)
81
+
82
+ messages.append({"role": "user", "content": input_data.user_input})
83
+ return messages
84
+
85
+ def _build_payload(self, input_data: FireworksRequestInput) -> Dict[str, Any]:
86
+ """Build the request payload."""
87
+ return {
88
+ "model": input_data.model,
89
+ "messages": self._build_messages(input_data),
90
+ "temperature": input_data.temperature,
91
+ "max_tokens": input_data.max_tokens,
92
+ "top_p": input_data.top_p,
93
+ "top_k": input_data.top_k,
94
+ "presence_penalty": input_data.presence_penalty,
95
+ "frequency_penalty": input_data.frequency_penalty,
96
+ "stream": input_data.stream,
97
+ }
98
+
99
+ def process_stream(
100
+ self, input_data: FireworksRequestInput
101
+ ) -> Generator[str, None, None]:
102
+ """Process the input and stream the response."""
103
+ try:
104
+ payload = self._build_payload(input_data)
105
+ response = requests.post(
106
+ self.BASE_URL,
107
+ headers=self.headers,
108
+ data=json.dumps(payload),
109
+ stream=True,
110
+ )
111
+ response.raise_for_status()
112
+
113
+ for line in response.iter_lines():
114
+ if line:
115
+ try:
116
+ data = json.loads(line.decode("utf-8").removeprefix("data: "))
117
+ if data["choices"][0]["delta"].get("content"):
118
+ yield data["choices"][0]["delta"]["content"]
119
+ except json.JSONDecodeError:
120
+ continue
121
+
122
+ except Exception as e:
123
+ raise ProcessingError(f"Fireworks streaming request failed: {str(e)}")
124
+
125
+ def process(self, input_data: FireworksRequestInput) -> FireworksRequestOutput:
126
+ """Process the input and return the complete response."""
127
+ try:
128
+ if input_data.stream:
129
+ # For streaming, collect the entire response
130
+ response_chunks = []
131
+ for chunk in self.process_stream(input_data):
132
+ response_chunks.append(chunk)
133
+ response_text = "".join(response_chunks)
134
+ usage = {} # Usage stats not available in streaming mode
135
+ else:
136
+ # For non-streaming, use regular request
137
+ payload = self._build_payload(input_data)
138
+ response = requests.post(
139
+ self.BASE_URL, headers=self.headers, data=json.dumps(payload)
140
+ )
141
+ response.raise_for_status()
142
+ data = response.json()
143
+
144
+ response_text = data["choices"][0]["message"]["content"]
145
+ usage = data["usage"]
146
+
147
+ return FireworksRequestOutput(
148
+ response=response_text, used_model=input_data.model, usage=usage
149
+ )
150
+
151
+ except Exception as e:
152
+ raise ProcessingError(f"Fireworks request failed: {str(e)}")
@@ -1,7 +1,9 @@
1
- from typing import List, Optional, Dict, Any
1
+ from typing import List, Optional, Dict, Any, Generator
2
2
  from pydantic import Field
3
3
  import requests
4
4
  from loguru import logger
5
+ from openai import OpenAI
6
+ from openai.types.chat import ChatCompletionChunk
5
7
 
6
8
  from airtrain.core.skills import Skill, ProcessingError
7
9
  from airtrain.core.schemas import InputSchema, OutputSchema
@@ -34,10 +36,14 @@ class FireworksInput(InputSchema):
34
36
  context_length_exceeded_behavior: str = Field(
35
37
  default="truncate", description="Behavior when context length is exceeded"
36
38
  )
39
+ stream: bool = Field(
40
+ default=False,
41
+ description="Whether to stream the response token by token",
42
+ )
37
43
 
38
44
 
39
45
  class FireworksOutput(OutputSchema):
40
- """Schema for Fireworks AI output"""
46
+ """Schema for Fireworks AI chat output"""
41
47
 
42
48
  response: str = Field(..., description="Model's response text")
43
49
  used_model: str = Field(..., description="Model used for generation")
@@ -54,76 +60,71 @@ class FireworksChatSkill(Skill[FireworksInput, FireworksOutput]):
54
60
  """Initialize the skill with optional credentials"""
55
61
  super().__init__()
56
62
  self.credentials = credentials or FireworksCredentials.from_env()
57
- self.base_url = "https://api.fireworks.ai/inference/v1"
63
+ self.client = OpenAI(
64
+ base_url="https://api.fireworks.ai/inference/v1",
65
+ api_key=self.credentials.fireworks_api_key.get_secret_value(),
66
+ )
58
67
 
59
68
  def _build_messages(self, input_data: FireworksInput) -> List[Dict[str, str]]:
60
- """
61
- Build messages list from input data including conversation history.
62
-
63
- Args:
64
- input_data: The input data containing system prompt, conversation history, and user input
65
-
66
- Returns:
67
- List[Dict[str, str]]: List of messages in the format required by Fireworks AI
68
- """
69
+ """Build messages list from input data including conversation history."""
69
70
  messages = [{"role": "system", "content": input_data.system_prompt}]
70
71
 
71
- # Add conversation history if present
72
72
  if input_data.conversation_history:
73
73
  messages.extend(input_data.conversation_history)
74
74
 
75
- # Add current user input
76
75
  messages.append({"role": "user", "content": input_data.user_input})
77
-
78
76
  return messages
79
77
 
80
- def process(self, input_data: FireworksInput) -> FireworksOutput:
81
- """Process the input using Fireworks AI API"""
78
+ def process_stream(self, input_data: FireworksInput) -> Generator[str, None, None]:
79
+ """Process the input and stream the response token by token."""
82
80
  try:
83
- logger.info(f"Processing request with model {input_data.model}")
84
-
85
- # Build messages using the helper method
86
81
  messages = self._build_messages(input_data)
87
82
 
88
- # Prepare request payload
89
- payload = {
90
- "messages": messages,
91
- "model": input_data.model,
92
- "context_length_exceeded_behavior": input_data.context_length_exceeded_behavior,
93
- "temperature": input_data.temperature,
94
- "n": 1,
95
- "response_format": {"type": "text"},
96
- "stream": False,
97
- }
98
-
99
- if input_data.max_tokens:
100
- payload["max_tokens"] = input_data.max_tokens
101
-
102
- # Make API request
103
- response = requests.post(
104
- f"{self.base_url}/chat/completions",
105
- json=payload,
106
- headers={
107
- "Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
108
- "Content-Type": "application/json",
109
- },
83
+ stream = self.client.chat.completions.create(
84
+ model=input_data.model,
85
+ messages=messages,
86
+ temperature=input_data.temperature,
87
+ max_tokens=input_data.max_tokens,
88
+ stream=True,
110
89
  )
111
90
 
112
- response.raise_for_status()
113
- response_data = FireworksResponse(**response.json())
91
+ for chunk in stream:
92
+ if chunk.choices[0].delta.content is not None:
93
+ yield chunk.choices[0].delta.content
94
+
95
+ except Exception as e:
96
+ raise ProcessingError(f"Fireworks streaming failed: {str(e)}")
114
97
 
115
- logger.success("Successfully processed Fireworks AI request")
98
+ def process(self, input_data: FireworksInput) -> FireworksOutput:
99
+ """Process the input and return the complete response."""
100
+ try:
101
+ if input_data.stream:
102
+ # For streaming, collect the entire response
103
+ response_chunks = []
104
+ for chunk in self.process_stream(input_data):
105
+ response_chunks.append(chunk)
106
+ response = "".join(response_chunks)
107
+ else:
108
+ # For non-streaming, use regular completion
109
+ messages = self._build_messages(input_data)
110
+ completion = self.client.chat.completions.create(
111
+ model=input_data.model,
112
+ messages=messages,
113
+ temperature=input_data.temperature,
114
+ max_tokens=input_data.max_tokens,
115
+ stream=False,
116
+ )
117
+ response = completion.choices[0].message.content
116
118
 
117
119
  return FireworksOutput(
118
- response=response_data.choices[0]["message"]["content"],
119
- used_model=response_data.model,
120
+ response=response,
121
+ used_model=input_data.model,
120
122
  usage={
121
- "prompt_tokens": response_data.usage.prompt_tokens,
122
- "completion_tokens": response_data.usage.completion_tokens,
123
- "total_tokens": response_data.usage.total_tokens,
123
+ "total_tokens": completion.usage.total_tokens,
124
+ "prompt_tokens": completion.usage.prompt_tokens,
125
+ "completion_tokens": completion.usage.completion_tokens,
124
126
  },
125
127
  )
126
128
 
127
129
  except Exception as e:
128
- logger.exception(f"Fireworks AI processing failed: {str(e)}")
129
- raise ProcessingError(f"Fireworks AI processing failed: {str(e)}")
130
+ raise ProcessingError(f"Fireworks chat failed: {str(e)}")