rapidata 2.17.1__py3-none-any.whl → 2.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidata might be problematic. Click here for more details.
- rapidata/api_client/rest.py +143 -169
- rapidata/rapidata_client/order/rapidata_order.py +2 -2
- rapidata/rapidata_client/rapidata_client.py +27 -16
- rapidata/rapidata_client/selection/__init__.py +2 -0
- rapidata/rapidata_client/selection/static_selection.py +22 -0
- rapidata/rapidata_client/validation/validation_set_manager.py +1 -1
- rapidata/service/credential_manager.py +2 -2
- rapidata/service/openapi_service.py +56 -28
- {rapidata-2.17.1.dist-info → rapidata-2.18.0.dist-info}/METADATA +2 -1
- {rapidata-2.17.1.dist-info → rapidata-2.18.0.dist-info}/RECORD +12 -12
- rapidata/service/token_manager.py +0 -176
- {rapidata-2.17.1.dist-info → rapidata-2.18.0.dist-info}/LICENSE +0 -0
- {rapidata-2.17.1.dist-info → rapidata-2.18.0.dist-info}/WHEEL +0 -0
rapidata/api_client/rest.py
CHANGED
|
@@ -11,114 +11,86 @@
|
|
|
11
11
|
Do not edit the class manually.
|
|
12
12
|
""" # noqa: E501
|
|
13
13
|
|
|
14
|
-
|
|
15
14
|
import io
|
|
16
15
|
import json
|
|
17
16
|
import re
|
|
18
|
-
import
|
|
17
|
+
from typing import Dict, Optional
|
|
19
18
|
|
|
20
|
-
import
|
|
19
|
+
import requests
|
|
20
|
+
from authlib.integrations.requests_client import OAuth2Session
|
|
21
|
+
from requests.adapters import HTTPAdapter
|
|
22
|
+
from urllib3 import Retry
|
|
21
23
|
|
|
22
24
|
from rapidata.api_client.exceptions import ApiException, ApiValueError
|
|
23
25
|
|
|
24
|
-
SUPPORTED_SOCKS_PROXIES = {"socks5", "socks5h", "socks4", "socks4a"}
|
|
25
|
-
RESTResponseType = urllib3.HTTPResponse
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def is_socks_proxy_url(url):
|
|
29
|
-
if url is None:
|
|
30
|
-
return False
|
|
31
|
-
split_section = url.split("://")
|
|
32
|
-
if len(split_section) < 2:
|
|
33
|
-
return False
|
|
34
|
-
else:
|
|
35
|
-
return split_section[0].lower() in SUPPORTED_SOCKS_PROXIES
|
|
36
|
-
|
|
37
26
|
|
|
38
27
|
class RESTResponse(io.IOBase):
|
|
39
28
|
|
|
40
|
-
def __init__(self, resp) -> None:
|
|
29
|
+
def __init__(self, resp: requests.Response) -> None:
|
|
41
30
|
self.response = resp
|
|
42
|
-
self.status = resp.
|
|
31
|
+
self.status = resp.status_code
|
|
43
32
|
self.reason = resp.reason
|
|
44
33
|
self.data = None
|
|
45
34
|
|
|
46
35
|
def read(self):
|
|
47
36
|
if self.data is None:
|
|
48
|
-
self.data = self.response.
|
|
37
|
+
self.data = self.response.content
|
|
49
38
|
return self.data
|
|
50
39
|
|
|
51
|
-
def getheaders(self):
|
|
40
|
+
def getheaders(self) -> Dict[str, str]:
|
|
52
41
|
"""Returns a dictionary of the response headers."""
|
|
53
|
-
return self.response.headers
|
|
42
|
+
return dict(self.response.headers)
|
|
54
43
|
|
|
55
|
-
def getheader(self, name, default=None):
|
|
44
|
+
def getheader(self, name, default=None) -> Optional[str]:
|
|
56
45
|
"""Returns a given response header."""
|
|
57
46
|
return self.response.headers.get(name, default)
|
|
58
47
|
|
|
59
48
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
def __init__(self, configuration) -> None:
|
|
63
|
-
# urllib3.PoolManager will pass all kw parameters to connectionpool
|
|
64
|
-
# https://github.com/shazow/urllib3/blob/f9409436f83aeb79fbaf090181cd81b784f1b8ce/urllib3/poolmanager.py#L75 # noqa: E501
|
|
65
|
-
# https://github.com/shazow/urllib3/blob/f9409436f83aeb79fbaf090181cd81b784f1b8ce/urllib3/connectionpool.py#L680 # noqa: E501
|
|
66
|
-
# Custom SSL certificates and client certificates: http://urllib3.readthedocs.io/en/latest/advanced-usage.html # noqa: E501
|
|
67
|
-
|
|
68
|
-
# cert_reqs
|
|
69
|
-
if configuration.verify_ssl:
|
|
70
|
-
cert_reqs = ssl.CERT_REQUIRED
|
|
71
|
-
else:
|
|
72
|
-
cert_reqs = ssl.CERT_NONE
|
|
73
|
-
|
|
74
|
-
pool_args = {
|
|
75
|
-
"cert_reqs": cert_reqs,
|
|
76
|
-
"ca_certs": configuration.ssl_ca_cert,
|
|
77
|
-
"cert_file": configuration.cert_file,
|
|
78
|
-
"key_file": configuration.key_file,
|
|
79
|
-
}
|
|
80
|
-
if configuration.assert_hostname is not None:
|
|
81
|
-
pool_args['assert_hostname'] = (
|
|
82
|
-
configuration.assert_hostname
|
|
83
|
-
)
|
|
49
|
+
RESTResponseType = RESTResponse
|
|
84
50
|
|
|
85
|
-
if configuration.retries is not None:
|
|
86
|
-
pool_args['retries'] = configuration.retries
|
|
87
51
|
|
|
88
|
-
|
|
89
|
-
pool_args['server_hostname'] = configuration.tls_server_name
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
if configuration.socket_options is not None:
|
|
93
|
-
pool_args['socket_options'] = configuration.socket_options
|
|
52
|
+
class RESTClientObject:
|
|
94
53
|
|
|
95
|
-
|
|
96
|
-
|
|
54
|
+
def __init__(self, configuration) -> None:
|
|
55
|
+
self.configuration = configuration
|
|
97
56
|
|
|
98
|
-
|
|
99
|
-
self.pool_manager: urllib3.PoolManager
|
|
57
|
+
self.session: Optional[OAuth2Session] = None
|
|
100
58
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
59
|
+
def setup_oauth_client_credentials(
|
|
60
|
+
self, client_id: str, client_secret: str, token_endpoint: str, scope: str
|
|
61
|
+
):
|
|
62
|
+
self.session = OAuth2Session(
|
|
63
|
+
client_id=client_id,
|
|
64
|
+
client_secret=client_secret,
|
|
65
|
+
token_endpoint=token_endpoint,
|
|
66
|
+
scope=scope,
|
|
67
|
+
)
|
|
68
|
+
self._configure_session_defaults()
|
|
69
|
+
self.session.fetch_token()
|
|
70
|
+
|
|
71
|
+
def setup_oauth_with_token(self,
|
|
72
|
+
client_id: str | None,
|
|
73
|
+
client_secret: str | None,
|
|
74
|
+
token: dict,
|
|
75
|
+
token_endpoint: str,
|
|
76
|
+
leeway: int = 60):
|
|
77
|
+
self.session = OAuth2Session(
|
|
78
|
+
token=token,
|
|
79
|
+
token_endpoint=token_endpoint,
|
|
80
|
+
client_id=client_id,
|
|
81
|
+
client_secret=client_secret,
|
|
82
|
+
leeway=leeway,
|
|
83
|
+
)
|
|
84
|
+
self._configure_session_defaults()
|
|
113
85
|
|
|
114
86
|
def request(
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
87
|
+
self,
|
|
88
|
+
method,
|
|
89
|
+
url,
|
|
90
|
+
headers=None,
|
|
91
|
+
body=None,
|
|
92
|
+
post_params=None,
|
|
93
|
+
_request_timeout=None,
|
|
122
94
|
):
|
|
123
95
|
"""Perform requests.
|
|
124
96
|
|
|
@@ -135,15 +107,7 @@ class RESTClientObject:
|
|
|
135
107
|
(connection, read) timeouts.
|
|
136
108
|
"""
|
|
137
109
|
method = method.upper()
|
|
138
|
-
assert method in [
|
|
139
|
-
'GET',
|
|
140
|
-
'HEAD',
|
|
141
|
-
'DELETE',
|
|
142
|
-
'POST',
|
|
143
|
-
'PUT',
|
|
144
|
-
'PATCH',
|
|
145
|
-
'OPTIONS'
|
|
146
|
-
]
|
|
110
|
+
assert method in ["GET", "HEAD", "DELETE", "POST", "PUT", "PATCH", "OPTIONS"]
|
|
147
111
|
|
|
148
112
|
if post_params and body:
|
|
149
113
|
raise ApiValueError(
|
|
@@ -153,105 +117,115 @@ class RESTClientObject:
|
|
|
153
117
|
post_params = post_params or {}
|
|
154
118
|
headers = headers or {}
|
|
155
119
|
|
|
120
|
+
if not self.session:
|
|
121
|
+
raise ApiValueError(
|
|
122
|
+
"OAuth2 session is not initialized. Please initialize it before making requests."
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
session = self.session
|
|
126
|
+
|
|
156
127
|
timeout = None
|
|
157
128
|
if _request_timeout:
|
|
158
129
|
if isinstance(_request_timeout, (int, float)):
|
|
159
|
-
timeout =
|
|
160
|
-
elif (
|
|
161
|
-
|
|
162
|
-
and len(_request_timeout) == 2
|
|
163
|
-
):
|
|
164
|
-
timeout = urllib3.Timeout(
|
|
165
|
-
connect=_request_timeout[0],
|
|
166
|
-
read=_request_timeout[1]
|
|
167
|
-
)
|
|
130
|
+
timeout = _request_timeout
|
|
131
|
+
elif isinstance(_request_timeout, tuple) and len(_request_timeout) == 2:
|
|
132
|
+
timeout = _request_timeout
|
|
168
133
|
|
|
169
134
|
try:
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
content_type = headers.get('Content-Type')
|
|
175
|
-
if (
|
|
176
|
-
not content_type
|
|
177
|
-
or re.search('json', content_type, re.IGNORECASE)
|
|
178
|
-
):
|
|
135
|
+
if method in ["POST", "PUT", "PATCH", "OPTIONS", "DELETE"]:
|
|
136
|
+
content_type = headers.get("Content-Type")
|
|
137
|
+
|
|
138
|
+
if not content_type or re.search("json", content_type, re.IGNORECASE):
|
|
179
139
|
request_body = None
|
|
180
140
|
if body is not None:
|
|
181
141
|
request_body = json.dumps(body)
|
|
182
|
-
r =
|
|
183
|
-
|
|
184
|
-
url,
|
|
185
|
-
body=request_body,
|
|
186
|
-
timeout=timeout,
|
|
187
|
-
headers=headers,
|
|
188
|
-
preload_content=False
|
|
189
|
-
)
|
|
142
|
+
r = session.request(method, url, data=request_body, timeout=timeout, headers=headers)
|
|
143
|
+
|
|
190
144
|
elif content_type == 'application/x-www-form-urlencoded':
|
|
191
|
-
r =
|
|
192
|
-
|
|
193
|
-
url,
|
|
194
|
-
fields=post_params,
|
|
195
|
-
encode_multipart=False,
|
|
196
|
-
timeout=timeout,
|
|
197
|
-
headers=headers,
|
|
198
|
-
preload_content=False
|
|
199
|
-
)
|
|
145
|
+
r = session.request(method, url, data=post_params, timeout=timeout, headers=headers)
|
|
146
|
+
|
|
200
147
|
elif content_type == 'multipart/form-data':
|
|
201
|
-
# must del headers['Content-Type'], or the correct
|
|
202
|
-
# Content-Type which generated by urllib3 will be
|
|
203
|
-
# overwritten.
|
|
204
148
|
del headers['Content-Type']
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
149
|
+
files = []
|
|
150
|
+
data = {}
|
|
151
|
+
|
|
152
|
+
for key, value in post_params:
|
|
153
|
+
if isinstance(value, tuple) and len(value) >= 2:
|
|
154
|
+
# This is a file tuple (filename, file_data, [content_type])
|
|
155
|
+
filename, file_data = value[0], value[1]
|
|
156
|
+
content_type = value[2] if len(value) > 2 else None
|
|
157
|
+
files.append((key, (filename, file_data, content_type)))
|
|
158
|
+
elif isinstance(value, dict):
|
|
159
|
+
# JSON-serialize dictionary values
|
|
160
|
+
if key in data:
|
|
161
|
+
# If we already have this key, handle as needed
|
|
162
|
+
# (convert to list or append to existing list)
|
|
163
|
+
if not isinstance(data[key], list):
|
|
164
|
+
data[key] = [data[key]]
|
|
165
|
+
data[key].append(json.dumps(value))
|
|
166
|
+
else:
|
|
167
|
+
data[key] = json.dumps(value)
|
|
168
|
+
else:
|
|
169
|
+
# Regular form data
|
|
170
|
+
if key in data:
|
|
171
|
+
if not isinstance(data[key], list):
|
|
172
|
+
data[key] = [data[key]]
|
|
173
|
+
data[key].append(value)
|
|
174
|
+
else:
|
|
175
|
+
data[key] = value
|
|
176
|
+
r = session.request(method, url, files=files, data=data, timeout=timeout, headers=headers)
|
|
177
|
+
|
|
219
178
|
elif isinstance(body, str) or isinstance(body, bytes):
|
|
220
|
-
r =
|
|
221
|
-
|
|
222
|
-
url,
|
|
223
|
-
body=body,
|
|
224
|
-
timeout=timeout,
|
|
225
|
-
headers=headers,
|
|
226
|
-
preload_content=False
|
|
227
|
-
)
|
|
179
|
+
r = session.request(method, url, data=body, timeout=timeout, headers=headers)
|
|
180
|
+
|
|
228
181
|
elif headers['Content-Type'].startswith('text/') and isinstance(body, bool):
|
|
229
|
-
request_body =
|
|
230
|
-
r =
|
|
231
|
-
|
|
232
|
-
url,
|
|
233
|
-
body=request_body,
|
|
234
|
-
preload_content=False,
|
|
235
|
-
timeout=timeout,
|
|
236
|
-
headers=headers)
|
|
182
|
+
request_body = 'true' if body else 'false'
|
|
183
|
+
r = session.request(method, url, data=request_body, timeout=timeout, headers=headers)
|
|
184
|
+
|
|
237
185
|
else:
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
arguments. Please check that your arguments match
|
|
241
|
-
declared content type."""
|
|
186
|
+
msg = '''Cannot prepare a request message for provided arguments.
|
|
187
|
+
Please check that your arguments match declared content type.'''
|
|
242
188
|
raise ApiException(status=0, reason=msg)
|
|
243
|
-
|
|
189
|
+
|
|
244
190
|
else:
|
|
245
|
-
r =
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
timeout=timeout,
|
|
250
|
-
headers=headers,
|
|
251
|
-
preload_content=False
|
|
252
|
-
)
|
|
253
|
-
except urllib3.exceptions.SSLError as e:
|
|
254
|
-
msg = "\n".join([type(e).__name__, str(e)])
|
|
191
|
+
r = session.request(method, url, params={}, timeout=timeout, headers=headers)
|
|
192
|
+
|
|
193
|
+
except requests.exceptions.SSLError as e:
|
|
194
|
+
msg = '\n'.join([type(e).__name__, str(e)])
|
|
255
195
|
raise ApiException(status=0, reason=msg)
|
|
256
196
|
|
|
257
197
|
return RESTResponse(r)
|
|
198
|
+
|
|
199
|
+
def _configure_session_defaults(self):
|
|
200
|
+
self.session.verify = (
|
|
201
|
+
self.configuration.ssl_ca_cert
|
|
202
|
+
if self.configuration.ssl_ca_cert
|
|
203
|
+
else self.configuration.verify_ssl
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
if self.configuration.cert_file and self.configuration.key_file:
|
|
207
|
+
self.session.cert = (
|
|
208
|
+
self.configuration.cert_file,
|
|
209
|
+
self.configuration.key_file,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
if self.configuration.retries is not None:
|
|
213
|
+
retry = Retry(
|
|
214
|
+
total=self.configuration.retries,
|
|
215
|
+
backoff_factor=0.3,
|
|
216
|
+
status_forcelist=[429, 500, 502, 503, 504],
|
|
217
|
+
)
|
|
218
|
+
adapter = HTTPAdapter(max_retries=retry)
|
|
219
|
+
self.session.mount("https://", adapter)
|
|
220
|
+
# noinspection HttpUrlsUsage
|
|
221
|
+
self.session.mount("http://", adapter)
|
|
222
|
+
|
|
223
|
+
if self.configuration.proxy:
|
|
224
|
+
self.session.proxies = {
|
|
225
|
+
"http": self.configuration.proxy,
|
|
226
|
+
"https": self.configuration.proxy,
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
if self.configuration.proxy_headers:
|
|
230
|
+
for key, value in self.configuration.proxy_headers.items():
|
|
231
|
+
self.session.headers[key] = value
|
|
@@ -51,7 +51,7 @@ class RapidataOrder:
|
|
|
51
51
|
"""Runs the order to start collecting responses."""
|
|
52
52
|
self.__openapi_service.order_api.order_submit_post(self.order_id)
|
|
53
53
|
if print_link:
|
|
54
|
-
print(f"Order '{self.name}' is now viewable under: https://app.{self.__openapi_service.
|
|
54
|
+
print(f"Order '{self.name}' is now viewable under: https://app.{self.__openapi_service.environment}/order/detail/{self.order_id}")
|
|
55
55
|
return self
|
|
56
56
|
|
|
57
57
|
def pause(self) -> None:
|
|
@@ -167,7 +167,7 @@ class RapidataOrder:
|
|
|
167
167
|
Exception: If the order is not in processing state.
|
|
168
168
|
"""
|
|
169
169
|
campaign_id = self.__get_campaign_id()
|
|
170
|
-
auth_url = f"https://app.{self.__openapi_service.
|
|
170
|
+
auth_url = f"https://app.{self.__openapi_service.environment}/order/detail/{self.order_id}/preview?campaignId={campaign_id}"
|
|
171
171
|
could_open_browser = webbrowser.open(auth_url)
|
|
172
172
|
if not could_open_browser:
|
|
173
173
|
encoded_url = urllib.parse.quote(auth_url, safe="%/:=&?~#+!$,;'@()*[]")
|
|
@@ -2,28 +2,37 @@ from rapidata.service.openapi_service import OpenAPIService
|
|
|
2
2
|
|
|
3
3
|
from rapidata.rapidata_client.order.rapidata_order_manager import RapidataOrderManager
|
|
4
4
|
|
|
5
|
-
from rapidata.rapidata_client.validation.validation_set_manager import
|
|
5
|
+
from rapidata.rapidata_client.validation.validation_set_manager import (
|
|
6
|
+
ValidationSetManager,
|
|
7
|
+
)
|
|
6
8
|
|
|
7
9
|
from rapidata.rapidata_client.demographic.demographic_manager import DemographicManager
|
|
8
10
|
|
|
11
|
+
|
|
9
12
|
class RapidataClient:
|
|
10
13
|
"""The Rapidata client is the main entry point for interacting with the Rapidata API. It allows you to create orders and validation sets."""
|
|
11
|
-
|
|
14
|
+
|
|
12
15
|
def __init__(
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
16
|
+
self,
|
|
17
|
+
client_id: str | None = None,
|
|
18
|
+
client_secret: str | None = None,
|
|
19
|
+
environment: str = "rapidata.ai",
|
|
20
|
+
oauth_scope: str = "openid",
|
|
21
|
+
cert_path: str | None = None,
|
|
22
|
+
token: dict | None = None,
|
|
23
|
+
leeway: int = 60
|
|
19
24
|
):
|
|
20
|
-
"""Initialize the RapidataClient. If both the client_id and client_secret are None, it will try using your credentials under "~/.config/rapidata/credentials.json".
|
|
21
|
-
If this is not successful, it will open a browser
|
|
25
|
+
"""Initialize the RapidataClient. If both the client_id and client_secret are None, it will try using your credentials under "~/.config/rapidata/credentials.json".
|
|
26
|
+
If this is not successful, it will open a browser window and ask you to log in, then save your new credentials in said json file.
|
|
22
27
|
|
|
23
28
|
Args:
|
|
24
29
|
client_id (str): The client ID for authentication.
|
|
25
30
|
client_secret (str): The client secret for authentication.
|
|
26
|
-
|
|
31
|
+
environment (str, optional): The API endpoint.
|
|
32
|
+
oauth_scope (str, optional): The scopes to use for authentication. In general this does not need to be changed.
|
|
33
|
+
cert_path (str, optional): An optional path to a certificate file useful for development.
|
|
34
|
+
token (dict, optional): If you already have a token that the client should use for authentication. Important, if set, this needs to be the complete token object containing the access token, token type and expiration time.
|
|
35
|
+
leeway (int, optional): An optional leeway to use to determine if a token is expired. Defaults to 60 seconds.
|
|
27
36
|
|
|
28
37
|
Attributes:
|
|
29
38
|
order (RapidataOrderManager): The RapidataOrderManager instance.
|
|
@@ -32,17 +41,19 @@ class RapidataClient:
|
|
|
32
41
|
self._openapi_service = OpenAPIService(
|
|
33
42
|
client_id=client_id,
|
|
34
43
|
client_secret=client_secret,
|
|
35
|
-
|
|
44
|
+
environment=environment,
|
|
36
45
|
oauth_scope=oauth_scope,
|
|
37
|
-
cert_path=cert_path
|
|
46
|
+
cert_path=cert_path,
|
|
47
|
+
token=token,
|
|
48
|
+
leeway=leeway,
|
|
38
49
|
)
|
|
39
|
-
|
|
50
|
+
|
|
40
51
|
self.order = RapidataOrderManager(openapi_service=self._openapi_service)
|
|
41
|
-
|
|
52
|
+
|
|
42
53
|
self.validation = ValidationSetManager(openapi_service=self._openapi_service)
|
|
43
54
|
|
|
44
55
|
self._demographic = DemographicManager(openapi_service=self._openapi_service)
|
|
45
56
|
|
|
46
57
|
def reset_credentials(self):
|
|
47
|
-
"""Reset the credentials saved in the configuration file for the current
|
|
58
|
+
"""Reset the credentials saved in the configuration file for the current environment."""
|
|
48
59
|
self._openapi_service.reset_credentials()
|
|
@@ -5,3 +5,5 @@ from .validation_selection import ValidationSelection
|
|
|
5
5
|
from .conditional_validation_selection import ConditionalValidationSelection
|
|
6
6
|
from .capped_selection import CappedSelection
|
|
7
7
|
from .shuffling_selection import ShufflingSelection
|
|
8
|
+
from .ab_test_selection import AbTestSelection
|
|
9
|
+
from .static_selection import StaticSelection
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
|
|
2
|
+
from rapidata.api_client.models.static_selection import StaticSelection as StaticSelectionModel
|
|
3
|
+
from rapidata.rapidata_client.selection._base_selection import RapidataSelection
|
|
4
|
+
|
|
5
|
+
class StaticSelection(RapidataSelection):
|
|
6
|
+
"""StaticSelection Class
|
|
7
|
+
|
|
8
|
+
Given a list of RapidIds, theses specific rapids will be shown in order for every session.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
selections (list[str]): List of rapid ids to show.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, rapid_ids: list[str]):
|
|
15
|
+
self.rapid_ids = rapid_ids
|
|
16
|
+
|
|
17
|
+
def _to_model(self) -> StaticSelectionModel:
|
|
18
|
+
return StaticSelectionModel(
|
|
19
|
+
_t="StaticSelection",
|
|
20
|
+
rapidIds=self.rapid_ids
|
|
21
|
+
)
|
|
22
|
+
|
|
@@ -468,7 +468,7 @@ class ValidationSetManager:
|
|
|
468
468
|
if print_confirmation:
|
|
469
469
|
print()
|
|
470
470
|
print(f"Validation set '{name}' created with ID {validation_set_id}\n",
|
|
471
|
-
f"Now viewable under: https://app.{self.__openapi_service.
|
|
471
|
+
f"Now viewable under: https://app.{self.__openapi_service.environment}/validation-set/detail/{validation_set_id}",
|
|
472
472
|
sep="")
|
|
473
473
|
|
|
474
474
|
if dimensions:
|
|
@@ -133,9 +133,9 @@ class CredentialManager:
|
|
|
133
133
|
return credential
|
|
134
134
|
|
|
135
135
|
return self._create_new_credentials()
|
|
136
|
-
|
|
136
|
+
|
|
137
137
|
def reset_credentials(self) -> None:
|
|
138
|
-
"""Reset the stored credentials for current
|
|
138
|
+
"""Reset the stored credentials for current environment."""
|
|
139
139
|
credentials = self._read_credentials()
|
|
140
140
|
if self.endpoint in credentials:
|
|
141
141
|
del credentials[self.endpoint]
|
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
import subprocess
|
|
2
|
+
from importlib.metadata import version, PackageNotFoundError
|
|
3
|
+
|
|
1
4
|
from rapidata.api_client.api.campaign_api import CampaignApi
|
|
2
5
|
from rapidata.api_client.api.dataset_api import DatasetApi
|
|
3
6
|
from rapidata.api_client.api.order_api import OrderApi
|
|
@@ -7,31 +10,31 @@ from rapidata.api_client.api.validation_api import ValidationApi
|
|
|
7
10
|
from rapidata.api_client.api.workflow_api import WorkflowApi
|
|
8
11
|
from rapidata.api_client.api_client import ApiClient
|
|
9
12
|
from rapidata.api_client.configuration import Configuration
|
|
10
|
-
from rapidata.service.token_manager import TokenManager, TokenInfo
|
|
11
13
|
from rapidata.service.credential_manager import CredentialManager
|
|
12
14
|
|
|
13
|
-
from importlib.metadata import version, PackageNotFoundError
|
|
14
|
-
|
|
15
15
|
|
|
16
16
|
class OpenAPIService:
|
|
17
17
|
def __init__(
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
18
|
+
self,
|
|
19
|
+
client_id: str | None,
|
|
20
|
+
client_secret: str | None,
|
|
21
|
+
environment: str,
|
|
22
|
+
oauth_scope: str,
|
|
23
|
+
cert_path: str | None = None,
|
|
24
|
+
token: dict | None = None,
|
|
25
|
+
leeway: int = 60,
|
|
24
26
|
):
|
|
25
|
-
self.
|
|
26
|
-
endpoint = f"https://api.{
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
cert_path=cert_path
|
|
27
|
+
self.environment = environment
|
|
28
|
+
endpoint = f"https://api.{environment}"
|
|
29
|
+
auth_endpoint = f"https://auth.{environment}"
|
|
30
|
+
|
|
31
|
+
if environment == "rapidata.dev" and not cert_path:
|
|
32
|
+
cert_path = _get_local_certificate()
|
|
33
|
+
|
|
34
|
+
self.credential_manager = CredentialManager(
|
|
35
|
+
endpoint=auth_endpoint, cert_path=cert_path
|
|
34
36
|
)
|
|
37
|
+
|
|
35
38
|
client_configuration = Configuration(host=endpoint, ssl_ca_cert=cert_path)
|
|
36
39
|
self.api_client = ApiClient(
|
|
37
40
|
configuration=client_configuration,
|
|
@@ -39,16 +42,32 @@ class OpenAPIService:
|
|
|
39
42
|
header_value=f"RapidataPythonSDK/{self._get_rapidata_package_version()}",
|
|
40
43
|
)
|
|
41
44
|
|
|
42
|
-
|
|
43
|
-
|
|
45
|
+
if token:
|
|
46
|
+
self.api_client.rest_client.setup_oauth_with_token(
|
|
47
|
+
token=token,
|
|
48
|
+
token_endpoint=f"{auth_endpoint}/connect/token",
|
|
49
|
+
client_id=client_id,
|
|
50
|
+
client_secret=client_secret,
|
|
51
|
+
leeway=leeway,
|
|
52
|
+
)
|
|
53
|
+
return
|
|
54
|
+
|
|
55
|
+
if not client_id or not client_secret:
|
|
56
|
+
credentials = self.credential_manager.get_client_credentials()
|
|
57
|
+
if not credentials:
|
|
58
|
+
raise ValueError("Failed to fetch client credentials")
|
|
59
|
+
client_id = credentials.client_id
|
|
60
|
+
client_secret = credentials.client_secret
|
|
61
|
+
|
|
62
|
+
self.api_client.rest_client.setup_oauth_client_credentials(
|
|
63
|
+
client_id=client_id,
|
|
64
|
+
client_secret=client_secret,
|
|
65
|
+
token_endpoint=f"{auth_endpoint}/connect/token",
|
|
66
|
+
scope=oauth_scope,
|
|
44
67
|
)
|
|
45
68
|
|
|
46
|
-
self._cert_path = cert_path
|
|
47
|
-
|
|
48
|
-
token_manager.start_token_refresh(token_callback=self._set_token)
|
|
49
|
-
|
|
50
69
|
def reset_credentials(self):
|
|
51
|
-
|
|
70
|
+
self.credential_manager.reset_credentials()
|
|
52
71
|
|
|
53
72
|
@property
|
|
54
73
|
def order_api(self) -> OrderApi:
|
|
@@ -78,9 +97,6 @@ class OpenAPIService:
|
|
|
78
97
|
def workflow_api(self) -> WorkflowApi:
|
|
79
98
|
return WorkflowApi(self.api_client)
|
|
80
99
|
|
|
81
|
-
def _set_token(self, token: TokenInfo):
|
|
82
|
-
self.api_client.configuration.api_key["bearer"] = f"Bearer {token.access_token}"
|
|
83
|
-
|
|
84
100
|
def _get_rapidata_package_version(self):
|
|
85
101
|
"""
|
|
86
102
|
Returns the version of the currently installed rapidata package.
|
|
@@ -93,3 +109,15 @@ class OpenAPIService:
|
|
|
93
109
|
return version("rapidata")
|
|
94
110
|
except PackageNotFoundError:
|
|
95
111
|
return None
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _get_local_certificate() -> str | None:
|
|
115
|
+
result = subprocess.run(["mkcert", "-CAROOT"], capture_output=True)
|
|
116
|
+
if result.returncode != 0:
|
|
117
|
+
return None
|
|
118
|
+
|
|
119
|
+
output = result.stdout.decode("utf-8").strip()
|
|
120
|
+
if not output:
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
return f"{output}/rootCA.pem"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: rapidata
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.18.0
|
|
4
4
|
Summary: Rapidata package containing the Rapidata Python Client to interact with the Rapidata Web API in an easy way.
|
|
5
5
|
License: Apache-2.0
|
|
6
6
|
Author: Rapidata AG
|
|
@@ -12,6 +12,7 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.12
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.13
|
|
15
|
+
Requires-Dist: authlib (>=1.5.1,<2.0.0)
|
|
15
16
|
Requires-Dist: colorama (==0.4.6)
|
|
16
17
|
Requires-Dist: deprecated (>=1.2.14,<2.0.0)
|
|
17
18
|
Requires-Dist: pandas (>=2.2.3,<3.0.0)
|
|
@@ -403,7 +403,7 @@ rapidata/api_client/models/workflow_labeling_step_model.py,sha256=iXeIb78bdMhGFj
|
|
|
403
403
|
rapidata/api_client/models/workflow_split_model.py,sha256=zthOSaUl8dbLhLymLK_lrPTBpeV1a4cODLxnHmNCAZw,4474
|
|
404
404
|
rapidata/api_client/models/workflow_split_model_filter_configs_inner.py,sha256=1Fx9uZtztiiAdMXkj7YeCqt7o6VkG9lKf7D7UP_h088,7447
|
|
405
405
|
rapidata/api_client/models/workflow_state.py,sha256=5LAK1se76RCoozeVB6oxMPb8p_5bhLZJqn7q5fFQWis,850
|
|
406
|
-
rapidata/api_client/rest.py,sha256=
|
|
406
|
+
rapidata/api_client/rest.py,sha256=Nnn1XE9sVUprPm_6AsUmetb_bd9dMjynDOob6y8NJNE,8775
|
|
407
407
|
rapidata/api_client_README.md,sha256=97mR2UeWNIhqNxgUAz87zH8kM9RsbrColX_FvIy_rYo,53784
|
|
408
408
|
rapidata/rapidata_client/__init__.py,sha256=yU0cRoX-RmOHQv0Qj3yJpHaDET4DHZWSO6w2cAApQhQ,920
|
|
409
409
|
rapidata/rapidata_client/assets/__init__.py,sha256=hKgrOSn8gJcBSULaf4auYhH1S1N5AfcwIhBSq1BOKwQ,323
|
|
@@ -438,15 +438,15 @@ rapidata/rapidata_client/metadata/_select_words_metadata.py,sha256=-MK5yQDi_G3BK
|
|
|
438
438
|
rapidata/rapidata_client/order/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
439
439
|
rapidata/rapidata_client/order/_rapidata_dataset.py,sha256=KyEDb4NvlY4uzN3GjiX331VSGOq_I6FhvTnyKrpo4Sg,19400
|
|
440
440
|
rapidata/rapidata_client/order/_rapidata_order_builder.py,sha256=De-gNvOnuXHz6QaRdNpr0_3zOkIEKu8hgtZkatzxZQ4,13155
|
|
441
|
-
rapidata/rapidata_client/order/rapidata_order.py,sha256=
|
|
441
|
+
rapidata/rapidata_client/order/rapidata_order.py,sha256=cz1bXkoS7y_MuJDmjdXao7nbCFd8UF4bb5SHZtOYDjI,10483
|
|
442
442
|
rapidata/rapidata_client/order/rapidata_order_manager.py,sha256=HQzZAnwKnwWl2wLcZPqxQRnth1xWhVoGNkMkzzYbjqA,30615
|
|
443
443
|
rapidata/rapidata_client/order/rapidata_results.py,sha256=0y8EOiqUV7XuwpRJyV53mfo-ddRV1cUPdWexbPEVOHM,8044
|
|
444
|
-
rapidata/rapidata_client/rapidata_client.py,sha256=
|
|
444
|
+
rapidata/rapidata_client/rapidata_client.py,sha256=fcWlmmNCZahK40Ox4aY153tEihIpwkUxYTuiypKF2SY,2857
|
|
445
445
|
rapidata/rapidata_client/referee/__init__.py,sha256=q0Hv9nmfEpyChejtyMLT8hWKL0vTTf_UgUXPYNJ-H6M,153
|
|
446
446
|
rapidata/rapidata_client/referee/_base_referee.py,sha256=MdFOhdxt3sRnWXLDKLJZKFdVpjBGn9jypPnWWQ6msQA,496
|
|
447
447
|
rapidata/rapidata_client/referee/_early_stopping_referee.py,sha256=ULbokQZ91wc9D_20qHUhe55D28D9eTY1J1cMp_-oIDc,2088
|
|
448
448
|
rapidata/rapidata_client/referee/_naive_referee.py,sha256=PVR8uy8hfRjr2DBzdOFyvou6S3swNc-4UvgjhO-09TU,1209
|
|
449
|
-
rapidata/rapidata_client/selection/__init__.py,sha256=
|
|
449
|
+
rapidata/rapidata_client/selection/__init__.py,sha256=OimK44ig39A3kHCR_JGNO4FiUYJ6JUY0ZT0J8dz32Rs,475
|
|
450
450
|
rapidata/rapidata_client/selection/_base_selection.py,sha256=tInbWOgxT_4CHkr5QHoG55ZcUi1ZmfcEGIwLKKCnN20,147
|
|
451
451
|
rapidata/rapidata_client/selection/ab_test_selection.py,sha256=fymubkVMawqJmYp9FKzWXTki9tgBgoj3cOP8rG9oOd0,1284
|
|
452
452
|
rapidata/rapidata_client/selection/capped_selection.py,sha256=iWhbM1LcayhgFm7oKADXCaKHGdiQIupI0jbYuuEVM2A,1184
|
|
@@ -455,6 +455,7 @@ rapidata/rapidata_client/selection/demographic_selection.py,sha256=l4vnNbzlf9ED6
|
|
|
455
455
|
rapidata/rapidata_client/selection/labeling_selection.py,sha256=v26QogjmraFfRoSIgWZl6NMIW_TqbGeuCI2p4HxCeOM,657
|
|
456
456
|
rapidata/rapidata_client/selection/rapidata_selections.py,sha256=Azh0ntBZp9EQNL19imIItotQ8QW3B1gEs5YmuTvUn6U,1526
|
|
457
457
|
rapidata/rapidata_client/selection/shuffling_selection.py,sha256=FzOp7mnBLxNzM5at_-935wd77IHyWnFR1f8uqokiMOg,1201
|
|
458
|
+
rapidata/rapidata_client/selection/static_selection.py,sha256=POhVLjzHcUIuU_GCvRxMuCb27m7CkLxaQPwgf20Xo9o,681
|
|
458
459
|
rapidata/rapidata_client/selection/validation_selection.py,sha256=sedeIa8lpXVXKtFJA9IDeRvo9A1Ne4ZGcepaWDUGhCU,851
|
|
459
460
|
rapidata/rapidata_client/settings/__init__.py,sha256=DTEAT_YykwodZJXqKYOtWRwimLCA-Jxo0F0d-H6A3vM,458
|
|
460
461
|
rapidata/rapidata_client/settings/_rapidata_setting.py,sha256=MD5JhhogSLLrjFKjvL3JhMszOMCygyqLF-st0EwMSkw,352
|
|
@@ -473,7 +474,7 @@ rapidata/rapidata_client/validation/rapids/__init__.py,sha256=WU5PPwtTJlte6U90MD
|
|
|
473
474
|
rapidata/rapidata_client/validation/rapids/box.py,sha256=t3_Kn6doKXdnJdtbwefXnYKPiTKHneJl9E2inkDSqL8,589
|
|
474
475
|
rapidata/rapidata_client/validation/rapids/rapids.py,sha256=uCKnoSn1RykNHgTFbrvCFlfzU8lF42cff-2I-Pd48w0,4620
|
|
475
476
|
rapidata/rapidata_client/validation/rapids/rapids_manager.py,sha256=F00lPYBUx5fTPRw50iZuobtdbjFo6ZHevPMk101JdaY,14271
|
|
476
|
-
rapidata/rapidata_client/validation/validation_set_manager.py,sha256=
|
|
477
|
+
rapidata/rapidata_client/validation/validation_set_manager.py,sha256=Ecm17xf4SFsKdWAGNPdjecuAqsEJ27cHPdns7f8GXhs,26699
|
|
477
478
|
rapidata/rapidata_client/workflow/__init__.py,sha256=7nXcY91xkxjHudBc9H0fP35eBBtgwHGWTQKbb-M4h7Y,477
|
|
478
479
|
rapidata/rapidata_client/workflow/_base_workflow.py,sha256=XyIZFKS_RxAuwIHS848S3AyLEHqd07oTD_5jm2oUbsw,762
|
|
479
480
|
rapidata/rapidata_client/workflow/_classify_workflow.py,sha256=9bT54wxVJgxC-zLk6MVNbseFpzYrvFPjt7DHvxqYfnk,1736
|
|
@@ -486,11 +487,10 @@ rapidata/rapidata_client/workflow/_ranking_workflow.py,sha256=XBIifokhu9Kaqo0qSw
|
|
|
486
487
|
rapidata/rapidata_client/workflow/_select_words_workflow.py,sha256=juTW4TPnnSeBWP3K2QfO092t0u1W8I1ksY6aAtPZOi0,1225
|
|
487
488
|
rapidata/rapidata_client/workflow/_timestamp_workflow.py,sha256=tPi2zu1-SlE_ppbGbMz6MM_2LUSWxM-GA0CZRlB0qFo,1176
|
|
488
489
|
rapidata/service/__init__.py,sha256=s9bS1AJZaWIhLtJX_ZA40_CK39rAAkwdAmymTMbeWl4,68
|
|
489
|
-
rapidata/service/credential_manager.py,sha256=
|
|
490
|
+
rapidata/service/credential_manager.py,sha256=3x-Fb6tyqmgtpjI1MSOtXWW_SkzTK8Lo7I0SSL2YD7E,8602
|
|
490
491
|
rapidata/service/local_file_service.py,sha256=pgorvlWcx52Uh3cEG6VrdMK_t__7dacQ_5AnfY14BW8,877
|
|
491
|
-
rapidata/service/openapi_service.py,sha256=
|
|
492
|
-
rapidata/
|
|
493
|
-
rapidata-2.
|
|
494
|
-
rapidata-2.
|
|
495
|
-
rapidata-2.
|
|
496
|
-
rapidata-2.17.1.dist-info/RECORD,,
|
|
492
|
+
rapidata/service/openapi_service.py,sha256=ORFPfHlb41zOUP5nDjYWZwO-ZcqNF_Mw2r71RitFtS0,4042
|
|
493
|
+
rapidata-2.18.0.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
494
|
+
rapidata-2.18.0.dist-info/METADATA,sha256=RmwN7x5_3aS68UGwC-BTB3FZg8S879NhC32ymMZ8S0A,1187
|
|
495
|
+
rapidata-2.18.0.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
|
|
496
|
+
rapidata-2.18.0.dist-info/RECORD,,
|
|
@@ -1,176 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
import logging
|
|
3
|
-
import threading
|
|
4
|
-
from datetime import datetime, timedelta
|
|
5
|
-
from typing import Optional, Callable
|
|
6
|
-
|
|
7
|
-
import requests
|
|
8
|
-
from pydantic import BaseModel
|
|
9
|
-
|
|
10
|
-
from rapidata.service.credential_manager import CredentialManager
|
|
11
|
-
|
|
12
|
-
logger = logging.getLogger(__name__)
|
|
13
|
-
|
|
14
|
-
class TokenInfo(BaseModel):
|
|
15
|
-
access_token: str
|
|
16
|
-
expires_in: int
|
|
17
|
-
issued_at: datetime
|
|
18
|
-
token_type: str = "Bearer"
|
|
19
|
-
|
|
20
|
-
@property
|
|
21
|
-
def auth_header(self):
|
|
22
|
-
return f"{self.token_type} {self.access_token}"
|
|
23
|
-
|
|
24
|
-
@property
|
|
25
|
-
def time_remaining(self):
|
|
26
|
-
remaining = (
|
|
27
|
-
(self.issued_at + timedelta(seconds=self.expires_in)) - datetime.now()
|
|
28
|
-
).total_seconds()
|
|
29
|
-
return max(0.0, remaining)
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class TokenManager:
|
|
33
|
-
def __init__(
|
|
34
|
-
self,
|
|
35
|
-
client_id: str | None = None,
|
|
36
|
-
client_secret: str | None = None,
|
|
37
|
-
endpoint: str = "https://auth.rapidata.ai",
|
|
38
|
-
oauth_scope: str = "openid profile email",
|
|
39
|
-
cert_path: str | None = None,
|
|
40
|
-
refresh_threshold: float = 0.8,
|
|
41
|
-
max_sleep_time: float = 30,
|
|
42
|
-
max_token_lifetime: int = 60,
|
|
43
|
-
):
|
|
44
|
-
self._client_id = client_id
|
|
45
|
-
self._client_secret = client_secret
|
|
46
|
-
|
|
47
|
-
if not client_id or not client_secret:
|
|
48
|
-
credential_manager = CredentialManager(
|
|
49
|
-
endpoint=endpoint, cert_path=cert_path
|
|
50
|
-
)
|
|
51
|
-
credentials = credential_manager.get_client_credentials()
|
|
52
|
-
if not credentials:
|
|
53
|
-
raise ValueError("Failed to fetch client credentials")
|
|
54
|
-
self._client_id = credentials.client_id
|
|
55
|
-
self._client_secret = credentials.client_secret
|
|
56
|
-
|
|
57
|
-
self._endpoint = endpoint
|
|
58
|
-
self._oauth_scope = oauth_scope
|
|
59
|
-
self._cert_path = cert_path
|
|
60
|
-
self._refresh_threshold = refresh_threshold
|
|
61
|
-
self._max_sleep_time = max_sleep_time
|
|
62
|
-
self._max_token_lifetime = max_token_lifetime
|
|
63
|
-
|
|
64
|
-
self._token_lock = threading.Lock()
|
|
65
|
-
self._current_token: Optional[TokenInfo] = None
|
|
66
|
-
self._refresh_thread: Optional[threading.Thread] = None
|
|
67
|
-
self._should_stop = threading.Event()
|
|
68
|
-
|
|
69
|
-
def fetch_token(self):
|
|
70
|
-
try:
|
|
71
|
-
response = requests.post(
|
|
72
|
-
f"{self._endpoint}/connect/token",
|
|
73
|
-
data={
|
|
74
|
-
"grant_type": "client_credentials",
|
|
75
|
-
"client_id": self._client_id,
|
|
76
|
-
"client_secret": self._client_secret,
|
|
77
|
-
"scope": self._oauth_scope,
|
|
78
|
-
},
|
|
79
|
-
verify=self._cert_path,
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
if response.ok:
|
|
83
|
-
data = response.json()
|
|
84
|
-
return TokenInfo(
|
|
85
|
-
access_token=data["access_token"],
|
|
86
|
-
token_type=data["token_type"],
|
|
87
|
-
expires_in=min(self._max_token_lifetime, data["expires_in"]),
|
|
88
|
-
issued_at=datetime.now(),
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
else:
|
|
92
|
-
data = response.text
|
|
93
|
-
error_description = "An unknown error occurred"
|
|
94
|
-
if "error_description" in data:
|
|
95
|
-
error_description = (
|
|
96
|
-
data.split("error_description")[1].split("\n")[0].strip()
|
|
97
|
-
)
|
|
98
|
-
raise ValueError(f"Failed to fetch token: {error_description}")
|
|
99
|
-
except requests.RequestException as e:
|
|
100
|
-
raise ValueError(f"Failed to fetch token: {e}")
|
|
101
|
-
except json.JSONDecodeError as e:
|
|
102
|
-
raise ValueError(f"Failed to parse token response: {e}")
|
|
103
|
-
except KeyError as e:
|
|
104
|
-
raise ValueError(f"Failed to extract token from response: {e}")
|
|
105
|
-
|
|
106
|
-
def start_token_refresh(self, token_callback: Callable[[TokenInfo], None]) -> None:
|
|
107
|
-
if self._refresh_thread and self._refresh_thread.is_alive():
|
|
108
|
-
logger.error("Token refresh thread is already running")
|
|
109
|
-
return
|
|
110
|
-
|
|
111
|
-
def refresh_loop():
|
|
112
|
-
while not self._should_stop.is_set():
|
|
113
|
-
try:
|
|
114
|
-
with self._token_lock:
|
|
115
|
-
if self._should_refresh_token(self._current_token):
|
|
116
|
-
logger.debug("Refreshing token")
|
|
117
|
-
self._current_token = self.fetch_token()
|
|
118
|
-
token_callback(self._current_token)
|
|
119
|
-
|
|
120
|
-
if self._current_token:
|
|
121
|
-
time_until_refresh_threshold = (
|
|
122
|
-
self._current_token.time_remaining
|
|
123
|
-
- (
|
|
124
|
-
self._current_token.expires_in
|
|
125
|
-
* (1 - self._refresh_threshold)
|
|
126
|
-
)
|
|
127
|
-
)
|
|
128
|
-
logger.debug("Time until refresh threshold: %s", time_until_refresh_threshold)
|
|
129
|
-
sleep_time = min(
|
|
130
|
-
self._max_sleep_time, time_until_refresh_threshold
|
|
131
|
-
)
|
|
132
|
-
logger.debug(
|
|
133
|
-
f"Sleeping for {sleep_time} until checking the token again"
|
|
134
|
-
)
|
|
135
|
-
self._should_stop.wait(timeout=max(1.0, sleep_time))
|
|
136
|
-
else:
|
|
137
|
-
self._should_stop.wait(timeout=self._max_sleep_time)
|
|
138
|
-
except Exception as e:
|
|
139
|
-
logger.error("Failed to refresh token: %s", e)
|
|
140
|
-
self._should_stop.wait(timeout=5)
|
|
141
|
-
|
|
142
|
-
self._should_stop.clear()
|
|
143
|
-
self._refresh_thread = threading.Thread(target=refresh_loop, daemon=True)
|
|
144
|
-
self._refresh_thread.start()
|
|
145
|
-
|
|
146
|
-
def stop_token_refresh(self):
|
|
147
|
-
self._should_stop.set()
|
|
148
|
-
if self._refresh_thread:
|
|
149
|
-
self._refresh_thread.join(timeout=1)
|
|
150
|
-
self._refresh_thread = None
|
|
151
|
-
|
|
152
|
-
def get_current_token(self) -> Optional[TokenInfo]:
|
|
153
|
-
with self._token_lock:
|
|
154
|
-
return self._current_token
|
|
155
|
-
|
|
156
|
-
def _should_refresh_token(self, token: TokenInfo | None) -> bool:
|
|
157
|
-
if not token:
|
|
158
|
-
return True
|
|
159
|
-
|
|
160
|
-
limit = token.expires_in * (1 - self._refresh_threshold)
|
|
161
|
-
|
|
162
|
-
logger.debug(
|
|
163
|
-
"The token was issued at %s, it expires in %s. It has %s seconds remaining and we refresh the token when it has %s seconds remaining",
|
|
164
|
-
token.issued_at,
|
|
165
|
-
token.expires_in,
|
|
166
|
-
token.time_remaining,
|
|
167
|
-
limit,
|
|
168
|
-
)
|
|
169
|
-
return token.time_remaining < limit
|
|
170
|
-
|
|
171
|
-
def __enter__(self):
|
|
172
|
-
return self
|
|
173
|
-
|
|
174
|
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
175
|
-
self.stop_token_refresh()
|
|
176
|
-
return False
|
|
File without changes
|
|
File without changes
|