athena-intelligence 0.1.45__py3-none-any.whl → 0.1.49__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- athena/__init__.py +12 -3
- athena/base_client.py +27 -6
- athena/chain/__init__.py +0 -3
- athena/chain/client.py +42 -44
- athena/core/__init__.py +2 -0
- athena/core/client_wrapper.py +14 -6
- athena/core/http_client.py +8 -3
- athena/core/jsonable_encoder.py +7 -11
- athena/core/pydantic_utilities.py +12 -0
- athena/dataset/client.py +15 -15
- athena/message/client.py +33 -25
- athena/query/client.py +15 -15
- athena/report/client.py +15 -15
- athena/search/client.py +15 -15
- athena/snippet/client.py +15 -15
- athena/tasks/__init__.py +2 -0
- athena/tasks/client.py +191 -0
- athena/tools/client.py +178 -25
- athena/types/__init__.py +8 -0
- athena/types/dataset.py +3 -6
- athena/types/document.py +31 -0
- athena/types/excecute_tool_first_workflow_out.py +3 -6
- athena/types/firecrawl_scrape_url_data_reponse_dto.py +3 -6
- athena/types/firecrawl_scrape_url_metadata.py +5 -7
- athena/types/get_datasets_response.py +3 -6
- athena/types/get_snippets_response.py +3 -6
- athena/types/http_validation_error.py +3 -6
- athena/types/langchain_documents_request_out.py +26 -0
- athena/types/llm_model.py +93 -0
- athena/types/message_out.py +3 -6
- athena/types/message_out_dto.py +3 -6
- athena/types/model.py +0 -4
- athena/types/plan_execute_out.py +32 -0
- athena/types/report.py +3 -6
- athena/types/snippet.py +3 -6
- athena/types/sql_results.py +3 -6
- athena/types/structured_parse_result.py +3 -6
- athena/types/url_result.py +3 -6
- athena/types/validation_error.py +3 -6
- athena/version.py +4 -0
- {athena_intelligence-0.1.45.dist-info → athena_intelligence-0.1.49.dist-info}/METADATA +1 -1
- athena_intelligence-0.1.49.dist-info/RECORD +65 -0
- athena/chain/types/__init__.py +0 -5
- athena/chain/types/structured_parse_in_parsing_model.py +0 -53
- athena_intelligence-0.1.45.dist-info/RECORD +0 -59
- {athena_intelligence-0.1.45.dist-info → athena_intelligence-0.1.49.dist-info}/WHEEL +0 -0
athena/__init__.py
CHANGED
@@ -2,15 +2,19 @@
|
|
2
2
|
|
3
3
|
from .types import (
|
4
4
|
Dataset,
|
5
|
+
Document,
|
5
6
|
ExcecuteToolFirstWorkflowOut,
|
6
7
|
FirecrawlScrapeUrlDataReponseDto,
|
7
8
|
FirecrawlScrapeUrlMetadata,
|
8
9
|
GetDatasetsResponse,
|
9
10
|
GetSnippetsResponse,
|
10
11
|
HttpValidationError,
|
12
|
+
LangchainDocumentsRequestOut,
|
13
|
+
LlmModel,
|
11
14
|
MessageOut,
|
12
15
|
MessageOutDto,
|
13
16
|
Model,
|
17
|
+
PlanExecuteOut,
|
14
18
|
Report,
|
15
19
|
Snippet,
|
16
20
|
SqlResults,
|
@@ -23,27 +27,30 @@ from .types import (
|
|
23
27
|
ValidationErrorLocItem,
|
24
28
|
)
|
25
29
|
from .errors import UnprocessableEntityError
|
26
|
-
from . import chain, dataset, message, query, report, search, snippet, tools
|
27
|
-
from .chain import StructuredParseInParsingModel
|
30
|
+
from . import chain, dataset, message, query, report, search, snippet, tasks, tools
|
28
31
|
from .environment import AthenaEnvironment
|
32
|
+
from .version import __version__
|
29
33
|
|
30
34
|
__all__ = [
|
31
35
|
"AthenaEnvironment",
|
32
36
|
"Dataset",
|
37
|
+
"Document",
|
33
38
|
"ExcecuteToolFirstWorkflowOut",
|
34
39
|
"FirecrawlScrapeUrlDataReponseDto",
|
35
40
|
"FirecrawlScrapeUrlMetadata",
|
36
41
|
"GetDatasetsResponse",
|
37
42
|
"GetSnippetsResponse",
|
38
43
|
"HttpValidationError",
|
44
|
+
"LangchainDocumentsRequestOut",
|
45
|
+
"LlmModel",
|
39
46
|
"MessageOut",
|
40
47
|
"MessageOutDto",
|
41
48
|
"Model",
|
49
|
+
"PlanExecuteOut",
|
42
50
|
"Report",
|
43
51
|
"Snippet",
|
44
52
|
"SqlResults",
|
45
53
|
"StatusEnum",
|
46
|
-
"StructuredParseInParsingModel",
|
47
54
|
"StructuredParseResult",
|
48
55
|
"ToolModels",
|
49
56
|
"Tools",
|
@@ -51,6 +58,7 @@ __all__ = [
|
|
51
58
|
"UrlResult",
|
52
59
|
"ValidationError",
|
53
60
|
"ValidationErrorLocItem",
|
61
|
+
"__version__",
|
54
62
|
"chain",
|
55
63
|
"dataset",
|
56
64
|
"message",
|
@@ -58,5 +66,6 @@ __all__ = [
|
|
58
66
|
"report",
|
59
67
|
"search",
|
60
68
|
"snippet",
|
69
|
+
"tasks",
|
61
70
|
"tools",
|
62
71
|
]
|
athena/base_client.py
CHANGED
@@ -13,6 +13,7 @@ from .query.client import AsyncQueryClient, QueryClient
|
|
13
13
|
from .report.client import AsyncReportClient, ReportClient
|
14
14
|
from .search.client import AsyncSearchClient, SearchClient
|
15
15
|
from .snippet.client import AsyncSnippetClient, SnippetClient
|
16
|
+
from .tasks.client import AsyncTasksClient, TasksClient
|
16
17
|
from .tools.client import AsyncToolsClient, ToolsClient
|
17
18
|
|
18
19
|
|
@@ -29,7 +30,9 @@ class BaseAthena:
|
|
29
30
|
|
30
31
|
- api_key: str.
|
31
32
|
|
32
|
-
- timeout: typing.Optional[float]. The timeout to be used, in seconds, for requests by default the timeout is 60 seconds.
|
33
|
+
- timeout: typing.Optional[float]. The timeout to be used, in seconds, for requests by default the timeout is 60 seconds, unless a custom httpx client is used, in which case a default is not set.
|
34
|
+
|
35
|
+
- follow_redirects: typing.Optional[bool]. Whether the default httpx client follows redirects or not, this is irrelevant if a custom httpx client is passed in.
|
33
36
|
|
34
37
|
- httpx_client: typing.Optional[httpx.Client]. The httpx client to use for making requests, a preconfigured client is used by default, however this is useful should you want to pass in any custom httpx configuration.
|
35
38
|
---
|
@@ -46,13 +49,20 @@ class BaseAthena:
|
|
46
49
|
base_url: typing.Optional[str] = None,
|
47
50
|
environment: AthenaEnvironment = AthenaEnvironment.DEFAULT,
|
48
51
|
api_key: str,
|
49
|
-
timeout: typing.Optional[float] =
|
52
|
+
timeout: typing.Optional[float] = None,
|
53
|
+
follow_redirects: typing.Optional[bool] = True,
|
50
54
|
httpx_client: typing.Optional[httpx.Client] = None
|
51
55
|
):
|
56
|
+
_defaulted_timeout = timeout if timeout is not None else 60 if httpx_client is None else None
|
52
57
|
self._client_wrapper = SyncClientWrapper(
|
53
58
|
base_url=_get_base_url(base_url=base_url, environment=environment),
|
54
59
|
api_key=api_key,
|
55
|
-
httpx_client=
|
60
|
+
httpx_client=httpx_client
|
61
|
+
if httpx_client is not None
|
62
|
+
else httpx.Client(timeout=_defaulted_timeout, follow_redirects=follow_redirects)
|
63
|
+
if follow_redirects is not None
|
64
|
+
else httpx.Client(timeout=_defaulted_timeout),
|
65
|
+
timeout=_defaulted_timeout,
|
56
66
|
)
|
57
67
|
self.message = MessageClient(client_wrapper=self._client_wrapper)
|
58
68
|
self.dataset = DatasetClient(client_wrapper=self._client_wrapper)
|
@@ -62,6 +72,7 @@ class BaseAthena:
|
|
62
72
|
self.search = SearchClient(client_wrapper=self._client_wrapper)
|
63
73
|
self.chain = ChainClient(client_wrapper=self._client_wrapper)
|
64
74
|
self.tools = ToolsClient(client_wrapper=self._client_wrapper)
|
75
|
+
self.tasks = TasksClient(client_wrapper=self._client_wrapper)
|
65
76
|
|
66
77
|
|
67
78
|
class AsyncBaseAthena:
|
@@ -77,7 +88,9 @@ class AsyncBaseAthena:
|
|
77
88
|
|
78
89
|
- api_key: str.
|
79
90
|
|
80
|
-
- timeout: typing.Optional[float]. The timeout to be used, in seconds, for requests by default the timeout is 60 seconds.
|
91
|
+
- timeout: typing.Optional[float]. The timeout to be used, in seconds, for requests by default the timeout is 60 seconds, unless a custom httpx client is used, in which case a default is not set.
|
92
|
+
|
93
|
+
- follow_redirects: typing.Optional[bool]. Whether the default httpx client follows redirects or not, this is irrelevant if a custom httpx client is passed in.
|
81
94
|
|
82
95
|
- httpx_client: typing.Optional[httpx.AsyncClient]. The httpx client to use for making requests, a preconfigured client is used by default, however this is useful should you want to pass in any custom httpx configuration.
|
83
96
|
---
|
@@ -94,13 +107,20 @@ class AsyncBaseAthena:
|
|
94
107
|
base_url: typing.Optional[str] = None,
|
95
108
|
environment: AthenaEnvironment = AthenaEnvironment.DEFAULT,
|
96
109
|
api_key: str,
|
97
|
-
timeout: typing.Optional[float] =
|
110
|
+
timeout: typing.Optional[float] = None,
|
111
|
+
follow_redirects: typing.Optional[bool] = True,
|
98
112
|
httpx_client: typing.Optional[httpx.AsyncClient] = None
|
99
113
|
):
|
114
|
+
_defaulted_timeout = timeout if timeout is not None else 60 if httpx_client is None else None
|
100
115
|
self._client_wrapper = AsyncClientWrapper(
|
101
116
|
base_url=_get_base_url(base_url=base_url, environment=environment),
|
102
117
|
api_key=api_key,
|
103
|
-
httpx_client=
|
118
|
+
httpx_client=httpx_client
|
119
|
+
if httpx_client is not None
|
120
|
+
else httpx.AsyncClient(timeout=_defaulted_timeout, follow_redirects=follow_redirects)
|
121
|
+
if follow_redirects is not None
|
122
|
+
else httpx.AsyncClient(timeout=_defaulted_timeout),
|
123
|
+
timeout=_defaulted_timeout,
|
104
124
|
)
|
105
125
|
self.message = AsyncMessageClient(client_wrapper=self._client_wrapper)
|
106
126
|
self.dataset = AsyncDatasetClient(client_wrapper=self._client_wrapper)
|
@@ -110,6 +130,7 @@ class AsyncBaseAthena:
|
|
110
130
|
self.search = AsyncSearchClient(client_wrapper=self._client_wrapper)
|
111
131
|
self.chain = AsyncChainClient(client_wrapper=self._client_wrapper)
|
112
132
|
self.tools = AsyncToolsClient(client_wrapper=self._client_wrapper)
|
133
|
+
self.tasks = AsyncTasksClient(client_wrapper=self._client_wrapper)
|
113
134
|
|
114
135
|
|
115
136
|
def _get_base_url(*, base_url: typing.Optional[str] = None, environment: AthenaEnvironment) -> str:
|
athena/chain/__init__.py
CHANGED
athena/chain/client.py
CHANGED
@@ -7,17 +7,13 @@ from json.decoder import JSONDecodeError
|
|
7
7
|
from ..core.api_error import ApiError
|
8
8
|
from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
|
9
9
|
from ..core.jsonable_encoder import jsonable_encoder
|
10
|
+
from ..core.pydantic_utilities import pydantic_v1
|
10
11
|
from ..core.remove_none_from_dict import remove_none_from_dict
|
11
12
|
from ..core.request_options import RequestOptions
|
12
13
|
from ..errors.unprocessable_entity_error import UnprocessableEntityError
|
13
14
|
from ..types.http_validation_error import HttpValidationError
|
15
|
+
from ..types.llm_model import LlmModel
|
14
16
|
from ..types.structured_parse_result import StructuredParseResult
|
15
|
-
from .types.structured_parse_in_parsing_model import StructuredParseInParsingModel
|
16
|
-
|
17
|
-
try:
|
18
|
-
import pydantic.v1 as pydantic # type: ignore
|
19
|
-
except ImportError:
|
20
|
-
import pydantic # type: ignore
|
21
17
|
|
22
18
|
# this is used as the default value for optional parameters
|
23
19
|
OMIT = typing.cast(typing.Any, ...)
|
@@ -32,7 +28,7 @@ class ChainClient:
|
|
32
28
|
*,
|
33
29
|
text_input: str,
|
34
30
|
custom_type_dict: typing.Dict[str, typing.Any],
|
35
|
-
|
31
|
+
model: LlmModel,
|
36
32
|
request_options: typing.Optional[RequestOptions] = None,
|
37
33
|
) -> StructuredParseResult:
|
38
34
|
"""
|
@@ -41,33 +37,32 @@ class ChainClient:
|
|
41
37
|
|
42
38
|
- custom_type_dict: typing.Dict[str, typing.Any]. A dictionary of field names and their default values.
|
43
39
|
|
44
|
-
-
|
40
|
+
- model: LlmModel.
|
45
41
|
|
46
42
|
- request_options: typing.Optional[RequestOptions]. Request-specific configuration.
|
47
43
|
---
|
48
|
-
from athena import
|
44
|
+
from athena import LlmModel
|
49
45
|
from athena.client import Athena
|
50
46
|
|
51
|
-
client = Athena(
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
47
|
+
client = Athena(
|
48
|
+
api_key="YOUR_API_KEY",
|
49
|
+
)
|
50
|
+
client.chain.structured_parse(
|
51
|
+
text_input='Athena is an AI-native analytics platform and artificial employee built to accelerate analytics workflows \n by offering enterprise teams co-pilot and auto-pilot modes. Athena learns your workflow as a co-pilot, \n allowing you to hand over controls to her for autonomous execution with confidence." \n \n Give me all of the modes Athena provides.',
|
52
|
+
custom_type_dict={"modes": {}},
|
53
|
+
model=LlmModel.GPT_35_TURBO,
|
54
|
+
)
|
57
55
|
"""
|
58
|
-
_request: typing.Dict[str, typing.Any] = {"text_input": text_input, "custom_type_dict": custom_type_dict}
|
59
|
-
if parsing_model is not OMIT:
|
60
|
-
_request["parsing_model"] = parsing_model
|
61
56
|
_response = self._client_wrapper.httpx_client.request(
|
62
|
-
"POST",
|
63
|
-
urllib.parse.urljoin(f"{self._client_wrapper.get_base_url()}/", "api/v0/structured-parse"),
|
57
|
+
method="POST",
|
58
|
+
url=urllib.parse.urljoin(f"{self._client_wrapper.get_base_url()}/", "api/v0/structured-parse"),
|
64
59
|
params=jsonable_encoder(
|
65
60
|
request_options.get("additional_query_parameters") if request_options is not None else None
|
66
61
|
),
|
67
|
-
json=jsonable_encoder(
|
62
|
+
json=jsonable_encoder({"text_input": text_input, "custom_type_dict": custom_type_dict, "model": model})
|
68
63
|
if request_options is None or request_options.get("additional_body_parameters") is None
|
69
64
|
else {
|
70
|
-
**jsonable_encoder(
|
65
|
+
**jsonable_encoder({"text_input": text_input, "custom_type_dict": custom_type_dict, "model": model}),
|
71
66
|
**(jsonable_encoder(remove_none_from_dict(request_options.get("additional_body_parameters", {})))),
|
72
67
|
},
|
73
68
|
headers=jsonable_encoder(
|
@@ -80,14 +75,16 @@ class ChainClient:
|
|
80
75
|
),
|
81
76
|
timeout=request_options.get("timeout_in_seconds")
|
82
77
|
if request_options is not None and request_options.get("timeout_in_seconds") is not None
|
83
|
-
else
|
78
|
+
else self._client_wrapper.get_timeout(),
|
84
79
|
retries=0,
|
85
80
|
max_retries=request_options.get("max_retries") if request_options is not None else 0, # type: ignore
|
86
81
|
)
|
87
82
|
if 200 <= _response.status_code < 300:
|
88
|
-
return
|
83
|
+
return pydantic_v1.parse_obj_as(StructuredParseResult, _response.json()) # type: ignore
|
89
84
|
if _response.status_code == 422:
|
90
|
-
raise UnprocessableEntityError(
|
85
|
+
raise UnprocessableEntityError(
|
86
|
+
pydantic_v1.parse_obj_as(HttpValidationError, _response.json()) # type: ignore
|
87
|
+
)
|
91
88
|
try:
|
92
89
|
_response_json = _response.json()
|
93
90
|
except JSONDecodeError:
|
@@ -104,7 +101,7 @@ class AsyncChainClient:
|
|
104
101
|
*,
|
105
102
|
text_input: str,
|
106
103
|
custom_type_dict: typing.Dict[str, typing.Any],
|
107
|
-
|
104
|
+
model: LlmModel,
|
108
105
|
request_options: typing.Optional[RequestOptions] = None,
|
109
106
|
) -> StructuredParseResult:
|
110
107
|
"""
|
@@ -113,33 +110,32 @@ class AsyncChainClient:
|
|
113
110
|
|
114
111
|
- custom_type_dict: typing.Dict[str, typing.Any]. A dictionary of field names and their default values.
|
115
112
|
|
116
|
-
-
|
113
|
+
- model: LlmModel.
|
117
114
|
|
118
115
|
- request_options: typing.Optional[RequestOptions]. Request-specific configuration.
|
119
116
|
---
|
120
|
-
from athena import
|
117
|
+
from athena import LlmModel
|
121
118
|
from athena.client import AsyncAthena
|
122
119
|
|
123
|
-
client = AsyncAthena(
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
120
|
+
client = AsyncAthena(
|
121
|
+
api_key="YOUR_API_KEY",
|
122
|
+
)
|
123
|
+
await client.chain.structured_parse(
|
124
|
+
text_input='Athena is an AI-native analytics platform and artificial employee built to accelerate analytics workflows \n by offering enterprise teams co-pilot and auto-pilot modes. Athena learns your workflow as a co-pilot, \n allowing you to hand over controls to her for autonomous execution with confidence." \n \n Give me all of the modes Athena provides.',
|
125
|
+
custom_type_dict={"modes": {}},
|
126
|
+
model=LlmModel.GPT_35_TURBO,
|
127
|
+
)
|
129
128
|
"""
|
130
|
-
_request: typing.Dict[str, typing.Any] = {"text_input": text_input, "custom_type_dict": custom_type_dict}
|
131
|
-
if parsing_model is not OMIT:
|
132
|
-
_request["parsing_model"] = parsing_model
|
133
129
|
_response = await self._client_wrapper.httpx_client.request(
|
134
|
-
"POST",
|
135
|
-
urllib.parse.urljoin(f"{self._client_wrapper.get_base_url()}/", "api/v0/structured-parse"),
|
130
|
+
method="POST",
|
131
|
+
url=urllib.parse.urljoin(f"{self._client_wrapper.get_base_url()}/", "api/v0/structured-parse"),
|
136
132
|
params=jsonable_encoder(
|
137
133
|
request_options.get("additional_query_parameters") if request_options is not None else None
|
138
134
|
),
|
139
|
-
json=jsonable_encoder(
|
135
|
+
json=jsonable_encoder({"text_input": text_input, "custom_type_dict": custom_type_dict, "model": model})
|
140
136
|
if request_options is None or request_options.get("additional_body_parameters") is None
|
141
137
|
else {
|
142
|
-
**jsonable_encoder(
|
138
|
+
**jsonable_encoder({"text_input": text_input, "custom_type_dict": custom_type_dict, "model": model}),
|
143
139
|
**(jsonable_encoder(remove_none_from_dict(request_options.get("additional_body_parameters", {})))),
|
144
140
|
},
|
145
141
|
headers=jsonable_encoder(
|
@@ -152,14 +148,16 @@ class AsyncChainClient:
|
|
152
148
|
),
|
153
149
|
timeout=request_options.get("timeout_in_seconds")
|
154
150
|
if request_options is not None and request_options.get("timeout_in_seconds") is not None
|
155
|
-
else
|
151
|
+
else self._client_wrapper.get_timeout(),
|
156
152
|
retries=0,
|
157
153
|
max_retries=request_options.get("max_retries") if request_options is not None else 0, # type: ignore
|
158
154
|
)
|
159
155
|
if 200 <= _response.status_code < 300:
|
160
|
-
return
|
156
|
+
return pydantic_v1.parse_obj_as(StructuredParseResult, _response.json()) # type: ignore
|
161
157
|
if _response.status_code == 422:
|
162
|
-
raise UnprocessableEntityError(
|
158
|
+
raise UnprocessableEntityError(
|
159
|
+
pydantic_v1.parse_obj_as(HttpValidationError, _response.json()) # type: ignore
|
160
|
+
)
|
163
161
|
try:
|
164
162
|
_response_json = _response.json()
|
165
163
|
except JSONDecodeError:
|
athena/core/__init__.py
CHANGED
@@ -6,6 +6,7 @@ from .datetime_utils import serialize_datetime
|
|
6
6
|
from .file import File, convert_file_dict_to_httpx_tuples
|
7
7
|
from .http_client import AsyncHttpClient, HttpClient
|
8
8
|
from .jsonable_encoder import jsonable_encoder
|
9
|
+
from .pydantic_utilities import pydantic_v1
|
9
10
|
from .remove_none_from_dict import remove_none_from_dict
|
10
11
|
from .request_options import RequestOptions
|
11
12
|
|
@@ -20,6 +21,7 @@ __all__ = [
|
|
20
21
|
"SyncClientWrapper",
|
21
22
|
"convert_file_dict_to_httpx_tuples",
|
22
23
|
"jsonable_encoder",
|
24
|
+
"pydantic_v1",
|
23
25
|
"remove_none_from_dict",
|
24
26
|
"serialize_datetime",
|
25
27
|
]
|
athena/core/client_wrapper.py
CHANGED
@@ -8,15 +8,16 @@ from .http_client import AsyncHttpClient, HttpClient
|
|
8
8
|
|
9
9
|
|
10
10
|
class BaseClientWrapper:
|
11
|
-
def __init__(self, *, api_key: str, base_url: str):
|
11
|
+
def __init__(self, *, api_key: str, base_url: str, timeout: typing.Optional[float] = None):
|
12
12
|
self.api_key = api_key
|
13
13
|
self._base_url = base_url
|
14
|
+
self._timeout = timeout
|
14
15
|
|
15
16
|
def get_headers(self) -> typing.Dict[str, str]:
|
16
17
|
headers: typing.Dict[str, str] = {
|
17
18
|
"X-Fern-Language": "Python",
|
18
19
|
"X-Fern-SDK-Name": "athena-intelligence",
|
19
|
-
"X-Fern-SDK-Version": "0.1.
|
20
|
+
"X-Fern-SDK-Version": "0.1.49",
|
20
21
|
}
|
21
22
|
headers["X-API-KEY"] = self.api_key
|
22
23
|
return headers
|
@@ -24,14 +25,21 @@ class BaseClientWrapper:
|
|
24
25
|
def get_base_url(self) -> str:
|
25
26
|
return self._base_url
|
26
27
|
|
28
|
+
def get_timeout(self) -> typing.Optional[float]:
|
29
|
+
return self._timeout
|
30
|
+
|
27
31
|
|
28
32
|
class SyncClientWrapper(BaseClientWrapper):
|
29
|
-
def __init__(
|
30
|
-
|
33
|
+
def __init__(
|
34
|
+
self, *, api_key: str, base_url: str, timeout: typing.Optional[float] = None, httpx_client: httpx.Client
|
35
|
+
):
|
36
|
+
super().__init__(api_key=api_key, base_url=base_url, timeout=timeout)
|
31
37
|
self.httpx_client = HttpClient(httpx_client=httpx_client)
|
32
38
|
|
33
39
|
|
34
40
|
class AsyncClientWrapper(BaseClientWrapper):
|
35
|
-
def __init__(
|
36
|
-
|
41
|
+
def __init__(
|
42
|
+
self, *, api_key: str, base_url: str, timeout: typing.Optional[float] = None, httpx_client: httpx.AsyncClient
|
43
|
+
):
|
44
|
+
super().__init__(api_key=api_key, base_url=base_url, timeout=timeout)
|
37
45
|
self.httpx_client = AsyncHttpClient(httpx_client=httpx_client)
|
athena/core/http_client.py
CHANGED
@@ -5,6 +5,7 @@ import email.utils
|
|
5
5
|
import re
|
6
6
|
import time
|
7
7
|
import typing
|
8
|
+
from contextlib import asynccontextmanager, contextmanager
|
8
9
|
from functools import wraps
|
9
10
|
from random import random
|
10
11
|
|
@@ -98,8 +99,10 @@ class HttpClient:
|
|
98
99
|
return response
|
99
100
|
|
100
101
|
@wraps(httpx.Client.stream)
|
102
|
+
@contextmanager
|
101
103
|
def stream(self, *args: typing.Any, max_retries: int = 0, retries: int = 0, **kwargs: typing.Any) -> typing.Any:
|
102
|
-
|
104
|
+
with self.httpx_client.stream(*args, **kwargs) as stream:
|
105
|
+
yield stream
|
103
106
|
|
104
107
|
|
105
108
|
class AsyncHttpClient:
|
@@ -118,8 +121,10 @@ class AsyncHttpClient:
|
|
118
121
|
return await self.request(max_retries=max_retries, retries=retries + 1, *args, **kwargs)
|
119
122
|
return response
|
120
123
|
|
121
|
-
@wraps(httpx.AsyncClient.
|
124
|
+
@wraps(httpx.AsyncClient.stream)
|
125
|
+
@asynccontextmanager
|
122
126
|
async def stream(
|
123
127
|
self, *args: typing.Any, max_retries: int = 0, retries: int = 0, **kwargs: typing.Any
|
124
128
|
) -> typing.Any:
|
125
|
-
|
129
|
+
async with self.httpx_client.stream(*args, **kwargs) as stream:
|
130
|
+
yield stream
|
athena/core/jsonable_encoder.py
CHANGED
@@ -16,12 +16,8 @@ from pathlib import PurePath
|
|
16
16
|
from types import GeneratorType
|
17
17
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
18
18
|
|
19
|
-
try:
|
20
|
-
import pydantic.v1 as pydantic # type: ignore
|
21
|
-
except ImportError:
|
22
|
-
import pydantic # type: ignore
|
23
|
-
|
24
19
|
from .datetime_utils import serialize_datetime
|
20
|
+
from .pydantic_utilities import pydantic_v1
|
25
21
|
|
26
22
|
SetIntStr = Set[Union[int, str]]
|
27
23
|
DictIntStrAny = Dict[Union[int, str], Any]
|
@@ -36,7 +32,7 @@ def generate_encoders_by_class_tuples(
|
|
36
32
|
return encoders_by_class_tuples
|
37
33
|
|
38
34
|
|
39
|
-
encoders_by_class_tuples = generate_encoders_by_class_tuples(
|
35
|
+
encoders_by_class_tuples = generate_encoders_by_class_tuples(pydantic_v1.json.ENCODERS_BY_TYPE)
|
40
36
|
|
41
37
|
|
42
38
|
def jsonable_encoder(obj: Any, custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None) -> Any:
|
@@ -48,7 +44,7 @@ def jsonable_encoder(obj: Any, custom_encoder: Optional[Dict[Any, Callable[[Any]
|
|
48
44
|
for encoder_type, encoder_instance in custom_encoder.items():
|
49
45
|
if isinstance(obj, encoder_type):
|
50
46
|
return encoder_instance(obj)
|
51
|
-
if isinstance(obj,
|
47
|
+
if isinstance(obj, pydantic_v1.BaseModel):
|
52
48
|
encoder = getattr(obj.__config__, "json_encoders", {})
|
53
49
|
if custom_encoder:
|
54
50
|
encoder.update(custom_encoder)
|
@@ -65,10 +61,10 @@ def jsonable_encoder(obj: Any, custom_encoder: Optional[Dict[Any, Callable[[Any]
|
|
65
61
|
return str(obj)
|
66
62
|
if isinstance(obj, (str, int, float, type(None))):
|
67
63
|
return obj
|
68
|
-
if isinstance(obj, dt.date):
|
69
|
-
return str(obj)
|
70
64
|
if isinstance(obj, dt.datetime):
|
71
65
|
return serialize_datetime(obj)
|
66
|
+
if isinstance(obj, dt.date):
|
67
|
+
return str(obj)
|
72
68
|
if isinstance(obj, dict):
|
73
69
|
encoded_dict = {}
|
74
70
|
allowed_keys = set(obj.keys())
|
@@ -84,8 +80,8 @@ def jsonable_encoder(obj: Any, custom_encoder: Optional[Dict[Any, Callable[[Any]
|
|
84
80
|
encoded_list.append(jsonable_encoder(item, custom_encoder=custom_encoder))
|
85
81
|
return encoded_list
|
86
82
|
|
87
|
-
if type(obj) in
|
88
|
-
return
|
83
|
+
if type(obj) in pydantic_v1.json.ENCODERS_BY_TYPE:
|
84
|
+
return pydantic_v1.json.ENCODERS_BY_TYPE[type(obj)](obj)
|
89
85
|
for encoder, classes_tuple in encoders_by_class_tuples.items():
|
90
86
|
if isinstance(obj, classes_tuple):
|
91
87
|
return encoder(obj)
|
@@ -0,0 +1,12 @@
|
|
1
|
+
# This file was auto-generated by Fern from our API Definition.
|
2
|
+
|
3
|
+
import pydantic
|
4
|
+
|
5
|
+
IS_PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
|
6
|
+
|
7
|
+
if IS_PYDANTIC_V2:
|
8
|
+
import pydantic.v1 as pydantic_v1 # type: ignore # nopycln: import
|
9
|
+
else:
|
10
|
+
import pydantic as pydantic_v1 # type: ignore # nopycln: import
|
11
|
+
|
12
|
+
__all__ = ["pydantic_v1"]
|
athena/dataset/client.py
CHANGED
@@ -7,17 +7,13 @@ from json.decoder import JSONDecodeError
|
|
7
7
|
from ..core.api_error import ApiError
|
8
8
|
from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
|
9
9
|
from ..core.jsonable_encoder import jsonable_encoder
|
10
|
+
from ..core.pydantic_utilities import pydantic_v1
|
10
11
|
from ..core.remove_none_from_dict import remove_none_from_dict
|
11
12
|
from ..core.request_options import RequestOptions
|
12
13
|
from ..errors.unprocessable_entity_error import UnprocessableEntityError
|
13
14
|
from ..types.get_datasets_response import GetDatasetsResponse
|
14
15
|
from ..types.http_validation_error import HttpValidationError
|
15
16
|
|
16
|
-
try:
|
17
|
-
import pydantic.v1 as pydantic # type: ignore
|
18
|
-
except ImportError:
|
19
|
-
import pydantic # type: ignore
|
20
|
-
|
21
17
|
|
22
18
|
class DatasetClient:
|
23
19
|
def __init__(self, *, client_wrapper: SyncClientWrapper):
|
@@ -46,8 +42,8 @@ class DatasetClient:
|
|
46
42
|
client.dataset.get()
|
47
43
|
"""
|
48
44
|
_response = self._client_wrapper.httpx_client.request(
|
49
|
-
"GET",
|
50
|
-
urllib.parse.urljoin(f"{self._client_wrapper.get_base_url()}/", "api/v0/datasets"),
|
45
|
+
method="GET",
|
46
|
+
url=urllib.parse.urljoin(f"{self._client_wrapper.get_base_url()}/", "api/v0/datasets"),
|
51
47
|
params=jsonable_encoder(
|
52
48
|
remove_none_from_dict(
|
53
49
|
{
|
@@ -71,14 +67,16 @@ class DatasetClient:
|
|
71
67
|
),
|
72
68
|
timeout=request_options.get("timeout_in_seconds")
|
73
69
|
if request_options is not None and request_options.get("timeout_in_seconds") is not None
|
74
|
-
else
|
70
|
+
else self._client_wrapper.get_timeout(),
|
75
71
|
retries=0,
|
76
72
|
max_retries=request_options.get("max_retries") if request_options is not None else 0, # type: ignore
|
77
73
|
)
|
78
74
|
if 200 <= _response.status_code < 300:
|
79
|
-
return
|
75
|
+
return pydantic_v1.parse_obj_as(GetDatasetsResponse, _response.json()) # type: ignore
|
80
76
|
if _response.status_code == 422:
|
81
|
-
raise UnprocessableEntityError(
|
77
|
+
raise UnprocessableEntityError(
|
78
|
+
pydantic_v1.parse_obj_as(HttpValidationError, _response.json()) # type: ignore
|
79
|
+
)
|
82
80
|
try:
|
83
81
|
_response_json = _response.json()
|
84
82
|
except JSONDecodeError:
|
@@ -113,8 +111,8 @@ class AsyncDatasetClient:
|
|
113
111
|
await client.dataset.get()
|
114
112
|
"""
|
115
113
|
_response = await self._client_wrapper.httpx_client.request(
|
116
|
-
"GET",
|
117
|
-
urllib.parse.urljoin(f"{self._client_wrapper.get_base_url()}/", "api/v0/datasets"),
|
114
|
+
method="GET",
|
115
|
+
url=urllib.parse.urljoin(f"{self._client_wrapper.get_base_url()}/", "api/v0/datasets"),
|
118
116
|
params=jsonable_encoder(
|
119
117
|
remove_none_from_dict(
|
120
118
|
{
|
@@ -138,14 +136,16 @@ class AsyncDatasetClient:
|
|
138
136
|
),
|
139
137
|
timeout=request_options.get("timeout_in_seconds")
|
140
138
|
if request_options is not None and request_options.get("timeout_in_seconds") is not None
|
141
|
-
else
|
139
|
+
else self._client_wrapper.get_timeout(),
|
142
140
|
retries=0,
|
143
141
|
max_retries=request_options.get("max_retries") if request_options is not None else 0, # type: ignore
|
144
142
|
)
|
145
143
|
if 200 <= _response.status_code < 300:
|
146
|
-
return
|
144
|
+
return pydantic_v1.parse_obj_as(GetDatasetsResponse, _response.json()) # type: ignore
|
147
145
|
if _response.status_code == 422:
|
148
|
-
raise UnprocessableEntityError(
|
146
|
+
raise UnprocessableEntityError(
|
147
|
+
pydantic_v1.parse_obj_as(HttpValidationError, _response.json()) # type: ignore
|
148
|
+
)
|
149
149
|
try:
|
150
150
|
_response_json = _response.json()
|
151
151
|
except JSONDecodeError:
|