athena-intelligence 0.1.125__py3-none-any.whl → 0.1.126__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- athena/__init__.py +3 -0
- athena/agents/client.py +88 -36
- athena/agents/drive/client.py +80 -32
- athena/agents/general/client.py +222 -91
- athena/agents/research/client.py +80 -32
- athena/agents/sql/client.py +80 -32
- athena/base_client.py +13 -11
- athena/client.py +161 -61
- athena/core/__init__.py +21 -4
- athena/core/client_wrapper.py +9 -10
- athena/core/file.py +37 -8
- athena/core/http_client.py +97 -41
- athena/core/jsonable_encoder.py +33 -31
- athena/core/pydantic_utilities.py +272 -4
- athena/core/query_encoder.py +38 -13
- athena/core/request_options.py +5 -2
- athena/core/serialization.py +272 -0
- athena/errors/internal_server_error.py +2 -3
- athena/errors/unauthorized_error.py +2 -3
- athena/errors/unprocessable_entity_error.py +2 -3
- athena/query/client.py +208 -58
- athena/tools/calendar/client.py +82 -30
- athena/tools/client.py +576 -184
- athena/tools/email/client.py +117 -43
- athena/tools/structured_data_extractor/client.py +118 -67
- athena/tools/tasks/client.py +41 -17
- athena/types/asset_node.py +14 -24
- athena/types/asset_not_found_error.py +11 -21
- athena/types/chunk.py +11 -21
- athena/types/chunk_content_item.py +21 -41
- athena/types/chunk_result.py +13 -23
- athena/types/custom_agent_response.py +12 -22
- athena/types/data_frame_request_out.py +11 -21
- athena/types/data_frame_unknown_format_error.py +11 -21
- athena/types/document_chunk.py +12 -22
- athena/types/drive_agent_response.py +12 -22
- athena/types/file_chunk_request_out.py +11 -21
- athena/types/file_too_large_error.py +11 -21
- athena/types/folder_response.py +11 -21
- athena/types/general_agent_config.py +11 -21
- athena/types/general_agent_config_enabled_tools_item.py +0 -1
- athena/types/general_agent_request.py +13 -23
- athena/types/general_agent_response.py +12 -22
- athena/types/image_url_content.py +11 -21
- athena/types/parent_folder_error.py +11 -21
- athena/types/prompt_message.py +12 -22
- athena/types/research_agent_response.py +12 -22
- athena/types/save_asset_request_out.py +11 -21
- athena/types/sql_agent_response.py +13 -23
- athena/types/structured_data_extractor_response.py +15 -25
- athena/types/text_content.py +11 -21
- athena/types/tool.py +1 -13
- athena/types/type.py +1 -21
- athena/version.py +0 -1
- {athena_intelligence-0.1.125.dist-info → athena_intelligence-0.1.126.dist-info}/METADATA +12 -4
- athena_intelligence-0.1.126.dist-info/RECORD +87 -0
- {athena_intelligence-0.1.125.dist-info → athena_intelligence-0.1.126.dist-info}/WHEEL +1 -1
- athena_intelligence-0.1.125.dist-info/RECORD +0 -86
athena/agents/sql/client.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
# This file was auto-generated by Fern from our API Definition.
|
2
2
|
|
3
3
|
import typing
|
4
|
-
from
|
5
|
-
|
6
|
-
from ...core.api_error import ApiError
|
7
|
-
from ...core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
|
8
|
-
from ...core.pydantic_utilities import pydantic_v1
|
4
|
+
from ...core.client_wrapper import SyncClientWrapper
|
9
5
|
from ...core.request_options import RequestOptions
|
10
|
-
from ...errors.unprocessable_entity_error import UnprocessableEntityError
|
11
6
|
from ...types.sql_agent_response import SqlAgentResponse
|
7
|
+
from ...core.pydantic_utilities import parse_obj_as
|
8
|
+
from ...errors.unprocessable_entity_error import UnprocessableEntityError
|
9
|
+
from json.decoder import JSONDecodeError
|
10
|
+
from ...core.api_error import ApiError
|
11
|
+
from ...core.client_wrapper import AsyncClientWrapper
|
12
12
|
|
13
13
|
# this is used as the default value for optional parameters
|
14
14
|
OMIT = typing.cast(typing.Any, ...)
|
@@ -21,19 +21,19 @@ class SqlClient:
|
|
21
21
|
def invoke(
|
22
22
|
self,
|
23
23
|
*,
|
24
|
-
config: typing.Dict[str, typing.Any],
|
25
|
-
messages: typing.Sequence[typing.Dict[str, typing.Any]],
|
26
|
-
request_options: typing.Optional[RequestOptions] = None
|
24
|
+
config: typing.Dict[str, typing.Optional[typing.Any]],
|
25
|
+
messages: typing.Sequence[typing.Dict[str, typing.Optional[typing.Any]]],
|
26
|
+
request_options: typing.Optional[RequestOptions] = None,
|
27
27
|
) -> SqlAgentResponse:
|
28
28
|
"""
|
29
29
|
Coming soon! Generate, execute, and test SQL queries. Returns an asset ID for the query object.
|
30
30
|
|
31
31
|
Parameters
|
32
32
|
----------
|
33
|
-
config : typing.Dict[str, typing.Any]
|
33
|
+
config : typing.Dict[str, typing.Optional[typing.Any]]
|
34
34
|
Configuration for the SQL agent including database connection details and query parameters
|
35
35
|
|
36
|
-
messages : typing.Sequence[typing.Dict[str, typing.Any]]
|
36
|
+
messages : typing.Sequence[typing.Dict[str, typing.Optional[typing.Any]]]
|
37
37
|
The messages to send to the SQL agent
|
38
38
|
|
39
39
|
request_options : typing.Optional[RequestOptions]
|
@@ -46,7 +46,7 @@ class SqlClient:
|
|
46
46
|
|
47
47
|
Examples
|
48
48
|
--------
|
49
|
-
from athena
|
49
|
+
from athena import Athena
|
50
50
|
|
51
51
|
client = Athena(
|
52
52
|
api_key="YOUR_API_KEY",
|
@@ -59,15 +59,35 @@ class SqlClient:
|
|
59
59
|
_response = self._client_wrapper.httpx_client.request(
|
60
60
|
"api/v0/agents/sql/invoke",
|
61
61
|
method="POST",
|
62
|
-
json={
|
62
|
+
json={
|
63
|
+
"config": config,
|
64
|
+
"messages": messages,
|
65
|
+
},
|
66
|
+
headers={
|
67
|
+
"content-type": "application/json",
|
68
|
+
},
|
63
69
|
request_options=request_options,
|
64
70
|
omit=OMIT,
|
65
71
|
)
|
66
|
-
if 200 <= _response.status_code < 300:
|
67
|
-
return pydantic_v1.parse_obj_as(SqlAgentResponse, _response.json()) # type: ignore
|
68
|
-
if _response.status_code == 422:
|
69
|
-
raise UnprocessableEntityError(pydantic_v1.parse_obj_as(typing.Any, _response.json())) # type: ignore
|
70
72
|
try:
|
73
|
+
if 200 <= _response.status_code < 300:
|
74
|
+
return typing.cast(
|
75
|
+
SqlAgentResponse,
|
76
|
+
parse_obj_as(
|
77
|
+
type_=SqlAgentResponse, # type: ignore
|
78
|
+
object_=_response.json(),
|
79
|
+
),
|
80
|
+
)
|
81
|
+
if _response.status_code == 422:
|
82
|
+
raise UnprocessableEntityError(
|
83
|
+
typing.cast(
|
84
|
+
typing.Optional[typing.Any],
|
85
|
+
parse_obj_as(
|
86
|
+
type_=typing.Optional[typing.Any], # type: ignore
|
87
|
+
object_=_response.json(),
|
88
|
+
),
|
89
|
+
)
|
90
|
+
)
|
71
91
|
_response_json = _response.json()
|
72
92
|
except JSONDecodeError:
|
73
93
|
raise ApiError(status_code=_response.status_code, body=_response.text)
|
@@ -81,19 +101,19 @@ class AsyncSqlClient:
|
|
81
101
|
async def invoke(
|
82
102
|
self,
|
83
103
|
*,
|
84
|
-
config: typing.Dict[str, typing.Any],
|
85
|
-
messages: typing.Sequence[typing.Dict[str, typing.Any]],
|
86
|
-
request_options: typing.Optional[RequestOptions] = None
|
104
|
+
config: typing.Dict[str, typing.Optional[typing.Any]],
|
105
|
+
messages: typing.Sequence[typing.Dict[str, typing.Optional[typing.Any]]],
|
106
|
+
request_options: typing.Optional[RequestOptions] = None,
|
87
107
|
) -> SqlAgentResponse:
|
88
108
|
"""
|
89
109
|
Coming soon! Generate, execute, and test SQL queries. Returns an asset ID for the query object.
|
90
110
|
|
91
111
|
Parameters
|
92
112
|
----------
|
93
|
-
config : typing.Dict[str, typing.Any]
|
113
|
+
config : typing.Dict[str, typing.Optional[typing.Any]]
|
94
114
|
Configuration for the SQL agent including database connection details and query parameters
|
95
115
|
|
96
|
-
messages : typing.Sequence[typing.Dict[str, typing.Any]]
|
116
|
+
messages : typing.Sequence[typing.Dict[str, typing.Optional[typing.Any]]]
|
97
117
|
The messages to send to the SQL agent
|
98
118
|
|
99
119
|
request_options : typing.Optional[RequestOptions]
|
@@ -106,28 +126,56 @@ class AsyncSqlClient:
|
|
106
126
|
|
107
127
|
Examples
|
108
128
|
--------
|
109
|
-
|
129
|
+
import asyncio
|
130
|
+
|
131
|
+
from athena import AsyncAthena
|
110
132
|
|
111
133
|
client = AsyncAthena(
|
112
134
|
api_key="YOUR_API_KEY",
|
113
135
|
)
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
136
|
+
|
137
|
+
|
138
|
+
async def main() -> None:
|
139
|
+
await client.agents.sql.invoke(
|
140
|
+
config={"key": "value"},
|
141
|
+
messages=[{"key": "value"}],
|
142
|
+
)
|
143
|
+
|
144
|
+
|
145
|
+
asyncio.run(main())
|
118
146
|
"""
|
119
147
|
_response = await self._client_wrapper.httpx_client.request(
|
120
148
|
"api/v0/agents/sql/invoke",
|
121
149
|
method="POST",
|
122
|
-
json={
|
150
|
+
json={
|
151
|
+
"config": config,
|
152
|
+
"messages": messages,
|
153
|
+
},
|
154
|
+
headers={
|
155
|
+
"content-type": "application/json",
|
156
|
+
},
|
123
157
|
request_options=request_options,
|
124
158
|
omit=OMIT,
|
125
159
|
)
|
126
|
-
if 200 <= _response.status_code < 300:
|
127
|
-
return pydantic_v1.parse_obj_as(SqlAgentResponse, _response.json()) # type: ignore
|
128
|
-
if _response.status_code == 422:
|
129
|
-
raise UnprocessableEntityError(pydantic_v1.parse_obj_as(typing.Any, _response.json())) # type: ignore
|
130
160
|
try:
|
161
|
+
if 200 <= _response.status_code < 300:
|
162
|
+
return typing.cast(
|
163
|
+
SqlAgentResponse,
|
164
|
+
parse_obj_as(
|
165
|
+
type_=SqlAgentResponse, # type: ignore
|
166
|
+
object_=_response.json(),
|
167
|
+
),
|
168
|
+
)
|
169
|
+
if _response.status_code == 422:
|
170
|
+
raise UnprocessableEntityError(
|
171
|
+
typing.cast(
|
172
|
+
typing.Optional[typing.Any],
|
173
|
+
parse_obj_as(
|
174
|
+
type_=typing.Optional[typing.Any], # type: ignore
|
175
|
+
object_=_response.json(),
|
176
|
+
),
|
177
|
+
)
|
178
|
+
)
|
131
179
|
_response_json = _response.json()
|
132
180
|
except JSONDecodeError:
|
133
181
|
raise ApiError(status_code=_response.status_code, body=_response.text)
|
athena/base_client.py
CHANGED
@@ -1,14 +1,16 @@
|
|
1
1
|
# This file was auto-generated by Fern from our API Definition.
|
2
2
|
|
3
3
|
import typing
|
4
|
-
|
5
|
-
import httpx
|
6
|
-
|
7
|
-
from .agents.client import AgentsClient, AsyncAgentsClient
|
8
|
-
from .core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
|
9
4
|
from .environment import AthenaEnvironment
|
10
|
-
|
11
|
-
from .
|
5
|
+
import httpx
|
6
|
+
from .core.client_wrapper import SyncClientWrapper
|
7
|
+
from .agents.client import AgentsClient
|
8
|
+
from .query.client import QueryClient
|
9
|
+
from .tools.client import ToolsClient
|
10
|
+
from .core.client_wrapper import AsyncClientWrapper
|
11
|
+
from .agents.client import AsyncAgentsClient
|
12
|
+
from .query.client import AsyncQueryClient
|
13
|
+
from .tools.client import AsyncToolsClient
|
12
14
|
|
13
15
|
|
14
16
|
class BaseAthena:
|
@@ -41,7 +43,7 @@ class BaseAthena:
|
|
41
43
|
|
42
44
|
Examples
|
43
45
|
--------
|
44
|
-
from athena
|
46
|
+
from athena import Athena
|
45
47
|
|
46
48
|
client = Athena(
|
47
49
|
api_key="YOUR_API_KEY",
|
@@ -56,7 +58,7 @@ class BaseAthena:
|
|
56
58
|
api_key: str,
|
57
59
|
timeout: typing.Optional[float] = None,
|
58
60
|
follow_redirects: typing.Optional[bool] = True,
|
59
|
-
httpx_client: typing.Optional[httpx.Client] = None
|
61
|
+
httpx_client: typing.Optional[httpx.Client] = None,
|
60
62
|
):
|
61
63
|
_defaulted_timeout = timeout if timeout is not None else 60 if httpx_client is None else None
|
62
64
|
self._client_wrapper = SyncClientWrapper(
|
@@ -104,7 +106,7 @@ class AsyncBaseAthena:
|
|
104
106
|
|
105
107
|
Examples
|
106
108
|
--------
|
107
|
-
from athena
|
109
|
+
from athena import AsyncAthena
|
108
110
|
|
109
111
|
client = AsyncAthena(
|
110
112
|
api_key="YOUR_API_KEY",
|
@@ -119,7 +121,7 @@ class AsyncBaseAthena:
|
|
119
121
|
api_key: str,
|
120
122
|
timeout: typing.Optional[float] = None,
|
121
123
|
follow_redirects: typing.Optional[bool] = True,
|
122
|
-
httpx_client: typing.Optional[httpx.AsyncClient] = None
|
124
|
+
httpx_client: typing.Optional[httpx.AsyncClient] = None,
|
123
125
|
):
|
124
126
|
_defaulted_timeout = timeout if timeout is not None else 60 if httpx_client is None else None
|
125
127
|
self._client_wrapper = AsyncClientWrapper(
|
athena/client.py
CHANGED
@@ -7,8 +7,9 @@ import typing
|
|
7
7
|
import warnings
|
8
8
|
|
9
9
|
import httpx
|
10
|
-
from typing import cast, List, Tuple, Union
|
10
|
+
from typing import cast, List, Tuple, Union, Optional
|
11
11
|
from typing_extensions import TypeVar, ParamSpec
|
12
|
+
from langserve import RemoteRunnable
|
12
13
|
|
13
14
|
from . import core
|
14
15
|
from .base_client import BaseAthena, AsyncBaseAthena
|
@@ -24,23 +25,23 @@ if typing.TYPE_CHECKING:
|
|
24
25
|
|
25
26
|
P = ParamSpec("P")
|
26
27
|
T = TypeVar("T")
|
27
|
-
U = TypeVar(
|
28
|
+
U = TypeVar("U")
|
28
29
|
|
29
30
|
|
30
31
|
def _inherit_signature_and_doc(
|
31
|
-
f: typing.Callable[P, T],
|
32
|
-
replace_in_doc: typing.Dict[str, str]
|
32
|
+
f: typing.Callable[P, T], replace_in_doc: typing.Dict[str, str]
|
33
33
|
) -> typing.Callable[..., typing.Callable[P, U]]:
|
34
34
|
def decorator(decorated):
|
35
35
|
for old, new in replace_in_doc.items():
|
36
36
|
assert old in f.__doc__
|
37
37
|
decorated.__doc__ = f.__doc__.replace(old, new)
|
38
38
|
return decorated
|
39
|
+
|
39
40
|
return decorator
|
40
41
|
|
41
42
|
|
42
43
|
class SpecialEnvironments(enum.Enum):
|
43
|
-
AUTODETECT_ENVIRONMENT =
|
44
|
+
AUTODETECT_ENVIRONMENT = "AUTO"
|
44
45
|
|
45
46
|
|
46
47
|
@dataclasses.dataclass
|
@@ -53,7 +54,7 @@ class AthenaAsset:
|
|
53
54
|
if self.media_type == "application/sql":
|
54
55
|
# it is safe to import IPython in `_repr_mimebundle_`
|
55
56
|
# as this is only intended to be invoked by IPython.
|
56
|
-
from IPython import display
|
57
|
+
from IPython import display # type: ignore[import]
|
57
58
|
|
58
59
|
code = display.Code(
|
59
60
|
data=self.data.decode(),
|
@@ -84,17 +85,19 @@ class WrappedToolsClient(ToolsClient):
|
|
84
85
|
file_io = client.tools.get_file(asset_id="asset_id")
|
85
86
|
pl.read_csv(file_io)
|
86
87
|
"""
|
87
|
-
file_bytes = b
|
88
|
+
file_bytes = b"".join(self.raw_data(asset_id=asset_id))
|
88
89
|
bytes_io = io.BytesIO(file_bytes)
|
89
90
|
return bytes_io
|
90
91
|
|
91
|
-
@_inherit_signature_and_doc(
|
92
|
-
|
92
|
+
@_inherit_signature_and_doc(
|
93
|
+
ToolsClient.data_frame, {"DataFrameRequestOut": "pd.DataFrame"}
|
94
|
+
)
|
95
|
+
def data_frame(self, *, asset_id: str, **kwargs) -> "pd.DataFrame":
|
93
96
|
_check_pandas_installed()
|
94
97
|
model = super().data_frame(asset_id=asset_id, **kwargs)
|
95
98
|
return _read_json_frame(model)
|
96
99
|
|
97
|
-
def read_data_frame(self, asset_id: str, *args, **kwargs) ->
|
100
|
+
def read_data_frame(self, asset_id: str, *args, **kwargs) -> "pd.DataFrame":
|
98
101
|
"""
|
99
102
|
Parameters
|
100
103
|
----------
|
@@ -119,8 +122,7 @@ class WrappedToolsClient(ToolsClient):
|
|
119
122
|
file_bytes, media_type = self._get_file_and_media_type(asset_id=asset_id)
|
120
123
|
return _to_pandas_df(file_bytes, *args, media_type=media_type, **kwargs)
|
121
124
|
|
122
|
-
|
123
|
-
def save_asset( # type: ignore[override]
|
125
|
+
def save_asset( # type: ignore[override]
|
124
126
|
self,
|
125
127
|
asset_object: Union["pd.DataFrame", "pd.Series", core.File],
|
126
128
|
*,
|
@@ -157,10 +159,10 @@ class WrappedToolsClient(ToolsClient):
|
|
157
159
|
client = Athena(api_key="YOUR_API_KEY")
|
158
160
|
client.tools.save_asset(df)
|
159
161
|
"""
|
160
|
-
asset_object = _convert_asset_object(
|
161
|
-
|
162
|
-
file=asset_object, parent_folder_id=parent_folder_id
|
162
|
+
asset_object = _convert_asset_object(
|
163
|
+
asset_object=asset_object, name=name, **kwargs
|
163
164
|
)
|
165
|
+
return super().save_asset(file=asset_object, parent_folder_id=parent_folder_id)
|
164
166
|
|
165
167
|
def _get_file_and_media_type(self, asset_id: str) -> Tuple[io.BytesIO, str]:
|
166
168
|
"""
|
@@ -236,22 +238,26 @@ class WrappedToolsClient(ToolsClient):
|
|
236
238
|
return _to_pandas_df(file_bytes, media_type=media_type)
|
237
239
|
|
238
240
|
raise NotImplementedError("Assets of `{media_type}` type are not yet supported")
|
239
|
-
|
241
|
+
|
240
242
|
|
241
243
|
class WrappedQueryClient(QueryClient):
|
242
244
|
|
243
|
-
@_inherit_signature_and_doc(
|
244
|
-
|
245
|
+
@_inherit_signature_and_doc(
|
246
|
+
QueryClient.execute, {"DataFrameRequestOut": "pd.DataFrame"}
|
247
|
+
)
|
248
|
+
def execute(
|
249
|
+
self, *, sql_command: str, database_asset_ids: Union[str, List[str]], **kwargs
|
250
|
+
) -> "pd.DataFrame":
|
245
251
|
_check_pandas_installed()
|
246
252
|
model = super().execute(
|
247
|
-
sql_command=sql_command,
|
248
|
-
database_asset_ids=database_asset_ids,
|
249
|
-
**kwargs
|
253
|
+
sql_command=sql_command, database_asset_ids=database_asset_ids, **kwargs
|
250
254
|
)
|
251
255
|
return _read_json_frame(model)
|
252
256
|
|
253
|
-
@_inherit_signature_and_doc(
|
254
|
-
|
257
|
+
@_inherit_signature_and_doc(
|
258
|
+
QueryClient.execute_snippet, {"DataFrameRequestOut": "pd.DataFrame"}
|
259
|
+
)
|
260
|
+
def execute_snippet(self, *, snippet_asset_id: str, **kwargs) -> "pd.DataFrame":
|
255
261
|
_check_pandas_installed()
|
256
262
|
model = super().execute_snippet(snippet_asset_id=snippet_asset_id, **kwargs)
|
257
263
|
return _read_json_frame(model)
|
@@ -261,12 +267,11 @@ def _add_docs_for_async_variant(obj):
|
|
261
267
|
def decorator(decorated):
|
262
268
|
doc = obj.__doc__
|
263
269
|
name = obj.__name__
|
264
|
-
decorated.__doc__ = (
|
265
|
-
|
266
|
-
|
267
|
-
.replace(f'client.tools.{name}', f'await client.tools.{name}')
|
268
|
-
)
|
270
|
+
decorated.__doc__ = doc.replace(
|
271
|
+
"client = Athena", "client = AsyncAthena"
|
272
|
+
).replace(f"client.tools.{name}", f"await client.tools.{name}")
|
269
273
|
return decorated
|
274
|
+
|
270
275
|
return decorator
|
271
276
|
|
272
277
|
|
@@ -274,18 +279,20 @@ class WrappedAsyncToolsClient(AsyncToolsClient):
|
|
274
279
|
|
275
280
|
@_add_docs_for_async_variant(WrappedToolsClient.get_file)
|
276
281
|
async def get_file(self, asset_id: str) -> io.BytesIO:
|
277
|
-
file_bytes = b
|
282
|
+
file_bytes = b"".join([gen async for gen in self.raw_data(asset_id=asset_id)])
|
278
283
|
bytes_io = io.BytesIO(file_bytes)
|
279
284
|
return bytes_io
|
280
285
|
|
281
|
-
@_inherit_signature_and_doc(
|
282
|
-
|
286
|
+
@_inherit_signature_and_doc(
|
287
|
+
AsyncToolsClient.data_frame, {"DataFrameRequestOut": "pd.DataFrame"}
|
288
|
+
)
|
289
|
+
async def data_frame(self, *, asset_id: str, **kwargs) -> "pd.DataFrame":
|
283
290
|
_check_pandas_installed()
|
284
291
|
model = await super().data_frame(asset_id=asset_id, **kwargs)
|
285
292
|
return _read_json_frame(model)
|
286
293
|
|
287
294
|
@_add_docs_for_async_variant(WrappedToolsClient.read_data_frame)
|
288
|
-
async def read_data_frame(self, asset_id: str, *args, **kwargs) ->
|
295
|
+
async def read_data_frame(self, asset_id: str, *args, **kwargs) -> "pd.DataFrame":
|
289
296
|
_check_pandas_installed()
|
290
297
|
file_bytes = await self.get_file(asset_id)
|
291
298
|
return _to_pandas_df(file_bytes, *args, **kwargs)
|
@@ -299,7 +306,9 @@ class WrappedAsyncToolsClient(AsyncToolsClient):
|
|
299
306
|
name: Union[str, None] = None,
|
300
307
|
**kwargs,
|
301
308
|
) -> SaveAssetRequestOut:
|
302
|
-
asset_object = _convert_asset_object(
|
309
|
+
asset_object = _convert_asset_object(
|
310
|
+
asset_object=asset_object, name=name, **kwargs
|
311
|
+
)
|
303
312
|
return await super().save_asset(
|
304
313
|
file=asset_object, parent_folder_id=parent_folder_id
|
305
314
|
)
|
@@ -307,23 +316,109 @@ class WrappedAsyncToolsClient(AsyncToolsClient):
|
|
307
316
|
|
308
317
|
class WrappedAsyncQueryClient(AsyncQueryClient):
|
309
318
|
|
310
|
-
@_inherit_signature_and_doc(
|
311
|
-
|
319
|
+
@_inherit_signature_and_doc(
|
320
|
+
AsyncQueryClient.execute, {"DataFrameRequestOut": "pd.DataFrame"}
|
321
|
+
)
|
322
|
+
async def execute(
|
323
|
+
self, *, sql_command: str, database_asset_ids: Union[str, List[str]], **kwargs
|
324
|
+
) -> "pd.DataFrame":
|
312
325
|
_check_pandas_installed()
|
313
326
|
model = await super().execute(
|
314
|
-
sql_command=sql_command,
|
315
|
-
database_asset_ids=database_asset_ids,
|
316
|
-
**kwargs
|
327
|
+
sql_command=sql_command, database_asset_ids=database_asset_ids, **kwargs
|
317
328
|
)
|
318
329
|
return _read_json_frame(model)
|
319
330
|
|
320
|
-
@_inherit_signature_and_doc(
|
321
|
-
|
331
|
+
@_inherit_signature_and_doc(
|
332
|
+
AsyncQueryClient.execute_snippet, {"DataFrameRequestOut": "pd.DataFrame"}
|
333
|
+
)
|
334
|
+
async def execute_snippet(
|
335
|
+
self, *, snippet_asset_id: str, **kwargs
|
336
|
+
) -> "pd.DataFrame":
|
322
337
|
_check_pandas_installed()
|
323
|
-
model = await super().execute_snippet(
|
338
|
+
model = await super().execute_snippet(
|
339
|
+
snippet_asset_id=snippet_asset_id, **kwargs
|
340
|
+
)
|
324
341
|
return _read_json_frame(model)
|
325
342
|
|
326
343
|
|
344
|
+
class AthenaModel(RemoteRunnable):
|
345
|
+
"""Use Athena's models directly with a Langchain-compatible client.
|
346
|
+
|
347
|
+
The Langchain Runnable interface is supported:
|
348
|
+
- `invoke`: Invoke the model with a string or a Langchain message.
|
349
|
+
- `batch`: Batch invoke the model on multiple inputs.
|
350
|
+
- `astream_events`: Streaming for real-time applications.
|
351
|
+
|
352
|
+
See Langchain documentation for more details.
|
353
|
+
|
354
|
+
Examples
|
355
|
+
--------
|
356
|
+
from src.athena.client import Athena
|
357
|
+
import asyncio
|
358
|
+
from langchain_core.messages import HumanMessage
|
359
|
+
|
360
|
+
client = Athena()
|
361
|
+
llm = client.llm
|
362
|
+
|
363
|
+
# sync invoke -- use strings or langchain messages
|
364
|
+
result = llm.invoke("Hello")
|
365
|
+
print(result.content)
|
366
|
+
|
367
|
+
result = llm.invoke("Hello")
|
368
|
+
print(result.content)
|
369
|
+
|
370
|
+
# choose the model explicitly
|
371
|
+
claude = llm.with_config(configurable={"model": "claude_3_7_sonnet"})
|
372
|
+
print(claude.invoke("Who are you?").content)
|
373
|
+
|
374
|
+
# batch (for multiple parallel requests)
|
375
|
+
results = llm.batch(["Hello", "World"] * 5)
|
376
|
+
print([r.content for r in results])
|
377
|
+
|
378
|
+
# handle stream events
|
379
|
+
async def stream_response():
|
380
|
+
separator = "---------------------------------------"
|
381
|
+
print(separator)
|
382
|
+
print("Starting stream")
|
383
|
+
print(separator)
|
384
|
+
async for event in llm.astream_events("Hello. Please respond in 5 sentences."):
|
385
|
+
data = event["data"]
|
386
|
+
if "chunk" in data:
|
387
|
+
print(data["chunk"].content)
|
388
|
+
elif "output" in data:
|
389
|
+
print(separator)
|
390
|
+
print("Final response")
|
391
|
+
print(separator)
|
392
|
+
print(data["output"].content)
|
393
|
+
else:
|
394
|
+
print("(other event) ", event.keys())
|
395
|
+
|
396
|
+
|
397
|
+
asyncio.run(stream_response())
|
398
|
+
from athena.client import Athena
|
399
|
+
client = Athena()
|
400
|
+
llm = client.llm
|
401
|
+
|
402
|
+
# sync invoke -- use strings or langchain messages
|
403
|
+
result = llm.invoke("Hello")
|
404
|
+
print(result.content)
|
405
|
+
|
406
|
+
# batch
|
407
|
+
results = llm.batch(["Hello", "World"])
|
408
|
+
print(r.content for r in results)
|
409
|
+
"""
|
410
|
+
|
411
|
+
def __init__(self, base_url: str, api_key: str, timeout: Optional[float] = None):
|
412
|
+
self.base_url = base_url
|
413
|
+
self.api_key = api_key
|
414
|
+
self.timeout = timeout
|
415
|
+
super().__init__(
|
416
|
+
base_url + "/api/models/general",
|
417
|
+
headers={"X-API-KEY": api_key},
|
418
|
+
timeout=timeout,
|
419
|
+
)
|
420
|
+
|
421
|
+
|
327
422
|
class Athena(BaseAthena):
|
328
423
|
"""
|
329
424
|
Use this class to access the different functions within the SDK. You can instantiate any number of clients with different configuration that will propogate to these functions.
|
@@ -360,37 +455,40 @@ class Athena(BaseAthena):
|
|
360
455
|
self,
|
361
456
|
*,
|
362
457
|
base_url: typing.Optional[str] = None,
|
363
|
-
environment: Union[AthenaEnvironment, SpecialEnvironments] = SpecialEnvironments.AUTODETECT_ENVIRONMENT,
|
458
|
+
environment: Union[AthenaEnvironment, SpecialEnvironments] = SpecialEnvironments.AUTODETECT_ENVIRONMENT, # type: ignore[arg-type]
|
364
459
|
api_key: typing.Optional[str] = None,
|
365
460
|
timeout: typing.Optional[float] = 60,
|
366
|
-
httpx_client: typing.Optional[httpx.Client] = None
|
461
|
+
httpx_client: typing.Optional[httpx.Client] = None,
|
367
462
|
):
|
368
463
|
if api_key is None:
|
369
464
|
try:
|
370
|
-
api_key = os.environ[
|
465
|
+
api_key = os.environ["ATHENA_API_KEY"]
|
371
466
|
except KeyError:
|
372
467
|
raise TypeError(
|
373
468
|
"Athena() missing 1 required keyword-only argument: 'api_key'"
|
374
|
-
|
469
|
+
" (ATHENA_API_KEY environment variable not found)"
|
375
470
|
)
|
376
471
|
if environment == SpecialEnvironments.AUTODETECT_ENVIRONMENT:
|
377
|
-
if
|
472
|
+
if "ATHENA_API_URL" in os.environ:
|
378
473
|
|
379
474
|
class AutodetectedEnvironments(enum.Enum):
|
380
|
-
CURRENT = os.environ[
|
475
|
+
CURRENT = os.environ["ATHENA_API_URL"]
|
381
476
|
|
382
477
|
environment = cast(AthenaEnvironment, AutodetectedEnvironments.CURRENT)
|
383
478
|
else:
|
384
479
|
environment = AthenaEnvironment.PRODUCTION
|
385
480
|
super().__init__(
|
386
481
|
base_url=base_url,
|
387
|
-
environment=environment,
|
482
|
+
environment=environment, # type: ignore[arg-type]
|
388
483
|
api_key=api_key,
|
389
484
|
timeout=timeout,
|
390
485
|
httpx_client=httpx_client,
|
391
486
|
)
|
392
487
|
self.tools = WrappedToolsClient(client_wrapper=self._client_wrapper)
|
393
488
|
self.query = WrappedQueryClient(client_wrapper=self._client_wrapper)
|
489
|
+
self.llm = AthenaModel(
|
490
|
+
base_url=base_url or environment.value, api_key=api_key, timeout=timeout
|
491
|
+
)
|
394
492
|
|
395
493
|
|
396
494
|
class AsyncAthena(AsyncBaseAthena):
|
@@ -430,31 +528,31 @@ class AsyncAthena(AsyncBaseAthena):
|
|
430
528
|
self,
|
431
529
|
*,
|
432
530
|
base_url: typing.Optional[str] = None,
|
433
|
-
environment: Union[AthenaEnvironment, SpecialEnvironments] = SpecialEnvironments.AUTODETECT_ENVIRONMENT,
|
531
|
+
environment: Union[AthenaEnvironment, SpecialEnvironments] = SpecialEnvironments.AUTODETECT_ENVIRONMENT, # type: ignore[arg-type]
|
434
532
|
api_key: typing.Optional[str] = None,
|
435
533
|
timeout: typing.Optional[float] = 60,
|
436
|
-
httpx_client: typing.Optional[httpx.AsyncClient] = None
|
534
|
+
httpx_client: typing.Optional[httpx.AsyncClient] = None,
|
437
535
|
):
|
438
536
|
if api_key is None:
|
439
537
|
try:
|
440
|
-
api_key = os.environ[
|
538
|
+
api_key = os.environ["ATHENA_API_KEY"]
|
441
539
|
except KeyError:
|
442
540
|
raise TypeError(
|
443
541
|
"AsyncAthena() missing 1 required keyword-only argument: 'api_key'"
|
444
|
-
|
542
|
+
" (ATHENA_API_KEY environment variable not found)"
|
445
543
|
)
|
446
544
|
if environment == SpecialEnvironments.AUTODETECT_ENVIRONMENT:
|
447
|
-
if
|
545
|
+
if "ATHENA_API_URL" in os.environ:
|
448
546
|
|
449
547
|
class AutodetectedEnvironments(enum.Enum):
|
450
|
-
CURRENT = os.environ[
|
548
|
+
CURRENT = os.environ["ATHENA_API_URL"]
|
451
549
|
|
452
550
|
environment = cast(AthenaEnvironment, AutodetectedEnvironments.CURRENT)
|
453
551
|
else:
|
454
552
|
environment = AthenaEnvironment.PRODUCTION
|
455
553
|
super().__init__(
|
456
554
|
base_url=base_url,
|
457
|
-
environment=environment,
|
555
|
+
environment=environment, # type: ignore[arg-type]
|
458
556
|
api_key=api_key,
|
459
557
|
timeout=timeout,
|
460
558
|
httpx_client=httpx_client,
|
@@ -463,19 +561,20 @@ class AsyncAthena(AsyncBaseAthena):
|
|
463
561
|
self.query = WrappedAsyncQueryClient(client_wrapper=self._client_wrapper)
|
464
562
|
|
465
563
|
|
466
|
-
def _read_json_frame(model: DataFrameRequestOut) ->
|
564
|
+
def _read_json_frame(model: DataFrameRequestOut) -> "pd.DataFrame":
|
467
565
|
import pandas as pd
|
468
566
|
|
469
567
|
string_io = io.StringIO(model.json())
|
470
568
|
|
471
569
|
with warnings.catch_warnings():
|
472
570
|
# Filter warnings due to https://github.com/pandas-dev/pandas/issues/59511
|
473
|
-
warnings.simplefilter(action=
|
474
|
-
return pd.read_json(string_io, orient=
|
571
|
+
warnings.simplefilter(action="ignore", category=FutureWarning)
|
572
|
+
return pd.read_json(string_io, orient="split")
|
475
573
|
|
476
574
|
|
477
575
|
def _check_pandas_installed():
|
478
576
|
import pandas
|
577
|
+
|
479
578
|
assert pandas
|
480
579
|
|
481
580
|
|
@@ -524,11 +623,12 @@ def _to_pandas_df(
|
|
524
623
|
def _convert_asset_object(
|
525
624
|
asset_object: Union["pd.DataFrame", "pd.Series", core.File],
|
526
625
|
name: Union[str, None] = None,
|
527
|
-
**kwargs
|
626
|
+
**kwargs,
|
528
627
|
) -> core.File:
|
529
628
|
import pandas as pd
|
629
|
+
|
530
630
|
try:
|
531
|
-
from IPython.core.formatters import format_display_data
|
631
|
+
from IPython.core.formatters import format_display_data # type: ignore[import]
|
532
632
|
except ImportError:
|
533
633
|
format_display_data = None
|
534
634
|
|
@@ -559,4 +659,4 @@ def _convert_asset_object(
|
|
559
659
|
image_bytes,
|
560
660
|
media_type,
|
561
661
|
)
|
562
|
-
return asset_object
|
662
|
+
return asset_object # type: ignore[return-value]
|