athena-intelligence 0.1.124__py3-none-any.whl → 0.1.126__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- athena/__init__.py +3 -0
- athena/agents/client.py +88 -36
- athena/agents/drive/client.py +80 -32
- athena/agents/general/client.py +222 -91
- athena/agents/research/client.py +80 -32
- athena/agents/sql/client.py +80 -32
- athena/base_client.py +13 -11
- athena/client.py +191 -80
- athena/core/__init__.py +21 -4
- athena/core/client_wrapper.py +9 -10
- athena/core/file.py +37 -8
- athena/core/http_client.py +97 -41
- athena/core/jsonable_encoder.py +33 -31
- athena/core/pydantic_utilities.py +272 -4
- athena/core/query_encoder.py +38 -13
- athena/core/request_options.py +5 -2
- athena/core/serialization.py +272 -0
- athena/errors/internal_server_error.py +2 -3
- athena/errors/unauthorized_error.py +2 -3
- athena/errors/unprocessable_entity_error.py +2 -3
- athena/query/client.py +208 -58
- athena/tools/calendar/client.py +82 -30
- athena/tools/client.py +576 -184
- athena/tools/email/client.py +117 -43
- athena/tools/structured_data_extractor/client.py +118 -67
- athena/tools/tasks/client.py +41 -17
- athena/types/asset_node.py +14 -24
- athena/types/asset_not_found_error.py +11 -21
- athena/types/chunk.py +11 -21
- athena/types/chunk_content_item.py +21 -41
- athena/types/chunk_result.py +13 -23
- athena/types/custom_agent_response.py +12 -22
- athena/types/data_frame_request_out.py +11 -21
- athena/types/data_frame_unknown_format_error.py +11 -21
- athena/types/document_chunk.py +12 -22
- athena/types/drive_agent_response.py +12 -22
- athena/types/file_chunk_request_out.py +11 -21
- athena/types/file_too_large_error.py +11 -21
- athena/types/folder_response.py +11 -21
- athena/types/general_agent_config.py +11 -21
- athena/types/general_agent_config_enabled_tools_item.py +0 -1
- athena/types/general_agent_request.py +13 -23
- athena/types/general_agent_response.py +12 -22
- athena/types/image_url_content.py +11 -21
- athena/types/parent_folder_error.py +11 -21
- athena/types/prompt_message.py +12 -22
- athena/types/research_agent_response.py +12 -22
- athena/types/save_asset_request_out.py +11 -21
- athena/types/sql_agent_response.py +13 -23
- athena/types/structured_data_extractor_response.py +15 -25
- athena/types/text_content.py +11 -21
- athena/types/tool.py +1 -13
- athena/types/type.py +1 -21
- athena/version.py +0 -1
- {athena_intelligence-0.1.124.dist-info → athena_intelligence-0.1.126.dist-info}/METADATA +12 -4
- athena_intelligence-0.1.126.dist-info/RECORD +87 -0
- {athena_intelligence-0.1.124.dist-info → athena_intelligence-0.1.126.dist-info}/WHEEL +1 -1
- athena_intelligence-0.1.124.dist-info/RECORD +0 -86
athena/client.py
CHANGED
@@ -7,8 +7,9 @@ import typing
|
|
7
7
|
import warnings
|
8
8
|
|
9
9
|
import httpx
|
10
|
-
from typing import cast, List, Union
|
10
|
+
from typing import cast, List, Tuple, Union, Optional
|
11
11
|
from typing_extensions import TypeVar, ParamSpec
|
12
|
+
from langserve import RemoteRunnable
|
12
13
|
|
13
14
|
from . import core
|
14
15
|
from .base_client import BaseAthena, AsyncBaseAthena
|
@@ -24,23 +25,23 @@ if typing.TYPE_CHECKING:
|
|
24
25
|
|
25
26
|
P = ParamSpec("P")
|
26
27
|
T = TypeVar("T")
|
27
|
-
U = TypeVar(
|
28
|
+
U = TypeVar("U")
|
28
29
|
|
29
30
|
|
30
31
|
def _inherit_signature_and_doc(
|
31
|
-
f: typing.Callable[P, T],
|
32
|
-
replace_in_doc: typing.Dict[str, str]
|
32
|
+
f: typing.Callable[P, T], replace_in_doc: typing.Dict[str, str]
|
33
33
|
) -> typing.Callable[..., typing.Callable[P, U]]:
|
34
34
|
def decorator(decorated):
|
35
35
|
for old, new in replace_in_doc.items():
|
36
36
|
assert old in f.__doc__
|
37
37
|
decorated.__doc__ = f.__doc__.replace(old, new)
|
38
38
|
return decorated
|
39
|
+
|
39
40
|
return decorator
|
40
41
|
|
41
42
|
|
42
43
|
class SpecialEnvironments(enum.Enum):
|
43
|
-
AUTODETECT_ENVIRONMENT =
|
44
|
+
AUTODETECT_ENVIRONMENT = "AUTO"
|
44
45
|
|
45
46
|
|
46
47
|
@dataclasses.dataclass
|
@@ -53,7 +54,7 @@ class AthenaAsset:
|
|
53
54
|
if self.media_type == "application/sql":
|
54
55
|
# it is safe to import IPython in `_repr_mimebundle_`
|
55
56
|
# as this is only intended to be invoked by IPython.
|
56
|
-
from IPython import display
|
57
|
+
from IPython import display # type: ignore[import]
|
57
58
|
|
58
59
|
code = display.Code(
|
59
60
|
data=self.data.decode(),
|
@@ -84,17 +85,19 @@ class WrappedToolsClient(ToolsClient):
|
|
84
85
|
file_io = client.tools.get_file(asset_id="asset_id")
|
85
86
|
pl.read_csv(file_io)
|
86
87
|
"""
|
87
|
-
file_bytes = b
|
88
|
+
file_bytes = b"".join(self.raw_data(asset_id=asset_id))
|
88
89
|
bytes_io = io.BytesIO(file_bytes)
|
89
90
|
return bytes_io
|
90
91
|
|
91
|
-
@_inherit_signature_and_doc(
|
92
|
-
|
92
|
+
@_inherit_signature_and_doc(
|
93
|
+
ToolsClient.data_frame, {"DataFrameRequestOut": "pd.DataFrame"}
|
94
|
+
)
|
95
|
+
def data_frame(self, *, asset_id: str, **kwargs) -> "pd.DataFrame":
|
93
96
|
_check_pandas_installed()
|
94
97
|
model = super().data_frame(asset_id=asset_id, **kwargs)
|
95
98
|
return _read_json_frame(model)
|
96
99
|
|
97
|
-
def read_data_frame(self, asset_id: str, *args, **kwargs) ->
|
100
|
+
def read_data_frame(self, asset_id: str, *args, **kwargs) -> "pd.DataFrame":
|
98
101
|
"""
|
99
102
|
Parameters
|
100
103
|
----------
|
@@ -116,11 +119,10 @@ class WrappedToolsClient(ToolsClient):
|
|
116
119
|
client.tools.read_data_frame(asset_id="asset_id")
|
117
120
|
"""
|
118
121
|
_check_pandas_installed()
|
119
|
-
file_bytes = self.
|
120
|
-
return _to_pandas_df(file_bytes, *args, **kwargs)
|
122
|
+
file_bytes, media_type = self._get_file_and_media_type(asset_id=asset_id)
|
123
|
+
return _to_pandas_df(file_bytes, *args, media_type=media_type, **kwargs)
|
121
124
|
|
122
|
-
|
123
|
-
def save_asset( # type: ignore[override]
|
125
|
+
def save_asset( # type: ignore[override]
|
124
126
|
self,
|
125
127
|
asset_object: Union["pd.DataFrame", "pd.Series", core.File],
|
126
128
|
*,
|
@@ -131,7 +133,7 @@ class WrappedToolsClient(ToolsClient):
|
|
131
133
|
"""
|
132
134
|
Parameters
|
133
135
|
----------
|
134
|
-
asset_object : pd.DataFrame | pd.Series | core.File
|
136
|
+
asset_object : pd.DataFrame | pd.Series | matplotlib.figure.Figure | core.File
|
135
137
|
A pandas data frame, series, matplotlib figure, or core.File
|
136
138
|
|
137
139
|
parent_folder_id : typing.Optional[str]
|
@@ -143,6 +145,8 @@ class WrappedToolsClient(ToolsClient):
|
|
143
145
|
request_options : typing.Optional[RequestOptions]
|
144
146
|
Request-specific configuration.
|
145
147
|
|
148
|
+
**kwargs : passed down to conversion methods
|
149
|
+
|
146
150
|
Returns
|
147
151
|
-------
|
148
152
|
SaveAssetRequestOut
|
@@ -155,27 +159,14 @@ class WrappedToolsClient(ToolsClient):
|
|
155
159
|
client = Athena(api_key="YOUR_API_KEY")
|
156
160
|
client.tools.save_asset(df)
|
157
161
|
"""
|
158
|
-
asset_object = _convert_asset_object(
|
159
|
-
|
160
|
-
file=asset_object, parent_folder_id=parent_folder_id, **kwargs
|
162
|
+
asset_object = _convert_asset_object(
|
163
|
+
asset_object=asset_object, name=name, **kwargs
|
161
164
|
)
|
165
|
+
return super().save_asset(file=asset_object, parent_folder_id=parent_folder_id)
|
162
166
|
|
163
|
-
def
|
167
|
+
def _get_file_and_media_type(self, asset_id: str) -> Tuple[io.BytesIO, str]:
|
164
168
|
"""
|
165
|
-
|
166
|
-
----------
|
167
|
-
asset_id : str
|
168
|
-
|
169
|
-
Returns
|
170
|
-
-------
|
171
|
-
pd.DataFrame or AthenaAsset
|
172
|
-
|
173
|
-
Examples
|
174
|
-
--------
|
175
|
-
from athena.client import Athena
|
176
|
-
|
177
|
-
client = Athena(api_key="YOUR_API_KEY")
|
178
|
-
client.tools.get_asset(asset_id="asset_id")
|
169
|
+
Gets the file togehter with media type returned by server
|
179
170
|
"""
|
180
171
|
# while we wait for https://github.com/fern-api/fern/issues/4316
|
181
172
|
result = self._client_wrapper.httpx_client.request(
|
@@ -196,6 +187,27 @@ class WrappedToolsClient(ToolsClient):
|
|
196
187
|
# fallback to `libmagic` inference
|
197
188
|
media_type = _infer_media_type(bytes_io=file_bytes)
|
198
189
|
|
190
|
+
return file_bytes, media_type
|
191
|
+
|
192
|
+
def get_asset(self, asset_id: str) -> Union["pd.DataFrame", AthenaAsset]:
|
193
|
+
"""
|
194
|
+
Parameters
|
195
|
+
----------
|
196
|
+
asset_id : str
|
197
|
+
|
198
|
+
Returns
|
199
|
+
-------
|
200
|
+
pd.DataFrame or AthenaAsset
|
201
|
+
|
202
|
+
Examples
|
203
|
+
--------
|
204
|
+
from athena.client import Athena
|
205
|
+
|
206
|
+
client = Athena(api_key="YOUR_API_KEY")
|
207
|
+
client.tools.get_asset(asset_id="asset_id")
|
208
|
+
"""
|
209
|
+
file_bytes, media_type = self._get_file_and_media_type(asset_id=asset_id)
|
210
|
+
|
199
211
|
media_type_aliases = {"image/jpg": "image/jpeg"}
|
200
212
|
media_type = media_type_aliases.get(media_type, media_type)
|
201
213
|
|
@@ -226,22 +238,26 @@ class WrappedToolsClient(ToolsClient):
|
|
226
238
|
return _to_pandas_df(file_bytes, media_type=media_type)
|
227
239
|
|
228
240
|
raise NotImplementedError("Assets of `{media_type}` type are not yet supported")
|
229
|
-
|
241
|
+
|
230
242
|
|
231
243
|
class WrappedQueryClient(QueryClient):
|
232
244
|
|
233
|
-
@_inherit_signature_and_doc(
|
234
|
-
|
245
|
+
@_inherit_signature_and_doc(
|
246
|
+
QueryClient.execute, {"DataFrameRequestOut": "pd.DataFrame"}
|
247
|
+
)
|
248
|
+
def execute(
|
249
|
+
self, *, sql_command: str, database_asset_ids: Union[str, List[str]], **kwargs
|
250
|
+
) -> "pd.DataFrame":
|
235
251
|
_check_pandas_installed()
|
236
252
|
model = super().execute(
|
237
|
-
sql_command=sql_command,
|
238
|
-
database_asset_ids=database_asset_ids,
|
239
|
-
**kwargs
|
253
|
+
sql_command=sql_command, database_asset_ids=database_asset_ids, **kwargs
|
240
254
|
)
|
241
255
|
return _read_json_frame(model)
|
242
256
|
|
243
|
-
@_inherit_signature_and_doc(
|
244
|
-
|
257
|
+
@_inherit_signature_and_doc(
|
258
|
+
QueryClient.execute_snippet, {"DataFrameRequestOut": "pd.DataFrame"}
|
259
|
+
)
|
260
|
+
def execute_snippet(self, *, snippet_asset_id: str, **kwargs) -> "pd.DataFrame":
|
245
261
|
_check_pandas_installed()
|
246
262
|
model = super().execute_snippet(snippet_asset_id=snippet_asset_id, **kwargs)
|
247
263
|
return _read_json_frame(model)
|
@@ -251,12 +267,11 @@ def _add_docs_for_async_variant(obj):
|
|
251
267
|
def decorator(decorated):
|
252
268
|
doc = obj.__doc__
|
253
269
|
name = obj.__name__
|
254
|
-
decorated.__doc__ = (
|
255
|
-
|
256
|
-
|
257
|
-
.replace(f'client.tools.{name}', f'await client.tools.{name}')
|
258
|
-
)
|
270
|
+
decorated.__doc__ = doc.replace(
|
271
|
+
"client = Athena", "client = AsyncAthena"
|
272
|
+
).replace(f"client.tools.{name}", f"await client.tools.{name}")
|
259
273
|
return decorated
|
274
|
+
|
260
275
|
return decorator
|
261
276
|
|
262
277
|
|
@@ -264,18 +279,20 @@ class WrappedAsyncToolsClient(AsyncToolsClient):
|
|
264
279
|
|
265
280
|
@_add_docs_for_async_variant(WrappedToolsClient.get_file)
|
266
281
|
async def get_file(self, asset_id: str) -> io.BytesIO:
|
267
|
-
file_bytes = b
|
282
|
+
file_bytes = b"".join([gen async for gen in self.raw_data(asset_id=asset_id)])
|
268
283
|
bytes_io = io.BytesIO(file_bytes)
|
269
284
|
return bytes_io
|
270
285
|
|
271
|
-
@_inherit_signature_and_doc(
|
272
|
-
|
286
|
+
@_inherit_signature_and_doc(
|
287
|
+
AsyncToolsClient.data_frame, {"DataFrameRequestOut": "pd.DataFrame"}
|
288
|
+
)
|
289
|
+
async def data_frame(self, *, asset_id: str, **kwargs) -> "pd.DataFrame":
|
273
290
|
_check_pandas_installed()
|
274
291
|
model = await super().data_frame(asset_id=asset_id, **kwargs)
|
275
292
|
return _read_json_frame(model)
|
276
293
|
|
277
294
|
@_add_docs_for_async_variant(WrappedToolsClient.read_data_frame)
|
278
|
-
async def read_data_frame(self, asset_id: str, *args, **kwargs) ->
|
295
|
+
async def read_data_frame(self, asset_id: str, *args, **kwargs) -> "pd.DataFrame":
|
279
296
|
_check_pandas_installed()
|
280
297
|
file_bytes = await self.get_file(asset_id)
|
281
298
|
return _to_pandas_df(file_bytes, *args, **kwargs)
|
@@ -289,31 +306,119 @@ class WrappedAsyncToolsClient(AsyncToolsClient):
|
|
289
306
|
name: Union[str, None] = None,
|
290
307
|
**kwargs,
|
291
308
|
) -> SaveAssetRequestOut:
|
292
|
-
asset_object = _convert_asset_object(
|
309
|
+
asset_object = _convert_asset_object(
|
310
|
+
asset_object=asset_object, name=name, **kwargs
|
311
|
+
)
|
293
312
|
return await super().save_asset(
|
294
|
-
file=asset_object, parent_folder_id=parent_folder_id
|
313
|
+
file=asset_object, parent_folder_id=parent_folder_id
|
295
314
|
)
|
296
315
|
|
297
316
|
|
298
317
|
class WrappedAsyncQueryClient(AsyncQueryClient):
|
299
318
|
|
300
|
-
@_inherit_signature_and_doc(
|
301
|
-
|
319
|
+
@_inherit_signature_and_doc(
|
320
|
+
AsyncQueryClient.execute, {"DataFrameRequestOut": "pd.DataFrame"}
|
321
|
+
)
|
322
|
+
async def execute(
|
323
|
+
self, *, sql_command: str, database_asset_ids: Union[str, List[str]], **kwargs
|
324
|
+
) -> "pd.DataFrame":
|
302
325
|
_check_pandas_installed()
|
303
326
|
model = await super().execute(
|
304
|
-
sql_command=sql_command,
|
305
|
-
database_asset_ids=database_asset_ids,
|
306
|
-
**kwargs
|
327
|
+
sql_command=sql_command, database_asset_ids=database_asset_ids, **kwargs
|
307
328
|
)
|
308
329
|
return _read_json_frame(model)
|
309
330
|
|
310
|
-
@_inherit_signature_and_doc(
|
311
|
-
|
331
|
+
@_inherit_signature_and_doc(
|
332
|
+
AsyncQueryClient.execute_snippet, {"DataFrameRequestOut": "pd.DataFrame"}
|
333
|
+
)
|
334
|
+
async def execute_snippet(
|
335
|
+
self, *, snippet_asset_id: str, **kwargs
|
336
|
+
) -> "pd.DataFrame":
|
312
337
|
_check_pandas_installed()
|
313
|
-
model = await super().execute_snippet(
|
338
|
+
model = await super().execute_snippet(
|
339
|
+
snippet_asset_id=snippet_asset_id, **kwargs
|
340
|
+
)
|
314
341
|
return _read_json_frame(model)
|
315
342
|
|
316
343
|
|
344
|
+
class AthenaModel(RemoteRunnable):
|
345
|
+
"""Use Athena's models directly with a Langchain-compatible client.
|
346
|
+
|
347
|
+
The Langchain Runnable interface is supported:
|
348
|
+
- `invoke`: Invoke the model with a string or a Langchain message.
|
349
|
+
- `batch`: Batch invoke the model on multiple inputs.
|
350
|
+
- `astream_events`: Streaming for real-time applications.
|
351
|
+
|
352
|
+
See Langchain documentation for more details.
|
353
|
+
|
354
|
+
Examples
|
355
|
+
--------
|
356
|
+
from src.athena.client import Athena
|
357
|
+
import asyncio
|
358
|
+
from langchain_core.messages import HumanMessage
|
359
|
+
|
360
|
+
client = Athena()
|
361
|
+
llm = client.llm
|
362
|
+
|
363
|
+
# sync invoke -- use strings or langchain messages
|
364
|
+
result = llm.invoke("Hello")
|
365
|
+
print(result.content)
|
366
|
+
|
367
|
+
result = llm.invoke("Hello")
|
368
|
+
print(result.content)
|
369
|
+
|
370
|
+
# choose the model explicitly
|
371
|
+
claude = llm.with_config(configurable={"model": "claude_3_7_sonnet"})
|
372
|
+
print(claude.invoke("Who are you?").content)
|
373
|
+
|
374
|
+
# batch (for multiple parallel requests)
|
375
|
+
results = llm.batch(["Hello", "World"] * 5)
|
376
|
+
print([r.content for r in results])
|
377
|
+
|
378
|
+
# handle stream events
|
379
|
+
async def stream_response():
|
380
|
+
separator = "---------------------------------------"
|
381
|
+
print(separator)
|
382
|
+
print("Starting stream")
|
383
|
+
print(separator)
|
384
|
+
async for event in llm.astream_events("Hello. Please respond in 5 sentences."):
|
385
|
+
data = event["data"]
|
386
|
+
if "chunk" in data:
|
387
|
+
print(data["chunk"].content)
|
388
|
+
elif "output" in data:
|
389
|
+
print(separator)
|
390
|
+
print("Final response")
|
391
|
+
print(separator)
|
392
|
+
print(data["output"].content)
|
393
|
+
else:
|
394
|
+
print("(other event) ", event.keys())
|
395
|
+
|
396
|
+
|
397
|
+
asyncio.run(stream_response())
|
398
|
+
from athena.client import Athena
|
399
|
+
client = Athena()
|
400
|
+
llm = client.llm
|
401
|
+
|
402
|
+
# sync invoke -- use strings or langchain messages
|
403
|
+
result = llm.invoke("Hello")
|
404
|
+
print(result.content)
|
405
|
+
|
406
|
+
# batch
|
407
|
+
results = llm.batch(["Hello", "World"])
|
408
|
+
print(r.content for r in results)
|
409
|
+
"""
|
410
|
+
|
411
|
+
def __init__(self, base_url: str, api_key: str, timeout: Optional[float] = None):
|
412
|
+
self.base_url = base_url
|
413
|
+
self.api_key = api_key
|
414
|
+
self.timeout = timeout
|
415
|
+
super().__init__(
|
416
|
+
base_url + "/api/models/general",
|
417
|
+
headers={"X-API-KEY": api_key},
|
418
|
+
timeout=timeout,
|
419
|
+
)
|
420
|
+
|
421
|
+
|
317
422
|
class Athena(BaseAthena):
|
318
423
|
"""
|
319
424
|
Use this class to access the different functions within the SDK. You can instantiate any number of clients with different configuration that will propogate to these functions.
|
@@ -350,37 +455,40 @@ class Athena(BaseAthena):
|
|
350
455
|
self,
|
351
456
|
*,
|
352
457
|
base_url: typing.Optional[str] = None,
|
353
|
-
environment: Union[AthenaEnvironment, SpecialEnvironments] = SpecialEnvironments.AUTODETECT_ENVIRONMENT,
|
458
|
+
environment: Union[AthenaEnvironment, SpecialEnvironments] = SpecialEnvironments.AUTODETECT_ENVIRONMENT, # type: ignore[arg-type]
|
354
459
|
api_key: typing.Optional[str] = None,
|
355
460
|
timeout: typing.Optional[float] = 60,
|
356
|
-
httpx_client: typing.Optional[httpx.Client] = None
|
461
|
+
httpx_client: typing.Optional[httpx.Client] = None,
|
357
462
|
):
|
358
463
|
if api_key is None:
|
359
464
|
try:
|
360
|
-
api_key = os.environ[
|
465
|
+
api_key = os.environ["ATHENA_API_KEY"]
|
361
466
|
except KeyError:
|
362
467
|
raise TypeError(
|
363
468
|
"Athena() missing 1 required keyword-only argument: 'api_key'"
|
364
|
-
|
469
|
+
" (ATHENA_API_KEY environment variable not found)"
|
365
470
|
)
|
366
471
|
if environment == SpecialEnvironments.AUTODETECT_ENVIRONMENT:
|
367
|
-
if
|
472
|
+
if "ATHENA_API_URL" in os.environ:
|
368
473
|
|
369
474
|
class AutodetectedEnvironments(enum.Enum):
|
370
|
-
CURRENT = os.environ[
|
475
|
+
CURRENT = os.environ["ATHENA_API_URL"]
|
371
476
|
|
372
477
|
environment = cast(AthenaEnvironment, AutodetectedEnvironments.CURRENT)
|
373
478
|
else:
|
374
479
|
environment = AthenaEnvironment.PRODUCTION
|
375
480
|
super().__init__(
|
376
481
|
base_url=base_url,
|
377
|
-
environment=environment,
|
482
|
+
environment=environment, # type: ignore[arg-type]
|
378
483
|
api_key=api_key,
|
379
484
|
timeout=timeout,
|
380
485
|
httpx_client=httpx_client,
|
381
486
|
)
|
382
487
|
self.tools = WrappedToolsClient(client_wrapper=self._client_wrapper)
|
383
488
|
self.query = WrappedQueryClient(client_wrapper=self._client_wrapper)
|
489
|
+
self.llm = AthenaModel(
|
490
|
+
base_url=base_url or environment.value, api_key=api_key, timeout=timeout
|
491
|
+
)
|
384
492
|
|
385
493
|
|
386
494
|
class AsyncAthena(AsyncBaseAthena):
|
@@ -420,31 +528,31 @@ class AsyncAthena(AsyncBaseAthena):
|
|
420
528
|
self,
|
421
529
|
*,
|
422
530
|
base_url: typing.Optional[str] = None,
|
423
|
-
environment: Union[AthenaEnvironment, SpecialEnvironments] = SpecialEnvironments.AUTODETECT_ENVIRONMENT,
|
531
|
+
environment: Union[AthenaEnvironment, SpecialEnvironments] = SpecialEnvironments.AUTODETECT_ENVIRONMENT, # type: ignore[arg-type]
|
424
532
|
api_key: typing.Optional[str] = None,
|
425
533
|
timeout: typing.Optional[float] = 60,
|
426
|
-
httpx_client: typing.Optional[httpx.AsyncClient] = None
|
534
|
+
httpx_client: typing.Optional[httpx.AsyncClient] = None,
|
427
535
|
):
|
428
536
|
if api_key is None:
|
429
537
|
try:
|
430
|
-
api_key = os.environ[
|
538
|
+
api_key = os.environ["ATHENA_API_KEY"]
|
431
539
|
except KeyError:
|
432
540
|
raise TypeError(
|
433
541
|
"AsyncAthena() missing 1 required keyword-only argument: 'api_key'"
|
434
|
-
|
542
|
+
" (ATHENA_API_KEY environment variable not found)"
|
435
543
|
)
|
436
544
|
if environment == SpecialEnvironments.AUTODETECT_ENVIRONMENT:
|
437
|
-
if
|
545
|
+
if "ATHENA_API_URL" in os.environ:
|
438
546
|
|
439
547
|
class AutodetectedEnvironments(enum.Enum):
|
440
|
-
CURRENT = os.environ[
|
548
|
+
CURRENT = os.environ["ATHENA_API_URL"]
|
441
549
|
|
442
550
|
environment = cast(AthenaEnvironment, AutodetectedEnvironments.CURRENT)
|
443
551
|
else:
|
444
552
|
environment = AthenaEnvironment.PRODUCTION
|
445
553
|
super().__init__(
|
446
554
|
base_url=base_url,
|
447
|
-
environment=environment,
|
555
|
+
environment=environment, # type: ignore[arg-type]
|
448
556
|
api_key=api_key,
|
449
557
|
timeout=timeout,
|
450
558
|
httpx_client=httpx_client,
|
@@ -453,19 +561,20 @@ class AsyncAthena(AsyncBaseAthena):
|
|
453
561
|
self.query = WrappedAsyncQueryClient(client_wrapper=self._client_wrapper)
|
454
562
|
|
455
563
|
|
456
|
-
def _read_json_frame(model: DataFrameRequestOut) ->
|
564
|
+
def _read_json_frame(model: DataFrameRequestOut) -> "pd.DataFrame":
|
457
565
|
import pandas as pd
|
458
566
|
|
459
567
|
string_io = io.StringIO(model.json())
|
460
568
|
|
461
569
|
with warnings.catch_warnings():
|
462
570
|
# Filter warnings due to https://github.com/pandas-dev/pandas/issues/59511
|
463
|
-
warnings.simplefilter(action=
|
464
|
-
return pd.read_json(string_io, orient=
|
571
|
+
warnings.simplefilter(action="ignore", category=FutureWarning)
|
572
|
+
return pd.read_json(string_io, orient="split")
|
465
573
|
|
466
574
|
|
467
575
|
def _check_pandas_installed():
|
468
576
|
import pandas
|
577
|
+
|
469
578
|
assert pandas
|
470
579
|
|
471
580
|
|
@@ -514,10 +623,12 @@ def _to_pandas_df(
|
|
514
623
|
def _convert_asset_object(
|
515
624
|
asset_object: Union["pd.DataFrame", "pd.Series", core.File],
|
516
625
|
name: Union[str, None] = None,
|
626
|
+
**kwargs,
|
517
627
|
) -> core.File:
|
518
628
|
import pandas as pd
|
629
|
+
|
519
630
|
try:
|
520
|
-
from IPython.core.formatters import format_display_data
|
631
|
+
from IPython.core.formatters import format_display_data # type: ignore[import]
|
521
632
|
except ImportError:
|
522
633
|
format_display_data = None
|
523
634
|
|
@@ -526,7 +637,7 @@ def _convert_asset_object(
|
|
526
637
|
if isinstance(asset_object, pd.DataFrame):
|
527
638
|
return (
|
528
639
|
name or "Uploaded data frame",
|
529
|
-
asset_object.to_parquet(path=None),
|
640
|
+
asset_object.to_parquet(path=None, **kwargs),
|
530
641
|
"application/vnd.apache.parquet",
|
531
642
|
)
|
532
643
|
if format_display_data:
|
@@ -548,4 +659,4 @@ def _convert_asset_object(
|
|
548
659
|
image_bytes,
|
549
660
|
media_type,
|
550
661
|
)
|
551
|
-
return asset_object
|
662
|
+
return asset_object # type: ignore[return-value]
|
athena/core/__init__.py
CHANGED
@@ -3,28 +3,45 @@
|
|
3
3
|
from .api_error import ApiError
|
4
4
|
from .client_wrapper import AsyncClientWrapper, BaseClientWrapper, SyncClientWrapper
|
5
5
|
from .datetime_utils import serialize_datetime
|
6
|
-
from .file import File, convert_file_dict_to_httpx_tuples
|
6
|
+
from .file import File, convert_file_dict_to_httpx_tuples, with_content_type
|
7
7
|
from .http_client import AsyncHttpClient, HttpClient
|
8
8
|
from .jsonable_encoder import jsonable_encoder
|
9
|
-
from .pydantic_utilities import
|
9
|
+
from .pydantic_utilities import (
|
10
|
+
IS_PYDANTIC_V2,
|
11
|
+
UniversalBaseModel,
|
12
|
+
UniversalRootModel,
|
13
|
+
parse_obj_as,
|
14
|
+
universal_field_validator,
|
15
|
+
universal_root_validator,
|
16
|
+
update_forward_refs,
|
17
|
+
)
|
10
18
|
from .query_encoder import encode_query
|
11
19
|
from .remove_none_from_dict import remove_none_from_dict
|
12
20
|
from .request_options import RequestOptions
|
21
|
+
from .serialization import FieldMetadata, convert_and_respect_annotation_metadata
|
13
22
|
|
14
23
|
__all__ = [
|
15
24
|
"ApiError",
|
16
25
|
"AsyncClientWrapper",
|
17
26
|
"AsyncHttpClient",
|
18
27
|
"BaseClientWrapper",
|
28
|
+
"FieldMetadata",
|
19
29
|
"File",
|
20
30
|
"HttpClient",
|
31
|
+
"IS_PYDANTIC_V2",
|
21
32
|
"RequestOptions",
|
22
33
|
"SyncClientWrapper",
|
34
|
+
"UniversalBaseModel",
|
35
|
+
"UniversalRootModel",
|
36
|
+
"convert_and_respect_annotation_metadata",
|
23
37
|
"convert_file_dict_to_httpx_tuples",
|
24
|
-
"deep_union_pydantic_dicts",
|
25
38
|
"encode_query",
|
26
39
|
"jsonable_encoder",
|
27
|
-
"
|
40
|
+
"parse_obj_as",
|
28
41
|
"remove_none_from_dict",
|
29
42
|
"serialize_datetime",
|
43
|
+
"universal_field_validator",
|
44
|
+
"universal_root_validator",
|
45
|
+
"update_forward_refs",
|
46
|
+
"with_content_type",
|
30
47
|
]
|
athena/core/client_wrapper.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1
1
|
# This file was auto-generated by Fern from our API Definition.
|
2
2
|
|
3
3
|
import typing
|
4
|
-
|
5
4
|
import httpx
|
6
|
-
|
7
|
-
from .http_client import AsyncHttpClient
|
5
|
+
from .http_client import HttpClient
|
6
|
+
from .http_client import AsyncHttpClient
|
8
7
|
|
9
8
|
|
10
9
|
class BaseClientWrapper:
|
@@ -17,7 +16,7 @@ class BaseClientWrapper:
|
|
17
16
|
headers: typing.Dict[str, str] = {
|
18
17
|
"X-Fern-Language": "Python",
|
19
18
|
"X-Fern-SDK-Name": "athena-intelligence",
|
20
|
-
"X-Fern-SDK-Version": "0.1.
|
19
|
+
"X-Fern-SDK-Version": "0.1.126",
|
21
20
|
}
|
22
21
|
headers["X-API-KEY"] = self.api_key
|
23
22
|
return headers
|
@@ -36,9 +35,9 @@ class SyncClientWrapper(BaseClientWrapper):
|
|
36
35
|
super().__init__(api_key=api_key, base_url=base_url, timeout=timeout)
|
37
36
|
self.httpx_client = HttpClient(
|
38
37
|
httpx_client=httpx_client,
|
39
|
-
base_headers=self.get_headers
|
40
|
-
base_timeout=self.get_timeout
|
41
|
-
base_url=self.get_base_url
|
38
|
+
base_headers=self.get_headers,
|
39
|
+
base_timeout=self.get_timeout,
|
40
|
+
base_url=self.get_base_url,
|
42
41
|
)
|
43
42
|
|
44
43
|
|
@@ -49,7 +48,7 @@ class AsyncClientWrapper(BaseClientWrapper):
|
|
49
48
|
super().__init__(api_key=api_key, base_url=base_url, timeout=timeout)
|
50
49
|
self.httpx_client = AsyncHttpClient(
|
51
50
|
httpx_client=httpx_client,
|
52
|
-
base_headers=self.get_headers
|
53
|
-
base_timeout=self.get_timeout
|
54
|
-
base_url=self.get_base_url
|
51
|
+
base_headers=self.get_headers,
|
52
|
+
base_timeout=self.get_timeout,
|
53
|
+
base_url=self.get_base_url,
|
55
54
|
)
|
athena/core/file.py
CHANGED
@@ -1,25 +1,30 @@
|
|
1
1
|
# This file was auto-generated by Fern from our API Definition.
|
2
2
|
|
3
|
-
import
|
3
|
+
from typing import IO, Dict, List, Mapping, Optional, Tuple, Union, cast
|
4
4
|
|
5
5
|
# File typing inspired by the flexibility of types within the httpx library
|
6
6
|
# https://github.com/encode/httpx/blob/master/httpx/_types.py
|
7
|
-
FileContent =
|
8
|
-
File =
|
7
|
+
FileContent = Union[IO[bytes], bytes, str]
|
8
|
+
File = Union[
|
9
9
|
# file (or bytes)
|
10
10
|
FileContent,
|
11
11
|
# (filename, file (or bytes))
|
12
|
-
|
12
|
+
Tuple[Optional[str], FileContent],
|
13
13
|
# (filename, file (or bytes), content_type)
|
14
|
-
|
14
|
+
Tuple[Optional[str], FileContent, Optional[str]],
|
15
15
|
# (filename, file (or bytes), content_type, headers)
|
16
|
-
|
16
|
+
Tuple[
|
17
|
+
Optional[str],
|
18
|
+
FileContent,
|
19
|
+
Optional[str],
|
20
|
+
Mapping[str, str],
|
21
|
+
],
|
17
22
|
]
|
18
23
|
|
19
24
|
|
20
25
|
def convert_file_dict_to_httpx_tuples(
|
21
|
-
d:
|
22
|
-
) ->
|
26
|
+
d: Dict[str, Union[File, List[File]]],
|
27
|
+
) -> List[Tuple[str, File]]:
|
23
28
|
"""
|
24
29
|
The format we use is a list of tuples, where the first element is the
|
25
30
|
name of the file and the second is the file object. Typically HTTPX wants
|
@@ -36,3 +41,27 @@ def convert_file_dict_to_httpx_tuples(
|
|
36
41
|
else:
|
37
42
|
httpx_tuples.append((key, file_like))
|
38
43
|
return httpx_tuples
|
44
|
+
|
45
|
+
|
46
|
+
def with_content_type(*, file: File, default_content_type: str) -> File:
|
47
|
+
"""
|
48
|
+
This function resolves to the file's content type, if provided, and defaults
|
49
|
+
to the default_content_type value if not.
|
50
|
+
"""
|
51
|
+
if isinstance(file, tuple):
|
52
|
+
if len(file) == 2:
|
53
|
+
filename, content = cast(Tuple[Optional[str], FileContent], file) # type: ignore
|
54
|
+
return (filename, content, default_content_type)
|
55
|
+
elif len(file) == 3:
|
56
|
+
filename, content, file_content_type = cast(Tuple[Optional[str], FileContent, Optional[str]], file) # type: ignore
|
57
|
+
out_content_type = file_content_type or default_content_type
|
58
|
+
return (filename, content, out_content_type)
|
59
|
+
elif len(file) == 4:
|
60
|
+
filename, content, file_content_type, headers = cast( # type: ignore
|
61
|
+
Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]], file
|
62
|
+
)
|
63
|
+
out_content_type = file_content_type or default_content_type
|
64
|
+
return (filename, content, out_content_type, headers)
|
65
|
+
else:
|
66
|
+
raise ValueError(f"Unexpected tuple length: {len(file)}")
|
67
|
+
return (None, file, default_content_type)
|