diracx-client 0.0.1a33__py3-none-any.whl → 0.0.1a35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- _diracx_client_importer.pth +1 -0
- diracx/_client_importer.py +384 -0
- diracx/client/__init__.py +12 -4
- diracx/client/{generated → _generated}/__init__.py +1 -1
- diracx/client/{generated → _generated}/_client.py +5 -5
- diracx/client/{generated → _generated}/_configuration.py +1 -1
- diracx/client/{generated/aio → _generated}/_patch.py +4 -5
- diracx/client/{generated → _generated}/_serialization.py +1 -1
- diracx/client/{generated/aio → _generated}/_vendor.py +1 -1
- diracx/client/{generated → _generated}/aio/__init__.py +1 -1
- diracx/client/{generated → _generated}/aio/_client.py +5 -5
- diracx/client/{generated → _generated}/aio/_configuration.py +1 -1
- diracx/client/{generated → _generated/aio}/_patch.py +4 -9
- diracx/client/{generated → _generated/aio}/_vendor.py +1 -1
- diracx/client/{generated → _generated}/aio/operations/__init__.py +1 -1
- diracx/client/{generated → _generated}/aio/operations/_operations.py +158 -28
- diracx/client/_generated/aio/operations/_patch.py +26 -0
- diracx/client/{generated → _generated}/models/__init__.py +11 -1
- diracx/client/{generated → _generated}/models/_enums.py +1 -1
- diracx/client/{generated → _generated}/models/_models.py +152 -39
- diracx/client/{generated → _generated}/models/_patch.py +15 -12
- diracx/client/{generated → _generated}/operations/__init__.py +1 -1
- diracx/client/{generated → _generated}/operations/_operations.py +178 -28
- diracx/client/_generated/operations/_patch.py +26 -0
- diracx/client/aio.py +12 -2
- diracx/client/models.py +3 -6
- diracx/client/patches/auth/aio.py +45 -0
- diracx/client/patches/auth/common.py +56 -0
- diracx/client/patches/auth/sync.py +41 -0
- diracx/client/patches/{aio/utils.py → client/aio.py} +22 -40
- diracx/client/patches/client/common.py +196 -0
- diracx/client/patches/client/sync.py +141 -0
- diracx/client/patches/jobs/aio.py +34 -0
- diracx/client/patches/jobs/common.py +85 -0
- diracx/client/patches/jobs/sync.py +34 -0
- diracx/client/py.typed +0 -1
- diracx/client/sync.py +13 -0
- {diracx_client-0.0.1a33.dist-info → diracx_client-0.0.1a35.dist-info}/METADATA +3 -4
- diracx_client-0.0.1a35.dist-info/RECORD +42 -0
- {diracx_client-0.0.1a33.dist-info → diracx_client-0.0.1a35.dist-info}/WHEEL +1 -2
- diracx/client/extensions.py +0 -90
- diracx/client/generated/aio/operations/_patch.py +0 -126
- diracx/client/generated/operations/_patch.py +0 -129
- diracx/client/patches/__init__.py +0 -19
- diracx/client/patches/aio/__init__.py +0 -18
- diracx_client-0.0.1a33.dist-info/RECORD +0 -36
- diracx_client-0.0.1a33.dist-info/entry_points.txt +0 -3
- diracx_client-0.0.1a33.dist-info/top_level.txt +0 -1
- /diracx/client/{generated → _generated}/py.typed +0 -0
@@ -1,38 +1,27 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
# Licensed under the MIT License.
|
4
|
-
# ------------------------------------
|
5
|
-
"""Customize generated code here.
|
6
|
-
|
7
|
-
Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize
|
8
|
-
"""
|
1
|
+
"""Patches for the autorest-generated client to enable authentication."""
|
2
|
+
|
9
3
|
from __future__ import annotations
|
10
4
|
|
11
|
-
import abc
|
12
5
|
from importlib.metadata import PackageNotFoundError, distribution
|
13
|
-
from types import TracebackType
|
14
6
|
from pathlib import Path
|
15
|
-
|
16
|
-
from typing import Any,
|
7
|
+
from types import TracebackType
|
8
|
+
from typing import Any, cast
|
17
9
|
|
18
10
|
from azure.core.credentials import AccessToken
|
19
11
|
from azure.core.credentials_async import AsyncTokenCredential
|
20
12
|
from azure.core.pipeline import PipelineRequest
|
21
13
|
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy
|
22
|
-
|
23
14
|
from diracx.core.preferences import get_diracx_preferences, DiracxPreferences
|
24
15
|
|
25
|
-
from ..utils import
|
26
|
-
|
27
|
-
get_token,
|
28
|
-
)
|
16
|
+
from ..utils import get_openid_configuration, get_token
|
17
|
+
from ..._generated.aio._client import Dirac as _Dirac
|
29
18
|
|
30
|
-
__all__
|
31
|
-
"
|
32
|
-
]
|
19
|
+
__all__ = [
|
20
|
+
"Dirac",
|
21
|
+
]
|
33
22
|
|
34
23
|
|
35
|
-
class
|
24
|
+
class AsyncDiracTokenCredential(AsyncTokenCredential):
|
36
25
|
"""Tailor get_token() for our context"""
|
37
26
|
|
38
27
|
def __init__(
|
@@ -51,8 +40,8 @@ class DiracTokenCredential(AsyncTokenCredential):
|
|
51
40
|
async def get_token(
|
52
41
|
self,
|
53
42
|
*scopes: str,
|
54
|
-
claims:
|
55
|
-
tenant_id:
|
43
|
+
claims: str | None = None,
|
44
|
+
tenant_id: str | None = None,
|
56
45
|
**kwargs: Any,
|
57
46
|
) -> AccessToken:
|
58
47
|
return get_token(
|
@@ -81,18 +70,17 @@ class DiracTokenCredential(AsyncTokenCredential):
|
|
81
70
|
pass
|
82
71
|
|
83
72
|
|
84
|
-
class
|
73
|
+
class AsyncDiracBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy):
|
85
74
|
"""Custom AsyncBearerTokenCredentialPolicy tailored for our use case.
|
86
75
|
|
87
76
|
* It does not ensure the connection is done through https.
|
88
77
|
* It does not ensure that an access token is available.
|
89
78
|
"""
|
90
79
|
|
91
|
-
|
92
|
-
_token: Optional[AccessToken] = None
|
80
|
+
_token: AccessToken | None = None
|
93
81
|
|
94
82
|
def __init__(
|
95
|
-
self, credential:
|
83
|
+
self, credential: AsyncDiracTokenCredential, *scopes: str, **kwargs: Any
|
96
84
|
) -> None:
|
97
85
|
super().__init__(credential, *scopes, **kwargs)
|
98
86
|
|
@@ -104,9 +92,10 @@ class DiracBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy):
|
|
104
92
|
:type request: ~azure.core.pipeline.PipelineRequest
|
105
93
|
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
|
106
94
|
"""
|
107
|
-
# Make mypy happy
|
108
95
|
if not isinstance(self._credential, AsyncTokenCredential):
|
109
|
-
|
96
|
+
raise NotImplementedError(
|
97
|
+
"AsyncDiracBearerTokenCredentialPolicy only supports AsyncTokenCredential"
|
98
|
+
)
|
110
99
|
|
111
100
|
self._token = await self._credential.get_token("", token=self._token)
|
112
101
|
if not self._token.token:
|
@@ -119,10 +108,7 @@ class DiracBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy):
|
|
119
108
|
)
|
120
109
|
|
121
110
|
|
122
|
-
class
|
123
|
-
"""This class inherits from the generated Dirac client and adds support for tokens,
|
124
|
-
so that the caller does not need to configure it by itself.
|
125
|
-
"""
|
111
|
+
class Dirac(_Dirac):
|
126
112
|
|
127
113
|
def __init__(
|
128
114
|
self,
|
@@ -156,14 +142,10 @@ class DiracClientMixin(metaclass=abc.ABCMeta):
|
|
156
142
|
"DiracX-Client-Version"
|
157
143
|
] = self.client_version
|
158
144
|
|
159
|
-
|
160
|
-
# We need to ignore types here because mypy complains that we give
|
161
|
-
# too many arguments to "object" constructor as this is a mixin
|
162
|
-
|
163
|
-
super().__init__( # type: ignore
|
145
|
+
super().__init__(
|
164
146
|
endpoint=self._endpoint,
|
165
|
-
authentication_policy=
|
166
|
-
|
147
|
+
authentication_policy=AsyncDiracBearerTokenCredentialPolicy(
|
148
|
+
AsyncDiracTokenCredential(
|
167
149
|
location=diracx_preferences.credentials_path,
|
168
150
|
token_endpoint=openid_configuration["token_endpoint"],
|
169
151
|
client_id=self._client_id,
|
@@ -0,0 +1,196 @@
|
|
1
|
+
"""Utilities which are common to the sync and async client patches."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
__all__ = [
|
6
|
+
"DiracAuthMixin",
|
7
|
+
]
|
8
|
+
|
9
|
+
import fcntl
|
10
|
+
import json
|
11
|
+
import os
|
12
|
+
from datetime import datetime, timedelta, timezone
|
13
|
+
from dataclasses import dataclass
|
14
|
+
from enum import Enum
|
15
|
+
from pathlib import Path
|
16
|
+
from typing import TextIO
|
17
|
+
from urllib import parse
|
18
|
+
|
19
|
+
import httpx
|
20
|
+
import jwt
|
21
|
+
from azure.core.credentials import AccessToken
|
22
|
+
from diracx.core.utils import EXPIRES_GRACE_SECONDS, serialize_credentials
|
23
|
+
from diracx.core.models import TokenResponse
|
24
|
+
|
25
|
+
|
26
|
+
class TokenStatus(Enum):
|
27
|
+
VALID = "valid"
|
28
|
+
REFRESH = "refresh"
|
29
|
+
INVALID = "invalid"
|
30
|
+
|
31
|
+
|
32
|
+
@dataclass
|
33
|
+
class TokenResult:
|
34
|
+
status: TokenStatus
|
35
|
+
access_token: AccessToken | None = None
|
36
|
+
refresh_token: str | None = None
|
37
|
+
|
38
|
+
|
39
|
+
def get_openid_configuration(
|
40
|
+
endpoint: str, *, verify: bool | str = True
|
41
|
+
) -> dict[str, str]:
|
42
|
+
"""Get the openid configuration from the .well-known endpoint"""
|
43
|
+
response = httpx.get(
|
44
|
+
url=parse.urljoin(endpoint, ".well-known/openid-configuration"),
|
45
|
+
verify=verify,
|
46
|
+
)
|
47
|
+
if not response.is_success:
|
48
|
+
raise RuntimeError("Cannot fetch any information from the .well-known endpoint")
|
49
|
+
return response.json()
|
50
|
+
|
51
|
+
|
52
|
+
def get_token(
|
53
|
+
location: Path,
|
54
|
+
token: AccessToken | None,
|
55
|
+
token_endpoint: str,
|
56
|
+
client_id: str,
|
57
|
+
verify: bool | str,
|
58
|
+
) -> AccessToken:
|
59
|
+
"""Get the access token if available and still valid."""
|
60
|
+
# Immediately return the token if it is available and still valid
|
61
|
+
if token and is_token_valid(token):
|
62
|
+
return token
|
63
|
+
|
64
|
+
if not location.exists():
|
65
|
+
# If we are here, it means the credentials path does not exist
|
66
|
+
# we suppose access token is not needed to perform the request
|
67
|
+
# we return an empty token to align with the expected return type
|
68
|
+
return AccessToken(token="", expires_on=0)
|
69
|
+
|
70
|
+
with open(location, "r+") as f:
|
71
|
+
# Acquire exclusive lock
|
72
|
+
fcntl.flock(f, fcntl.LOCK_EX)
|
73
|
+
try:
|
74
|
+
response = extract_token_from_credentials(f, token)
|
75
|
+
if response.status == TokenStatus.VALID and response.access_token:
|
76
|
+
# Lock is released in the finally block
|
77
|
+
return response.access_token
|
78
|
+
|
79
|
+
if response.status == TokenStatus.REFRESH and response.refresh_token:
|
80
|
+
# If we are here, it means the token needs to be refreshed
|
81
|
+
token_response = refresh_token(
|
82
|
+
token_endpoint,
|
83
|
+
client_id,
|
84
|
+
response.refresh_token,
|
85
|
+
verify=verify,
|
86
|
+
)
|
87
|
+
|
88
|
+
# Write the new credentials to the file
|
89
|
+
f.seek(0)
|
90
|
+
f.truncate()
|
91
|
+
f.write(serialize_credentials(token_response))
|
92
|
+
f.flush()
|
93
|
+
os.fsync(f.fileno())
|
94
|
+
|
95
|
+
# Get an AccessToken instance
|
96
|
+
return AccessToken(
|
97
|
+
token=token_response.access_token,
|
98
|
+
expires_on=int(
|
99
|
+
(
|
100
|
+
datetime.now(tz=timezone.utc)
|
101
|
+
+ timedelta(
|
102
|
+
seconds=token_response.expires_in
|
103
|
+
- EXPIRES_GRACE_SECONDS
|
104
|
+
)
|
105
|
+
).timestamp()
|
106
|
+
),
|
107
|
+
)
|
108
|
+
# If we are here, it means the token is not available or not valid anymore
|
109
|
+
return AccessToken(token="", expires_on=0)
|
110
|
+
finally:
|
111
|
+
# Release the lock
|
112
|
+
fcntl.flock(f, fcntl.LOCK_UN)
|
113
|
+
|
114
|
+
|
115
|
+
def refresh_token(
|
116
|
+
token_endpoint: str,
|
117
|
+
client_id: str,
|
118
|
+
refresh_token: str,
|
119
|
+
*,
|
120
|
+
verify: bool | str = True,
|
121
|
+
) -> TokenResponse:
|
122
|
+
"""Refresh the access token using the refresh_token flow."""
|
123
|
+
response = httpx.post(
|
124
|
+
url=token_endpoint,
|
125
|
+
data={
|
126
|
+
"client_id": client_id,
|
127
|
+
"grant_type": "refresh_token",
|
128
|
+
"refresh_token": refresh_token,
|
129
|
+
},
|
130
|
+
verify=verify,
|
131
|
+
)
|
132
|
+
|
133
|
+
if response.status_code != 200:
|
134
|
+
raise RuntimeError(
|
135
|
+
f"An issue occured while refreshing your access token: {response.json()['detail']}"
|
136
|
+
)
|
137
|
+
|
138
|
+
res = response.json()
|
139
|
+
return TokenResponse(
|
140
|
+
access_token=res["access_token"],
|
141
|
+
expires_in=res["expires_in"],
|
142
|
+
token_type=res.get("token_type"),
|
143
|
+
refresh_token=res.get("refresh_token"),
|
144
|
+
)
|
145
|
+
|
146
|
+
|
147
|
+
def is_token_valid(token: AccessToken) -> bool:
|
148
|
+
"""Condition to get a new token"""
|
149
|
+
return (
|
150
|
+
datetime.fromtimestamp(token.expires_on, tz=timezone.utc)
|
151
|
+
- datetime.now(tz=timezone.utc)
|
152
|
+
).total_seconds() > 300
|
153
|
+
|
154
|
+
|
155
|
+
def extract_token_from_credentials(
|
156
|
+
token_file_descriptor: TextIO, token: AccessToken | None
|
157
|
+
) -> TokenResult:
|
158
|
+
"""Get token if available and still valid."""
|
159
|
+
# If we are here, it means the token is not available or not valid anymore
|
160
|
+
# We try to get it from the file
|
161
|
+
try:
|
162
|
+
credentials = json.load(token_file_descriptor)
|
163
|
+
except json.JSONDecodeError:
|
164
|
+
return TokenResult(TokenStatus.INVALID)
|
165
|
+
|
166
|
+
try:
|
167
|
+
token = AccessToken(
|
168
|
+
token=credentials["access_token"],
|
169
|
+
expires_on=credentials["expires_on"],
|
170
|
+
)
|
171
|
+
refresh_token = credentials["refresh_token"]
|
172
|
+
except KeyError:
|
173
|
+
return TokenResult(TokenStatus.INVALID)
|
174
|
+
|
175
|
+
# We check the validity of the tokens
|
176
|
+
if is_token_valid(token):
|
177
|
+
return TokenResult(TokenStatus.VALID, access_token=token)
|
178
|
+
|
179
|
+
if is_refresh_token_valid(refresh_token):
|
180
|
+
return TokenResult(TokenStatus.REFRESH, refresh_token=refresh_token)
|
181
|
+
|
182
|
+
# If we are here, it means the refresh token is not valid anymore
|
183
|
+
return TokenResult(TokenStatus.INVALID)
|
184
|
+
|
185
|
+
|
186
|
+
def is_refresh_token_valid(refresh_token: str | None) -> bool:
|
187
|
+
"""Check if the refresh token is still valid."""
|
188
|
+
if not refresh_token:
|
189
|
+
return False
|
190
|
+
# Decode the refresh token
|
191
|
+
refresh_payload = jwt.decode(refresh_token, options={"verify_signature": False})
|
192
|
+
if not refresh_payload or "exp" not in refresh_payload:
|
193
|
+
return False
|
194
|
+
|
195
|
+
# Check the expiration time
|
196
|
+
return refresh_payload["exp"] > datetime.now(tz=timezone.utc).timestamp()
|
@@ -0,0 +1,141 @@
|
|
1
|
+
"""Patches for the autorest-generated client to enable authentication."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
__all__ = [
|
6
|
+
"Dirac",
|
7
|
+
]
|
8
|
+
|
9
|
+
from importlib.metadata import PackageNotFoundError, distribution
|
10
|
+
from pathlib import Path
|
11
|
+
from typing import Any, Optional
|
12
|
+
|
13
|
+
from azure.core.credentials import AccessToken, TokenCredential
|
14
|
+
from azure.core.pipeline import PipelineRequest
|
15
|
+
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
|
16
|
+
from diracx.core.preferences import DiracxPreferences, get_diracx_preferences
|
17
|
+
|
18
|
+
from .common import get_openid_configuration, get_token
|
19
|
+
from ..._generated._client import Dirac as _Dirac
|
20
|
+
|
21
|
+
|
22
|
+
class SyncDiracTokenCredential(TokenCredential):
|
23
|
+
"""Tailor get_token() for our context"""
|
24
|
+
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
location: Path,
|
28
|
+
token_endpoint: str,
|
29
|
+
client_id: str,
|
30
|
+
*,
|
31
|
+
verify: bool | str = True,
|
32
|
+
) -> None:
|
33
|
+
self.location = location
|
34
|
+
self.verify = verify
|
35
|
+
self.token_endpoint = token_endpoint
|
36
|
+
self.client_id = client_id
|
37
|
+
|
38
|
+
def get_token(
|
39
|
+
self,
|
40
|
+
*scopes: str,
|
41
|
+
claims: Optional[str] = None,
|
42
|
+
tenant_id: Optional[str] = None,
|
43
|
+
**kwargs: Any,
|
44
|
+
) -> AccessToken:
|
45
|
+
return get_token(
|
46
|
+
self.location,
|
47
|
+
kwargs.get("token"),
|
48
|
+
self.token_endpoint,
|
49
|
+
self.client_id,
|
50
|
+
self.verify,
|
51
|
+
)
|
52
|
+
|
53
|
+
|
54
|
+
class SyncDiracBearerTokenCredentialPolicy(BearerTokenCredentialPolicy):
|
55
|
+
"""Custom BearerTokenCredentialPolicy tailored for our use case.
|
56
|
+
|
57
|
+
* It does not ensure the connection is done through https.
|
58
|
+
* It does not ensure that an access token is available.
|
59
|
+
"""
|
60
|
+
|
61
|
+
# Make mypy happy
|
62
|
+
_token: Optional[AccessToken] = None
|
63
|
+
|
64
|
+
def __init__(
|
65
|
+
self, credential: SyncDiracTokenCredential, *scopes: str, **kwargs: Any
|
66
|
+
) -> None:
|
67
|
+
super().__init__(credential, *scopes, **kwargs)
|
68
|
+
|
69
|
+
def on_request(self, request: PipelineRequest) -> None:
|
70
|
+
"""Authorization Bearer is optional here.
|
71
|
+
:param request: The pipeline request object to be modified.
|
72
|
+
:type request: ~azure.core.pipeline.PipelineRequest
|
73
|
+
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
|
74
|
+
"""
|
75
|
+
if not isinstance(self._credential, TokenCredential):
|
76
|
+
raise NotImplementedError(
|
77
|
+
"SyncDiracBearerTokenCredentialPolicy only supports TokenCredential"
|
78
|
+
)
|
79
|
+
|
80
|
+
self._token = self._credential.get_token("", token=self._token)
|
81
|
+
if not self._token.token:
|
82
|
+
# If we are here, it means the token is not available
|
83
|
+
# we suppose it is not needed to perform the request
|
84
|
+
return
|
85
|
+
|
86
|
+
self._update_headers(request.http_request.headers, self._token.token)
|
87
|
+
|
88
|
+
|
89
|
+
class Dirac(_Dirac):
|
90
|
+
"""This class inherits from the generated Dirac client and adds support for tokens,
|
91
|
+
so that the caller does not need to configure it by itself.
|
92
|
+
"""
|
93
|
+
|
94
|
+
def __init__(
|
95
|
+
self,
|
96
|
+
endpoint: str | None = None,
|
97
|
+
client_id: str | None = None,
|
98
|
+
diracx_preferences: DiracxPreferences | None = None,
|
99
|
+
verify: bool | str = True,
|
100
|
+
**kwargs: Any,
|
101
|
+
) -> None:
|
102
|
+
diracx_preferences = diracx_preferences or get_diracx_preferences()
|
103
|
+
self._endpoint = str(endpoint or diracx_preferences.url)
|
104
|
+
if verify is True and diracx_preferences.ca_path:
|
105
|
+
verify = str(diracx_preferences.ca_path)
|
106
|
+
kwargs["connection_verify"] = verify
|
107
|
+
self._client_id = client_id or "myDIRACClientID"
|
108
|
+
|
109
|
+
# Get .well-known configuration
|
110
|
+
openid_configuration = get_openid_configuration(self._endpoint, verify=verify)
|
111
|
+
|
112
|
+
try:
|
113
|
+
self.client_version = distribution("diracx").version
|
114
|
+
except PackageNotFoundError:
|
115
|
+
try:
|
116
|
+
self.client_version = distribution("diracx-client").version
|
117
|
+
except PackageNotFoundError:
|
118
|
+
print("Error while getting client version")
|
119
|
+
self.client_version = "Unknown"
|
120
|
+
|
121
|
+
# Setting default headers
|
122
|
+
kwargs.setdefault("base_headers", {})[
|
123
|
+
"DiracX-Client-Version"
|
124
|
+
] = self.client_version
|
125
|
+
|
126
|
+
super().__init__(
|
127
|
+
endpoint=self._endpoint,
|
128
|
+
authentication_policy=SyncDiracBearerTokenCredentialPolicy(
|
129
|
+
SyncDiracTokenCredential(
|
130
|
+
location=diracx_preferences.credentials_path,
|
131
|
+
token_endpoint=openid_configuration["token_endpoint"],
|
132
|
+
client_id=self._client_id,
|
133
|
+
verify=verify,
|
134
|
+
),
|
135
|
+
),
|
136
|
+
**kwargs,
|
137
|
+
)
|
138
|
+
|
139
|
+
@property
|
140
|
+
def client_id(self):
|
141
|
+
return self._client_id
|
@@ -0,0 +1,34 @@
|
|
1
|
+
"""Patches for the autorest-generated jobs client.
|
2
|
+
|
3
|
+
This file can be used to customize the generated code for the jobs client.
|
4
|
+
When adding new classes to this file, make sure to also add them to the
|
5
|
+
__all__ list in the corresponding file in the patches directory.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from __future__ import annotations
|
9
|
+
|
10
|
+
__all__ = [
|
11
|
+
"JobsOperations",
|
12
|
+
]
|
13
|
+
|
14
|
+
from typing import Any, Unpack
|
15
|
+
|
16
|
+
from azure.core.tracing.decorator_async import distributed_trace_async
|
17
|
+
|
18
|
+
from ..._generated.aio.operations._operations import JobsOperations as _JobsOperations
|
19
|
+
from .common import make_search_body, make_summary_body, SearchKwargs, SummaryKwargs
|
20
|
+
|
21
|
+
# We're intentionally ignoring overrides here because we want to change the interface.
|
22
|
+
# mypy: disable-error-code=override
|
23
|
+
|
24
|
+
|
25
|
+
class JobsOperations(_JobsOperations):
|
26
|
+
@distributed_trace_async
|
27
|
+
async def search(self, **kwargs: Unpack[SearchKwargs]) -> list[dict[str, Any]]:
|
28
|
+
"""TODO"""
|
29
|
+
return await super().search(**make_search_body(**kwargs))
|
30
|
+
|
31
|
+
@distributed_trace_async
|
32
|
+
async def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]]:
|
33
|
+
"""TODO"""
|
34
|
+
return await super().summary(**make_summary_body(**kwargs))
|
@@ -0,0 +1,85 @@
|
|
1
|
+
"""Utilities which are common to the sync and async jobs operator patches."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
__all__ = [
|
6
|
+
"make_search_body",
|
7
|
+
"SearchKwargs",
|
8
|
+
"make_summary_body",
|
9
|
+
"SummaryKwargs",
|
10
|
+
]
|
11
|
+
|
12
|
+
import json
|
13
|
+
from io import BytesIO
|
14
|
+
from typing import Any, IO, TypedDict, Unpack, cast, Literal
|
15
|
+
|
16
|
+
from diracx.core.models import SearchSpec
|
17
|
+
|
18
|
+
|
19
|
+
class ResponseExtra(TypedDict, total=False):
|
20
|
+
content_type: str
|
21
|
+
headers: dict[str, str]
|
22
|
+
params: dict[str, str]
|
23
|
+
cls: Any
|
24
|
+
|
25
|
+
|
26
|
+
class SearchBody(TypedDict, total=False):
|
27
|
+
parameters: list[str] | None
|
28
|
+
search: list[SearchSpec] | None
|
29
|
+
sort: list[str] | None
|
30
|
+
|
31
|
+
|
32
|
+
class SearchExtra(ResponseExtra, total=False):
|
33
|
+
page: int
|
34
|
+
per_page: int
|
35
|
+
|
36
|
+
|
37
|
+
class SearchKwargs(SearchBody, SearchExtra): ...
|
38
|
+
|
39
|
+
|
40
|
+
class UnderlyingSearchArgs(ResponseExtra, total=False):
|
41
|
+
# FIXME: The autorest-generated has a bug that it expected IO[bytes] despite
|
42
|
+
# the code being generated to support IO[bytes] | bytes.
|
43
|
+
body: IO[bytes]
|
44
|
+
|
45
|
+
|
46
|
+
def make_search_body(**kwargs: Unpack[SearchKwargs]) -> UnderlyingSearchArgs:
|
47
|
+
body: SearchBody = {}
|
48
|
+
for key in SearchBody.__optional_keys__:
|
49
|
+
if key not in kwargs:
|
50
|
+
continue
|
51
|
+
key = cast(Literal["parameters", "search", "sort"], key)
|
52
|
+
value = kwargs.pop(key)
|
53
|
+
if value is not None:
|
54
|
+
body[key] = value
|
55
|
+
result: UnderlyingSearchArgs = {"body": BytesIO(json.dumps(body).encode("utf-8"))}
|
56
|
+
result.update(cast(SearchExtra, kwargs))
|
57
|
+
return result
|
58
|
+
|
59
|
+
|
60
|
+
class SummaryBody(TypedDict, total=False):
|
61
|
+
grouping: list[str]
|
62
|
+
search: list[str]
|
63
|
+
|
64
|
+
|
65
|
+
class SummaryKwargs(SummaryBody, ResponseExtra): ...
|
66
|
+
|
67
|
+
|
68
|
+
class UnderlyingSummaryArgs(ResponseExtra, total=False):
|
69
|
+
# FIXME: The autorest-generated has a bug that it expected IO[bytes] despite
|
70
|
+
# the code being generated to support IO[bytes] | bytes.
|
71
|
+
body: IO[bytes]
|
72
|
+
|
73
|
+
|
74
|
+
def make_summary_body(**kwargs: Unpack[SummaryKwargs]) -> UnderlyingSummaryArgs:
|
75
|
+
body: SummaryBody = {}
|
76
|
+
for key in SummaryBody.__optional_keys__:
|
77
|
+
if key not in kwargs:
|
78
|
+
continue
|
79
|
+
key = cast(Literal["grouping", "search"], key)
|
80
|
+
value = kwargs.pop(key)
|
81
|
+
if value is not None:
|
82
|
+
body[key] = value
|
83
|
+
result: UnderlyingSummaryArgs = {"body": BytesIO(json.dumps(body).encode("utf-8"))}
|
84
|
+
result.update(cast(ResponseExtra, kwargs))
|
85
|
+
return result
|
@@ -0,0 +1,34 @@
|
|
1
|
+
"""Patches for the autorest-generated jobs client.
|
2
|
+
|
3
|
+
This file can be used to customize the generated code for the jobs client.
|
4
|
+
When adding new classes to this file, make sure to also add them to the
|
5
|
+
__all__ list in the corresponding file in the patches directory.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from __future__ import annotations
|
9
|
+
|
10
|
+
__all__ = [
|
11
|
+
"JobsOperations",
|
12
|
+
]
|
13
|
+
|
14
|
+
from typing import Any, Unpack
|
15
|
+
|
16
|
+
from azure.core.tracing.decorator import distributed_trace
|
17
|
+
|
18
|
+
from ..._generated.operations._operations import JobsOperations as _JobsOperations
|
19
|
+
from .common import make_search_body, make_summary_body, SearchKwargs, SummaryKwargs
|
20
|
+
|
21
|
+
# We're intentionally ignoring overrides here because we want to change the interface.
|
22
|
+
# mypy: disable-error-code=override
|
23
|
+
|
24
|
+
|
25
|
+
class JobsOperations(_JobsOperations):
|
26
|
+
@distributed_trace
|
27
|
+
def search(self, **kwargs: Unpack[SearchKwargs]) -> list[dict[str, Any]]:
|
28
|
+
"""TODO"""
|
29
|
+
return super().search(**make_search_body(**kwargs))
|
30
|
+
|
31
|
+
@distributed_trace
|
32
|
+
def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]]:
|
33
|
+
"""TODO"""
|
34
|
+
return super().summary(**make_summary_body(**kwargs))
|
diracx/client/py.typed
CHANGED
@@ -1 +0,0 @@
|
|
1
|
-
# Marker file for PEP 561.
|
diracx/client/sync.py
ADDED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: diracx-client
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.1a35
|
4
4
|
Summary: TODO
|
5
5
|
License: GPL-3.0-only
|
6
6
|
Classifier: Intended Audience :: Science/Research
|
@@ -9,11 +9,10 @@ Classifier: Programming Language :: Python :: 3
|
|
9
9
|
Classifier: Topic :: Scientific/Engineering
|
10
10
|
Classifier: Topic :: System :: Distributed Computing
|
11
11
|
Requires-Python: >=3.11
|
12
|
-
Description-Content-Type: text/markdown
|
13
12
|
Requires-Dist: azure-core
|
14
13
|
Requires-Dist: diracx-core
|
15
|
-
Requires-Dist: isodate
|
16
14
|
Requires-Dist: httpx
|
15
|
+
Requires-Dist: isodate
|
17
16
|
Provides-Extra: testing
|
18
|
-
Requires-Dist: diracx-testing; extra ==
|
17
|
+
Requires-Dist: diracx-testing; extra == 'testing'
|
19
18
|
Provides-Extra: types
|