mlops-python-sdk 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlops/__init__.py +46 -0
- mlops/api/client/__init__.py +8 -0
- mlops/api/client/api/__init__.py +1 -0
- mlops/api/client/api/tasks/__init__.py +1 -0
- mlops/api/client/api/tasks/cancel_task.py +196 -0
- mlops/api/client/api/tasks/delete_task.py +204 -0
- mlops/api/client/api/tasks/get_task.py +196 -0
- mlops/api/client/api/tasks/list_tasks.py +255 -0
- mlops/api/client/api/tasks/submit_task.py +188 -0
- mlops/api/client/client.py +268 -0
- mlops/api/client/errors.py +16 -0
- mlops/api/client/models/__init__.py +33 -0
- mlops/api/client/models/error_response.py +68 -0
- mlops/api/client/models/message_response.py +59 -0
- mlops/api/client/models/task.py +1629 -0
- mlops/api/client/models/task_alloc_tres_type_0.py +49 -0
- mlops/api/client/models/task_gres_detail_type_0_item.py +44 -0
- mlops/api/client/models/task_job_resources_type_0.py +49 -0
- mlops/api/client/models/task_list_response.py +102 -0
- mlops/api/client/models/task_resources_type_0.py +49 -0
- mlops/api/client/models/task_status.py +15 -0
- mlops/api/client/models/task_submit_request.py +640 -0
- mlops/api/client/models/task_submit_request_environment_type_0.py +49 -0
- mlops/api/client/models/task_submit_response.py +78 -0
- mlops/api/client/models/task_tres_type_0.py +49 -0
- mlops/api/client/models/task_tres_used_type_0.py +49 -0
- mlops/api/client/py.typed +1 -0
- mlops/api/client/types.py +54 -0
- mlops/connection_config.py +106 -0
- mlops/exceptions.py +82 -0
- mlops/task/__init__.py +10 -0
- mlops/task/client.py +146 -0
- mlops/task/task.py +464 -0
- mlops_python_sdk-0.0.1.dist-info/METADATA +416 -0
- mlops_python_sdk-0.0.1.dist-info/RECORD +36 -0
- mlops_python_sdk-0.0.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
from typing import Any, TypeVar, Union
|
|
3
|
+
|
|
4
|
+
from attrs import define as _attrs_define
|
|
5
|
+
from attrs import field as _attrs_field
|
|
6
|
+
|
|
7
|
+
from ..types import UNSET, Unset
|
|
8
|
+
|
|
9
|
+
T = TypeVar("T", bound="TaskSubmitResponse")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@_attrs_define
|
|
13
|
+
class TaskSubmitResponse:
|
|
14
|
+
"""Task submission response
|
|
15
|
+
|
|
16
|
+
Attributes:
|
|
17
|
+
id (Union[Unset, int]): Database ID (slurm_jobs.id) Example: 1.
|
|
18
|
+
job_id (Union[Unset, str]): Slurm job ID (assigned by scheduler) Example: 12345.
|
|
19
|
+
message (Union[Unset, str]): Example: Task submitted successfully.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
id: Union[Unset, int] = UNSET
|
|
23
|
+
job_id: Union[Unset, str] = UNSET
|
|
24
|
+
message: Union[Unset, str] = UNSET
|
|
25
|
+
additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict)
|
|
26
|
+
|
|
27
|
+
def to_dict(self) -> dict[str, Any]:
|
|
28
|
+
id = self.id
|
|
29
|
+
|
|
30
|
+
job_id = self.job_id
|
|
31
|
+
|
|
32
|
+
message = self.message
|
|
33
|
+
|
|
34
|
+
field_dict: dict[str, Any] = {}
|
|
35
|
+
field_dict.update(self.additional_properties)
|
|
36
|
+
field_dict.update({})
|
|
37
|
+
if id is not UNSET:
|
|
38
|
+
field_dict["id"] = id
|
|
39
|
+
if job_id is not UNSET:
|
|
40
|
+
field_dict["job_id"] = job_id
|
|
41
|
+
if message is not UNSET:
|
|
42
|
+
field_dict["message"] = message
|
|
43
|
+
|
|
44
|
+
return field_dict
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T:
|
|
48
|
+
d = dict(src_dict)
|
|
49
|
+
id = d.pop("id", UNSET)
|
|
50
|
+
|
|
51
|
+
job_id = d.pop("job_id", UNSET)
|
|
52
|
+
|
|
53
|
+
message = d.pop("message", UNSET)
|
|
54
|
+
|
|
55
|
+
task_submit_response = cls(
|
|
56
|
+
id=id,
|
|
57
|
+
job_id=job_id,
|
|
58
|
+
message=message,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
task_submit_response.additional_properties = d
|
|
62
|
+
return task_submit_response
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def additional_keys(self) -> list[str]:
|
|
66
|
+
return list(self.additional_properties.keys())
|
|
67
|
+
|
|
68
|
+
def __getitem__(self, key: str) -> Any:
|
|
69
|
+
return self.additional_properties[key]
|
|
70
|
+
|
|
71
|
+
def __setitem__(self, key: str, value: Any) -> None:
|
|
72
|
+
self.additional_properties[key] = value
|
|
73
|
+
|
|
74
|
+
def __delitem__(self, key: str) -> None:
|
|
75
|
+
del self.additional_properties[key]
|
|
76
|
+
|
|
77
|
+
def __contains__(self, key: str) -> bool:
|
|
78
|
+
return key in self.additional_properties
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
from typing import Any, TypeVar
|
|
3
|
+
|
|
4
|
+
from attrs import define as _attrs_define
|
|
5
|
+
from attrs import field as _attrs_field
|
|
6
|
+
|
|
7
|
+
T = TypeVar("T", bound="TaskTresType0")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@_attrs_define
|
|
11
|
+
class TaskTresType0:
|
|
12
|
+
"""Trackable Resources
|
|
13
|
+
|
|
14
|
+
Example:
|
|
15
|
+
{'cpu': 4, 'mem': 8589934592}
|
|
16
|
+
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
additional_properties: dict[str, int] = _attrs_field(init=False, factory=dict)
|
|
20
|
+
|
|
21
|
+
def to_dict(self) -> dict[str, Any]:
|
|
22
|
+
field_dict: dict[str, Any] = {}
|
|
23
|
+
field_dict.update(self.additional_properties)
|
|
24
|
+
|
|
25
|
+
return field_dict
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T:
|
|
29
|
+
d = dict(src_dict)
|
|
30
|
+
task_tres_type_0 = cls()
|
|
31
|
+
|
|
32
|
+
task_tres_type_0.additional_properties = d
|
|
33
|
+
return task_tres_type_0
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def additional_keys(self) -> list[str]:
|
|
37
|
+
return list(self.additional_properties.keys())
|
|
38
|
+
|
|
39
|
+
def __getitem__(self, key: str) -> int:
|
|
40
|
+
return self.additional_properties[key]
|
|
41
|
+
|
|
42
|
+
def __setitem__(self, key: str, value: int) -> None:
|
|
43
|
+
self.additional_properties[key] = value
|
|
44
|
+
|
|
45
|
+
def __delitem__(self, key: str) -> None:
|
|
46
|
+
del self.additional_properties[key]
|
|
47
|
+
|
|
48
|
+
def __contains__(self, key: str) -> bool:
|
|
49
|
+
return key in self.additional_properties
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
from typing import Any, TypeVar
|
|
3
|
+
|
|
4
|
+
from attrs import define as _attrs_define
|
|
5
|
+
from attrs import field as _attrs_field
|
|
6
|
+
|
|
7
|
+
T = TypeVar("T", bound="TaskTresUsedType0")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@_attrs_define
|
|
11
|
+
class TaskTresUsedType0:
|
|
12
|
+
"""Trackable Resources Used
|
|
13
|
+
|
|
14
|
+
Example:
|
|
15
|
+
{'cpu': 4, 'mem': 4294967296}
|
|
16
|
+
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
additional_properties: dict[str, int] = _attrs_field(init=False, factory=dict)
|
|
20
|
+
|
|
21
|
+
def to_dict(self) -> dict[str, Any]:
|
|
22
|
+
field_dict: dict[str, Any] = {}
|
|
23
|
+
field_dict.update(self.additional_properties)
|
|
24
|
+
|
|
25
|
+
return field_dict
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T:
|
|
29
|
+
d = dict(src_dict)
|
|
30
|
+
task_tres_used_type_0 = cls()
|
|
31
|
+
|
|
32
|
+
task_tres_used_type_0.additional_properties = d
|
|
33
|
+
return task_tres_used_type_0
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def additional_keys(self) -> list[str]:
|
|
37
|
+
return list(self.additional_properties.keys())
|
|
38
|
+
|
|
39
|
+
def __getitem__(self, key: str) -> int:
|
|
40
|
+
return self.additional_properties[key]
|
|
41
|
+
|
|
42
|
+
def __setitem__(self, key: str, value: int) -> None:
|
|
43
|
+
self.additional_properties[key] = value
|
|
44
|
+
|
|
45
|
+
def __delitem__(self, key: str) -> None:
|
|
46
|
+
del self.additional_properties[key]
|
|
47
|
+
|
|
48
|
+
def __contains__(self, key: str) -> bool:
|
|
49
|
+
return key in self.additional_properties
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Marker file for PEP 561
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Contains some shared types for properties"""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Mapping, MutableMapping
|
|
4
|
+
from http import HTTPStatus
|
|
5
|
+
from typing import IO, BinaryIO, Generic, Literal, Optional, TypeVar, Union
|
|
6
|
+
|
|
7
|
+
from attrs import define
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Unset:
|
|
11
|
+
def __bool__(self) -> Literal[False]:
|
|
12
|
+
return False
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
UNSET: Unset = Unset()
|
|
16
|
+
|
|
17
|
+
# The types that `httpx.Client(files=)` can accept, copied from that library.
|
|
18
|
+
FileContent = Union[IO[bytes], bytes, str]
|
|
19
|
+
FileTypes = Union[
|
|
20
|
+
# (filename, file (or bytes), content_type)
|
|
21
|
+
tuple[Optional[str], FileContent, Optional[str]],
|
|
22
|
+
# (filename, file (or bytes), content_type, headers)
|
|
23
|
+
tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]],
|
|
24
|
+
]
|
|
25
|
+
RequestFiles = list[tuple[str, FileTypes]]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@define
|
|
29
|
+
class File:
|
|
30
|
+
"""Contains information for file uploads"""
|
|
31
|
+
|
|
32
|
+
payload: BinaryIO
|
|
33
|
+
file_name: Optional[str] = None
|
|
34
|
+
mime_type: Optional[str] = None
|
|
35
|
+
|
|
36
|
+
def to_tuple(self) -> FileTypes:
|
|
37
|
+
"""Return a tuple representation that httpx will accept for multipart/form-data"""
|
|
38
|
+
return self.file_name, self.payload, self.mime_type
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
T = TypeVar("T")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@define
|
|
45
|
+
class Response(Generic[T]):
|
|
46
|
+
"""A response from an endpoint"""
|
|
47
|
+
|
|
48
|
+
status_code: HTTPStatus
|
|
49
|
+
content: bytes
|
|
50
|
+
headers: MutableMapping[str, str]
|
|
51
|
+
parsed: Optional[T]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
__all__ = ["UNSET", "File", "FileTypes", "RequestFiles", "Response", "Unset"]
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from typing import Literal, Optional, Dict
|
|
4
|
+
from httpx._types import ProxyTypes
|
|
5
|
+
|
|
6
|
+
REQUEST_TIMEOUT: float = 30.0 # 30 seconds
|
|
7
|
+
|
|
8
|
+
KEEPALIVE_PING_INTERVAL_SEC = 50 # 50 seconds
|
|
9
|
+
KEEPALIVE_PING_HEADER = "Keepalive-Ping-Interval"
|
|
10
|
+
|
|
11
|
+
# API path prefix
|
|
12
|
+
DEFAULT_API_PATH: str = "/api/v1"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ConnectionConfig:
|
|
16
|
+
"""
|
|
17
|
+
Configuration for the connection to the API.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
@staticmethod
|
|
21
|
+
def _domain():
|
|
22
|
+
return os.getenv("XCLIENT_DOMAIN", "localhost:8090")
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
def _debug():
|
|
26
|
+
return os.getenv("XCLIENT_DEBUG", "false").lower() == "true"
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def _api_key():
|
|
30
|
+
return os.getenv("XCLIENT_API_KEY")
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def _access_token():
|
|
34
|
+
return os.getenv("XCLIENT_ACCESS_TOKEN")
|
|
35
|
+
|
|
36
|
+
@staticmethod
|
|
37
|
+
def _api_path():
|
|
38
|
+
return os.getenv("XCLIENT_API_PATH", DEFAULT_API_PATH)
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
domain: Optional[str] = None,
|
|
43
|
+
debug: Optional[bool] = None,
|
|
44
|
+
api_key: Optional[str] = None,
|
|
45
|
+
access_token: Optional[str] = None,
|
|
46
|
+
request_timeout: Optional[float] = None,
|
|
47
|
+
headers: Optional[Dict[str, str]] = None,
|
|
48
|
+
proxy: Optional[ProxyTypes] = None,
|
|
49
|
+
api_path: Optional[str] = None,
|
|
50
|
+
):
|
|
51
|
+
self.domain = domain or ConnectionConfig._domain()
|
|
52
|
+
self.debug = debug or ConnectionConfig._debug()
|
|
53
|
+
self.api_key = api_key or ConnectionConfig._api_key()
|
|
54
|
+
self.access_token = access_token or ConnectionConfig._access_token()
|
|
55
|
+
self.headers = headers or {}
|
|
56
|
+
self.proxy = proxy
|
|
57
|
+
self.api_path = api_path or ConnectionConfig._api_path()
|
|
58
|
+
|
|
59
|
+
self.request_timeout = ConnectionConfig._get_request_timeout(
|
|
60
|
+
REQUEST_TIMEOUT,
|
|
61
|
+
request_timeout,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
if request_timeout == 0:
|
|
65
|
+
self.request_timeout = None
|
|
66
|
+
elif request_timeout is not None:
|
|
67
|
+
self.request_timeout = request_timeout
|
|
68
|
+
else:
|
|
69
|
+
self.request_timeout = REQUEST_TIMEOUT
|
|
70
|
+
|
|
71
|
+
# Ensure api_path starts with /
|
|
72
|
+
if not self.api_path.startswith("/"):
|
|
73
|
+
self.api_path = "/" + self.api_path
|
|
74
|
+
|
|
75
|
+
# Build API URL
|
|
76
|
+
if self.debug:
|
|
77
|
+
base_url = "http://localhost:8090"
|
|
78
|
+
else:
|
|
79
|
+
# If domain already includes protocol, use it as-is
|
|
80
|
+
# Otherwise, default to http:// for backward compatibility
|
|
81
|
+
if self.domain.startswith(("http://", "https://")):
|
|
82
|
+
base_url = self.domain
|
|
83
|
+
else:
|
|
84
|
+
base_url = f"http://{self.domain}"
|
|
85
|
+
|
|
86
|
+
self.api_url = f"{base_url}{self.api_path}"
|
|
87
|
+
|
|
88
|
+
@staticmethod
|
|
89
|
+
def _get_request_timeout(
|
|
90
|
+
default_timeout: Optional[float],
|
|
91
|
+
request_timeout: Optional[float],
|
|
92
|
+
):
|
|
93
|
+
if request_timeout == 0:
|
|
94
|
+
return None
|
|
95
|
+
elif request_timeout is not None:
|
|
96
|
+
return request_timeout
|
|
97
|
+
else:
|
|
98
|
+
return default_timeout
|
|
99
|
+
|
|
100
|
+
def get_request_timeout(self, request_timeout: Optional[float] = None):
|
|
101
|
+
return self._get_request_timeout(self.request_timeout, request_timeout)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# Re-export ProxyTypes for convenience
|
|
105
|
+
__all__ = ["ConnectionConfig", "ProxyTypes"]
|
|
106
|
+
|
mlops/exceptions.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
def format_request_timeout_error() -> Exception:
|
|
2
|
+
return TimeoutException(
|
|
3
|
+
"Request timed out — the 'request_timeout' option can be used to increase this timeout",
|
|
4
|
+
)
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def format_execution_timeout_error() -> Exception:
|
|
8
|
+
return TimeoutException(
|
|
9
|
+
"Execution timed out — the 'timeout' option can be used to increase this timeout",
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class XClientException(Exception):
|
|
14
|
+
"""
|
|
15
|
+
Base class for all XClient errors.
|
|
16
|
+
|
|
17
|
+
Raised when a general XClient exception occurs.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TimeoutException(XClientException):
|
|
24
|
+
"""
|
|
25
|
+
Raised when a timeout occurs.
|
|
26
|
+
|
|
27
|
+
The `unavailable` exception type is caused by service timeout.\n
|
|
28
|
+
The `canceled` exception type is caused by exceeding request timeout.\n
|
|
29
|
+
The `deadline_exceeded` exception type is caused by exceeding the timeout for process, watch, etc.\n
|
|
30
|
+
The `unknown` exception type is sometimes caused by the service timeout when the request is not processed correctly.\n
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class InvalidArgumentException(XClientException):
|
|
37
|
+
"""
|
|
38
|
+
Raised when an invalid argument is provided.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class NotEnoughSpaceException(XClientException):
|
|
45
|
+
"""
|
|
46
|
+
Raised when there is not enough disk space.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class NotFoundException(XClientException):
|
|
53
|
+
"""
|
|
54
|
+
Raised when a resource is not found.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class AuthenticationException(XClientException):
|
|
61
|
+
"""
|
|
62
|
+
Raised when authentication fails.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class RateLimitException(XClientException):
|
|
69
|
+
"""
|
|
70
|
+
Raised when the API rate limit is exceeded.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
pass
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class APIException(XClientException):
|
|
77
|
+
"""
|
|
78
|
+
Raised when an API error occurs.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
pass
|
|
82
|
+
|
mlops/task/__init__.py
ADDED
mlops/task/client.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Optional
|
|
4
|
+
from httpx import Limits
|
|
5
|
+
|
|
6
|
+
from ..api.client.client import AuthenticatedClient
|
|
7
|
+
from ..connection_config import ConnectionConfig
|
|
8
|
+
from ..exceptions import (
|
|
9
|
+
AuthenticationException,
|
|
10
|
+
RateLimitException,
|
|
11
|
+
NotFoundException,
|
|
12
|
+
APIException,
|
|
13
|
+
)
|
|
14
|
+
from ..api.client.types import Response
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def handle_api_exception(e: Response):
|
|
20
|
+
"""Handle API exceptions and convert them to appropriate XClient exceptions."""
|
|
21
|
+
try:
|
|
22
|
+
body = json.loads(e.content) if e.content else {}
|
|
23
|
+
except json.JSONDecodeError:
|
|
24
|
+
body = {}
|
|
25
|
+
|
|
26
|
+
if e.status_code == 401:
|
|
27
|
+
return AuthenticationException(
|
|
28
|
+
f"Authentication failed: {body.get('error', 'Invalid credentials')}"
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
if e.status_code == 404:
|
|
32
|
+
return NotFoundException(
|
|
33
|
+
f"Resource not found: {body.get('error', 'The requested resource was not found')}"
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
if e.status_code == 429:
|
|
37
|
+
return RateLimitException(
|
|
38
|
+
f"{e.status_code}: Rate limit exceeded, please try again later."
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
if "error" in body:
|
|
42
|
+
return APIException(f"{e.status_code}: {body['error']}")
|
|
43
|
+
|
|
44
|
+
if "message" in body:
|
|
45
|
+
return APIException(f"{e.status_code}: {body['message']}")
|
|
46
|
+
|
|
47
|
+
return APIException(f"{e.status_code}: {e.content}")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class TaskClient(AuthenticatedClient):
|
|
51
|
+
"""
|
|
52
|
+
The client for interacting with the XClient Task API.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
config: ConnectionConfig,
|
|
58
|
+
require_api_key: bool = True,
|
|
59
|
+
require_access_token: bool = False,
|
|
60
|
+
limits: Optional[Limits] = None,
|
|
61
|
+
*args,
|
|
62
|
+
**kwargs,
|
|
63
|
+
):
|
|
64
|
+
if require_api_key and require_access_token:
|
|
65
|
+
raise AuthenticationException(
|
|
66
|
+
"Only one of api_key or access_token can be required, not both",
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
if not require_api_key and not require_access_token:
|
|
70
|
+
raise AuthenticationException(
|
|
71
|
+
"Either api_key or access_token is required",
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
token = None
|
|
75
|
+
if require_api_key:
|
|
76
|
+
if config.api_key is None:
|
|
77
|
+
raise AuthenticationException(
|
|
78
|
+
"API key is required. "
|
|
79
|
+
"You can either set the environment variable `XCLIENT_API_KEY` "
|
|
80
|
+
'or pass it directly like TaskClient(api_key="xclient_...")',
|
|
81
|
+
)
|
|
82
|
+
token = config.api_key
|
|
83
|
+
|
|
84
|
+
if require_access_token:
|
|
85
|
+
if config.access_token is None:
|
|
86
|
+
raise AuthenticationException(
|
|
87
|
+
"Access token is required. "
|
|
88
|
+
"You can set the environment variable `XCLIENT_ACCESS_TOKEN` "
|
|
89
|
+
"or pass the `access_token` in options.",
|
|
90
|
+
)
|
|
91
|
+
token = config.access_token
|
|
92
|
+
|
|
93
|
+
# API Key header: X-API-Key (per OpenAPI spec)
|
|
94
|
+
# JWT header: Authorization: Bearer <token>
|
|
95
|
+
auth_header_name = "X-API-Key" if require_api_key else "Authorization"
|
|
96
|
+
prefix = "" if require_api_key else "Bearer"
|
|
97
|
+
|
|
98
|
+
headers = {
|
|
99
|
+
**(config.headers or {}),
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
httpx_args = {
|
|
103
|
+
"event_hooks": {
|
|
104
|
+
"request": [self._log_request],
|
|
105
|
+
"response": [self._log_response],
|
|
106
|
+
},
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
if config.proxy is not None:
|
|
110
|
+
httpx_args["proxy"] = config.proxy
|
|
111
|
+
|
|
112
|
+
if limits is not None:
|
|
113
|
+
httpx_args["limits"] = limits
|
|
114
|
+
|
|
115
|
+
super().__init__(
|
|
116
|
+
base_url=config.api_url,
|
|
117
|
+
httpx_args=httpx_args,
|
|
118
|
+
headers=headers,
|
|
119
|
+
token=token,
|
|
120
|
+
auth_header_name=auth_header_name,
|
|
121
|
+
prefix=prefix,
|
|
122
|
+
*args,
|
|
123
|
+
**kwargs,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
def _log_request(self, request):
|
|
127
|
+
logger.info(f"Request {request.method} {request.url}")
|
|
128
|
+
|
|
129
|
+
def _log_response(self, response: Response):
|
|
130
|
+
if response.status_code >= 400:
|
|
131
|
+
logger.error(f"Response {response.status_code}")
|
|
132
|
+
else:
|
|
133
|
+
logger.info(f"Response {response.status_code}")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
# We need to override the logging hooks for the async usage
|
|
137
|
+
class AsyncTaskClient(TaskClient):
|
|
138
|
+
async def _log_request(self, request):
|
|
139
|
+
logger.info(f"Request {request.method} {request.url}")
|
|
140
|
+
|
|
141
|
+
async def _log_response(self, response: Response):
|
|
142
|
+
if response.status_code >= 400:
|
|
143
|
+
logger.error(f"Response {response.status_code}")
|
|
144
|
+
else:
|
|
145
|
+
logger.info(f"Response {response.status_code}")
|
|
146
|
+
|