airtrain 0.1.29__py3-none-any.whl → 0.1.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +1 -1
- airtrain/cli/__init__.py +0 -0
- airtrain/cli/main.py +53 -0
- airtrain/integrations/anthropic/skills.py +36 -8
- airtrain/integrations/cerebras/skills.py +38 -6
- airtrain/integrations/fireworks/requests_skills.py +56 -1
- airtrain/integrations/groq/skills.py +47 -10
- airtrain/integrations/openai/skills.py +48 -6
- airtrain/integrations/sambanova/skills.py +38 -6
- airtrain/integrations/together/skills.py +39 -8
- {airtrain-0.1.29.dist-info → airtrain-0.1.31.dist-info}/METADATA +24 -36
- {airtrain-0.1.29.dist-info → airtrain-0.1.31.dist-info}/RECORD +15 -12
- airtrain-0.1.31.dist-info/entry_points.txt +2 -0
- {airtrain-0.1.29.dist-info → airtrain-0.1.31.dist-info}/WHEEL +0 -0
- {airtrain-0.1.29.dist-info → airtrain-0.1.31.dist-info}/top_level.txt +0 -0
airtrain/__init__.py
CHANGED
airtrain/cli/__init__.py
ADDED
File without changes
|
airtrain/cli/main.py
ADDED
@@ -0,0 +1,53 @@
|
|
1
|
+
import click
|
2
|
+
from airtrain.integrations.openai.skills import OpenAIChatSkill, OpenAIInput
|
3
|
+
import os
|
4
|
+
from dotenv import load_dotenv
|
5
|
+
|
6
|
+
load_dotenv()
|
7
|
+
|
8
|
+
|
9
|
+
def initialize_chat():
|
10
|
+
return OpenAIChatSkill()
|
11
|
+
|
12
|
+
|
13
|
+
@click.group()
|
14
|
+
def cli():
|
15
|
+
"""Airtrain CLI - Your AI Agent Building Assistant"""
|
16
|
+
pass
|
17
|
+
|
18
|
+
|
19
|
+
@cli.command()
|
20
|
+
def chat():
|
21
|
+
"""Start an interactive chat session with Airtrain"""
|
22
|
+
skill = initialize_chat()
|
23
|
+
click.echo("Welcome to Airtrain! I'm here to help you build your AI Agent.")
|
24
|
+
click.echo("Type 'exit' to end the conversation.\n")
|
25
|
+
|
26
|
+
while True:
|
27
|
+
user_input = click.prompt("You", type=str)
|
28
|
+
|
29
|
+
if user_input.lower() == "exit":
|
30
|
+
click.echo("\nGoodbye! Have a great day!")
|
31
|
+
break
|
32
|
+
|
33
|
+
try:
|
34
|
+
input_data = OpenAIInput(
|
35
|
+
user_input=user_input,
|
36
|
+
system_prompt="You are an AI assistant that helps users build their own AI agents. Be helpful and provide clear explanations.",
|
37
|
+
model="gpt-4o",
|
38
|
+
temperature=0.7,
|
39
|
+
)
|
40
|
+
|
41
|
+
result = skill.process(input_data)
|
42
|
+
click.echo(f"\nAirtrain: {result.response}\n")
|
43
|
+
|
44
|
+
except Exception as e:
|
45
|
+
click.echo(f"\nError: {str(e)}\n")
|
46
|
+
|
47
|
+
|
48
|
+
def main():
|
49
|
+
cli()
|
50
|
+
|
51
|
+
|
52
|
+
if __name__ == "__main__":
|
53
|
+
main()
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import List, Optional, Dict, Any
|
1
|
+
from typing import List, Optional, Dict, Any, Generator
|
2
2
|
from pydantic import Field
|
3
3
|
from anthropic import Anthropic
|
4
4
|
import base64
|
@@ -35,6 +35,9 @@ class AnthropicInput(InputSchema):
|
|
35
35
|
default_factory=list,
|
36
36
|
description="List of image paths to include in the message",
|
37
37
|
)
|
38
|
+
stream: bool = Field(
|
39
|
+
default=False, description="Whether to stream the response progressively"
|
40
|
+
)
|
38
41
|
|
39
42
|
|
40
43
|
class AnthropicOutput(OutputSchema):
|
@@ -102,24 +105,49 @@ class AnthropicChatSkill(Skill[AnthropicInput, AnthropicOutput]):
|
|
102
105
|
|
103
106
|
return messages
|
104
107
|
|
105
|
-
def
|
108
|
+
def process_stream(self, input_data: AnthropicInput) -> Generator[str, None, None]:
|
109
|
+
"""Process the input and stream the response token by token."""
|
106
110
|
try:
|
107
|
-
# Build messages using the helper method
|
108
111
|
messages = self._build_messages(input_data)
|
109
112
|
|
110
|
-
|
111
|
-
response = self.client.messages.create(
|
113
|
+
with self.client.beta.messages.stream(
|
112
114
|
model=input_data.model,
|
113
|
-
system=input_data.system_prompt,
|
115
|
+
system=input_data.system_prompt,
|
114
116
|
messages=messages,
|
115
117
|
max_tokens=input_data.max_tokens,
|
116
118
|
temperature=input_data.temperature,
|
117
|
-
)
|
119
|
+
) as stream:
|
120
|
+
for chunk in stream.text_stream:
|
121
|
+
yield chunk
|
122
|
+
|
123
|
+
except Exception as e:
|
124
|
+
logger.exception(f"Anthropic streaming failed: {str(e)}")
|
125
|
+
raise ProcessingError(f"Anthropic streaming failed: {str(e)}")
|
126
|
+
|
127
|
+
def process(self, input_data: AnthropicInput) -> AnthropicOutput:
|
128
|
+
"""Process the input and return the complete response."""
|
129
|
+
try:
|
130
|
+
if input_data.stream:
|
131
|
+
response_chunks = []
|
132
|
+
for chunk in self.process_stream(input_data):
|
133
|
+
response_chunks.append(chunk)
|
134
|
+
response = "".join(response_chunks)
|
135
|
+
usage = {} # Usage stats not available in streaming
|
136
|
+
else:
|
137
|
+
messages = self._build_messages(input_data)
|
138
|
+
response = self.client.messages.create(
|
139
|
+
model=input_data.model,
|
140
|
+
system=input_data.system_prompt,
|
141
|
+
messages=messages,
|
142
|
+
max_tokens=input_data.max_tokens,
|
143
|
+
temperature=input_data.temperature,
|
144
|
+
)
|
145
|
+
usage = response.usage.model_dump() if response.usage else {}
|
118
146
|
|
119
147
|
return AnthropicOutput(
|
120
148
|
response=response.content[0].text,
|
121
149
|
used_model=input_data.model,
|
122
|
-
usage=
|
150
|
+
usage=usage,
|
123
151
|
)
|
124
152
|
|
125
153
|
except Exception as e:
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import List, Optional, Dict, Any
|
1
|
+
from typing import List, Optional, Dict, Any, Generator
|
2
2
|
from pydantic import Field
|
3
3
|
from cerebras.cloud.sdk import Cerebras
|
4
4
|
from loguru import logger
|
@@ -27,6 +27,9 @@ class CerebrasInput(InputSchema):
|
|
27
27
|
temperature: float = Field(
|
28
28
|
default=0.7, description="Temperature for response generation", ge=0, le=1
|
29
29
|
)
|
30
|
+
stream: bool = Field(
|
31
|
+
default=False, description="Whether to stream the response progressively"
|
32
|
+
)
|
30
33
|
|
31
34
|
|
32
35
|
class CerebrasOutput(OutputSchema):
|
@@ -71,23 +74,52 @@ class CerebrasChatSkill(Skill[CerebrasInput, CerebrasOutput]):
|
|
71
74
|
|
72
75
|
return messages
|
73
76
|
|
74
|
-
def
|
77
|
+
def process_stream(self, input_data: CerebrasInput) -> Generator[str, None, None]:
|
78
|
+
"""Process the input and stream the response token by token."""
|
75
79
|
try:
|
76
|
-
# Build messages using the helper method
|
77
80
|
messages = self._build_messages(input_data)
|
78
81
|
|
79
|
-
|
80
|
-
response = self.client.chat.completions.create(
|
82
|
+
stream = self.client.chat.completions.create(
|
81
83
|
model=input_data.model,
|
82
84
|
messages=messages,
|
83
85
|
temperature=input_data.temperature,
|
84
86
|
max_tokens=input_data.max_tokens,
|
87
|
+
stream=True,
|
85
88
|
)
|
86
89
|
|
90
|
+
for chunk in stream:
|
91
|
+
if chunk.choices[0].delta.content is not None:
|
92
|
+
yield chunk.choices[0].delta.content
|
93
|
+
|
94
|
+
except Exception as e:
|
95
|
+
logger.exception(f"Cerebras streaming failed: {str(e)}")
|
96
|
+
raise ProcessingError(f"Cerebras streaming failed: {str(e)}")
|
97
|
+
|
98
|
+
def process(self, input_data: CerebrasInput) -> CerebrasOutput:
|
99
|
+
"""Process the input and return the complete response."""
|
100
|
+
try:
|
101
|
+
if input_data.stream:
|
102
|
+
response_chunks = []
|
103
|
+
for chunk in self.process_stream(input_data):
|
104
|
+
response_chunks.append(chunk)
|
105
|
+
response = "".join(response_chunks)
|
106
|
+
usage = {} # Usage stats not available in streaming
|
107
|
+
else:
|
108
|
+
messages = self._build_messages(input_data)
|
109
|
+
response = self.client.chat.completions.create(
|
110
|
+
model=input_data.model,
|
111
|
+
messages=messages,
|
112
|
+
temperature=input_data.temperature,
|
113
|
+
max_tokens=input_data.max_tokens,
|
114
|
+
)
|
115
|
+
usage = (
|
116
|
+
response.usage.model_dump() if hasattr(response, "usage") else {}
|
117
|
+
)
|
118
|
+
|
87
119
|
return CerebrasOutput(
|
88
120
|
response=response.choices[0].message.content,
|
89
121
|
used_model=input_data.model,
|
90
|
-
usage=
|
122
|
+
usage=usage,
|
91
123
|
)
|
92
124
|
|
93
125
|
except Exception as e:
|
@@ -1,8 +1,9 @@
|
|
1
|
-
from typing import List, Optional, Dict, Any, Generator
|
1
|
+
from typing import List, Optional, Dict, Any, Generator, AsyncGenerator
|
2
2
|
from pydantic import Field
|
3
3
|
import requests
|
4
4
|
import json
|
5
5
|
from loguru import logger
|
6
|
+
import aiohttp
|
6
7
|
|
7
8
|
from airtrain.core.skills import Skill, ProcessingError
|
8
9
|
from airtrain.core.schemas import InputSchema, OutputSchema
|
@@ -69,6 +70,11 @@ class FireworksRequestSkill(Skill[FireworksRequestInput, FireworksRequestOutput]
|
|
69
70
|
"Content-Type": "application/json",
|
70
71
|
"Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
|
71
72
|
}
|
73
|
+
self.stream_headers = {
|
74
|
+
"Accept": "text/event-stream",
|
75
|
+
"Content-Type": "application/json",
|
76
|
+
"Authorization": f"Bearer {self.credentials.fireworks_api_key.get_secret_value()}",
|
77
|
+
}
|
72
78
|
|
73
79
|
def _build_messages(
|
74
80
|
self, input_data: FireworksRequestInput
|
@@ -150,3 +156,52 @@ class FireworksRequestSkill(Skill[FireworksRequestInput, FireworksRequestOutput]
|
|
150
156
|
|
151
157
|
except Exception as e:
|
152
158
|
raise ProcessingError(f"Fireworks request failed: {str(e)}")
|
159
|
+
|
160
|
+
async def process_async(
|
161
|
+
self, input_data: FireworksRequestInput
|
162
|
+
) -> FireworksRequestOutput:
|
163
|
+
"""Async version of process method using aiohttp"""
|
164
|
+
try:
|
165
|
+
async with aiohttp.ClientSession() as session:
|
166
|
+
payload = self._build_payload(input_data)
|
167
|
+
async with session.post(
|
168
|
+
self.BASE_URL, headers=self.headers, json=payload
|
169
|
+
) as response:
|
170
|
+
response.raise_for_status()
|
171
|
+
data = await response.json()
|
172
|
+
|
173
|
+
return FireworksRequestOutput(
|
174
|
+
response=data["choices"][0]["message"]["content"],
|
175
|
+
used_model=input_data.model,
|
176
|
+
usage=data.get("usage", {}),
|
177
|
+
)
|
178
|
+
|
179
|
+
except Exception as e:
|
180
|
+
raise ProcessingError(f"Async Fireworks request failed: {str(e)}")
|
181
|
+
|
182
|
+
async def process_stream_async(
|
183
|
+
self, input_data: FireworksRequestInput
|
184
|
+
) -> AsyncGenerator[str, None]:
|
185
|
+
"""Async version of stream processor using aiohttp"""
|
186
|
+
try:
|
187
|
+
async with aiohttp.ClientSession() as session:
|
188
|
+
payload = self._build_payload(input_data)
|
189
|
+
async with session.post(
|
190
|
+
self.BASE_URL, headers=self.stream_headers, json=payload
|
191
|
+
) as response:
|
192
|
+
response.raise_for_status()
|
193
|
+
|
194
|
+
async for line in response.content:
|
195
|
+
if line.startswith(b"data: "):
|
196
|
+
chunk = json.loads(line[6:].strip())
|
197
|
+
if "choices" in chunk:
|
198
|
+
content = (
|
199
|
+
chunk["choices"][0]
|
200
|
+
.get("delta", {})
|
201
|
+
.get("content", "")
|
202
|
+
)
|
203
|
+
if content:
|
204
|
+
yield content
|
205
|
+
|
206
|
+
except Exception as e:
|
207
|
+
raise ProcessingError(f"Async Fireworks streaming failed: {str(e)}")
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Optional, Dict, Any, List
|
1
|
+
from typing import Generator, Optional, Dict, Any, List
|
2
2
|
from pydantic import Field
|
3
3
|
from airtrain.core.skills import Skill, ProcessingError
|
4
4
|
from airtrain.core.schemas import InputSchema, OutputSchema
|
@@ -18,11 +18,16 @@ class GroqInput(InputSchema):
|
|
18
18
|
default_factory=list,
|
19
19
|
description="List of previous conversation messages in [{'role': 'user|assistant', 'content': 'message'}] format",
|
20
20
|
)
|
21
|
-
model: str = Field(
|
21
|
+
model: str = Field(
|
22
|
+
default="deepseek-r1-distill-llama-70b-specdec", description="Groq model to use"
|
23
|
+
)
|
22
24
|
max_tokens: int = Field(default=1024, description="Maximum tokens in response")
|
23
25
|
temperature: float = Field(
|
24
26
|
default=0.7, description="Temperature for response generation", ge=0, le=1
|
25
27
|
)
|
28
|
+
stream: bool = Field(
|
29
|
+
default=False, description="Whether to stream the response progressively"
|
30
|
+
)
|
26
31
|
|
27
32
|
|
28
33
|
class GroqOutput(OutputSchema):
|
@@ -30,7 +35,9 @@ class GroqOutput(OutputSchema):
|
|
30
35
|
|
31
36
|
response: str = Field(..., description="Model's response text")
|
32
37
|
used_model: str = Field(..., description="Model used for generation")
|
33
|
-
usage: Dict[str, Any] = Field(
|
38
|
+
usage: Dict[str, Any] = Field(
|
39
|
+
default_factory=dict, description="Usage statistics from the API"
|
40
|
+
)
|
34
41
|
|
35
42
|
|
36
43
|
class GroqChatSkill(Skill[GroqInput, GroqOutput]):
|
@@ -65,23 +72,53 @@ class GroqChatSkill(Skill[GroqInput, GroqOutput]):
|
|
65
72
|
|
66
73
|
return messages
|
67
74
|
|
68
|
-
def
|
75
|
+
def process_stream(self, input_data: GroqInput) -> Generator[str, None, None]:
|
76
|
+
"""Process the input and stream the response token by token."""
|
69
77
|
try:
|
70
|
-
# Build messages using the helper method
|
71
78
|
messages = self._build_messages(input_data)
|
72
79
|
|
73
|
-
|
74
|
-
response = self.client.chat.completions.create(
|
80
|
+
stream = self.client.chat.completions.create(
|
75
81
|
model=input_data.model,
|
76
82
|
messages=messages,
|
77
83
|
temperature=input_data.temperature,
|
78
84
|
max_tokens=input_data.max_tokens,
|
85
|
+
stream=True,
|
79
86
|
)
|
80
87
|
|
88
|
+
for chunk in stream:
|
89
|
+
if chunk.choices[0].delta.content is not None:
|
90
|
+
yield chunk.choices[0].delta.content
|
91
|
+
|
92
|
+
except Exception as e:
|
93
|
+
raise ProcessingError(f"Groq streaming failed: {str(e)}")
|
94
|
+
|
95
|
+
def process(self, input_data: GroqInput) -> GroqOutput:
|
96
|
+
"""Process the input and return the complete response."""
|
97
|
+
try:
|
98
|
+
if input_data.stream:
|
99
|
+
response_chunks = []
|
100
|
+
for chunk in self.process_stream(input_data):
|
101
|
+
response_chunks.append(chunk)
|
102
|
+
response = "".join(response_chunks)
|
103
|
+
usage = {} # Usage stats not available in streaming
|
104
|
+
else:
|
105
|
+
messages = self._build_messages(input_data)
|
106
|
+
completion = self.client.chat.completions.create(
|
107
|
+
model=input_data.model,
|
108
|
+
messages=messages,
|
109
|
+
temperature=input_data.temperature,
|
110
|
+
max_tokens=input_data.max_tokens,
|
111
|
+
stream=False,
|
112
|
+
)
|
113
|
+
response = completion.choices[0].message.content
|
114
|
+
usage = {
|
115
|
+
"total_tokens": completion.usage.total_tokens,
|
116
|
+
"prompt_tokens": completion.usage.prompt_tokens,
|
117
|
+
"completion_tokens": completion.usage.completion_tokens,
|
118
|
+
}
|
119
|
+
|
81
120
|
return GroqOutput(
|
82
|
-
response=response.
|
83
|
-
used_model=input_data.model,
|
84
|
-
usage=response.usage.model_dump(),
|
121
|
+
response=response, used_model=input_data.model, usage=usage
|
85
122
|
)
|
86
123
|
|
87
124
|
except Exception as e:
|
@@ -1,9 +1,6 @@
|
|
1
|
-
from typing import List, Optional, Dict,
|
1
|
+
from typing import AsyncGenerator, List, Optional, Dict, TypeVar, Type, Generator
|
2
2
|
from pydantic import Field, BaseModel
|
3
|
-
from openai import OpenAI
|
4
|
-
import base64
|
5
|
-
from pathlib import Path
|
6
|
-
from loguru import logger
|
3
|
+
from openai import OpenAI, AsyncOpenAI
|
7
4
|
from openai.types.chat import ChatCompletionChunk
|
8
5
|
|
9
6
|
from airtrain.core.skills import Skill, ProcessingError
|
@@ -48,7 +45,7 @@ class OpenAIOutput(OutputSchema):
|
|
48
45
|
|
49
46
|
|
50
47
|
class OpenAIChatSkill(Skill[OpenAIInput, OpenAIOutput]):
|
51
|
-
"""Skill for interacting with OpenAI models"""
|
48
|
+
"""Skill for interacting with OpenAI models with async support"""
|
52
49
|
|
53
50
|
input_schema = OpenAIInput
|
54
51
|
output_schema = OpenAIOutput
|
@@ -61,6 +58,10 @@ class OpenAIChatSkill(Skill[OpenAIInput, OpenAIOutput]):
|
|
61
58
|
api_key=self.credentials.openai_api_key.get_secret_value(),
|
62
59
|
organization=self.credentials.openai_organization_id,
|
63
60
|
)
|
61
|
+
self.async_client = AsyncOpenAI(
|
62
|
+
api_key=self.credentials.openai_api_key.get_secret_value(),
|
63
|
+
organization=self.credentials.openai_organization_id,
|
64
|
+
)
|
64
65
|
|
65
66
|
def _build_messages(self, input_data: OpenAIInput) -> List[Dict[str, str]]:
|
66
67
|
"""Build messages list from input data including conversation history."""
|
@@ -126,6 +127,47 @@ class OpenAIChatSkill(Skill[OpenAIInput, OpenAIOutput]):
|
|
126
127
|
except Exception as e:
|
127
128
|
raise ProcessingError(f"OpenAI chat failed: {str(e)}")
|
128
129
|
|
130
|
+
async def process_async(self, input_data: OpenAIInput) -> OpenAIOutput:
|
131
|
+
"""Async version of process method"""
|
132
|
+
try:
|
133
|
+
messages = self._build_messages(input_data)
|
134
|
+
completion = await self.async_client.chat.completions.create(
|
135
|
+
model=input_data.model,
|
136
|
+
messages=messages,
|
137
|
+
temperature=input_data.temperature,
|
138
|
+
max_tokens=input_data.max_tokens,
|
139
|
+
)
|
140
|
+
return OpenAIOutput(
|
141
|
+
response=completion.choices[0].message.content,
|
142
|
+
used_model=completion.model,
|
143
|
+
usage={
|
144
|
+
"total_tokens": completion.usage.total_tokens,
|
145
|
+
"prompt_tokens": completion.usage.prompt_tokens,
|
146
|
+
"completion_tokens": completion.usage.completion_tokens,
|
147
|
+
},
|
148
|
+
)
|
149
|
+
except Exception as e:
|
150
|
+
raise ProcessingError(f"OpenAI async chat failed: {str(e)}")
|
151
|
+
|
152
|
+
async def process_stream_async(
|
153
|
+
self, input_data: OpenAIInput
|
154
|
+
) -> AsyncGenerator[str, None]:
|
155
|
+
"""Async version of stream processor"""
|
156
|
+
try:
|
157
|
+
messages = self._build_messages(input_data)
|
158
|
+
stream = await self.async_client.chat.completions.create(
|
159
|
+
model=input_data.model,
|
160
|
+
messages=messages,
|
161
|
+
temperature=input_data.temperature,
|
162
|
+
max_tokens=input_data.max_tokens,
|
163
|
+
stream=True,
|
164
|
+
)
|
165
|
+
async for chunk in stream:
|
166
|
+
if chunk.choices[0].delta.content is not None:
|
167
|
+
yield chunk.choices[0].delta.content
|
168
|
+
except Exception as e:
|
169
|
+
raise ProcessingError(f"OpenAI async streaming failed: {str(e)}")
|
170
|
+
|
129
171
|
|
130
172
|
ResponseT = TypeVar("ResponseT", bound=BaseModel)
|
131
173
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Optional, Dict, Any, List
|
1
|
+
from typing import Optional, Dict, Any, List, Generator
|
2
2
|
from pydantic import Field
|
3
3
|
from airtrain.core.skills import Skill, ProcessingError
|
4
4
|
from airtrain.core.schemas import InputSchema, OutputSchema
|
@@ -28,6 +28,9 @@ class SambanovaInput(InputSchema):
|
|
28
28
|
top_p: float = Field(
|
29
29
|
default=0.1, description="Top p sampling parameter", ge=0, le=1
|
30
30
|
)
|
31
|
+
stream: bool = Field(
|
32
|
+
default=False, description="Whether to stream the response progressively"
|
33
|
+
)
|
31
34
|
|
32
35
|
|
33
36
|
class SambanovaOutput(OutputSchema):
|
@@ -73,24 +76,53 @@ class SambanovaChatSkill(Skill[SambanovaInput, SambanovaOutput]):
|
|
73
76
|
|
74
77
|
return messages
|
75
78
|
|
76
|
-
def
|
79
|
+
def process_stream(self, input_data: SambanovaInput) -> Generator[str, None, None]:
|
80
|
+
"""Process the input and stream the response token by token."""
|
77
81
|
try:
|
78
|
-
# Build messages using the helper method
|
79
82
|
messages = self._build_messages(input_data)
|
80
83
|
|
81
|
-
|
82
|
-
response = self.client.chat.completions.create(
|
84
|
+
stream = self.client.chat.completions.create(
|
83
85
|
model=input_data.model,
|
84
86
|
messages=messages,
|
85
87
|
temperature=input_data.temperature,
|
86
88
|
max_tokens=input_data.max_tokens,
|
87
89
|
top_p=input_data.top_p,
|
90
|
+
stream=True,
|
88
91
|
)
|
89
92
|
|
93
|
+
for chunk in stream:
|
94
|
+
if chunk.choices[0].delta.content is not None:
|
95
|
+
yield chunk.choices[0].delta.content
|
96
|
+
|
97
|
+
except Exception as e:
|
98
|
+
raise ProcessingError(f"Sambanova streaming failed: {str(e)}")
|
99
|
+
|
100
|
+
def process(self, input_data: SambanovaInput) -> SambanovaOutput:
|
101
|
+
"""Process the input and return the complete response."""
|
102
|
+
try:
|
103
|
+
if input_data.stream:
|
104
|
+
response_chunks = []
|
105
|
+
for chunk in self.process_stream(input_data):
|
106
|
+
response_chunks.append(chunk)
|
107
|
+
response = "".join(response_chunks)
|
108
|
+
usage = {} # Usage stats not available in streaming
|
109
|
+
else:
|
110
|
+
messages = self._build_messages(input_data)
|
111
|
+
response = self.client.chat.completions.create(
|
112
|
+
model=input_data.model,
|
113
|
+
messages=messages,
|
114
|
+
temperature=input_data.temperature,
|
115
|
+
max_tokens=input_data.max_tokens,
|
116
|
+
top_p=input_data.top_p,
|
117
|
+
)
|
118
|
+
usage = (
|
119
|
+
response.usage.model_dump() if hasattr(response, "usage") else {}
|
120
|
+
)
|
121
|
+
|
90
122
|
return SambanovaOutput(
|
91
123
|
response=response.choices[0].message.content,
|
92
124
|
used_model=input_data.model,
|
93
|
-
usage=
|
125
|
+
usage=usage,
|
94
126
|
)
|
95
127
|
|
96
128
|
except Exception as e:
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Optional, Dict, Any, List
|
1
|
+
from typing import Optional, Dict, Any, List, Generator
|
2
2
|
from pydantic import Field
|
3
3
|
from airtrain.core.skills import Skill, ProcessingError
|
4
4
|
from airtrain.core.schemas import InputSchema, OutputSchema
|
@@ -29,6 +29,7 @@ class TogetherAIInput(InputSchema):
|
|
29
29
|
temperature: float = Field(
|
30
30
|
default=0.7, description="Temperature for response generation", ge=0, le=1
|
31
31
|
)
|
32
|
+
stream: bool = Field(default=False, description="Whether to stream the response")
|
32
33
|
|
33
34
|
|
34
35
|
class TogetherAIOutput(OutputSchema):
|
@@ -73,23 +74,53 @@ class TogetherAIChatSkill(Skill[TogetherAIInput, TogetherAIOutput]):
|
|
73
74
|
|
74
75
|
return messages
|
75
76
|
|
76
|
-
def
|
77
|
+
def process_stream(self, input_data: TogetherAIInput) -> Generator[str, None, None]:
|
78
|
+
"""Process the input and stream the response token by token."""
|
77
79
|
try:
|
78
|
-
# Build messages using the helper method
|
79
80
|
messages = self._build_messages(input_data)
|
80
81
|
|
81
|
-
|
82
|
-
response = self.client.chat.completions.create(
|
82
|
+
stream = self.client.chat.completions.create(
|
83
83
|
model=input_data.model,
|
84
84
|
messages=messages,
|
85
|
-
max_tokens=input_data.max_tokens,
|
86
85
|
temperature=input_data.temperature,
|
86
|
+
max_tokens=input_data.max_tokens,
|
87
|
+
stream=True,
|
87
88
|
)
|
88
89
|
|
90
|
+
for chunk in stream:
|
91
|
+
if chunk.choices[0].delta.content is not None:
|
92
|
+
yield chunk.choices[0].delta.content
|
93
|
+
|
94
|
+
except Exception as e:
|
95
|
+
raise ProcessingError(f"Together AI streaming failed: {str(e)}")
|
96
|
+
|
97
|
+
def process(self, input_data: TogetherAIInput) -> TogetherAIOutput:
|
98
|
+
"""Process the input and return the complete response."""
|
99
|
+
try:
|
100
|
+
if input_data.stream:
|
101
|
+
response_chunks = []
|
102
|
+
for chunk in self.process_stream(input_data):
|
103
|
+
response_chunks.append(chunk)
|
104
|
+
response = "".join(response_chunks)
|
105
|
+
else:
|
106
|
+
messages = self._build_messages(input_data)
|
107
|
+
completion = self.client.chat.completions.create(
|
108
|
+
model=input_data.model,
|
109
|
+
messages=messages,
|
110
|
+
temperature=input_data.temperature,
|
111
|
+
max_tokens=input_data.max_tokens,
|
112
|
+
stream=False,
|
113
|
+
)
|
114
|
+
response = completion.choices[0].message.content
|
115
|
+
|
89
116
|
return TogetherAIOutput(
|
90
|
-
response=response
|
117
|
+
response=response,
|
91
118
|
used_model=input_data.model,
|
92
|
-
usage=
|
119
|
+
usage=(
|
120
|
+
completion.usage.model_dump()
|
121
|
+
if hasattr(completion, "usage")
|
122
|
+
else {}
|
123
|
+
),
|
93
124
|
)
|
94
125
|
|
95
126
|
except Exception as e:
|
@@ -1,10 +1,12 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: airtrain
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.31
|
4
4
|
Summary: A platform for building and deploying AI agents with structured skills
|
5
5
|
Home-page: https://github.com/rosaboyle/airtrain.dev
|
6
6
|
Author: Dheeraj Pai
|
7
|
-
Author-email: helloworldcmu@gmail.com
|
7
|
+
Author-email: Dheeraj Pai <helloworldcmu@gmail.com>
|
8
|
+
Project-URL: Homepage, https://github.com/rosaboyle/airtrain.dev
|
9
|
+
Project-URL: Documentation, https://docs.airtrain.dev/
|
8
10
|
Classifier: Development Status :: 3 - Alpha
|
9
11
|
Classifier: Intended Audience :: Developers
|
10
12
|
Classifier: License :: OSI Approved :: MIT License
|
@@ -26,15 +28,29 @@ Requires-Dist: boto3>=1.36.6
|
|
26
28
|
Requires-Dist: together>=1.3.13
|
27
29
|
Requires-Dist: anthropic>=0.45.0
|
28
30
|
Requires-Dist: groq>=0.15.0
|
31
|
+
Requires-Dist: cerebras-cloud-sdk>=1.19.0
|
32
|
+
Requires-Dist: google-genai>=1.0.0
|
33
|
+
Requires-Dist: fireworks-ai>=0.15.12
|
34
|
+
Requires-Dist: google-generativeai>=0.8.4
|
35
|
+
Requires-Dist: click>=8.0.0
|
36
|
+
Requires-Dist: rich>=13.3.1
|
37
|
+
Requires-Dist: prompt-toolkit>=3.0.36
|
38
|
+
Requires-Dist: colorama>=0.4.6
|
39
|
+
Requires-Dist: typer>=0.9.0
|
40
|
+
Provides-Extra: dev
|
41
|
+
Requires-Dist: black>=24.10.0; extra == "dev"
|
42
|
+
Requires-Dist: flake8>=7.1.1; extra == "dev"
|
43
|
+
Requires-Dist: isort>=5.13.0; extra == "dev"
|
44
|
+
Requires-Dist: mypy>=1.9.0; extra == "dev"
|
45
|
+
Requires-Dist: pytest>=7.0.0; extra == "dev"
|
46
|
+
Requires-Dist: twine>=4.0.0; extra == "dev"
|
47
|
+
Requires-Dist: build>=0.10.0; extra == "dev"
|
48
|
+
Requires-Dist: types-PyYAML>=6.0; extra == "dev"
|
49
|
+
Requires-Dist: types-requests>=2.31.0; extra == "dev"
|
50
|
+
Requires-Dist: types-Markdown>=3.5.0; extra == "dev"
|
29
51
|
Dynamic: author
|
30
|
-
Dynamic: author-email
|
31
|
-
Dynamic: classifier
|
32
|
-
Dynamic: description
|
33
|
-
Dynamic: description-content-type
|
34
52
|
Dynamic: home-page
|
35
|
-
Dynamic: requires-dist
|
36
53
|
Dynamic: requires-python
|
37
|
-
Dynamic: summary
|
38
54
|
|
39
55
|
# Airtrain
|
40
56
|
|
@@ -167,31 +183,3 @@ Contributions are welcome! Please feel free to submit a Pull Request.
|
|
167
183
|
## License
|
168
184
|
|
169
185
|
This project is licensed under the MIT License - see the LICENSE file for details.
|
170
|
-
|
171
|
-
## Changelog
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
## 0.1.28
|
176
|
-
|
177
|
-
- Bug fix: reasoning to Fireworks structured output.
|
178
|
-
- Added reasoning to Fireworks structured output.
|
179
|
-
|
180
|
-
## 0.1.27
|
181
|
-
|
182
|
-
- Added structured completion skills for Fireworks AI
|
183
|
-
- Added Completion skills for Fireworks AI.
|
184
|
-
- Added Combination skill for Groq and Fireworks AI.
|
185
|
-
- Add completion streaming.
|
186
|
-
- Added strcutured output streaming for Fireworks AI.
|
187
|
-
|
188
|
-
## 0.1.23
|
189
|
-
|
190
|
-
- Added conversation support for Deepseek, Togehter AI, Fireworks AI, Gemini, Groq, Cerebras and Sambanova.
|
191
|
-
- Added Change Log
|
192
|
-
|
193
|
-
|
194
|
-
## Notes
|
195
|
-
|
196
|
-
The changelog format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
197
|
-
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
@@ -1,4 +1,6 @@
|
|
1
|
-
airtrain/__init__.py,sha256=
|
1
|
+
airtrain/__init__.py,sha256=pF1FDnJhZuGJNN6mpZWRCaLVjfwV1uXcp5e8m55u-e0,2099
|
2
|
+
airtrain/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
+
airtrain/cli/main.py,sha256=5tOBEEJILcIjIBB8JL3djgKpVqImaniJBuJyh0DG0Hg,1297
|
2
4
|
airtrain/contrib/__init__.py,sha256=pG-7mJ0pBMqp3Q86mIF9bo1PqoBOVSGlnEK1yY1U1ok,641
|
3
5
|
airtrain/contrib/travel/__init__.py,sha256=clmBodw4nkTA-DsgjVGcXfJGPaWxIpCZDtdO-8RzL0M,811
|
4
6
|
airtrain/contrib/travel/agents.py,sha256=tpQtZ0WUiXBuhvZtc2JlEam5TuR5l-Tndi14YyImDBM,8975
|
@@ -10,19 +12,19 @@ airtrain/core/skills.py,sha256=LljalzeSHK5eQPTAOEAYc5D8Qn1kVSfiz9WgziTD5UM,4688
|
|
10
12
|
airtrain/integrations/__init__.py,sha256=-3Vz2bqAUNvVHEZxFGUv5BfzJZsO_7MRyLifbHsweEE,1488
|
11
13
|
airtrain/integrations/anthropic/__init__.py,sha256=qwlWLDh1rEVizYFbW8430z-f1SxHio7_Gaw5cCTUtoo,274
|
12
14
|
airtrain/integrations/anthropic/credentials.py,sha256=hlTSw9HX66kYNaeQUtn0JjdZQBMNkzzFOJOoLOOzvcY,1246
|
13
|
-
airtrain/integrations/anthropic/skills.py,sha256=
|
15
|
+
airtrain/integrations/anthropic/skills.py,sha256=WV-9254H2VqUAq_7Zr1xG5IhejeC_gQSqyH0hwW1_tY,5870
|
14
16
|
airtrain/integrations/aws/__init__.py,sha256=3x7v2NxpAfI-U-YgwQeH5PtsmUrNLPMfLyUGFLiBjbs,155
|
15
17
|
airtrain/integrations/aws/credentials.py,sha256=nN-daKAl7qOb_VdRpsThG8gN5GeSUkx-ji5E_gF_vYw,1444
|
16
18
|
airtrain/integrations/aws/skills.py,sha256=TQiMXeXRRcJ14fe8Xi7Uk20iS6_INbcznuLGtMorcKY,3870
|
17
19
|
airtrain/integrations/cerebras/__init__.py,sha256=zAD-qV38OzHhMCz1z-NvjjqcYEhURbm8RWTOKHNqbew,174
|
18
20
|
airtrain/integrations/cerebras/credentials.py,sha256=KDEH4r8FGT68L9p34MLZWK65wq_a703pqIF3ODaSbts,694
|
19
|
-
airtrain/integrations/cerebras/skills.py,sha256=
|
21
|
+
airtrain/integrations/cerebras/skills.py,sha256=hmqcnF-nkFk5YJVf8f-TiKBfb8kYCfnC30W67VZ7CKU,4922
|
20
22
|
airtrain/integrations/fireworks/__init__.py,sha256=9pJvP0u1FJbNtB0oHa09mHVJLctELf_c27LOYyDk2ZI,271
|
21
23
|
airtrain/integrations/fireworks/completion_skills.py,sha256=G657xWd7izLOxXq7RmqdupBF4DHqXQgXuhQ-MW7mtqc,5613
|
22
24
|
airtrain/integrations/fireworks/conversation_manager.py,sha256=m6VEHijqpYEYawkKhuHtb8RQxw4kxGWFWdbSK6zGuro,3704
|
23
25
|
airtrain/integrations/fireworks/credentials.py,sha256=UpcwR9V5Hbk5sJbjFDJDbHMRqc90IQSqAvrtJCOvwEo,524
|
24
26
|
airtrain/integrations/fireworks/models.py,sha256=F-MddbLCLAsTjwRr1l6IpJxOegyY4pD7jN9ySPiypSo,593
|
25
|
-
airtrain/integrations/fireworks/requests_skills.py,sha256=
|
27
|
+
airtrain/integrations/fireworks/requests_skills.py,sha256=c84Vy_4EcBrwJfp3jqizzlcja_LsEtvWh59qiaIjukg,8233
|
26
28
|
airtrain/integrations/fireworks/skills.py,sha256=OB4epD4CSTxExUCC1oMJ_8rHLOoftlxf0FUoIVrd4mA,5163
|
27
29
|
airtrain/integrations/fireworks/structured_completion_skills.py,sha256=IXG4gsZDSfuscrmKIHfnyHkBaCV7zlPInaWXb95iC5k,6428
|
28
30
|
airtrain/integrations/fireworks/structured_requests_skills.py,sha256=oRpbKMOcKgY2js16uNkIx866UEwEYSNgFPNYn9cLO3U,8409
|
@@ -32,7 +34,7 @@ airtrain/integrations/google/credentials.py,sha256=KSvWNqW8Mjr4MkysRvUqlrOSGdShN
|
|
32
34
|
airtrain/integrations/google/skills.py,sha256=ytsoksCY4qbfRO9Brnxhc2694fAj0ytnHX20SXS_FOM,4547
|
33
35
|
airtrain/integrations/groq/__init__.py,sha256=B_X2fXbsJfFD6GquKeVCsEJjwd9Ygbq1uEHlV4Jy7YE,154
|
34
36
|
airtrain/integrations/groq/credentials.py,sha256=bdTHykcIeaQ7td8KZlQBPfEFAkvJuxk2f_cbTLPD_I4,844
|
35
|
-
airtrain/integrations/groq/skills.py,sha256=
|
37
|
+
airtrain/integrations/groq/skills.py,sha256=qFyxC_2xZYnByAPo5p2aHbrqhdHYCoIdvDRAauSfnjk,4821
|
36
38
|
airtrain/integrations/ollama/__init__.py,sha256=zMHBsGzViVrvxAeJmfq6r-ZfSE6Dy5QcKLhe4d5fEcM,164
|
37
39
|
airtrain/integrations/ollama/credentials.py,sha256=D7O4kUAb_VHs5s1ncUN9Ezhu5PvLfgj3RifAkB9sEZk,940
|
38
40
|
airtrain/integrations/ollama/skills.py,sha256=M_Un8D5VJ5XtPEq9IClzqV3jCPBoFTSm2ve6EO8W2JU,1556
|
@@ -40,10 +42,10 @@ airtrain/integrations/openai/__init__.py,sha256=K-NY2_T1T6SEOgkpbUA55cWvK2nr2NOJ
|
|
40
42
|
airtrain/integrations/openai/chinese_assistant.py,sha256=MMhv4NBOoEQ0O22ZZtP255rd5ajHC9l6FPWIjpqxBOA,1581
|
41
43
|
airtrain/integrations/openai/credentials.py,sha256=NfRyp1QgEtgm8cxt2-BOLq-6d0X-Pcm80NnfHM8p0FY,1470
|
42
44
|
airtrain/integrations/openai/models_config.py,sha256=bzosqqpDy2AJxu2vGdk2H4voqEGlv7LORR6fpJLhNic,3962
|
43
|
-
airtrain/integrations/openai/skills.py,sha256=
|
45
|
+
airtrain/integrations/openai/skills.py,sha256=gikb9RBH1ggSSUwDDE7t6cg3LZrCAPooXg04MgqAJ-0,8862
|
44
46
|
airtrain/integrations/sambanova/__init__.py,sha256=dp_263iOckM_J9pOEvyqpf3FrejD6-_x33r0edMCTe0,179
|
45
47
|
airtrain/integrations/sambanova/credentials.py,sha256=JyN8sbMCoXuXAjim46aI3LTicBijoemS7Ao0rn4yBJU,824
|
46
|
-
airtrain/integrations/sambanova/skills.py,sha256=
|
48
|
+
airtrain/integrations/sambanova/skills.py,sha256=SZ_GAimMiOCILiNkzyhNflyRR6bdC5r0Tnog19K8geU,4997
|
47
49
|
airtrain/integrations/together/__init__.py,sha256=we4KXn_pUs6Dxo3QcB-t40BSRraQFdKg2nXw7yi2FjM,185
|
48
50
|
airtrain/integrations/together/audio_models_config.py,sha256=GtqfmKR1vJ5x4B3kScvEO3x4exvzwNP78vcGVTk_fBE,1004
|
49
51
|
airtrain/integrations/together/credentials.py,sha256=cYNhyIwgsxm8LfiFfT-omBvgV3mUP6SZeRSukyzzDlI,747
|
@@ -55,9 +57,10 @@ airtrain/integrations/together/models_config.py,sha256=XMKp0Oq1nWWnMMdNAZxkFXmJa
|
|
55
57
|
airtrain/integrations/together/rerank_models_config.py,sha256=coCg0IOG2tU4L2uc2uPtPdoBwGjSc_zQwxENwdDuwHE,1188
|
56
58
|
airtrain/integrations/together/rerank_skill.py,sha256=gjH24hLWCweWKPyyfKZMG3K_g9gWzm80WgiJNjkA9eg,1894
|
57
59
|
airtrain/integrations/together/schemas.py,sha256=pBMrbX67oxPCr-sg4K8_Xqu1DWbaC4uLCloVSascROg,1210
|
58
|
-
airtrain/integrations/together/skills.py,sha256=
|
60
|
+
airtrain/integrations/together/skills.py,sha256=8DwkexMJu1Gm6QmNDfNasYStQ31QsXBbFP99zR-YCf0,7598
|
59
61
|
airtrain/integrations/together/vision_models_config.py,sha256=m28HwYDk2Kup_J-a1FtynIa2ZVcbl37kltfoHnK8zxs,1544
|
60
|
-
airtrain-0.1.
|
61
|
-
airtrain-0.1.
|
62
|
-
airtrain-0.1.
|
63
|
-
airtrain-0.1.
|
62
|
+
airtrain-0.1.31.dist-info/METADATA,sha256=ijqjVYKg0ECRdU5v3k4kWQJsEjkd_oMjfJFCxaDoQEw,5331
|
63
|
+
airtrain-0.1.31.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
64
|
+
airtrain-0.1.31.dist-info/entry_points.txt,sha256=rrJ36IUsyq6n1dSfTWXqVAgpQLPRWDfCqwd6_3B-G0U,52
|
65
|
+
airtrain-0.1.31.dist-info/top_level.txt,sha256=cFWW1vY6VMCb3AGVdz6jBDpZ65xxBRSqlsPyySxTkxY,9
|
66
|
+
airtrain-0.1.31.dist-info/RECORD,,
|
File without changes
|
File without changes
|