airtrain 0.1.67__py3-none-any.whl → 0.1.68__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airtrain/__init__.py +1 -1
- airtrain/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/__init__.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/schemas.cpython-313.pyc +0 -0
- airtrain/core/__pycache__/skills.cpython-313.pyc +0 -0
- airtrain/integrations/openai/skills.py +84 -20
- {airtrain-0.1.67.dist-info → airtrain-0.1.68.dist-info}/METADATA +1 -1
- {airtrain-0.1.67.dist-info → airtrain-0.1.68.dist-info}/RECORD +11 -11
- {airtrain-0.1.67.dist-info → airtrain-0.1.68.dist-info}/WHEEL +0 -0
- {airtrain-0.1.67.dist-info → airtrain-0.1.68.dist-info}/entry_points.txt +0 -0
- {airtrain-0.1.67.dist-info → airtrain-0.1.68.dist-info}/top_level.txt +0 -0
airtrain/__init__.py
CHANGED
Binary file
|
Binary file
|
Binary file
|
Binary file
|
@@ -16,11 +16,13 @@ import numpy as np
|
|
16
16
|
import json
|
17
17
|
from loguru import logger
|
18
18
|
|
19
|
-
from airtrain.core.skills import
|
19
|
+
from airtrain.core.skills import (
|
20
|
+
Skill,
|
21
|
+
ProcessingError,
|
22
|
+
InputValidationError as ValidationError,
|
23
|
+
)
|
20
24
|
from airtrain.core.schemas import InputSchema, OutputSchema
|
21
25
|
from .credentials import OpenAICredentials
|
22
|
-
from ...core.credentials import get_credentials
|
23
|
-
from ...core.errors import ValidationError
|
24
26
|
|
25
27
|
|
26
28
|
class OpenAIInput(InputSchema):
|
@@ -106,13 +108,23 @@ class OpenAIChatSkill(Skill[OpenAIInput, OpenAIOutput]):
|
|
106
108
|
self.validate_input(input_data)
|
107
109
|
messages = self._build_messages(input_data)
|
108
110
|
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
111
|
+
# Create the completion parameters for streaming
|
112
|
+
completion_params = {
|
113
|
+
"model": input_data.model,
|
114
|
+
"messages": messages,
|
115
|
+
"temperature": input_data.temperature,
|
116
|
+
"max_tokens": input_data.max_tokens,
|
117
|
+
"stream": True,
|
118
|
+
}
|
119
|
+
|
120
|
+
# Add tools if provided
|
121
|
+
if input_data.tools and len(input_data.tools) > 0:
|
122
|
+
logger.info(
|
123
|
+
f"Using {len(input_data.tools)} tools with OpenAI streaming"
|
124
|
+
)
|
125
|
+
completion_params["tools"] = input_data.tools
|
126
|
+
|
127
|
+
stream = self.client.chat.completions.create(**completion_params)
|
116
128
|
|
117
129
|
for chunk in stream:
|
118
130
|
if chunk.choices[0].delta.content is not None:
|
@@ -205,22 +217,56 @@ class OpenAIChatSkill(Skill[OpenAIInput, OpenAIOutput]):
|
|
205
217
|
try:
|
206
218
|
self.validate_input(input_data)
|
207
219
|
messages = self._build_messages(input_data)
|
220
|
+
|
221
|
+
# Create the completion parameters
|
222
|
+
completion_params = {
|
223
|
+
"model": input_data.model,
|
224
|
+
"messages": messages,
|
225
|
+
"temperature": input_data.temperature,
|
226
|
+
"max_tokens": input_data.max_tokens,
|
227
|
+
"stream": False,
|
228
|
+
}
|
229
|
+
|
230
|
+
# Add tools if provided
|
231
|
+
if input_data.tools and len(input_data.tools) > 0:
|
232
|
+
logger.info(f"Using {len(input_data.tools)} tools with OpenAI async")
|
233
|
+
completion_params["tools"] = input_data.tools
|
234
|
+
|
208
235
|
completion = await self.async_client.chat.completions.create(
|
209
|
-
|
210
|
-
messages=messages,
|
211
|
-
temperature=input_data.temperature,
|
212
|
-
max_tokens=input_data.max_tokens,
|
236
|
+
**completion_params
|
213
237
|
)
|
238
|
+
|
239
|
+
# Extract tool calls if present
|
240
|
+
tool_calls = None
|
241
|
+
if (
|
242
|
+
hasattr(completion.choices[0].message, "tool_calls")
|
243
|
+
and completion.choices[0].message.tool_calls
|
244
|
+
):
|
245
|
+
tool_calls = []
|
246
|
+
for tool_call in completion.choices[0].message.tool_calls:
|
247
|
+
formatted_tool_call = {
|
248
|
+
"id": tool_call.id,
|
249
|
+
"name": tool_call.function.name,
|
250
|
+
"arguments": tool_call.function.arguments,
|
251
|
+
}
|
252
|
+
tool_calls.append(formatted_tool_call)
|
253
|
+
logger.info(
|
254
|
+
f"Extracted {len(tool_calls)} tool calls from OpenAI async response"
|
255
|
+
)
|
256
|
+
|
214
257
|
return OpenAIOutput(
|
215
|
-
response=completion.choices[0].message.content,
|
258
|
+
response=completion.choices[0].message.content or "",
|
216
259
|
used_model=completion.model,
|
217
260
|
usage={
|
218
261
|
"total_tokens": completion.usage.total_tokens,
|
219
262
|
"prompt_tokens": completion.usage.prompt_tokens,
|
220
263
|
"completion_tokens": completion.usage.completion_tokens,
|
221
264
|
},
|
265
|
+
raw_response=completion,
|
266
|
+
tool_calls=tool_calls,
|
222
267
|
)
|
223
268
|
except Exception as e:
|
269
|
+
logger.error(f"OpenAI async chat failed: {str(e)}")
|
224
270
|
raise ProcessingError(f"OpenAI async chat failed: {str(e)}")
|
225
271
|
|
226
272
|
async def process_stream_async(
|
@@ -230,17 +276,32 @@ class OpenAIChatSkill(Skill[OpenAIInput, OpenAIOutput]):
|
|
230
276
|
try:
|
231
277
|
self.validate_input(input_data)
|
232
278
|
messages = self._build_messages(input_data)
|
279
|
+
|
280
|
+
# Create the completion parameters for streaming
|
281
|
+
completion_params = {
|
282
|
+
"model": input_data.model,
|
283
|
+
"messages": messages,
|
284
|
+
"temperature": input_data.temperature,
|
285
|
+
"max_tokens": input_data.max_tokens,
|
286
|
+
"stream": True,
|
287
|
+
}
|
288
|
+
|
289
|
+
# Add tools if provided
|
290
|
+
if input_data.tools and len(input_data.tools) > 0:
|
291
|
+
logger.info(
|
292
|
+
f"Using {len(input_data.tools)} tools with OpenAI async streaming"
|
293
|
+
)
|
294
|
+
completion_params["tools"] = input_data.tools
|
295
|
+
|
233
296
|
stream = await self.async_client.chat.completions.create(
|
234
|
-
|
235
|
-
messages=messages,
|
236
|
-
temperature=input_data.temperature,
|
237
|
-
max_tokens=input_data.max_tokens,
|
238
|
-
stream=True,
|
297
|
+
**completion_params
|
239
298
|
)
|
299
|
+
|
240
300
|
async for chunk in stream:
|
241
301
|
if chunk.choices[0].delta.content is not None:
|
242
302
|
yield chunk.choices[0].delta.content
|
243
303
|
except Exception as e:
|
304
|
+
logger.error(f"OpenAI async streaming failed: {str(e)}")
|
244
305
|
raise ProcessingError(f"OpenAI async streaming failed: {str(e)}")
|
245
306
|
|
246
307
|
|
@@ -306,6 +367,7 @@ class OpenAIParserSkill(Skill[OpenAIParserInput, OpenAIParserOutput]):
|
|
306
367
|
)
|
307
368
|
|
308
369
|
except Exception as e:
|
370
|
+
logger.error(f"OpenAI parsing failed: {str(e)}")
|
309
371
|
raise ProcessingError(f"OpenAI parsing failed: {str(e)}")
|
310
372
|
|
311
373
|
|
@@ -380,6 +442,7 @@ class OpenAIEmbeddingsSkill(Skill[OpenAIEmbeddingsInput, OpenAIEmbeddingsOutput]
|
|
380
442
|
tokens_used=response.usage.total_tokens,
|
381
443
|
)
|
382
444
|
except Exception as e:
|
445
|
+
logger.error(f"OpenAI embeddings generation failed: {str(e)}")
|
383
446
|
raise ProcessingError(f"OpenAI embeddings generation failed: {str(e)}")
|
384
447
|
|
385
448
|
async def process_async(
|
@@ -411,6 +474,7 @@ class OpenAIEmbeddingsSkill(Skill[OpenAIEmbeddingsInput, OpenAIEmbeddingsOutput]
|
|
411
474
|
tokens_used=response.usage.total_tokens,
|
412
475
|
)
|
413
476
|
except Exception as e:
|
477
|
+
logger.error(f"OpenAI async embeddings generation failed: {str(e)}")
|
414
478
|
raise ProcessingError(
|
415
479
|
f"OpenAI async embeddings generation failed: {str(e)}"
|
416
480
|
)
|
@@ -1,6 +1,6 @@
|
|
1
|
-
airtrain/__init__.py,sha256=
|
1
|
+
airtrain/__init__.py,sha256=4kP3qhcib7cRor8EFs8syt7KyP9FxFEvDHfnzIRpinw,3357
|
2
2
|
airtrain/__main__.py,sha256=EU8ffFmCdC1G-UcHHt0Oo3lB1PGqfC6kwzH39CnYSwU,72
|
3
|
-
airtrain/__pycache__/__init__.cpython-313.pyc,sha256=
|
3
|
+
airtrain/__pycache__/__init__.cpython-313.pyc,sha256=WFA-vgd8933joKp50Rgj4Yt04TM-gTvQS0hRR-VtEdw,2759
|
4
4
|
airtrain/agents/__init__.py,sha256=r6v5_bblxamRgiaCT8CVhyzaDdWohGM7sSjLgIUpA5s,795
|
5
5
|
airtrain/agents/example_agent.py,sha256=0dCS8QXIvUYYkxwyOEMLMdlQ4KgMAerQ56r7NcYGqTw,11681
|
6
6
|
airtrain/agents/groq_agent.py,sha256=s-f-cGSrgsR4Jlrv6xzeg2Z0LAEHByhFMYVbTvu-Bkg,10558
|
@@ -19,9 +19,9 @@ airtrain/core/__init__.py,sha256=9h7iKwTzZocCPc9bU6j8bA02BokteWIOcO1uaqGMcrk,254
|
|
19
19
|
airtrain/core/credentials.py,sha256=J1jd8vLrOfet0GhLI1J44d35o7skjriBsMgpODmXwfo,5592
|
20
20
|
airtrain/core/schemas.py,sha256=MMXrDviC4gRea_QaPpbjgO--B_UKxnD7YrxqZOLJZZU,7003
|
21
21
|
airtrain/core/skills.py,sha256=kIkI1MwIzuAwyoxdnZ5MGOb70BOB-njbVb_csJEdVvc,9244
|
22
|
-
airtrain/core/__pycache__/__init__.cpython-313.pyc,sha256
|
23
|
-
airtrain/core/__pycache__/schemas.cpython-313.pyc,sha256=
|
24
|
-
airtrain/core/__pycache__/skills.cpython-313.pyc,sha256=
|
22
|
+
airtrain/core/__pycache__/__init__.cpython-313.pyc,sha256=-xKNz5rm5DOdR4c9MtqdEkD8AQ5csdTsSDAQlNOT-N8,534
|
23
|
+
airtrain/core/__pycache__/schemas.cpython-313.pyc,sha256=R0VqNuR747I0OUwuw03U8MWpLMpC7N_4wI__v9vhUkM,8327
|
24
|
+
airtrain/core/__pycache__/skills.cpython-313.pyc,sha256=7K1QYJmfqtyXuBqRHOR_zQOV41d2WBprp3UznC1Yu4A,11839
|
25
25
|
airtrain/integrations/__init__.py,sha256=nGG3gZJ_sEaMzi_utsQCzExgnpHUjua8L9TyzofLQnw,2842
|
26
26
|
airtrain/integrations/anthropic/__init__.py,sha256=K741w3v7fWsCknTo38ARqDL0D3HPlwDIvDuuBao9Tto,800
|
27
27
|
airtrain/integrations/anthropic/credentials.py,sha256=hlTSw9HX66kYNaeQUtn0JjdZQBMNkzzFOJOoLOOzvcY,1246
|
@@ -63,7 +63,7 @@ airtrain/integrations/openai/chinese_assistant.py,sha256=F8bMeUUDly7BYG6wO648cAE
|
|
63
63
|
airtrain/integrations/openai/credentials.py,sha256=NfRyp1QgEtgm8cxt2-BOLq-6d0X-Pcm80NnfHM8p0FY,1470
|
64
64
|
airtrain/integrations/openai/list_models.py,sha256=vg8pZwLZ3F2Fx42X18WykpJOzZD9JG-2KJi49XWgSKo,4121
|
65
65
|
airtrain/integrations/openai/models_config.py,sha256=W9mu_z9tCC4ZUKHSJ6Hk4X09TRZLqEhT7TtRY5JEk5g,8007
|
66
|
-
airtrain/integrations/openai/skills.py,sha256=
|
66
|
+
airtrain/integrations/openai/skills.py,sha256=olLCkF29oKjWividPtS9JQ2L9S2nyrD6ziT3Ac1y0iE,18087
|
67
67
|
airtrain/integrations/perplexity/__init__.py,sha256=asVQs-oVXVhnLmAOZ6l9R8KoigeHmw8D5OONb_aGgnY,1189
|
68
68
|
airtrain/integrations/perplexity/credentials.py,sha256=5acl2DcF0dk7DQJWRwmFkBXPQ1zfKN4ARW31dP-4rpQ,1524
|
69
69
|
airtrain/integrations/perplexity/list_models.py,sha256=iOOxQ50BfLc226ps0gSdFl-enxa-SmStiuU2vFTUhJ0,4298
|
@@ -102,8 +102,8 @@ airtrain/tools/registry.py,sha256=K-1H5EipYcDNDx2jdpsEY9gjfV4aNCGI1pY2UsgSpC0,10
|
|
102
102
|
airtrain/tools/search.py,sha256=MJNi17g6aBPSqbF0ChV8ZgMlzz_PoKSPAIpe_dazdt8,15081
|
103
103
|
airtrain/tools/testing.py,sha256=q4ALEPRzukiadY6wFSPY7vA-T1o3XInLhXt18dsf6yY,4397
|
104
104
|
airtrain/tools/weather.py,sha256=cOP79XF2GOHD_TKnwW7OA5DzykixugB06CzCQLIyONQ,2787
|
105
|
-
airtrain-0.1.
|
106
|
-
airtrain-0.1.
|
107
|
-
airtrain-0.1.
|
108
|
-
airtrain-0.1.
|
109
|
-
airtrain-0.1.
|
105
|
+
airtrain-0.1.68.dist-info/METADATA,sha256=Ee7I0VzSRztZCPWcsh5nFcuAq9HSaGspIJBpVgsZzI4,6503
|
106
|
+
airtrain-0.1.68.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
107
|
+
airtrain-0.1.68.dist-info/entry_points.txt,sha256=rrJ36IUsyq6n1dSfTWXqVAgpQLPRWDfCqwd6_3B-G0U,52
|
108
|
+
airtrain-0.1.68.dist-info/top_level.txt,sha256=cFWW1vY6VMCb3AGVdz6jBDpZ65xxBRSqlsPyySxTkxY,9
|
109
|
+
airtrain-0.1.68.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|