athena-intelligence 0.1.125__py3-none-any.whl → 0.1.127__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. athena/__init__.py +7 -0
  2. athena/agents/client.py +88 -36
  3. athena/agents/drive/client.py +80 -32
  4. athena/agents/general/client.py +222 -91
  5. athena/agents/research/client.py +80 -32
  6. athena/agents/sql/client.py +80 -32
  7. athena/base_client.py +13 -11
  8. athena/client.py +161 -61
  9. athena/core/__init__.py +21 -4
  10. athena/core/client_wrapper.py +9 -10
  11. athena/core/file.py +37 -8
  12. athena/core/http_client.py +97 -41
  13. athena/core/jsonable_encoder.py +33 -31
  14. athena/core/pydantic_utilities.py +272 -4
  15. athena/core/query_encoder.py +38 -13
  16. athena/core/request_options.py +5 -2
  17. athena/core/serialization.py +272 -0
  18. athena/errors/internal_server_error.py +2 -3
  19. athena/errors/unauthorized_error.py +2 -3
  20. athena/errors/unprocessable_entity_error.py +2 -3
  21. athena/query/client.py +208 -58
  22. athena/tools/calendar/client.py +82 -30
  23. athena/tools/client.py +956 -188
  24. athena/tools/email/client.py +117 -43
  25. athena/tools/structured_data_extractor/client.py +118 -67
  26. athena/tools/tasks/client.py +41 -17
  27. athena/types/__init__.py +4 -0
  28. athena/types/asset_content_request_out.py +26 -0
  29. athena/types/asset_node.py +14 -24
  30. athena/types/asset_not_found_error.py +11 -21
  31. athena/types/asset_screenshot_response_out.py +43 -0
  32. athena/types/chunk.py +11 -21
  33. athena/types/chunk_content_item.py +21 -41
  34. athena/types/chunk_result.py +13 -23
  35. athena/types/custom_agent_response.py +12 -22
  36. athena/types/data_frame_request_out.py +11 -21
  37. athena/types/data_frame_unknown_format_error.py +11 -21
  38. athena/types/document_chunk.py +12 -22
  39. athena/types/drive_agent_response.py +12 -22
  40. athena/types/file_chunk_request_out.py +11 -21
  41. athena/types/file_too_large_error.py +11 -21
  42. athena/types/folder_response.py +11 -21
  43. athena/types/general_agent_config.py +12 -21
  44. athena/types/general_agent_config_enabled_tools_item.py +0 -1
  45. athena/types/general_agent_request.py +13 -23
  46. athena/types/general_agent_response.py +12 -22
  47. athena/types/image_url_content.py +11 -21
  48. athena/types/parent_folder_error.py +11 -21
  49. athena/types/prompt_message.py +12 -22
  50. athena/types/research_agent_response.py +12 -22
  51. athena/types/save_asset_request_out.py +11 -21
  52. athena/types/sql_agent_response.py +13 -23
  53. athena/types/structured_data_extractor_response.py +15 -25
  54. athena/types/text_content.py +11 -21
  55. athena/types/tool.py +1 -13
  56. athena/types/type.py +1 -21
  57. athena/version.py +0 -1
  58. {athena_intelligence-0.1.125.dist-info → athena_intelligence-0.1.127.dist-info}/METADATA +12 -4
  59. athena_intelligence-0.1.127.dist-info/RECORD +89 -0
  60. {athena_intelligence-0.1.125.dist-info → athena_intelligence-0.1.127.dist-info}/WHEEL +1 -1
  61. athena_intelligence-0.1.125.dist-info/RECORD +0 -86
@@ -1,14 +1,14 @@
1
1
  # This file was auto-generated by Fern from our API Definition.
2
2
 
3
3
  import typing
4
- from json.decoder import JSONDecodeError
5
-
6
- from ...core.api_error import ApiError
7
- from ...core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
8
- from ...core.pydantic_utilities import pydantic_v1
4
+ from ...core.client_wrapper import SyncClientWrapper
9
5
  from ...core.request_options import RequestOptions
10
- from ...errors.unprocessable_entity_error import UnprocessableEntityError
11
6
  from ...types.sql_agent_response import SqlAgentResponse
7
+ from ...core.pydantic_utilities import parse_obj_as
8
+ from ...errors.unprocessable_entity_error import UnprocessableEntityError
9
+ from json.decoder import JSONDecodeError
10
+ from ...core.api_error import ApiError
11
+ from ...core.client_wrapper import AsyncClientWrapper
12
12
 
13
13
  # this is used as the default value for optional parameters
14
14
  OMIT = typing.cast(typing.Any, ...)
@@ -21,19 +21,19 @@ class SqlClient:
21
21
  def invoke(
22
22
  self,
23
23
  *,
24
- config: typing.Dict[str, typing.Any],
25
- messages: typing.Sequence[typing.Dict[str, typing.Any]],
26
- request_options: typing.Optional[RequestOptions] = None
24
+ config: typing.Dict[str, typing.Optional[typing.Any]],
25
+ messages: typing.Sequence[typing.Dict[str, typing.Optional[typing.Any]]],
26
+ request_options: typing.Optional[RequestOptions] = None,
27
27
  ) -> SqlAgentResponse:
28
28
  """
29
29
  Coming soon! Generate, execute, and test SQL queries. Returns an asset ID for the query object.
30
30
 
31
31
  Parameters
32
32
  ----------
33
- config : typing.Dict[str, typing.Any]
33
+ config : typing.Dict[str, typing.Optional[typing.Any]]
34
34
  Configuration for the SQL agent including database connection details and query parameters
35
35
 
36
- messages : typing.Sequence[typing.Dict[str, typing.Any]]
36
+ messages : typing.Sequence[typing.Dict[str, typing.Optional[typing.Any]]]
37
37
  The messages to send to the SQL agent
38
38
 
39
39
  request_options : typing.Optional[RequestOptions]
@@ -46,7 +46,7 @@ class SqlClient:
46
46
 
47
47
  Examples
48
48
  --------
49
- from athena.client import Athena
49
+ from athena import Athena
50
50
 
51
51
  client = Athena(
52
52
  api_key="YOUR_API_KEY",
@@ -59,15 +59,35 @@ class SqlClient:
59
59
  _response = self._client_wrapper.httpx_client.request(
60
60
  "api/v0/agents/sql/invoke",
61
61
  method="POST",
62
- json={"config": config, "messages": messages},
62
+ json={
63
+ "config": config,
64
+ "messages": messages,
65
+ },
66
+ headers={
67
+ "content-type": "application/json",
68
+ },
63
69
  request_options=request_options,
64
70
  omit=OMIT,
65
71
  )
66
- if 200 <= _response.status_code < 300:
67
- return pydantic_v1.parse_obj_as(SqlAgentResponse, _response.json()) # type: ignore
68
- if _response.status_code == 422:
69
- raise UnprocessableEntityError(pydantic_v1.parse_obj_as(typing.Any, _response.json())) # type: ignore
70
72
  try:
73
+ if 200 <= _response.status_code < 300:
74
+ return typing.cast(
75
+ SqlAgentResponse,
76
+ parse_obj_as(
77
+ type_=SqlAgentResponse, # type: ignore
78
+ object_=_response.json(),
79
+ ),
80
+ )
81
+ if _response.status_code == 422:
82
+ raise UnprocessableEntityError(
83
+ typing.cast(
84
+ typing.Optional[typing.Any],
85
+ parse_obj_as(
86
+ type_=typing.Optional[typing.Any], # type: ignore
87
+ object_=_response.json(),
88
+ ),
89
+ )
90
+ )
71
91
  _response_json = _response.json()
72
92
  except JSONDecodeError:
73
93
  raise ApiError(status_code=_response.status_code, body=_response.text)
@@ -81,19 +101,19 @@ class AsyncSqlClient:
81
101
  async def invoke(
82
102
  self,
83
103
  *,
84
- config: typing.Dict[str, typing.Any],
85
- messages: typing.Sequence[typing.Dict[str, typing.Any]],
86
- request_options: typing.Optional[RequestOptions] = None
104
+ config: typing.Dict[str, typing.Optional[typing.Any]],
105
+ messages: typing.Sequence[typing.Dict[str, typing.Optional[typing.Any]]],
106
+ request_options: typing.Optional[RequestOptions] = None,
87
107
  ) -> SqlAgentResponse:
88
108
  """
89
109
  Coming soon! Generate, execute, and test SQL queries. Returns an asset ID for the query object.
90
110
 
91
111
  Parameters
92
112
  ----------
93
- config : typing.Dict[str, typing.Any]
113
+ config : typing.Dict[str, typing.Optional[typing.Any]]
94
114
  Configuration for the SQL agent including database connection details and query parameters
95
115
 
96
- messages : typing.Sequence[typing.Dict[str, typing.Any]]
116
+ messages : typing.Sequence[typing.Dict[str, typing.Optional[typing.Any]]]
97
117
  The messages to send to the SQL agent
98
118
 
99
119
  request_options : typing.Optional[RequestOptions]
@@ -106,28 +126,56 @@ class AsyncSqlClient:
106
126
 
107
127
  Examples
108
128
  --------
109
- from athena.client import AsyncAthena
129
+ import asyncio
130
+
131
+ from athena import AsyncAthena
110
132
 
111
133
  client = AsyncAthena(
112
134
  api_key="YOUR_API_KEY",
113
135
  )
114
- await client.agents.sql.invoke(
115
- config={"key": "value"},
116
- messages=[{"key": "value"}],
117
- )
136
+
137
+
138
+ async def main() -> None:
139
+ await client.agents.sql.invoke(
140
+ config={"key": "value"},
141
+ messages=[{"key": "value"}],
142
+ )
143
+
144
+
145
+ asyncio.run(main())
118
146
  """
119
147
  _response = await self._client_wrapper.httpx_client.request(
120
148
  "api/v0/agents/sql/invoke",
121
149
  method="POST",
122
- json={"config": config, "messages": messages},
150
+ json={
151
+ "config": config,
152
+ "messages": messages,
153
+ },
154
+ headers={
155
+ "content-type": "application/json",
156
+ },
123
157
  request_options=request_options,
124
158
  omit=OMIT,
125
159
  )
126
- if 200 <= _response.status_code < 300:
127
- return pydantic_v1.parse_obj_as(SqlAgentResponse, _response.json()) # type: ignore
128
- if _response.status_code == 422:
129
- raise UnprocessableEntityError(pydantic_v1.parse_obj_as(typing.Any, _response.json())) # type: ignore
130
160
  try:
161
+ if 200 <= _response.status_code < 300:
162
+ return typing.cast(
163
+ SqlAgentResponse,
164
+ parse_obj_as(
165
+ type_=SqlAgentResponse, # type: ignore
166
+ object_=_response.json(),
167
+ ),
168
+ )
169
+ if _response.status_code == 422:
170
+ raise UnprocessableEntityError(
171
+ typing.cast(
172
+ typing.Optional[typing.Any],
173
+ parse_obj_as(
174
+ type_=typing.Optional[typing.Any], # type: ignore
175
+ object_=_response.json(),
176
+ ),
177
+ )
178
+ )
131
179
  _response_json = _response.json()
132
180
  except JSONDecodeError:
133
181
  raise ApiError(status_code=_response.status_code, body=_response.text)
athena/base_client.py CHANGED
@@ -1,14 +1,16 @@
1
1
  # This file was auto-generated by Fern from our API Definition.
2
2
 
3
3
  import typing
4
-
5
- import httpx
6
-
7
- from .agents.client import AgentsClient, AsyncAgentsClient
8
- from .core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
9
4
  from .environment import AthenaEnvironment
10
- from .query.client import AsyncQueryClient, QueryClient
11
- from .tools.client import AsyncToolsClient, ToolsClient
5
+ import httpx
6
+ from .core.client_wrapper import SyncClientWrapper
7
+ from .agents.client import AgentsClient
8
+ from .query.client import QueryClient
9
+ from .tools.client import ToolsClient
10
+ from .core.client_wrapper import AsyncClientWrapper
11
+ from .agents.client import AsyncAgentsClient
12
+ from .query.client import AsyncQueryClient
13
+ from .tools.client import AsyncToolsClient
12
14
 
13
15
 
14
16
  class BaseAthena:
@@ -41,7 +43,7 @@ class BaseAthena:
41
43
 
42
44
  Examples
43
45
  --------
44
- from athena.client import Athena
46
+ from athena import Athena
45
47
 
46
48
  client = Athena(
47
49
  api_key="YOUR_API_KEY",
@@ -56,7 +58,7 @@ class BaseAthena:
56
58
  api_key: str,
57
59
  timeout: typing.Optional[float] = None,
58
60
  follow_redirects: typing.Optional[bool] = True,
59
- httpx_client: typing.Optional[httpx.Client] = None
61
+ httpx_client: typing.Optional[httpx.Client] = None,
60
62
  ):
61
63
  _defaulted_timeout = timeout if timeout is not None else 60 if httpx_client is None else None
62
64
  self._client_wrapper = SyncClientWrapper(
@@ -104,7 +106,7 @@ class AsyncBaseAthena:
104
106
 
105
107
  Examples
106
108
  --------
107
- from athena.client import AsyncAthena
109
+ from athena import AsyncAthena
108
110
 
109
111
  client = AsyncAthena(
110
112
  api_key="YOUR_API_KEY",
@@ -119,7 +121,7 @@ class AsyncBaseAthena:
119
121
  api_key: str,
120
122
  timeout: typing.Optional[float] = None,
121
123
  follow_redirects: typing.Optional[bool] = True,
122
- httpx_client: typing.Optional[httpx.AsyncClient] = None
124
+ httpx_client: typing.Optional[httpx.AsyncClient] = None,
123
125
  ):
124
126
  _defaulted_timeout = timeout if timeout is not None else 60 if httpx_client is None else None
125
127
  self._client_wrapper = AsyncClientWrapper(
athena/client.py CHANGED
@@ -7,8 +7,9 @@ import typing
7
7
  import warnings
8
8
 
9
9
  import httpx
10
- from typing import cast, List, Tuple, Union
10
+ from typing import cast, List, Tuple, Union, Optional
11
11
  from typing_extensions import TypeVar, ParamSpec
12
+ from langserve import RemoteRunnable
12
13
 
13
14
  from . import core
14
15
  from .base_client import BaseAthena, AsyncBaseAthena
@@ -24,23 +25,23 @@ if typing.TYPE_CHECKING:
24
25
 
25
26
  P = ParamSpec("P")
26
27
  T = TypeVar("T")
27
- U = TypeVar('U')
28
+ U = TypeVar("U")
28
29
 
29
30
 
30
31
  def _inherit_signature_and_doc(
31
- f: typing.Callable[P, T],
32
- replace_in_doc: typing.Dict[str, str]
32
+ f: typing.Callable[P, T], replace_in_doc: typing.Dict[str, str]
33
33
  ) -> typing.Callable[..., typing.Callable[P, U]]:
34
34
  def decorator(decorated):
35
35
  for old, new in replace_in_doc.items():
36
36
  assert old in f.__doc__
37
37
  decorated.__doc__ = f.__doc__.replace(old, new)
38
38
  return decorated
39
+
39
40
  return decorator
40
41
 
41
42
 
42
43
  class SpecialEnvironments(enum.Enum):
43
- AUTODETECT_ENVIRONMENT = 'AUTO'
44
+ AUTODETECT_ENVIRONMENT = "AUTO"
44
45
 
45
46
 
46
47
  @dataclasses.dataclass
@@ -53,7 +54,7 @@ class AthenaAsset:
53
54
  if self.media_type == "application/sql":
54
55
  # it is safe to import IPython in `_repr_mimebundle_`
55
56
  # as this is only intended to be invoked by IPython.
56
- from IPython import display # type: ignore[import]
57
+ from IPython import display # type: ignore[import]
57
58
 
58
59
  code = display.Code(
59
60
  data=self.data.decode(),
@@ -84,17 +85,19 @@ class WrappedToolsClient(ToolsClient):
84
85
  file_io = client.tools.get_file(asset_id="asset_id")
85
86
  pl.read_csv(file_io)
86
87
  """
87
- file_bytes = b''.join(self.raw_data(asset_id=asset_id))
88
+ file_bytes = b"".join(self.raw_data(asset_id=asset_id))
88
89
  bytes_io = io.BytesIO(file_bytes)
89
90
  return bytes_io
90
91
 
91
- @_inherit_signature_and_doc(ToolsClient.data_frame, {'DataFrameRequestOut': 'pd.DataFrame'})
92
- def data_frame(self, *, asset_id: str, **kwargs) -> 'pd.DataFrame':
92
+ @_inherit_signature_and_doc(
93
+ ToolsClient.data_frame, {"DataFrameRequestOut": "pd.DataFrame"}
94
+ )
95
+ def data_frame(self, *, asset_id: str, **kwargs) -> "pd.DataFrame":
93
96
  _check_pandas_installed()
94
97
  model = super().data_frame(asset_id=asset_id, **kwargs)
95
98
  return _read_json_frame(model)
96
99
 
97
- def read_data_frame(self, asset_id: str, *args, **kwargs) -> 'pd.DataFrame':
100
+ def read_data_frame(self, asset_id: str, *args, **kwargs) -> "pd.DataFrame":
98
101
  """
99
102
  Parameters
100
103
  ----------
@@ -119,8 +122,7 @@ class WrappedToolsClient(ToolsClient):
119
122
  file_bytes, media_type = self._get_file_and_media_type(asset_id=asset_id)
120
123
  return _to_pandas_df(file_bytes, *args, media_type=media_type, **kwargs)
121
124
 
122
-
123
- def save_asset( # type: ignore[override]
125
+ def save_asset( # type: ignore[override]
124
126
  self,
125
127
  asset_object: Union["pd.DataFrame", "pd.Series", core.File],
126
128
  *,
@@ -157,10 +159,10 @@ class WrappedToolsClient(ToolsClient):
157
159
  client = Athena(api_key="YOUR_API_KEY")
158
160
  client.tools.save_asset(df)
159
161
  """
160
- asset_object = _convert_asset_object(asset_object=asset_object, name=name, **kwargs)
161
- return super().save_asset(
162
- file=asset_object, parent_folder_id=parent_folder_id
162
+ asset_object = _convert_asset_object(
163
+ asset_object=asset_object, name=name, **kwargs
163
164
  )
165
+ return super().save_asset(file=asset_object, parent_folder_id=parent_folder_id)
164
166
 
165
167
  def _get_file_and_media_type(self, asset_id: str) -> Tuple[io.BytesIO, str]:
166
168
  """
@@ -236,22 +238,26 @@ class WrappedToolsClient(ToolsClient):
236
238
  return _to_pandas_df(file_bytes, media_type=media_type)
237
239
 
238
240
  raise NotImplementedError("Assets of `{media_type}` type are not yet supported")
239
-
241
+
240
242
 
241
243
  class WrappedQueryClient(QueryClient):
242
244
 
243
- @_inherit_signature_and_doc(QueryClient.execute, {'DataFrameRequestOut': 'pd.DataFrame'})
244
- def execute(self, *, sql_command: str, database_asset_ids: Union[str, List[str]], **kwargs) -> 'pd.DataFrame':
245
+ @_inherit_signature_and_doc(
246
+ QueryClient.execute, {"DataFrameRequestOut": "pd.DataFrame"}
247
+ )
248
+ def execute(
249
+ self, *, sql_command: str, database_asset_ids: Union[str, List[str]], **kwargs
250
+ ) -> "pd.DataFrame":
245
251
  _check_pandas_installed()
246
252
  model = super().execute(
247
- sql_command=sql_command,
248
- database_asset_ids=database_asset_ids,
249
- **kwargs
253
+ sql_command=sql_command, database_asset_ids=database_asset_ids, **kwargs
250
254
  )
251
255
  return _read_json_frame(model)
252
256
 
253
- @_inherit_signature_and_doc(QueryClient.execute_snippet, {'DataFrameRequestOut': 'pd.DataFrame'})
254
- def execute_snippet(self, *, snippet_asset_id: str, **kwargs) -> 'pd.DataFrame':
257
+ @_inherit_signature_and_doc(
258
+ QueryClient.execute_snippet, {"DataFrameRequestOut": "pd.DataFrame"}
259
+ )
260
+ def execute_snippet(self, *, snippet_asset_id: str, **kwargs) -> "pd.DataFrame":
255
261
  _check_pandas_installed()
256
262
  model = super().execute_snippet(snippet_asset_id=snippet_asset_id, **kwargs)
257
263
  return _read_json_frame(model)
@@ -261,12 +267,11 @@ def _add_docs_for_async_variant(obj):
261
267
  def decorator(decorated):
262
268
  doc = obj.__doc__
263
269
  name = obj.__name__
264
- decorated.__doc__ = (
265
- doc
266
- .replace('client = Athena', 'client = AsyncAthena')
267
- .replace(f'client.tools.{name}', f'await client.tools.{name}')
268
- )
270
+ decorated.__doc__ = doc.replace(
271
+ "client = Athena", "client = AsyncAthena"
272
+ ).replace(f"client.tools.{name}", f"await client.tools.{name}")
269
273
  return decorated
274
+
270
275
  return decorator
271
276
 
272
277
 
@@ -274,18 +279,20 @@ class WrappedAsyncToolsClient(AsyncToolsClient):
274
279
 
275
280
  @_add_docs_for_async_variant(WrappedToolsClient.get_file)
276
281
  async def get_file(self, asset_id: str) -> io.BytesIO:
277
- file_bytes = b''.join([gen async for gen in self.raw_data(asset_id=asset_id)])
282
+ file_bytes = b"".join([gen async for gen in self.raw_data(asset_id=asset_id)])
278
283
  bytes_io = io.BytesIO(file_bytes)
279
284
  return bytes_io
280
285
 
281
- @_inherit_signature_and_doc(AsyncToolsClient.data_frame, {'DataFrameRequestOut': 'pd.DataFrame'})
282
- async def data_frame(self, *, asset_id: str, **kwargs) -> 'pd.DataFrame':
286
+ @_inherit_signature_and_doc(
287
+ AsyncToolsClient.data_frame, {"DataFrameRequestOut": "pd.DataFrame"}
288
+ )
289
+ async def data_frame(self, *, asset_id: str, **kwargs) -> "pd.DataFrame":
283
290
  _check_pandas_installed()
284
291
  model = await super().data_frame(asset_id=asset_id, **kwargs)
285
292
  return _read_json_frame(model)
286
293
 
287
294
  @_add_docs_for_async_variant(WrappedToolsClient.read_data_frame)
288
- async def read_data_frame(self, asset_id: str, *args, **kwargs) -> 'pd.DataFrame':
295
+ async def read_data_frame(self, asset_id: str, *args, **kwargs) -> "pd.DataFrame":
289
296
  _check_pandas_installed()
290
297
  file_bytes = await self.get_file(asset_id)
291
298
  return _to_pandas_df(file_bytes, *args, **kwargs)
@@ -299,7 +306,9 @@ class WrappedAsyncToolsClient(AsyncToolsClient):
299
306
  name: Union[str, None] = None,
300
307
  **kwargs,
301
308
  ) -> SaveAssetRequestOut:
302
- asset_object = _convert_asset_object(asset_object=asset_object, name=name, **kwargs)
309
+ asset_object = _convert_asset_object(
310
+ asset_object=asset_object, name=name, **kwargs
311
+ )
303
312
  return await super().save_asset(
304
313
  file=asset_object, parent_folder_id=parent_folder_id
305
314
  )
@@ -307,23 +316,109 @@ class WrappedAsyncToolsClient(AsyncToolsClient):
307
316
 
308
317
  class WrappedAsyncQueryClient(AsyncQueryClient):
309
318
 
310
- @_inherit_signature_and_doc(AsyncQueryClient.execute, {'DataFrameRequestOut': 'pd.DataFrame'})
311
- async def execute(self, *, sql_command: str, database_asset_ids: Union[str, List[str]], **kwargs) -> 'pd.DataFrame':
319
+ @_inherit_signature_and_doc(
320
+ AsyncQueryClient.execute, {"DataFrameRequestOut": "pd.DataFrame"}
321
+ )
322
+ async def execute(
323
+ self, *, sql_command: str, database_asset_ids: Union[str, List[str]], **kwargs
324
+ ) -> "pd.DataFrame":
312
325
  _check_pandas_installed()
313
326
  model = await super().execute(
314
- sql_command=sql_command,
315
- database_asset_ids=database_asset_ids,
316
- **kwargs
327
+ sql_command=sql_command, database_asset_ids=database_asset_ids, **kwargs
317
328
  )
318
329
  return _read_json_frame(model)
319
330
 
320
- @_inherit_signature_and_doc(AsyncQueryClient.execute_snippet, {'DataFrameRequestOut': 'pd.DataFrame'})
321
- async def execute_snippet(self, *, snippet_asset_id: str, **kwargs) -> 'pd.DataFrame':
331
+ @_inherit_signature_and_doc(
332
+ AsyncQueryClient.execute_snippet, {"DataFrameRequestOut": "pd.DataFrame"}
333
+ )
334
+ async def execute_snippet(
335
+ self, *, snippet_asset_id: str, **kwargs
336
+ ) -> "pd.DataFrame":
322
337
  _check_pandas_installed()
323
- model = await super().execute_snippet(snippet_asset_id=snippet_asset_id, **kwargs)
338
+ model = await super().execute_snippet(
339
+ snippet_asset_id=snippet_asset_id, **kwargs
340
+ )
324
341
  return _read_json_frame(model)
325
342
 
326
343
 
344
+ class AthenaModel(RemoteRunnable):
345
+ """Use Athena's models directly with a Langchain-compatible client.
346
+
347
+ The Langchain Runnable interface is supported:
348
+ - `invoke`: Invoke the model with a string or a Langchain message.
349
+ - `batch`: Batch invoke the model on multiple inputs.
350
+ - `astream_events`: Streaming for real-time applications.
351
+
352
+ See Langchain documentation for more details.
353
+
354
+ Examples
355
+ --------
356
+ from src.athena.client import Athena
357
+ import asyncio
358
+ from langchain_core.messages import HumanMessage
359
+
360
+ client = Athena()
361
+ llm = client.llm
362
+
363
+ # sync invoke -- use strings or langchain messages
364
+ result = llm.invoke("Hello")
365
+ print(result.content)
366
+
367
+ result = llm.invoke("Hello")
368
+ print(result.content)
369
+
370
+ # choose the model explicitly
371
+ claude = llm.with_config(configurable={"model": "claude_3_7_sonnet"})
372
+ print(claude.invoke("Who are you?").content)
373
+
374
+ # batch (for multiple parallel requests)
375
+ results = llm.batch(["Hello", "World"] * 5)
376
+ print([r.content for r in results])
377
+
378
+ # handle stream events
379
+ async def stream_response():
380
+ separator = "---------------------------------------"
381
+ print(separator)
382
+ print("Starting stream")
383
+ print(separator)
384
+ async for event in llm.astream_events("Hello. Please respond in 5 sentences."):
385
+ data = event["data"]
386
+ if "chunk" in data:
387
+ print(data["chunk"].content)
388
+ elif "output" in data:
389
+ print(separator)
390
+ print("Final response")
391
+ print(separator)
392
+ print(data["output"].content)
393
+ else:
394
+ print("(other event) ", event.keys())
395
+
396
+
397
+ asyncio.run(stream_response())
398
+ from athena.client import Athena
399
+ client = Athena()
400
+ llm = client.llm
401
+
402
+ # sync invoke -- use strings or langchain messages
403
+ result = llm.invoke("Hello")
404
+ print(result.content)
405
+
406
+ # batch
407
+ results = llm.batch(["Hello", "World"])
408
+ print(r.content for r in results)
409
+ """
410
+
411
+ def __init__(self, base_url: str, api_key: str, timeout: Optional[float] = None):
412
+ self.base_url = base_url
413
+ self.api_key = api_key
414
+ self.timeout = timeout
415
+ super().__init__(
416
+ base_url + "/api/models/general",
417
+ headers={"X-API-KEY": api_key},
418
+ timeout=timeout,
419
+ )
420
+
421
+
327
422
  class Athena(BaseAthena):
328
423
  """
329
424
  Use this class to access the different functions within the SDK. You can instantiate any number of clients with different configuration that will propogate to these functions.
@@ -360,37 +455,40 @@ class Athena(BaseAthena):
360
455
  self,
361
456
  *,
362
457
  base_url: typing.Optional[str] = None,
363
- environment: Union[AthenaEnvironment, SpecialEnvironments] = SpecialEnvironments.AUTODETECT_ENVIRONMENT, # type: ignore[arg-type]
458
+ environment: Union[AthenaEnvironment, SpecialEnvironments] = SpecialEnvironments.AUTODETECT_ENVIRONMENT, # type: ignore[arg-type]
364
459
  api_key: typing.Optional[str] = None,
365
460
  timeout: typing.Optional[float] = 60,
366
- httpx_client: typing.Optional[httpx.Client] = None
461
+ httpx_client: typing.Optional[httpx.Client] = None,
367
462
  ):
368
463
  if api_key is None:
369
464
  try:
370
- api_key = os.environ['ATHENA_API_KEY']
465
+ api_key = os.environ["ATHENA_API_KEY"]
371
466
  except KeyError:
372
467
  raise TypeError(
373
468
  "Athena() missing 1 required keyword-only argument: 'api_key'"
374
- ' (ATHENA_API_KEY environment variable not found)'
469
+ " (ATHENA_API_KEY environment variable not found)"
375
470
  )
376
471
  if environment == SpecialEnvironments.AUTODETECT_ENVIRONMENT:
377
- if 'ATHENA_API_URL' in os.environ:
472
+ if "ATHENA_API_URL" in os.environ:
378
473
 
379
474
  class AutodetectedEnvironments(enum.Enum):
380
- CURRENT = os.environ['ATHENA_API_URL']
475
+ CURRENT = os.environ["ATHENA_API_URL"]
381
476
 
382
477
  environment = cast(AthenaEnvironment, AutodetectedEnvironments.CURRENT)
383
478
  else:
384
479
  environment = AthenaEnvironment.PRODUCTION
385
480
  super().__init__(
386
481
  base_url=base_url,
387
- environment=environment, # type: ignore[arg-type]
482
+ environment=environment, # type: ignore[arg-type]
388
483
  api_key=api_key,
389
484
  timeout=timeout,
390
485
  httpx_client=httpx_client,
391
486
  )
392
487
  self.tools = WrappedToolsClient(client_wrapper=self._client_wrapper)
393
488
  self.query = WrappedQueryClient(client_wrapper=self._client_wrapper)
489
+ self.llm = AthenaModel(
490
+ base_url=base_url or environment.value, api_key=api_key, timeout=timeout
491
+ )
394
492
 
395
493
 
396
494
  class AsyncAthena(AsyncBaseAthena):
@@ -430,31 +528,31 @@ class AsyncAthena(AsyncBaseAthena):
430
528
  self,
431
529
  *,
432
530
  base_url: typing.Optional[str] = None,
433
- environment: Union[AthenaEnvironment, SpecialEnvironments] = SpecialEnvironments.AUTODETECT_ENVIRONMENT, # type: ignore[arg-type]
531
+ environment: Union[AthenaEnvironment, SpecialEnvironments] = SpecialEnvironments.AUTODETECT_ENVIRONMENT, # type: ignore[arg-type]
434
532
  api_key: typing.Optional[str] = None,
435
533
  timeout: typing.Optional[float] = 60,
436
- httpx_client: typing.Optional[httpx.AsyncClient] = None
534
+ httpx_client: typing.Optional[httpx.AsyncClient] = None,
437
535
  ):
438
536
  if api_key is None:
439
537
  try:
440
- api_key = os.environ['ATHENA_API_KEY']
538
+ api_key = os.environ["ATHENA_API_KEY"]
441
539
  except KeyError:
442
540
  raise TypeError(
443
541
  "AsyncAthena() missing 1 required keyword-only argument: 'api_key'"
444
- ' (ATHENA_API_KEY environment variable not found)'
542
+ " (ATHENA_API_KEY environment variable not found)"
445
543
  )
446
544
  if environment == SpecialEnvironments.AUTODETECT_ENVIRONMENT:
447
- if 'ATHENA_API_URL' in os.environ:
545
+ if "ATHENA_API_URL" in os.environ:
448
546
 
449
547
  class AutodetectedEnvironments(enum.Enum):
450
- CURRENT = os.environ['ATHENA_API_URL']
548
+ CURRENT = os.environ["ATHENA_API_URL"]
451
549
 
452
550
  environment = cast(AthenaEnvironment, AutodetectedEnvironments.CURRENT)
453
551
  else:
454
552
  environment = AthenaEnvironment.PRODUCTION
455
553
  super().__init__(
456
554
  base_url=base_url,
457
- environment=environment, # type: ignore[arg-type]
555
+ environment=environment, # type: ignore[arg-type]
458
556
  api_key=api_key,
459
557
  timeout=timeout,
460
558
  httpx_client=httpx_client,
@@ -463,19 +561,20 @@ class AsyncAthena(AsyncBaseAthena):
463
561
  self.query = WrappedAsyncQueryClient(client_wrapper=self._client_wrapper)
464
562
 
465
563
 
466
- def _read_json_frame(model: DataFrameRequestOut) -> 'pd.DataFrame':
564
+ def _read_json_frame(model: DataFrameRequestOut) -> "pd.DataFrame":
467
565
  import pandas as pd
468
566
 
469
567
  string_io = io.StringIO(model.json())
470
568
 
471
569
  with warnings.catch_warnings():
472
570
  # Filter warnings due to https://github.com/pandas-dev/pandas/issues/59511
473
- warnings.simplefilter(action='ignore', category=FutureWarning)
474
- return pd.read_json(string_io, orient='split')
571
+ warnings.simplefilter(action="ignore", category=FutureWarning)
572
+ return pd.read_json(string_io, orient="split")
475
573
 
476
574
 
477
575
  def _check_pandas_installed():
478
576
  import pandas
577
+
479
578
  assert pandas
480
579
 
481
580
 
@@ -524,11 +623,12 @@ def _to_pandas_df(
524
623
  def _convert_asset_object(
525
624
  asset_object: Union["pd.DataFrame", "pd.Series", core.File],
526
625
  name: Union[str, None] = None,
527
- **kwargs
626
+ **kwargs,
528
627
  ) -> core.File:
529
628
  import pandas as pd
629
+
530
630
  try:
531
- from IPython.core.formatters import format_display_data # type: ignore[import]
631
+ from IPython.core.formatters import format_display_data # type: ignore[import]
532
632
  except ImportError:
533
633
  format_display_data = None
534
634
 
@@ -559,4 +659,4 @@ def _convert_asset_object(
559
659
  image_bytes,
560
660
  media_type,
561
661
  )
562
- return asset_object # type: ignore[return-value]
662
+ return asset_object # type: ignore[return-value]