airtrain 0.1.62__py3-none-any.whl → 0.1.67__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
airtrain/__init__.py CHANGED
@@ -5,7 +5,7 @@ This library provides a flexible framework for building AI agents
5
5
  that can complete complex tasks using AI models, skills, and tools.
6
6
  """
7
7
 
8
- __version__ = "0.1.62"
8
+ __version__ = "0.1.67"
9
9
 
10
10
  import sys
11
11
 
@@ -161,6 +161,9 @@ class FireworksStructuredCompletionSkill(
161
161
  data = response.json()
162
162
 
163
163
  response_text = data["choices"][0]["text"]
164
+
165
+ response_text = response_text.split("</think>")[-1]
166
+
164
167
  parsed_response = input_data.response_model.model_validate_json(
165
168
  response_text
166
169
  )
@@ -1,12 +1,26 @@
1
- from typing import AsyncGenerator, List, Optional, Dict, TypeVar, Type, Generator, Union
1
+ from typing import (
2
+ AsyncGenerator,
3
+ List,
4
+ Optional,
5
+ Dict,
6
+ TypeVar,
7
+ Type,
8
+ Generator,
9
+ Union,
10
+ Any,
11
+ )
2
12
  from pydantic import Field, BaseModel
3
13
  from openai import OpenAI, AsyncOpenAI
4
14
  from openai.types.chat import ChatCompletionChunk
5
15
  import numpy as np
16
+ import json
17
+ from loguru import logger
6
18
 
7
19
  from airtrain.core.skills import Skill, ProcessingError
8
20
  from airtrain.core.schemas import InputSchema, OutputSchema
9
21
  from .credentials import OpenAICredentials
22
+ from ...core.credentials import get_credentials
23
+ from ...core.errors import ValidationError
10
24
 
11
25
 
12
26
  class OpenAIInput(InputSchema):
@@ -35,6 +49,7 @@ class OpenAIInput(InputSchema):
35
49
  default=False,
36
50
  description="Whether to stream the response token by token",
37
51
  )
52
+ tools: Optional[List[Dict[str, Any]]] = None
38
53
 
39
54
 
40
55
  class OpenAIOutput(OutputSchema):
@@ -43,6 +58,8 @@ class OpenAIOutput(OutputSchema):
43
58
  response: str
44
59
  used_model: str
45
60
  usage: Dict[str, int]
61
+ raw_response: Optional[Any] = None
62
+ tool_calls: Optional[List[Dict[str, Any]]] = None
46
63
 
47
64
 
48
65
  class OpenAIChatSkill(Skill[OpenAIInput, OpenAIOutput]):
@@ -64,6 +81,15 @@ class OpenAIChatSkill(Skill[OpenAIInput, OpenAIOutput]):
64
81
  organization=self.credentials.openai_organization_id,
65
82
  )
66
83
 
84
+ def validate_input(self, input_data: OpenAIInput) -> None:
85
+ """Validate the input before processing."""
86
+ if not input_data.model:
87
+ raise ValidationError("Model name is required")
88
+ if not input_data.user_input and not input_data.conversation_history:
89
+ raise ValidationError(
90
+ "Either user input or conversation history is required"
91
+ )
92
+
67
93
  def _build_messages(self, input_data: OpenAIInput) -> List[Dict[str, str]]:
68
94
  """Build messages list from input data including conversation history."""
69
95
  messages = [{"role": "system", "content": input_data.system_prompt}]
@@ -77,6 +103,7 @@ class OpenAIChatSkill(Skill[OpenAIInput, OpenAIOutput]):
77
103
  def process_stream(self, input_data: OpenAIInput) -> Generator[str, None, None]:
78
104
  """Process the input and stream the response token by token."""
79
105
  try:
106
+ self.validate_input(input_data)
80
107
  messages = self._build_messages(input_data)
81
108
 
82
109
  stream = self.client.chat.completions.create(
@@ -92,45 +119,91 @@ class OpenAIChatSkill(Skill[OpenAIInput, OpenAIOutput]):
92
119
  yield chunk.choices[0].delta.content
93
120
 
94
121
  except Exception as e:
122
+ logger.error(f"OpenAI streaming failed: {str(e)}")
95
123
  raise ProcessingError(f"OpenAI streaming failed: {str(e)}")
96
124
 
97
125
  def process(self, input_data: OpenAIInput) -> OpenAIOutput:
98
126
  """Process the input and return the complete response."""
99
127
  try:
128
+ self.validate_input(input_data)
129
+
100
130
  if input_data.stream:
101
131
  # For streaming, collect the entire response
102
132
  response_chunks = []
103
133
  for chunk in self.process_stream(input_data):
104
134
  response_chunks.append(chunk)
105
135
  response = "".join(response_chunks)
136
+ return OpenAIOutput(
137
+ response=response,
138
+ used_model=input_data.model,
139
+ usage={
140
+ "total_tokens": 0,
141
+ "prompt_tokens": 0,
142
+ "completion_tokens": 0,
143
+ },
144
+ )
106
145
  else:
107
146
  # For non-streaming, use regular completion
108
147
  messages = self._build_messages(input_data)
109
- completion = self.client.chat.completions.create(
110
- model=input_data.model,
111
- messages=messages,
112
- temperature=input_data.temperature,
113
- max_tokens=input_data.max_tokens,
114
- stream=False,
115
- )
116
- response = completion.choices[0].message.content
117
148
 
118
- return OpenAIOutput(
119
- response=response,
120
- used_model=input_data.model,
121
- usage={
122
- "total_tokens": completion.usage.total_tokens,
123
- "prompt_tokens": completion.usage.prompt_tokens,
124
- "completion_tokens": completion.usage.completion_tokens,
125
- },
126
- )
149
+ # Create the completion parameters
150
+ completion_params = {
151
+ "model": input_data.model,
152
+ "messages": messages,
153
+ "temperature": input_data.temperature,
154
+ "max_tokens": input_data.max_tokens,
155
+ "stream": False,
156
+ }
157
+
158
+ # Add tools if provided
159
+ if input_data.tools and len(input_data.tools) > 0:
160
+ logger.info(f"Using {len(input_data.tools)} tools with OpenAI")
161
+ completion_params["tools"] = input_data.tools
162
+
163
+ # Make the API call
164
+ completion = self.client.chat.completions.create(**completion_params)
165
+
166
+ # Extract response content
167
+ response = completion.choices[0].message.content or ""
168
+
169
+ # Extract tool calls if present
170
+ tool_calls = None
171
+ if (
172
+ hasattr(completion.choices[0].message, "tool_calls")
173
+ and completion.choices[0].message.tool_calls
174
+ ):
175
+ tool_calls = []
176
+ for tool_call in completion.choices[0].message.tool_calls:
177
+ formatted_tool_call = {
178
+ "id": tool_call.id,
179
+ "name": tool_call.function.name,
180
+ "arguments": tool_call.function.arguments,
181
+ }
182
+ tool_calls.append(formatted_tool_call)
183
+ logger.info(
184
+ f"Extracted {len(tool_calls)} tool calls from OpenAI response"
185
+ )
186
+
187
+ return OpenAIOutput(
188
+ response=response,
189
+ used_model=input_data.model,
190
+ usage={
191
+ "total_tokens": completion.usage.total_tokens,
192
+ "prompt_tokens": completion.usage.prompt_tokens,
193
+ "completion_tokens": completion.usage.completion_tokens,
194
+ },
195
+ raw_response=completion,
196
+ tool_calls=tool_calls,
197
+ )
127
198
 
128
199
  except Exception as e:
200
+ logger.error(f"OpenAI chat failed: {str(e)}")
129
201
  raise ProcessingError(f"OpenAI chat failed: {str(e)}")
130
202
 
131
203
  async def process_async(self, input_data: OpenAIInput) -> OpenAIOutput:
132
204
  """Async version of process method"""
133
205
  try:
206
+ self.validate_input(input_data)
134
207
  messages = self._build_messages(input_data)
135
208
  completion = await self.async_client.chat.completions.create(
136
209
  model=input_data.model,
@@ -155,6 +228,7 @@ class OpenAIChatSkill(Skill[OpenAIInput, OpenAIOutput]):
155
228
  ) -> AsyncGenerator[str, None]:
156
229
  """Async version of stream processor"""
157
230
  try:
231
+ self.validate_input(input_data)
158
232
  messages = self._build_messages(input_data)
159
233
  stream = await self.async_client.chat.completions.create(
160
234
  model=input_data.model,
@@ -21,6 +21,7 @@ from .network import ApiCallTool
21
21
  from .command import ExecuteCommandTool, FindFilesTool, TerminalNavigationTool
22
22
  from .search import SearchTermTool, WebSearchTool
23
23
  from .testing import RunPytestTool
24
+ from .weather import WeatherTool
24
25
 
25
26
  __all__ = [
26
27
  # Base classes
@@ -42,4 +43,5 @@ __all__ = [
42
43
  "SearchTermTool",
43
44
  "WebSearchTool",
44
45
  "RunPytestTool",
46
+ "WeatherTool",
45
47
  ]
@@ -0,0 +1,88 @@
1
+ """
2
+ Weather tool for AirTrain.
3
+
4
+ This module provides a tool for fetching weather data.
5
+ For demonstration purposes, this returns mock data regardless of the location.
6
+ """
7
+
8
+ import random
9
+ from typing import Dict, Any
10
+ from datetime import datetime
11
+ from loguru import logger
12
+
13
+ from .registry import StatelessTool, register_tool
14
+
15
+
16
+ @register_tool("weather")
17
+ class WeatherTool(StatelessTool):
18
+ """Tool for retrieving weather information for a location."""
19
+
20
+ def __init__(self):
21
+ self.name = "weather"
22
+ self.description = "Get current weather information for a location"
23
+ self.parameters = {
24
+ "type": "object",
25
+ "properties": {
26
+ "location": {
27
+ "type": "string",
28
+ "description": "City and country (e.g., 'New York, USA', 'Paris, France')",
29
+ }
30
+ },
31
+ "required": ["location"],
32
+ "additionalProperties": False,
33
+ }
34
+
35
+ def __call__(self, location: str) -> Dict[str, Any]:
36
+ """
37
+ Execute the weather tool to get mock weather data.
38
+
39
+ Args:
40
+ location: The city and country to get weather for
41
+
42
+ Returns:
43
+ A dictionary containing mock weather data
44
+ """
45
+ logger.info(f"Weather tool called with location: {location}")
46
+
47
+ # For demonstration purposes, return mock data regardless of input
48
+ temp_celsius = random.uniform(15.0, 30.0)
49
+ temp_fahrenheit = (temp_celsius * 9 / 5) + 32
50
+
51
+ conditions = random.choice(
52
+ ["Sunny", "Partly Cloudy", "Cloudy", "Light Rain", "Thunderstorms", "Clear"]
53
+ )
54
+
55
+ humidity = random.randint(30, 90)
56
+ wind_speed = random.uniform(0, 20)
57
+
58
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
59
+
60
+ weather_data = {
61
+ "status": "success",
62
+ "location": location,
63
+ "timestamp": current_time,
64
+ "weather": {
65
+ "temperature": {
66
+ "celsius": round(temp_celsius, 1),
67
+ "fahrenheit": round(temp_fahrenheit, 1),
68
+ },
69
+ "conditions": conditions,
70
+ "humidity": humidity,
71
+ "wind_speed": round(wind_speed, 1),
72
+ "forecast": "This is mock data for demonstration purposes",
73
+ },
74
+ }
75
+
76
+ logger.info(f"Returning mock weather data for {location}")
77
+ return weather_data
78
+
79
+ def to_dict(self) -> Dict[str, Any]:
80
+ """Convert tool to dictionary format for LLM function calling."""
81
+ return {
82
+ "type": "function",
83
+ "function": {
84
+ "name": self.name,
85
+ "description": self.description,
86
+ "parameters": self.parameters,
87
+ },
88
+ }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: airtrain
3
- Version: 0.1.62
3
+ Version: 0.1.67
4
4
  Summary: A platform for building and deploying AI agents with structured skills
5
5
  Home-page: https://github.com/rosaboyle/airtrain.dev
6
6
  Author: Dheeraj Pai
@@ -1,6 +1,6 @@
1
- airtrain/__init__.py,sha256=gTO8GEuN1gIKjo5HlBOURfVnRtEuv61zF-HMxTXQYQk,3357
1
+ airtrain/__init__.py,sha256=Wt0Bf2-pkLkX1tN2gepDmegHlmF7Xq80x3imMBmhEv4,3357
2
2
  airtrain/__main__.py,sha256=EU8ffFmCdC1G-UcHHt0Oo3lB1PGqfC6kwzH39CnYSwU,72
3
- airtrain/__pycache__/__init__.cpython-313.pyc,sha256=gNzuKeI25qKWqhRq6MRzf8Otl1qJEekuX_QU0HAx7bE,2759
3
+ airtrain/__pycache__/__init__.cpython-313.pyc,sha256=SV1Rq3CsgY9QTiEgul7-YxCkdIEMymkD8KE9TVk5yxs,2759
4
4
  airtrain/agents/__init__.py,sha256=r6v5_bblxamRgiaCT8CVhyzaDdWohGM7sSjLgIUpA5s,795
5
5
  airtrain/agents/example_agent.py,sha256=0dCS8QXIvUYYkxwyOEMLMdlQ4KgMAerQ56r7NcYGqTw,11681
6
6
  airtrain/agents/groq_agent.py,sha256=s-f-cGSrgsR4Jlrv6xzeg2Z0LAEHByhFMYVbTvu-Bkg,10558
@@ -19,9 +19,9 @@ airtrain/core/__init__.py,sha256=9h7iKwTzZocCPc9bU6j8bA02BokteWIOcO1uaqGMcrk,254
19
19
  airtrain/core/credentials.py,sha256=J1jd8vLrOfet0GhLI1J44d35o7skjriBsMgpODmXwfo,5592
20
20
  airtrain/core/schemas.py,sha256=MMXrDviC4gRea_QaPpbjgO--B_UKxnD7YrxqZOLJZZU,7003
21
21
  airtrain/core/skills.py,sha256=kIkI1MwIzuAwyoxdnZ5MGOb70BOB-njbVb_csJEdVvc,9244
22
- airtrain/core/__pycache__/__init__.cpython-313.pyc,sha256=aHwrmX5Hf3s9pSa8V0PH3ytqldRjB_X9AWwA6IDpzyk,534
23
- airtrain/core/__pycache__/schemas.cpython-313.pyc,sha256=WpKRuARzTVGt_NDMApl8GNIQgF_jZcHV34z-KYs3g5g,8327
24
- airtrain/core/__pycache__/skills.cpython-313.pyc,sha256=4EppdsB6rISIJNF8CbyCPHXlMCX1JKETeKxOS8hENLw,11839
22
+ airtrain/core/__pycache__/__init__.cpython-313.pyc,sha256=TQKjYutr-_x05pI-yyJShDLaNnCQMH-ZsGwSWBH_wi4,534
23
+ airtrain/core/__pycache__/schemas.cpython-313.pyc,sha256=9OaxWCuPLgfeedvDnkKoEwNuLrxyRENjAOqnGAXpTK4,8327
24
+ airtrain/core/__pycache__/skills.cpython-313.pyc,sha256=dUPAFMzeilfh7RYyQbZPaFnI8Tg0GTfI7CD4vGXOBtQ,11839
25
25
  airtrain/integrations/__init__.py,sha256=nGG3gZJ_sEaMzi_utsQCzExgnpHUjua8L9TyzofLQnw,2842
26
26
  airtrain/integrations/anthropic/__init__.py,sha256=K741w3v7fWsCknTo38ARqDL0D3HPlwDIvDuuBao9Tto,800
27
27
  airtrain/integrations/anthropic/credentials.py,sha256=hlTSw9HX66kYNaeQUtn0JjdZQBMNkzzFOJOoLOOzvcY,1246
@@ -45,7 +45,7 @@ airtrain/integrations/fireworks/list_models.py,sha256=o4fP0K3qstBopO7va2LysLp4_K
45
45
  airtrain/integrations/fireworks/models.py,sha256=yo4xtweSi4qQftg04r4naRddx3KjU9Jluzqf5C7V9f4,4626
46
46
  airtrain/integrations/fireworks/requests_skills.py,sha256=h6HRV5dGvV7t3zyjD-awW47RyeDbu8onNevhcgSSy94,8235
47
47
  airtrain/integrations/fireworks/skills.py,sha256=Ns1tXXTVtTeeVYadzm4dnmmOboo430WTMu2o56oWTDc,7156
48
- airtrain/integrations/fireworks/structured_completion_skills.py,sha256=airYakYWXzYRS9nfNfrH90N3eeN8YW7GaY3ygLSiBO8,6622
48
+ airtrain/integrations/fireworks/structured_completion_skills.py,sha256=1cOt6aGYJGrO0Hy74jbLMzHGlgsUfynCSFLBf_6_Q9U,6692
49
49
  airtrain/integrations/fireworks/structured_requests_skills.py,sha256=uQR-nygtWmdGTwvU-aUdMNOMit_PiBVPYRa80ZloHLs,11852
50
50
  airtrain/integrations/fireworks/structured_skills.py,sha256=1wZ_7QDUhKWCSv_1lSEF6VnAqEeEA3jWHq7n0fWicgw,3897
51
51
  airtrain/integrations/google/__init__.py,sha256=ElwgcXfbg_gGMm6zbkMXCQPFKZUb-yTJk986o19A7Cs,214
@@ -63,7 +63,7 @@ airtrain/integrations/openai/chinese_assistant.py,sha256=F8bMeUUDly7BYG6wO648cAE
63
63
  airtrain/integrations/openai/credentials.py,sha256=NfRyp1QgEtgm8cxt2-BOLq-6d0X-Pcm80NnfHM8p0FY,1470
64
64
  airtrain/integrations/openai/list_models.py,sha256=vg8pZwLZ3F2Fx42X18WykpJOzZD9JG-2KJi49XWgSKo,4121
65
65
  airtrain/integrations/openai/models_config.py,sha256=W9mu_z9tCC4ZUKHSJ6Hk4X09TRZLqEhT7TtRY5JEk5g,8007
66
- airtrain/integrations/openai/skills.py,sha256=kEDe5q0Zlv_X-JGOYtb552ktb3aQQYVUYczVwMH0jxA,12823
66
+ airtrain/integrations/openai/skills.py,sha256=7Rn5byvKVhk-OQI87Z-iCEP9ujrOyqcMM5K1gv8I1CM,15604
67
67
  airtrain/integrations/perplexity/__init__.py,sha256=asVQs-oVXVhnLmAOZ6l9R8KoigeHmw8D5OONb_aGgnY,1189
68
68
  airtrain/integrations/perplexity/credentials.py,sha256=5acl2DcF0dk7DQJWRwmFkBXPQ1zfKN4ARW31dP-4rpQ,1524
69
69
  airtrain/integrations/perplexity/list_models.py,sha256=iOOxQ50BfLc226ps0gSdFl-enxa-SmStiuU2vFTUhJ0,4298
@@ -94,15 +94,16 @@ airtrain/integrations/together/vision_models_config.py,sha256=m28HwYDk2Kup_J-a1F
94
94
  airtrain/telemetry/__init__.py,sha256=_xDHzSmQvRCqihT0QTmF7bS9yKEl2LaZ3Zq05hXaI3k,1108
95
95
  airtrain/telemetry/service.py,sha256=6IB2-FL3fGsYYgMIVQ9MNCp9UlSE5cVhLlB3VBYkAnY,5592
96
96
  airtrain/telemetry/views.py,sha256=qob1swyLNEk_UIpDi5VTwMDsEsGZhJheFQrGbP8T5hw,8115
97
- airtrain/tools/__init__.py,sha256=AauO_EEBcK09mkyCuvvVK_Oz0L5DOrnuySViWXCOt6Y,1021
97
+ airtrain/tools/__init__.py,sha256=dL_CsjCD1uut0T-n6-LVXaH9k3Le2kvTIUDoQzjd1Hw,1073
98
98
  airtrain/tools/command.py,sha256=dxvs6RzppjWmkUe1oMtxOc7w2mFOGFFZ9Gylwnm37Sw,13355
99
99
  airtrain/tools/filesystem.py,sha256=-YYdHj_KeSWPYXeRhWhIX9s_KujVA1R5tF3r93zRVTU,6324
100
100
  airtrain/tools/network.py,sha256=YR0AtMXDXkhCsXcx7_t2d12ItnKY8XXTmyP1kdj2M4c,3883
101
101
  airtrain/tools/registry.py,sha256=K-1H5EipYcDNDx2jdpsEY9gjfV4aNCGI1pY2UsgSpC0,10246
102
102
  airtrain/tools/search.py,sha256=MJNi17g6aBPSqbF0ChV8ZgMlzz_PoKSPAIpe_dazdt8,15081
103
103
  airtrain/tools/testing.py,sha256=q4ALEPRzukiadY6wFSPY7vA-T1o3XInLhXt18dsf6yY,4397
104
- airtrain-0.1.62.dist-info/METADATA,sha256=fAOQQht25N594JsKGN6sFuiqjVMPjKMtly_RP6JW9co,6503
105
- airtrain-0.1.62.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
106
- airtrain-0.1.62.dist-info/entry_points.txt,sha256=rrJ36IUsyq6n1dSfTWXqVAgpQLPRWDfCqwd6_3B-G0U,52
107
- airtrain-0.1.62.dist-info/top_level.txt,sha256=cFWW1vY6VMCb3AGVdz6jBDpZ65xxBRSqlsPyySxTkxY,9
108
- airtrain-0.1.62.dist-info/RECORD,,
104
+ airtrain/tools/weather.py,sha256=cOP79XF2GOHD_TKnwW7OA5DzykixugB06CzCQLIyONQ,2787
105
+ airtrain-0.1.67.dist-info/METADATA,sha256=9xpWaDyUUyso_W_Z49Ff4hSXqKwM3zStRt61EvSMGng,6503
106
+ airtrain-0.1.67.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
107
+ airtrain-0.1.67.dist-info/entry_points.txt,sha256=rrJ36IUsyq6n1dSfTWXqVAgpQLPRWDfCqwd6_3B-G0U,52
108
+ airtrain-0.1.67.dist-info/top_level.txt,sha256=cFWW1vY6VMCb3AGVdz6jBDpZ65xxBRSqlsPyySxTkxY,9
109
+ airtrain-0.1.67.dist-info/RECORD,,