athena-intelligence 0.1.124__py3-none-any.whl → 0.1.126__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. athena/__init__.py +3 -0
  2. athena/agents/client.py +88 -36
  3. athena/agents/drive/client.py +80 -32
  4. athena/agents/general/client.py +222 -91
  5. athena/agents/research/client.py +80 -32
  6. athena/agents/sql/client.py +80 -32
  7. athena/base_client.py +13 -11
  8. athena/client.py +191 -80
  9. athena/core/__init__.py +21 -4
  10. athena/core/client_wrapper.py +9 -10
  11. athena/core/file.py +37 -8
  12. athena/core/http_client.py +97 -41
  13. athena/core/jsonable_encoder.py +33 -31
  14. athena/core/pydantic_utilities.py +272 -4
  15. athena/core/query_encoder.py +38 -13
  16. athena/core/request_options.py +5 -2
  17. athena/core/serialization.py +272 -0
  18. athena/errors/internal_server_error.py +2 -3
  19. athena/errors/unauthorized_error.py +2 -3
  20. athena/errors/unprocessable_entity_error.py +2 -3
  21. athena/query/client.py +208 -58
  22. athena/tools/calendar/client.py +82 -30
  23. athena/tools/client.py +576 -184
  24. athena/tools/email/client.py +117 -43
  25. athena/tools/structured_data_extractor/client.py +118 -67
  26. athena/tools/tasks/client.py +41 -17
  27. athena/types/asset_node.py +14 -24
  28. athena/types/asset_not_found_error.py +11 -21
  29. athena/types/chunk.py +11 -21
  30. athena/types/chunk_content_item.py +21 -41
  31. athena/types/chunk_result.py +13 -23
  32. athena/types/custom_agent_response.py +12 -22
  33. athena/types/data_frame_request_out.py +11 -21
  34. athena/types/data_frame_unknown_format_error.py +11 -21
  35. athena/types/document_chunk.py +12 -22
  36. athena/types/drive_agent_response.py +12 -22
  37. athena/types/file_chunk_request_out.py +11 -21
  38. athena/types/file_too_large_error.py +11 -21
  39. athena/types/folder_response.py +11 -21
  40. athena/types/general_agent_config.py +11 -21
  41. athena/types/general_agent_config_enabled_tools_item.py +0 -1
  42. athena/types/general_agent_request.py +13 -23
  43. athena/types/general_agent_response.py +12 -22
  44. athena/types/image_url_content.py +11 -21
  45. athena/types/parent_folder_error.py +11 -21
  46. athena/types/prompt_message.py +12 -22
  47. athena/types/research_agent_response.py +12 -22
  48. athena/types/save_asset_request_out.py +11 -21
  49. athena/types/sql_agent_response.py +13 -23
  50. athena/types/structured_data_extractor_response.py +15 -25
  51. athena/types/text_content.py +11 -21
  52. athena/types/tool.py +1 -13
  53. athena/types/type.py +1 -21
  54. athena/version.py +0 -1
  55. {athena_intelligence-0.1.124.dist-info → athena_intelligence-0.1.126.dist-info}/METADATA +12 -4
  56. athena_intelligence-0.1.126.dist-info/RECORD +87 -0
  57. {athena_intelligence-0.1.124.dist-info → athena_intelligence-0.1.126.dist-info}/WHEEL +1 -1
  58. athena_intelligence-0.1.124.dist-info/RECORD +0 -86
athena/client.py CHANGED
@@ -7,8 +7,9 @@ import typing
7
7
  import warnings
8
8
 
9
9
  import httpx
10
- from typing import cast, List, Union
10
+ from typing import cast, List, Tuple, Union, Optional
11
11
  from typing_extensions import TypeVar, ParamSpec
12
+ from langserve import RemoteRunnable
12
13
 
13
14
  from . import core
14
15
  from .base_client import BaseAthena, AsyncBaseAthena
@@ -24,23 +25,23 @@ if typing.TYPE_CHECKING:
24
25
 
25
26
  P = ParamSpec("P")
26
27
  T = TypeVar("T")
27
- U = TypeVar('U')
28
+ U = TypeVar("U")
28
29
 
29
30
 
30
31
  def _inherit_signature_and_doc(
31
- f: typing.Callable[P, T],
32
- replace_in_doc: typing.Dict[str, str]
32
+ f: typing.Callable[P, T], replace_in_doc: typing.Dict[str, str]
33
33
  ) -> typing.Callable[..., typing.Callable[P, U]]:
34
34
  def decorator(decorated):
35
35
  for old, new in replace_in_doc.items():
36
36
  assert old in f.__doc__
37
37
  decorated.__doc__ = f.__doc__.replace(old, new)
38
38
  return decorated
39
+
39
40
  return decorator
40
41
 
41
42
 
42
43
  class SpecialEnvironments(enum.Enum):
43
- AUTODETECT_ENVIRONMENT = 'AUTO'
44
+ AUTODETECT_ENVIRONMENT = "AUTO"
44
45
 
45
46
 
46
47
  @dataclasses.dataclass
@@ -53,7 +54,7 @@ class AthenaAsset:
53
54
  if self.media_type == "application/sql":
54
55
  # it is safe to import IPython in `_repr_mimebundle_`
55
56
  # as this is only intended to be invoked by IPython.
56
- from IPython import display # type: ignore[import]
57
+ from IPython import display # type: ignore[import]
57
58
 
58
59
  code = display.Code(
59
60
  data=self.data.decode(),
@@ -84,17 +85,19 @@ class WrappedToolsClient(ToolsClient):
84
85
  file_io = client.tools.get_file(asset_id="asset_id")
85
86
  pl.read_csv(file_io)
86
87
  """
87
- file_bytes = b''.join(self.raw_data(asset_id=asset_id))
88
+ file_bytes = b"".join(self.raw_data(asset_id=asset_id))
88
89
  bytes_io = io.BytesIO(file_bytes)
89
90
  return bytes_io
90
91
 
91
- @_inherit_signature_and_doc(ToolsClient.data_frame, {'DataFrameRequestOut': 'pd.DataFrame'})
92
- def data_frame(self, *, asset_id: str, **kwargs) -> 'pd.DataFrame':
92
+ @_inherit_signature_and_doc(
93
+ ToolsClient.data_frame, {"DataFrameRequestOut": "pd.DataFrame"}
94
+ )
95
+ def data_frame(self, *, asset_id: str, **kwargs) -> "pd.DataFrame":
93
96
  _check_pandas_installed()
94
97
  model = super().data_frame(asset_id=asset_id, **kwargs)
95
98
  return _read_json_frame(model)
96
99
 
97
- def read_data_frame(self, asset_id: str, *args, **kwargs) -> 'pd.DataFrame':
100
+ def read_data_frame(self, asset_id: str, *args, **kwargs) -> "pd.DataFrame":
98
101
  """
99
102
  Parameters
100
103
  ----------
@@ -116,11 +119,10 @@ class WrappedToolsClient(ToolsClient):
116
119
  client.tools.read_data_frame(asset_id="asset_id")
117
120
  """
118
121
  _check_pandas_installed()
119
- file_bytes = self.get_file(asset_id)
120
- return _to_pandas_df(file_bytes, *args, **kwargs)
122
+ file_bytes, media_type = self._get_file_and_media_type(asset_id=asset_id)
123
+ return _to_pandas_df(file_bytes, *args, media_type=media_type, **kwargs)
121
124
 
122
-
123
- def save_asset( # type: ignore[override]
125
+ def save_asset( # type: ignore[override]
124
126
  self,
125
127
  asset_object: Union["pd.DataFrame", "pd.Series", core.File],
126
128
  *,
@@ -131,7 +133,7 @@ class WrappedToolsClient(ToolsClient):
131
133
  """
132
134
  Parameters
133
135
  ----------
134
- asset_object : pd.DataFrame | pd.Series | core.File
136
+ asset_object : pd.DataFrame | pd.Series | matplotlib.figure.Figure | core.File
135
137
  A pandas data frame, series, matplotlib figure, or core.File
136
138
 
137
139
  parent_folder_id : typing.Optional[str]
@@ -143,6 +145,8 @@ class WrappedToolsClient(ToolsClient):
143
145
  request_options : typing.Optional[RequestOptions]
144
146
  Request-specific configuration.
145
147
 
148
+ **kwargs : passed down to conversion methods
149
+
146
150
  Returns
147
151
  -------
148
152
  SaveAssetRequestOut
@@ -155,27 +159,14 @@ class WrappedToolsClient(ToolsClient):
155
159
  client = Athena(api_key="YOUR_API_KEY")
156
160
  client.tools.save_asset(df)
157
161
  """
158
- asset_object = _convert_asset_object(asset_object=asset_object, name=name)
159
- return super().save_asset(
160
- file=asset_object, parent_folder_id=parent_folder_id, **kwargs
162
+ asset_object = _convert_asset_object(
163
+ asset_object=asset_object, name=name, **kwargs
161
164
  )
165
+ return super().save_asset(file=asset_object, parent_folder_id=parent_folder_id)
162
166
 
163
- def get_asset(self, asset_id: str) -> Union["pd.DataFrame", AthenaAsset]:
167
+ def _get_file_and_media_type(self, asset_id: str) -> Tuple[io.BytesIO, str]:
164
168
  """
165
- Parameters
166
- ----------
167
- asset_id : str
168
-
169
- Returns
170
- -------
171
- pd.DataFrame or AthenaAsset
172
-
173
- Examples
174
- --------
175
- from athena.client import Athena
176
-
177
- client = Athena(api_key="YOUR_API_KEY")
178
- client.tools.get_asset(asset_id="asset_id")
169
+ Gets the file togehter with media type returned by server
179
170
  """
180
171
  # while we wait for https://github.com/fern-api/fern/issues/4316
181
172
  result = self._client_wrapper.httpx_client.request(
@@ -196,6 +187,27 @@ class WrappedToolsClient(ToolsClient):
196
187
  # fallback to `libmagic` inference
197
188
  media_type = _infer_media_type(bytes_io=file_bytes)
198
189
 
190
+ return file_bytes, media_type
191
+
192
+ def get_asset(self, asset_id: str) -> Union["pd.DataFrame", AthenaAsset]:
193
+ """
194
+ Parameters
195
+ ----------
196
+ asset_id : str
197
+
198
+ Returns
199
+ -------
200
+ pd.DataFrame or AthenaAsset
201
+
202
+ Examples
203
+ --------
204
+ from athena.client import Athena
205
+
206
+ client = Athena(api_key="YOUR_API_KEY")
207
+ client.tools.get_asset(asset_id="asset_id")
208
+ """
209
+ file_bytes, media_type = self._get_file_and_media_type(asset_id=asset_id)
210
+
199
211
  media_type_aliases = {"image/jpg": "image/jpeg"}
200
212
  media_type = media_type_aliases.get(media_type, media_type)
201
213
 
@@ -226,22 +238,26 @@ class WrappedToolsClient(ToolsClient):
226
238
  return _to_pandas_df(file_bytes, media_type=media_type)
227
239
 
228
240
  raise NotImplementedError("Assets of `{media_type}` type are not yet supported")
229
-
241
+
230
242
 
231
243
  class WrappedQueryClient(QueryClient):
232
244
 
233
- @_inherit_signature_and_doc(QueryClient.execute, {'DataFrameRequestOut': 'pd.DataFrame'})
234
- def execute(self, *, sql_command: str, database_asset_ids: Union[str, List[str]], **kwargs) -> 'pd.DataFrame':
245
+ @_inherit_signature_and_doc(
246
+ QueryClient.execute, {"DataFrameRequestOut": "pd.DataFrame"}
247
+ )
248
+ def execute(
249
+ self, *, sql_command: str, database_asset_ids: Union[str, List[str]], **kwargs
250
+ ) -> "pd.DataFrame":
235
251
  _check_pandas_installed()
236
252
  model = super().execute(
237
- sql_command=sql_command,
238
- database_asset_ids=database_asset_ids,
239
- **kwargs
253
+ sql_command=sql_command, database_asset_ids=database_asset_ids, **kwargs
240
254
  )
241
255
  return _read_json_frame(model)
242
256
 
243
- @_inherit_signature_and_doc(QueryClient.execute_snippet, {'DataFrameRequestOut': 'pd.DataFrame'})
244
- def execute_snippet(self, *, snippet_asset_id: str, **kwargs) -> 'pd.DataFrame':
257
+ @_inherit_signature_and_doc(
258
+ QueryClient.execute_snippet, {"DataFrameRequestOut": "pd.DataFrame"}
259
+ )
260
+ def execute_snippet(self, *, snippet_asset_id: str, **kwargs) -> "pd.DataFrame":
245
261
  _check_pandas_installed()
246
262
  model = super().execute_snippet(snippet_asset_id=snippet_asset_id, **kwargs)
247
263
  return _read_json_frame(model)
@@ -251,12 +267,11 @@ def _add_docs_for_async_variant(obj):
251
267
  def decorator(decorated):
252
268
  doc = obj.__doc__
253
269
  name = obj.__name__
254
- decorated.__doc__ = (
255
- doc
256
- .replace('client = Athena', 'client = AsyncAthena')
257
- .replace(f'client.tools.{name}', f'await client.tools.{name}')
258
- )
270
+ decorated.__doc__ = doc.replace(
271
+ "client = Athena", "client = AsyncAthena"
272
+ ).replace(f"client.tools.{name}", f"await client.tools.{name}")
259
273
  return decorated
274
+
260
275
  return decorator
261
276
 
262
277
 
@@ -264,18 +279,20 @@ class WrappedAsyncToolsClient(AsyncToolsClient):
264
279
 
265
280
  @_add_docs_for_async_variant(WrappedToolsClient.get_file)
266
281
  async def get_file(self, asset_id: str) -> io.BytesIO:
267
- file_bytes = b''.join([gen async for gen in self.raw_data(asset_id=asset_id)])
282
+ file_bytes = b"".join([gen async for gen in self.raw_data(asset_id=asset_id)])
268
283
  bytes_io = io.BytesIO(file_bytes)
269
284
  return bytes_io
270
285
 
271
- @_inherit_signature_and_doc(AsyncToolsClient.data_frame, {'DataFrameRequestOut': 'pd.DataFrame'})
272
- async def data_frame(self, *, asset_id: str, **kwargs) -> 'pd.DataFrame':
286
+ @_inherit_signature_and_doc(
287
+ AsyncToolsClient.data_frame, {"DataFrameRequestOut": "pd.DataFrame"}
288
+ )
289
+ async def data_frame(self, *, asset_id: str, **kwargs) -> "pd.DataFrame":
273
290
  _check_pandas_installed()
274
291
  model = await super().data_frame(asset_id=asset_id, **kwargs)
275
292
  return _read_json_frame(model)
276
293
 
277
294
  @_add_docs_for_async_variant(WrappedToolsClient.read_data_frame)
278
- async def read_data_frame(self, asset_id: str, *args, **kwargs) -> 'pd.DataFrame':
295
+ async def read_data_frame(self, asset_id: str, *args, **kwargs) -> "pd.DataFrame":
279
296
  _check_pandas_installed()
280
297
  file_bytes = await self.get_file(asset_id)
281
298
  return _to_pandas_df(file_bytes, *args, **kwargs)
@@ -289,31 +306,119 @@ class WrappedAsyncToolsClient(AsyncToolsClient):
289
306
  name: Union[str, None] = None,
290
307
  **kwargs,
291
308
  ) -> SaveAssetRequestOut:
292
- asset_object = _convert_asset_object(asset_object=asset_object, name=name)
309
+ asset_object = _convert_asset_object(
310
+ asset_object=asset_object, name=name, **kwargs
311
+ )
293
312
  return await super().save_asset(
294
- file=asset_object, parent_folder_id=parent_folder_id, **kwargs
313
+ file=asset_object, parent_folder_id=parent_folder_id
295
314
  )
296
315
 
297
316
 
298
317
  class WrappedAsyncQueryClient(AsyncQueryClient):
299
318
 
300
- @_inherit_signature_and_doc(AsyncQueryClient.execute, {'DataFrameRequestOut': 'pd.DataFrame'})
301
- async def execute(self, *, sql_command: str, database_asset_ids: Union[str, List[str]], **kwargs) -> 'pd.DataFrame':
319
+ @_inherit_signature_and_doc(
320
+ AsyncQueryClient.execute, {"DataFrameRequestOut": "pd.DataFrame"}
321
+ )
322
+ async def execute(
323
+ self, *, sql_command: str, database_asset_ids: Union[str, List[str]], **kwargs
324
+ ) -> "pd.DataFrame":
302
325
  _check_pandas_installed()
303
326
  model = await super().execute(
304
- sql_command=sql_command,
305
- database_asset_ids=database_asset_ids,
306
- **kwargs
327
+ sql_command=sql_command, database_asset_ids=database_asset_ids, **kwargs
307
328
  )
308
329
  return _read_json_frame(model)
309
330
 
310
- @_inherit_signature_and_doc(AsyncQueryClient.execute_snippet, {'DataFrameRequestOut': 'pd.DataFrame'})
311
- async def execute_snippet(self, *, snippet_asset_id: str, **kwargs) -> 'pd.DataFrame':
331
+ @_inherit_signature_and_doc(
332
+ AsyncQueryClient.execute_snippet, {"DataFrameRequestOut": "pd.DataFrame"}
333
+ )
334
+ async def execute_snippet(
335
+ self, *, snippet_asset_id: str, **kwargs
336
+ ) -> "pd.DataFrame":
312
337
  _check_pandas_installed()
313
- model = await super().execute_snippet(snippet_asset_id=snippet_asset_id, **kwargs)
338
+ model = await super().execute_snippet(
339
+ snippet_asset_id=snippet_asset_id, **kwargs
340
+ )
314
341
  return _read_json_frame(model)
315
342
 
316
343
 
344
+ class AthenaModel(RemoteRunnable):
345
+ """Use Athena's models directly with a Langchain-compatible client.
346
+
347
+ The Langchain Runnable interface is supported:
348
+ - `invoke`: Invoke the model with a string or a Langchain message.
349
+ - `batch`: Batch invoke the model on multiple inputs.
350
+ - `astream_events`: Streaming for real-time applications.
351
+
352
+ See Langchain documentation for more details.
353
+
354
+ Examples
355
+ --------
356
+ from src.athena.client import Athena
357
+ import asyncio
358
+ from langchain_core.messages import HumanMessage
359
+
360
+ client = Athena()
361
+ llm = client.llm
362
+
363
+ # sync invoke -- use strings or langchain messages
364
+ result = llm.invoke("Hello")
365
+ print(result.content)
366
+
367
+ result = llm.invoke("Hello")
368
+ print(result.content)
369
+
370
+ # choose the model explicitly
371
+ claude = llm.with_config(configurable={"model": "claude_3_7_sonnet"})
372
+ print(claude.invoke("Who are you?").content)
373
+
374
+ # batch (for multiple parallel requests)
375
+ results = llm.batch(["Hello", "World"] * 5)
376
+ print([r.content for r in results])
377
+
378
+ # handle stream events
379
+ async def stream_response():
380
+ separator = "---------------------------------------"
381
+ print(separator)
382
+ print("Starting stream")
383
+ print(separator)
384
+ async for event in llm.astream_events("Hello. Please respond in 5 sentences."):
385
+ data = event["data"]
386
+ if "chunk" in data:
387
+ print(data["chunk"].content)
388
+ elif "output" in data:
389
+ print(separator)
390
+ print("Final response")
391
+ print(separator)
392
+ print(data["output"].content)
393
+ else:
394
+ print("(other event) ", event.keys())
395
+
396
+
397
+ asyncio.run(stream_response())
398
+ from athena.client import Athena
399
+ client = Athena()
400
+ llm = client.llm
401
+
402
+ # sync invoke -- use strings or langchain messages
403
+ result = llm.invoke("Hello")
404
+ print(result.content)
405
+
406
+ # batch
407
+ results = llm.batch(["Hello", "World"])
408
+ print(r.content for r in results)
409
+ """
410
+
411
+ def __init__(self, base_url: str, api_key: str, timeout: Optional[float] = None):
412
+ self.base_url = base_url
413
+ self.api_key = api_key
414
+ self.timeout = timeout
415
+ super().__init__(
416
+ base_url + "/api/models/general",
417
+ headers={"X-API-KEY": api_key},
418
+ timeout=timeout,
419
+ )
420
+
421
+
317
422
  class Athena(BaseAthena):
318
423
  """
319
424
  Use this class to access the different functions within the SDK. You can instantiate any number of clients with different configuration that will propogate to these functions.
@@ -350,37 +455,40 @@ class Athena(BaseAthena):
350
455
  self,
351
456
  *,
352
457
  base_url: typing.Optional[str] = None,
353
- environment: Union[AthenaEnvironment, SpecialEnvironments] = SpecialEnvironments.AUTODETECT_ENVIRONMENT, # type: ignore[arg-type]
458
+ environment: Union[AthenaEnvironment, SpecialEnvironments] = SpecialEnvironments.AUTODETECT_ENVIRONMENT, # type: ignore[arg-type]
354
459
  api_key: typing.Optional[str] = None,
355
460
  timeout: typing.Optional[float] = 60,
356
- httpx_client: typing.Optional[httpx.Client] = None
461
+ httpx_client: typing.Optional[httpx.Client] = None,
357
462
  ):
358
463
  if api_key is None:
359
464
  try:
360
- api_key = os.environ['ATHENA_API_KEY']
465
+ api_key = os.environ["ATHENA_API_KEY"]
361
466
  except KeyError:
362
467
  raise TypeError(
363
468
  "Athena() missing 1 required keyword-only argument: 'api_key'"
364
- ' (ATHENA_API_KEY environment variable not found)'
469
+ " (ATHENA_API_KEY environment variable not found)"
365
470
  )
366
471
  if environment == SpecialEnvironments.AUTODETECT_ENVIRONMENT:
367
- if 'ATHENA_API_URL' in os.environ:
472
+ if "ATHENA_API_URL" in os.environ:
368
473
 
369
474
  class AutodetectedEnvironments(enum.Enum):
370
- CURRENT = os.environ['ATHENA_API_URL']
475
+ CURRENT = os.environ["ATHENA_API_URL"]
371
476
 
372
477
  environment = cast(AthenaEnvironment, AutodetectedEnvironments.CURRENT)
373
478
  else:
374
479
  environment = AthenaEnvironment.PRODUCTION
375
480
  super().__init__(
376
481
  base_url=base_url,
377
- environment=environment, # type: ignore[arg-type]
482
+ environment=environment, # type: ignore[arg-type]
378
483
  api_key=api_key,
379
484
  timeout=timeout,
380
485
  httpx_client=httpx_client,
381
486
  )
382
487
  self.tools = WrappedToolsClient(client_wrapper=self._client_wrapper)
383
488
  self.query = WrappedQueryClient(client_wrapper=self._client_wrapper)
489
+ self.llm = AthenaModel(
490
+ base_url=base_url or environment.value, api_key=api_key, timeout=timeout
491
+ )
384
492
 
385
493
 
386
494
  class AsyncAthena(AsyncBaseAthena):
@@ -420,31 +528,31 @@ class AsyncAthena(AsyncBaseAthena):
420
528
  self,
421
529
  *,
422
530
  base_url: typing.Optional[str] = None,
423
- environment: Union[AthenaEnvironment, SpecialEnvironments] = SpecialEnvironments.AUTODETECT_ENVIRONMENT, # type: ignore[arg-type]
531
+ environment: Union[AthenaEnvironment, SpecialEnvironments] = SpecialEnvironments.AUTODETECT_ENVIRONMENT, # type: ignore[arg-type]
424
532
  api_key: typing.Optional[str] = None,
425
533
  timeout: typing.Optional[float] = 60,
426
- httpx_client: typing.Optional[httpx.AsyncClient] = None
534
+ httpx_client: typing.Optional[httpx.AsyncClient] = None,
427
535
  ):
428
536
  if api_key is None:
429
537
  try:
430
- api_key = os.environ['ATHENA_API_KEY']
538
+ api_key = os.environ["ATHENA_API_KEY"]
431
539
  except KeyError:
432
540
  raise TypeError(
433
541
  "AsyncAthena() missing 1 required keyword-only argument: 'api_key'"
434
- ' (ATHENA_API_KEY environment variable not found)'
542
+ " (ATHENA_API_KEY environment variable not found)"
435
543
  )
436
544
  if environment == SpecialEnvironments.AUTODETECT_ENVIRONMENT:
437
- if 'ATHENA_API_URL' in os.environ:
545
+ if "ATHENA_API_URL" in os.environ:
438
546
 
439
547
  class AutodetectedEnvironments(enum.Enum):
440
- CURRENT = os.environ['ATHENA_API_URL']
548
+ CURRENT = os.environ["ATHENA_API_URL"]
441
549
 
442
550
  environment = cast(AthenaEnvironment, AutodetectedEnvironments.CURRENT)
443
551
  else:
444
552
  environment = AthenaEnvironment.PRODUCTION
445
553
  super().__init__(
446
554
  base_url=base_url,
447
- environment=environment, # type: ignore[arg-type]
555
+ environment=environment, # type: ignore[arg-type]
448
556
  api_key=api_key,
449
557
  timeout=timeout,
450
558
  httpx_client=httpx_client,
@@ -453,19 +561,20 @@ class AsyncAthena(AsyncBaseAthena):
453
561
  self.query = WrappedAsyncQueryClient(client_wrapper=self._client_wrapper)
454
562
 
455
563
 
456
- def _read_json_frame(model: DataFrameRequestOut) -> 'pd.DataFrame':
564
+ def _read_json_frame(model: DataFrameRequestOut) -> "pd.DataFrame":
457
565
  import pandas as pd
458
566
 
459
567
  string_io = io.StringIO(model.json())
460
568
 
461
569
  with warnings.catch_warnings():
462
570
  # Filter warnings due to https://github.com/pandas-dev/pandas/issues/59511
463
- warnings.simplefilter(action='ignore', category=FutureWarning)
464
- return pd.read_json(string_io, orient='split')
571
+ warnings.simplefilter(action="ignore", category=FutureWarning)
572
+ return pd.read_json(string_io, orient="split")
465
573
 
466
574
 
467
575
  def _check_pandas_installed():
468
576
  import pandas
577
+
469
578
  assert pandas
470
579
 
471
580
 
@@ -514,10 +623,12 @@ def _to_pandas_df(
514
623
  def _convert_asset_object(
515
624
  asset_object: Union["pd.DataFrame", "pd.Series", core.File],
516
625
  name: Union[str, None] = None,
626
+ **kwargs,
517
627
  ) -> core.File:
518
628
  import pandas as pd
629
+
519
630
  try:
520
- from IPython.core.formatters import format_display_data # type: ignore[import]
631
+ from IPython.core.formatters import format_display_data # type: ignore[import]
521
632
  except ImportError:
522
633
  format_display_data = None
523
634
 
@@ -526,7 +637,7 @@ def _convert_asset_object(
526
637
  if isinstance(asset_object, pd.DataFrame):
527
638
  return (
528
639
  name or "Uploaded data frame",
529
- asset_object.to_parquet(path=None),
640
+ asset_object.to_parquet(path=None, **kwargs),
530
641
  "application/vnd.apache.parquet",
531
642
  )
532
643
  if format_display_data:
@@ -548,4 +659,4 @@ def _convert_asset_object(
548
659
  image_bytes,
549
660
  media_type,
550
661
  )
551
- return asset_object # type: ignore[return-value]
662
+ return asset_object # type: ignore[return-value]
athena/core/__init__.py CHANGED
@@ -3,28 +3,45 @@
3
3
  from .api_error import ApiError
4
4
  from .client_wrapper import AsyncClientWrapper, BaseClientWrapper, SyncClientWrapper
5
5
  from .datetime_utils import serialize_datetime
6
- from .file import File, convert_file_dict_to_httpx_tuples
6
+ from .file import File, convert_file_dict_to_httpx_tuples, with_content_type
7
7
  from .http_client import AsyncHttpClient, HttpClient
8
8
  from .jsonable_encoder import jsonable_encoder
9
- from .pydantic_utilities import deep_union_pydantic_dicts, pydantic_v1
9
+ from .pydantic_utilities import (
10
+ IS_PYDANTIC_V2,
11
+ UniversalBaseModel,
12
+ UniversalRootModel,
13
+ parse_obj_as,
14
+ universal_field_validator,
15
+ universal_root_validator,
16
+ update_forward_refs,
17
+ )
10
18
  from .query_encoder import encode_query
11
19
  from .remove_none_from_dict import remove_none_from_dict
12
20
  from .request_options import RequestOptions
21
+ from .serialization import FieldMetadata, convert_and_respect_annotation_metadata
13
22
 
14
23
  __all__ = [
15
24
  "ApiError",
16
25
  "AsyncClientWrapper",
17
26
  "AsyncHttpClient",
18
27
  "BaseClientWrapper",
28
+ "FieldMetadata",
19
29
  "File",
20
30
  "HttpClient",
31
+ "IS_PYDANTIC_V2",
21
32
  "RequestOptions",
22
33
  "SyncClientWrapper",
34
+ "UniversalBaseModel",
35
+ "UniversalRootModel",
36
+ "convert_and_respect_annotation_metadata",
23
37
  "convert_file_dict_to_httpx_tuples",
24
- "deep_union_pydantic_dicts",
25
38
  "encode_query",
26
39
  "jsonable_encoder",
27
- "pydantic_v1",
40
+ "parse_obj_as",
28
41
  "remove_none_from_dict",
29
42
  "serialize_datetime",
43
+ "universal_field_validator",
44
+ "universal_root_validator",
45
+ "update_forward_refs",
46
+ "with_content_type",
30
47
  ]
@@ -1,10 +1,9 @@
1
1
  # This file was auto-generated by Fern from our API Definition.
2
2
 
3
3
  import typing
4
-
5
4
  import httpx
6
-
7
- from .http_client import AsyncHttpClient, HttpClient
5
+ from .http_client import HttpClient
6
+ from .http_client import AsyncHttpClient
8
7
 
9
8
 
10
9
  class BaseClientWrapper:
@@ -17,7 +16,7 @@ class BaseClientWrapper:
17
16
  headers: typing.Dict[str, str] = {
18
17
  "X-Fern-Language": "Python",
19
18
  "X-Fern-SDK-Name": "athena-intelligence",
20
- "X-Fern-SDK-Version": "0.1.124",
19
+ "X-Fern-SDK-Version": "0.1.126",
21
20
  }
22
21
  headers["X-API-KEY"] = self.api_key
23
22
  return headers
@@ -36,9 +35,9 @@ class SyncClientWrapper(BaseClientWrapper):
36
35
  super().__init__(api_key=api_key, base_url=base_url, timeout=timeout)
37
36
  self.httpx_client = HttpClient(
38
37
  httpx_client=httpx_client,
39
- base_headers=self.get_headers(),
40
- base_timeout=self.get_timeout(),
41
- base_url=self.get_base_url(),
38
+ base_headers=self.get_headers,
39
+ base_timeout=self.get_timeout,
40
+ base_url=self.get_base_url,
42
41
  )
43
42
 
44
43
 
@@ -49,7 +48,7 @@ class AsyncClientWrapper(BaseClientWrapper):
49
48
  super().__init__(api_key=api_key, base_url=base_url, timeout=timeout)
50
49
  self.httpx_client = AsyncHttpClient(
51
50
  httpx_client=httpx_client,
52
- base_headers=self.get_headers(),
53
- base_timeout=self.get_timeout(),
54
- base_url=self.get_base_url(),
51
+ base_headers=self.get_headers,
52
+ base_timeout=self.get_timeout,
53
+ base_url=self.get_base_url,
55
54
  )
athena/core/file.py CHANGED
@@ -1,25 +1,30 @@
1
1
  # This file was auto-generated by Fern from our API Definition.
2
2
 
3
- import typing
3
+ from typing import IO, Dict, List, Mapping, Optional, Tuple, Union, cast
4
4
 
5
5
  # File typing inspired by the flexibility of types within the httpx library
6
6
  # https://github.com/encode/httpx/blob/master/httpx/_types.py
7
- FileContent = typing.Union[typing.IO[bytes], bytes, str]
8
- File = typing.Union[
7
+ FileContent = Union[IO[bytes], bytes, str]
8
+ File = Union[
9
9
  # file (or bytes)
10
10
  FileContent,
11
11
  # (filename, file (or bytes))
12
- typing.Tuple[typing.Optional[str], FileContent],
12
+ Tuple[Optional[str], FileContent],
13
13
  # (filename, file (or bytes), content_type)
14
- typing.Tuple[typing.Optional[str], FileContent, typing.Optional[str]],
14
+ Tuple[Optional[str], FileContent, Optional[str]],
15
15
  # (filename, file (or bytes), content_type, headers)
16
- typing.Tuple[typing.Optional[str], FileContent, typing.Optional[str], typing.Mapping[str, str]],
16
+ Tuple[
17
+ Optional[str],
18
+ FileContent,
19
+ Optional[str],
20
+ Mapping[str, str],
21
+ ],
17
22
  ]
18
23
 
19
24
 
20
25
  def convert_file_dict_to_httpx_tuples(
21
- d: typing.Dict[str, typing.Union[File, typing.List[File]]]
22
- ) -> typing.List[typing.Tuple[str, File]]:
26
+ d: Dict[str, Union[File, List[File]]],
27
+ ) -> List[Tuple[str, File]]:
23
28
  """
24
29
  The format we use is a list of tuples, where the first element is the
25
30
  name of the file and the second is the file object. Typically HTTPX wants
@@ -36,3 +41,27 @@ def convert_file_dict_to_httpx_tuples(
36
41
  else:
37
42
  httpx_tuples.append((key, file_like))
38
43
  return httpx_tuples
44
+
45
+
46
+ def with_content_type(*, file: File, default_content_type: str) -> File:
47
+ """
48
+ This function resolves to the file's content type, if provided, and defaults
49
+ to the default_content_type value if not.
50
+ """
51
+ if isinstance(file, tuple):
52
+ if len(file) == 2:
53
+ filename, content = cast(Tuple[Optional[str], FileContent], file) # type: ignore
54
+ return (filename, content, default_content_type)
55
+ elif len(file) == 3:
56
+ filename, content, file_content_type = cast(Tuple[Optional[str], FileContent, Optional[str]], file) # type: ignore
57
+ out_content_type = file_content_type or default_content_type
58
+ return (filename, content, out_content_type)
59
+ elif len(file) == 4:
60
+ filename, content, file_content_type, headers = cast( # type: ignore
61
+ Tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]], file
62
+ )
63
+ out_content_type = file_content_type or default_content_type
64
+ return (filename, content, out_content_type, headers)
65
+ else:
66
+ raise ValueError(f"Unexpected tuple length: {len(file)}")
67
+ return (None, file, default_content_type)