rapidata 2.17.0__py3-none-any.whl → 2.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rapidata might be problematic. Click here for more details.
- rapidata/api_client/rest.py +143 -169
- rapidata/rapidata_client/order/rapidata_order.py +2 -2
- rapidata/rapidata_client/rapidata_client.py +27 -16
- rapidata/rapidata_client/selection/__init__.py +2 -0
- rapidata/rapidata_client/selection/ab_test_selection.py +1 -1
- rapidata/rapidata_client/selection/static_selection.py +22 -0
- rapidata/rapidata_client/validation/rapidata_validation_set.py +2 -2
- rapidata/rapidata_client/validation/validation_set_manager.py +18 -16
- rapidata/service/credential_manager.py +2 -2
- rapidata/service/openapi_service.py +56 -28
- {rapidata-2.17.0.dist-info → rapidata-2.18.0.dist-info}/METADATA +2 -1
- {rapidata-2.17.0.dist-info → rapidata-2.18.0.dist-info}/RECORD +14 -14
- rapidata/service/token_manager.py +0 -176
- {rapidata-2.17.0.dist-info → rapidata-2.18.0.dist-info}/LICENSE +0 -0
- {rapidata-2.17.0.dist-info → rapidata-2.18.0.dist-info}/WHEEL +0 -0
rapidata/api_client/rest.py
CHANGED
|
@@ -11,114 +11,86 @@
|
|
|
11
11
|
Do not edit the class manually.
|
|
12
12
|
""" # noqa: E501
|
|
13
13
|
|
|
14
|
-
|
|
15
14
|
import io
|
|
16
15
|
import json
|
|
17
16
|
import re
|
|
18
|
-
import
|
|
17
|
+
from typing import Dict, Optional
|
|
19
18
|
|
|
20
|
-
import
|
|
19
|
+
import requests
|
|
20
|
+
from authlib.integrations.requests_client import OAuth2Session
|
|
21
|
+
from requests.adapters import HTTPAdapter
|
|
22
|
+
from urllib3 import Retry
|
|
21
23
|
|
|
22
24
|
from rapidata.api_client.exceptions import ApiException, ApiValueError
|
|
23
25
|
|
|
24
|
-
SUPPORTED_SOCKS_PROXIES = {"socks5", "socks5h", "socks4", "socks4a"}
|
|
25
|
-
RESTResponseType = urllib3.HTTPResponse
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def is_socks_proxy_url(url):
|
|
29
|
-
if url is None:
|
|
30
|
-
return False
|
|
31
|
-
split_section = url.split("://")
|
|
32
|
-
if len(split_section) < 2:
|
|
33
|
-
return False
|
|
34
|
-
else:
|
|
35
|
-
return split_section[0].lower() in SUPPORTED_SOCKS_PROXIES
|
|
36
|
-
|
|
37
26
|
|
|
38
27
|
class RESTResponse(io.IOBase):
|
|
39
28
|
|
|
40
|
-
def __init__(self, resp) -> None:
|
|
29
|
+
def __init__(self, resp: requests.Response) -> None:
|
|
41
30
|
self.response = resp
|
|
42
|
-
self.status = resp.
|
|
31
|
+
self.status = resp.status_code
|
|
43
32
|
self.reason = resp.reason
|
|
44
33
|
self.data = None
|
|
45
34
|
|
|
46
35
|
def read(self):
|
|
47
36
|
if self.data is None:
|
|
48
|
-
self.data = self.response.
|
|
37
|
+
self.data = self.response.content
|
|
49
38
|
return self.data
|
|
50
39
|
|
|
51
|
-
def getheaders(self):
|
|
40
|
+
def getheaders(self) -> Dict[str, str]:
|
|
52
41
|
"""Returns a dictionary of the response headers."""
|
|
53
|
-
return self.response.headers
|
|
42
|
+
return dict(self.response.headers)
|
|
54
43
|
|
|
55
|
-
def getheader(self, name, default=None):
|
|
44
|
+
def getheader(self, name, default=None) -> Optional[str]:
|
|
56
45
|
"""Returns a given response header."""
|
|
57
46
|
return self.response.headers.get(name, default)
|
|
58
47
|
|
|
59
48
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
def __init__(self, configuration) -> None:
|
|
63
|
-
# urllib3.PoolManager will pass all kw parameters to connectionpool
|
|
64
|
-
# https://github.com/shazow/urllib3/blob/f9409436f83aeb79fbaf090181cd81b784f1b8ce/urllib3/poolmanager.py#L75 # noqa: E501
|
|
65
|
-
# https://github.com/shazow/urllib3/blob/f9409436f83aeb79fbaf090181cd81b784f1b8ce/urllib3/connectionpool.py#L680 # noqa: E501
|
|
66
|
-
# Custom SSL certificates and client certificates: http://urllib3.readthedocs.io/en/latest/advanced-usage.html # noqa: E501
|
|
67
|
-
|
|
68
|
-
# cert_reqs
|
|
69
|
-
if configuration.verify_ssl:
|
|
70
|
-
cert_reqs = ssl.CERT_REQUIRED
|
|
71
|
-
else:
|
|
72
|
-
cert_reqs = ssl.CERT_NONE
|
|
73
|
-
|
|
74
|
-
pool_args = {
|
|
75
|
-
"cert_reqs": cert_reqs,
|
|
76
|
-
"ca_certs": configuration.ssl_ca_cert,
|
|
77
|
-
"cert_file": configuration.cert_file,
|
|
78
|
-
"key_file": configuration.key_file,
|
|
79
|
-
}
|
|
80
|
-
if configuration.assert_hostname is not None:
|
|
81
|
-
pool_args['assert_hostname'] = (
|
|
82
|
-
configuration.assert_hostname
|
|
83
|
-
)
|
|
49
|
+
RESTResponseType = RESTResponse
|
|
84
50
|
|
|
85
|
-
if configuration.retries is not None:
|
|
86
|
-
pool_args['retries'] = configuration.retries
|
|
87
51
|
|
|
88
|
-
|
|
89
|
-
pool_args['server_hostname'] = configuration.tls_server_name
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
if configuration.socket_options is not None:
|
|
93
|
-
pool_args['socket_options'] = configuration.socket_options
|
|
52
|
+
class RESTClientObject:
|
|
94
53
|
|
|
95
|
-
|
|
96
|
-
|
|
54
|
+
def __init__(self, configuration) -> None:
|
|
55
|
+
self.configuration = configuration
|
|
97
56
|
|
|
98
|
-
|
|
99
|
-
self.pool_manager: urllib3.PoolManager
|
|
57
|
+
self.session: Optional[OAuth2Session] = None
|
|
100
58
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
59
|
+
def setup_oauth_client_credentials(
|
|
60
|
+
self, client_id: str, client_secret: str, token_endpoint: str, scope: str
|
|
61
|
+
):
|
|
62
|
+
self.session = OAuth2Session(
|
|
63
|
+
client_id=client_id,
|
|
64
|
+
client_secret=client_secret,
|
|
65
|
+
token_endpoint=token_endpoint,
|
|
66
|
+
scope=scope,
|
|
67
|
+
)
|
|
68
|
+
self._configure_session_defaults()
|
|
69
|
+
self.session.fetch_token()
|
|
70
|
+
|
|
71
|
+
def setup_oauth_with_token(self,
|
|
72
|
+
client_id: str | None,
|
|
73
|
+
client_secret: str | None,
|
|
74
|
+
token: dict,
|
|
75
|
+
token_endpoint: str,
|
|
76
|
+
leeway: int = 60):
|
|
77
|
+
self.session = OAuth2Session(
|
|
78
|
+
token=token,
|
|
79
|
+
token_endpoint=token_endpoint,
|
|
80
|
+
client_id=client_id,
|
|
81
|
+
client_secret=client_secret,
|
|
82
|
+
leeway=leeway,
|
|
83
|
+
)
|
|
84
|
+
self._configure_session_defaults()
|
|
113
85
|
|
|
114
86
|
def request(
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
87
|
+
self,
|
|
88
|
+
method,
|
|
89
|
+
url,
|
|
90
|
+
headers=None,
|
|
91
|
+
body=None,
|
|
92
|
+
post_params=None,
|
|
93
|
+
_request_timeout=None,
|
|
122
94
|
):
|
|
123
95
|
"""Perform requests.
|
|
124
96
|
|
|
@@ -135,15 +107,7 @@ class RESTClientObject:
|
|
|
135
107
|
(connection, read) timeouts.
|
|
136
108
|
"""
|
|
137
109
|
method = method.upper()
|
|
138
|
-
assert method in [
|
|
139
|
-
'GET',
|
|
140
|
-
'HEAD',
|
|
141
|
-
'DELETE',
|
|
142
|
-
'POST',
|
|
143
|
-
'PUT',
|
|
144
|
-
'PATCH',
|
|
145
|
-
'OPTIONS'
|
|
146
|
-
]
|
|
110
|
+
assert method in ["GET", "HEAD", "DELETE", "POST", "PUT", "PATCH", "OPTIONS"]
|
|
147
111
|
|
|
148
112
|
if post_params and body:
|
|
149
113
|
raise ApiValueError(
|
|
@@ -153,105 +117,115 @@ class RESTClientObject:
|
|
|
153
117
|
post_params = post_params or {}
|
|
154
118
|
headers = headers or {}
|
|
155
119
|
|
|
120
|
+
if not self.session:
|
|
121
|
+
raise ApiValueError(
|
|
122
|
+
"OAuth2 session is not initialized. Please initialize it before making requests."
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
session = self.session
|
|
126
|
+
|
|
156
127
|
timeout = None
|
|
157
128
|
if _request_timeout:
|
|
158
129
|
if isinstance(_request_timeout, (int, float)):
|
|
159
|
-
timeout =
|
|
160
|
-
elif (
|
|
161
|
-
|
|
162
|
-
and len(_request_timeout) == 2
|
|
163
|
-
):
|
|
164
|
-
timeout = urllib3.Timeout(
|
|
165
|
-
connect=_request_timeout[0],
|
|
166
|
-
read=_request_timeout[1]
|
|
167
|
-
)
|
|
130
|
+
timeout = _request_timeout
|
|
131
|
+
elif isinstance(_request_timeout, tuple) and len(_request_timeout) == 2:
|
|
132
|
+
timeout = _request_timeout
|
|
168
133
|
|
|
169
134
|
try:
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
content_type = headers.get('Content-Type')
|
|
175
|
-
if (
|
|
176
|
-
not content_type
|
|
177
|
-
or re.search('json', content_type, re.IGNORECASE)
|
|
178
|
-
):
|
|
135
|
+
if method in ["POST", "PUT", "PATCH", "OPTIONS", "DELETE"]:
|
|
136
|
+
content_type = headers.get("Content-Type")
|
|
137
|
+
|
|
138
|
+
if not content_type or re.search("json", content_type, re.IGNORECASE):
|
|
179
139
|
request_body = None
|
|
180
140
|
if body is not None:
|
|
181
141
|
request_body = json.dumps(body)
|
|
182
|
-
r =
|
|
183
|
-
|
|
184
|
-
url,
|
|
185
|
-
body=request_body,
|
|
186
|
-
timeout=timeout,
|
|
187
|
-
headers=headers,
|
|
188
|
-
preload_content=False
|
|
189
|
-
)
|
|
142
|
+
r = session.request(method, url, data=request_body, timeout=timeout, headers=headers)
|
|
143
|
+
|
|
190
144
|
elif content_type == 'application/x-www-form-urlencoded':
|
|
191
|
-
r =
|
|
192
|
-
|
|
193
|
-
url,
|
|
194
|
-
fields=post_params,
|
|
195
|
-
encode_multipart=False,
|
|
196
|
-
timeout=timeout,
|
|
197
|
-
headers=headers,
|
|
198
|
-
preload_content=False
|
|
199
|
-
)
|
|
145
|
+
r = session.request(method, url, data=post_params, timeout=timeout, headers=headers)
|
|
146
|
+
|
|
200
147
|
elif content_type == 'multipart/form-data':
|
|
201
|
-
# must del headers['Content-Type'], or the correct
|
|
202
|
-
# Content-Type which generated by urllib3 will be
|
|
203
|
-
# overwritten.
|
|
204
148
|
del headers['Content-Type']
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
149
|
+
files = []
|
|
150
|
+
data = {}
|
|
151
|
+
|
|
152
|
+
for key, value in post_params:
|
|
153
|
+
if isinstance(value, tuple) and len(value) >= 2:
|
|
154
|
+
# This is a file tuple (filename, file_data, [content_type])
|
|
155
|
+
filename, file_data = value[0], value[1]
|
|
156
|
+
content_type = value[2] if len(value) > 2 else None
|
|
157
|
+
files.append((key, (filename, file_data, content_type)))
|
|
158
|
+
elif isinstance(value, dict):
|
|
159
|
+
# JSON-serialize dictionary values
|
|
160
|
+
if key in data:
|
|
161
|
+
# If we already have this key, handle as needed
|
|
162
|
+
# (convert to list or append to existing list)
|
|
163
|
+
if not isinstance(data[key], list):
|
|
164
|
+
data[key] = [data[key]]
|
|
165
|
+
data[key].append(json.dumps(value))
|
|
166
|
+
else:
|
|
167
|
+
data[key] = json.dumps(value)
|
|
168
|
+
else:
|
|
169
|
+
# Regular form data
|
|
170
|
+
if key in data:
|
|
171
|
+
if not isinstance(data[key], list):
|
|
172
|
+
data[key] = [data[key]]
|
|
173
|
+
data[key].append(value)
|
|
174
|
+
else:
|
|
175
|
+
data[key] = value
|
|
176
|
+
r = session.request(method, url, files=files, data=data, timeout=timeout, headers=headers)
|
|
177
|
+
|
|
219
178
|
elif isinstance(body, str) or isinstance(body, bytes):
|
|
220
|
-
r =
|
|
221
|
-
|
|
222
|
-
url,
|
|
223
|
-
body=body,
|
|
224
|
-
timeout=timeout,
|
|
225
|
-
headers=headers,
|
|
226
|
-
preload_content=False
|
|
227
|
-
)
|
|
179
|
+
r = session.request(method, url, data=body, timeout=timeout, headers=headers)
|
|
180
|
+
|
|
228
181
|
elif headers['Content-Type'].startswith('text/') and isinstance(body, bool):
|
|
229
|
-
request_body =
|
|
230
|
-
r =
|
|
231
|
-
|
|
232
|
-
url,
|
|
233
|
-
body=request_body,
|
|
234
|
-
preload_content=False,
|
|
235
|
-
timeout=timeout,
|
|
236
|
-
headers=headers)
|
|
182
|
+
request_body = 'true' if body else 'false'
|
|
183
|
+
r = session.request(method, url, data=request_body, timeout=timeout, headers=headers)
|
|
184
|
+
|
|
237
185
|
else:
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
arguments. Please check that your arguments match
|
|
241
|
-
declared content type."""
|
|
186
|
+
msg = '''Cannot prepare a request message for provided arguments.
|
|
187
|
+
Please check that your arguments match declared content type.'''
|
|
242
188
|
raise ApiException(status=0, reason=msg)
|
|
243
|
-
|
|
189
|
+
|
|
244
190
|
else:
|
|
245
|
-
r =
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
timeout=timeout,
|
|
250
|
-
headers=headers,
|
|
251
|
-
preload_content=False
|
|
252
|
-
)
|
|
253
|
-
except urllib3.exceptions.SSLError as e:
|
|
254
|
-
msg = "\n".join([type(e).__name__, str(e)])
|
|
191
|
+
r = session.request(method, url, params={}, timeout=timeout, headers=headers)
|
|
192
|
+
|
|
193
|
+
except requests.exceptions.SSLError as e:
|
|
194
|
+
msg = '\n'.join([type(e).__name__, str(e)])
|
|
255
195
|
raise ApiException(status=0, reason=msg)
|
|
256
196
|
|
|
257
197
|
return RESTResponse(r)
|
|
198
|
+
|
|
199
|
+
def _configure_session_defaults(self):
|
|
200
|
+
self.session.verify = (
|
|
201
|
+
self.configuration.ssl_ca_cert
|
|
202
|
+
if self.configuration.ssl_ca_cert
|
|
203
|
+
else self.configuration.verify_ssl
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
if self.configuration.cert_file and self.configuration.key_file:
|
|
207
|
+
self.session.cert = (
|
|
208
|
+
self.configuration.cert_file,
|
|
209
|
+
self.configuration.key_file,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
if self.configuration.retries is not None:
|
|
213
|
+
retry = Retry(
|
|
214
|
+
total=self.configuration.retries,
|
|
215
|
+
backoff_factor=0.3,
|
|
216
|
+
status_forcelist=[429, 500, 502, 503, 504],
|
|
217
|
+
)
|
|
218
|
+
adapter = HTTPAdapter(max_retries=retry)
|
|
219
|
+
self.session.mount("https://", adapter)
|
|
220
|
+
# noinspection HttpUrlsUsage
|
|
221
|
+
self.session.mount("http://", adapter)
|
|
222
|
+
|
|
223
|
+
if self.configuration.proxy:
|
|
224
|
+
self.session.proxies = {
|
|
225
|
+
"http": self.configuration.proxy,
|
|
226
|
+
"https": self.configuration.proxy,
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
if self.configuration.proxy_headers:
|
|
230
|
+
for key, value in self.configuration.proxy_headers.items():
|
|
231
|
+
self.session.headers[key] = value
|
|
@@ -51,7 +51,7 @@ class RapidataOrder:
|
|
|
51
51
|
"""Runs the order to start collecting responses."""
|
|
52
52
|
self.__openapi_service.order_api.order_submit_post(self.order_id)
|
|
53
53
|
if print_link:
|
|
54
|
-
print(f"Order '{self.name}' is now viewable under: https://app.{self.__openapi_service.
|
|
54
|
+
print(f"Order '{self.name}' is now viewable under: https://app.{self.__openapi_service.environment}/order/detail/{self.order_id}")
|
|
55
55
|
return self
|
|
56
56
|
|
|
57
57
|
def pause(self) -> None:
|
|
@@ -167,7 +167,7 @@ class RapidataOrder:
|
|
|
167
167
|
Exception: If the order is not in processing state.
|
|
168
168
|
"""
|
|
169
169
|
campaign_id = self.__get_campaign_id()
|
|
170
|
-
auth_url = f"https://app.{self.__openapi_service.
|
|
170
|
+
auth_url = f"https://app.{self.__openapi_service.environment}/order/detail/{self.order_id}/preview?campaignId={campaign_id}"
|
|
171
171
|
could_open_browser = webbrowser.open(auth_url)
|
|
172
172
|
if not could_open_browser:
|
|
173
173
|
encoded_url = urllib.parse.quote(auth_url, safe="%/:=&?~#+!$,;'@()*[]")
|
|
@@ -2,28 +2,37 @@ from rapidata.service.openapi_service import OpenAPIService
|
|
|
2
2
|
|
|
3
3
|
from rapidata.rapidata_client.order.rapidata_order_manager import RapidataOrderManager
|
|
4
4
|
|
|
5
|
-
from rapidata.rapidata_client.validation.validation_set_manager import
|
|
5
|
+
from rapidata.rapidata_client.validation.validation_set_manager import (
|
|
6
|
+
ValidationSetManager,
|
|
7
|
+
)
|
|
6
8
|
|
|
7
9
|
from rapidata.rapidata_client.demographic.demographic_manager import DemographicManager
|
|
8
10
|
|
|
11
|
+
|
|
9
12
|
class RapidataClient:
|
|
10
13
|
"""The Rapidata client is the main entry point for interacting with the Rapidata API. It allows you to create orders and validation sets."""
|
|
11
|
-
|
|
14
|
+
|
|
12
15
|
def __init__(
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
16
|
+
self,
|
|
17
|
+
client_id: str | None = None,
|
|
18
|
+
client_secret: str | None = None,
|
|
19
|
+
environment: str = "rapidata.ai",
|
|
20
|
+
oauth_scope: str = "openid",
|
|
21
|
+
cert_path: str | None = None,
|
|
22
|
+
token: dict | None = None,
|
|
23
|
+
leeway: int = 60
|
|
19
24
|
):
|
|
20
|
-
"""Initialize the RapidataClient. If both the client_id and client_secret are None, it will try using your credentials under "~/.config/rapidata/credentials.json".
|
|
21
|
-
If this is not successful, it will open a browser
|
|
25
|
+
"""Initialize the RapidataClient. If both the client_id and client_secret are None, it will try using your credentials under "~/.config/rapidata/credentials.json".
|
|
26
|
+
If this is not successful, it will open a browser window and ask you to log in, then save your new credentials in said json file.
|
|
22
27
|
|
|
23
28
|
Args:
|
|
24
29
|
client_id (str): The client ID for authentication.
|
|
25
30
|
client_secret (str): The client secret for authentication.
|
|
26
|
-
|
|
31
|
+
environment (str, optional): The API endpoint.
|
|
32
|
+
oauth_scope (str, optional): The scopes to use for authentication. In general this does not need to be changed.
|
|
33
|
+
cert_path (str, optional): An optional path to a certificate file useful for development.
|
|
34
|
+
token (dict, optional): If you already have a token that the client should use for authentication. Important, if set, this needs to be the complete token object containing the access token, token type and expiration time.
|
|
35
|
+
leeway (int, optional): An optional leeway to use to determine if a token is expired. Defaults to 60 seconds.
|
|
27
36
|
|
|
28
37
|
Attributes:
|
|
29
38
|
order (RapidataOrderManager): The RapidataOrderManager instance.
|
|
@@ -32,17 +41,19 @@ class RapidataClient:
|
|
|
32
41
|
self._openapi_service = OpenAPIService(
|
|
33
42
|
client_id=client_id,
|
|
34
43
|
client_secret=client_secret,
|
|
35
|
-
|
|
44
|
+
environment=environment,
|
|
36
45
|
oauth_scope=oauth_scope,
|
|
37
|
-
cert_path=cert_path
|
|
46
|
+
cert_path=cert_path,
|
|
47
|
+
token=token,
|
|
48
|
+
leeway=leeway,
|
|
38
49
|
)
|
|
39
|
-
|
|
50
|
+
|
|
40
51
|
self.order = RapidataOrderManager(openapi_service=self._openapi_service)
|
|
41
|
-
|
|
52
|
+
|
|
42
53
|
self.validation = ValidationSetManager(openapi_service=self._openapi_service)
|
|
43
54
|
|
|
44
55
|
self._demographic = DemographicManager(openapi_service=self._openapi_service)
|
|
45
56
|
|
|
46
57
|
def reset_credentials(self):
|
|
47
|
-
"""Reset the credentials saved in the configuration file for the current
|
|
58
|
+
"""Reset the credentials saved in the configuration file for the current environment."""
|
|
48
59
|
self._openapi_service.reset_credentials()
|
|
@@ -5,3 +5,5 @@ from .validation_selection import ValidationSelection
|
|
|
5
5
|
from .conditional_validation_selection import ConditionalValidationSelection
|
|
6
6
|
from .capped_selection import CappedSelection
|
|
7
7
|
from .shuffling_selection import ShufflingSelection
|
|
8
|
+
from .ab_test_selection import AbTestSelection
|
|
9
|
+
from .static_selection import StaticSelection
|
|
@@ -20,7 +20,7 @@ class AbTestSelection(RapidataSelection):
|
|
|
20
20
|
b_selections (Sequence[RapidataSelection]): List of selections for group B.
|
|
21
21
|
"""
|
|
22
22
|
|
|
23
|
-
def __init__(self, a_selections: Sequence[RapidataSelection], b_selections: Sequence[RapidataSelection]
|
|
23
|
+
def __init__(self, a_selections: Sequence[RapidataSelection], b_selections: Sequence[RapidataSelection]):
|
|
24
24
|
self.a_selections = a_selections
|
|
25
25
|
self.b_selections = b_selections
|
|
26
26
|
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
|
|
2
|
+
from rapidata.api_client.models.static_selection import StaticSelection as StaticSelectionModel
|
|
3
|
+
from rapidata.rapidata_client.selection._base_selection import RapidataSelection
|
|
4
|
+
|
|
5
|
+
class StaticSelection(RapidataSelection):
|
|
6
|
+
"""StaticSelection Class
|
|
7
|
+
|
|
8
|
+
Given a list of RapidIds, theses specific rapids will be shown in order for every session.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
selections (list[str]): List of rapid ids to show.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, rapid_ids: list[str]):
|
|
15
|
+
self.rapid_ids = rapid_ids
|
|
16
|
+
|
|
17
|
+
def _to_model(self) -> StaticSelectionModel:
|
|
18
|
+
return StaticSelectionModel(
|
|
19
|
+
_t="StaticSelection",
|
|
20
|
+
rapidIds=self.rapid_ids
|
|
21
|
+
)
|
|
22
|
+
|
|
@@ -31,13 +31,13 @@ class RapidataValidationSet:
|
|
|
31
31
|
rapid._add_to_validation_set(self.id, self.__openapi_service, self.__session)
|
|
32
32
|
return self
|
|
33
33
|
|
|
34
|
-
def update_dimensions(self, dimensions: list[str]
|
|
34
|
+
def update_dimensions(self, dimensions: list[str]):
|
|
35
35
|
"""Update the dimensions of the validation set.
|
|
36
36
|
|
|
37
37
|
Args:
|
|
38
38
|
dimensions (list[str]): The new dimensions of the validation set.
|
|
39
39
|
"""
|
|
40
|
-
self.__openapi_service.validation_api.validation_validation_set_id_dimensions_patch(self.id, UpdateDimensionsModel(dimensions=dimensions)
|
|
40
|
+
self.__openapi_service.validation_api.validation_validation_set_id_dimensions_patch(self.id, UpdateDimensionsModel(dimensions=dimensions))
|
|
41
41
|
return self
|
|
42
42
|
|
|
43
43
|
def _get_session(self, max_retries: int = 5, max_workers: int = 10) -> requests.Session:
|
|
@@ -39,7 +39,7 @@ class ValidationSetManager:
|
|
|
39
39
|
data_type: str = RapidataDataTypes.MEDIA,
|
|
40
40
|
contexts: list[str] | None = None,
|
|
41
41
|
explanations: list[str | None] | None = None,
|
|
42
|
-
dimensions: list[str]
|
|
42
|
+
dimensions: list[str] = [],
|
|
43
43
|
print_confirmation: bool = True,
|
|
44
44
|
) -> RapidataValidationSet:
|
|
45
45
|
"""Create a classification validation set.
|
|
@@ -59,7 +59,7 @@ class ValidationSetManager:
|
|
|
59
59
|
If provided has to be the same length as datapoints and will be shown in addition to the instruction and answer options. (Therefore will be different for each datapoint)
|
|
60
60
|
Will be match up with the datapoints using the list index.
|
|
61
61
|
explanations (list[str | None], optional): The explanations for each datapoint. Will be given to the annotators in case the answer is wrong. Defaults to None.
|
|
62
|
-
dimensions (list[str]
|
|
62
|
+
dimensions (list[str], optional): The dimensions to add to the validation set accross which users will be tracked. Defaults to [] which is the default dimension.
|
|
63
63
|
print_confirmation (bool, optional): Whether to print a confirmation message that validation set has been created. Defaults to True.
|
|
64
64
|
|
|
65
65
|
Example:
|
|
@@ -107,7 +107,7 @@ class ValidationSetManager:
|
|
|
107
107
|
data_type: str = RapidataDataTypes.MEDIA,
|
|
108
108
|
contexts: list[str] | None = None,
|
|
109
109
|
explanation: list[str | None] | None = None,
|
|
110
|
-
dimensions: list[str]
|
|
110
|
+
dimensions: list[str] = [],
|
|
111
111
|
print_confirmation: bool = True,
|
|
112
112
|
) -> RapidataValidationSet:
|
|
113
113
|
"""Create a comparison validation set.
|
|
@@ -127,7 +127,7 @@ class ValidationSetManager:
|
|
|
127
127
|
If provided has to be the same length as datapoints and will be shown in addition to the instruction and truth. (Therefore will be different for each datapoint)
|
|
128
128
|
Will be match up with the datapoints using the list index.
|
|
129
129
|
explanation (list[str | None], optional): The explanations for each datapoint. Will be given to the annotators in case the answer is wrong. Defaults to None.
|
|
130
|
-
dimensions (list[str]
|
|
130
|
+
dimensions (list[str], optional): The dimensions to add to the validation set accross which users will be tracked. Defaults to [] which is the default dimension.
|
|
131
131
|
print_confirmation (bool, optional): Whether to print a confirmation message that validation set has been created. Defaults to True.
|
|
132
132
|
|
|
133
133
|
Example:
|
|
@@ -175,7 +175,7 @@ class ValidationSetManager:
|
|
|
175
175
|
required_precision: float = 1.0,
|
|
176
176
|
required_completeness: float = 1.0,
|
|
177
177
|
explanation: list[str | None] | None = None,
|
|
178
|
-
dimensions: list[str]
|
|
178
|
+
dimensions: list[str] = [],
|
|
179
179
|
print_confirmation: bool = True,
|
|
180
180
|
) -> RapidataValidationSet:
|
|
181
181
|
"""Create a select words validation set.
|
|
@@ -194,7 +194,7 @@ class ValidationSetManager:
|
|
|
194
194
|
required_precision (float, optional): The required precision for the labeler to get the rapid correct (minimum ratio of the words selected that need to be correct). Defaults to 1.0 (no wrong word can be selected).
|
|
195
195
|
required_completeness (float, optional): The required completeness for the labeler to get the rapid correct (miminum ratio of total correct words selected). Defaults to 1.0 (all correct words need to be selected).
|
|
196
196
|
explanation (list[str | None], optional): The explanations for each datapoint. Will be given to the annotators in case the answer is wrong. Defaults to None.
|
|
197
|
-
dimensions (list[str]
|
|
197
|
+
dimensions (list[str], optional): The dimensions to add to the validation set accross which users will be tracked. Defaults to [] which is the default dimension.
|
|
198
198
|
print_confirmation (bool, optional): Whether to print a confirmation message that validation set has been created. Defaults to True.
|
|
199
199
|
|
|
200
200
|
Example:
|
|
@@ -238,7 +238,7 @@ class ValidationSetManager:
|
|
|
238
238
|
datapoints: list[str],
|
|
239
239
|
contexts: list[str] | None = None,
|
|
240
240
|
explanation: list[str | None] | None = None,
|
|
241
|
-
dimensions: list[str]
|
|
241
|
+
dimensions: list[str] = [],
|
|
242
242
|
print_confirmation: bool = True,
|
|
243
243
|
) -> RapidataValidationSet:
|
|
244
244
|
"""Create a locate validation set.
|
|
@@ -253,7 +253,7 @@ class ValidationSetManager:
|
|
|
253
253
|
datapoints (list[str]): The datapoints that will be used for validation.
|
|
254
254
|
contexts (list[str], optional): The contexts for each datapoint. Defaults to None.
|
|
255
255
|
explanation (list[str | None], optional): The explanations for each datapoint. Will be given to the annotators in case the answer is wrong. Defaults to None.
|
|
256
|
-
dimensions (list[str]
|
|
256
|
+
dimensions (list[str], optional): The dimensions to add to the validation set accross which users will be tracked. Defaults to [] which is the default dimension.
|
|
257
257
|
print_confirmation (bool, optional): Whether to print a confirmation message that validation set has been created. Defaults to True.
|
|
258
258
|
|
|
259
259
|
Example:
|
|
@@ -299,7 +299,7 @@ class ValidationSetManager:
|
|
|
299
299
|
datapoints: list[str],
|
|
300
300
|
contexts: list[str] | None = None,
|
|
301
301
|
explanation: list[str | None] | None = None,
|
|
302
|
-
dimensions: list[str]
|
|
302
|
+
dimensions: list[str] = [],
|
|
303
303
|
print_confirmation: bool = True,
|
|
304
304
|
) -> RapidataValidationSet:
|
|
305
305
|
"""Create a draw validation set.
|
|
@@ -314,7 +314,7 @@ class ValidationSetManager:
|
|
|
314
314
|
datapoints (list[str]): The datapoints that will be used for validation.
|
|
315
315
|
contexts (list[str], optional): The contexts for each datapoint. Defaults to None.
|
|
316
316
|
explanation (list[str | None], optional): The explanations for each datapoint. Will be given to the annotators in case the answer is wrong. Defaults to None.
|
|
317
|
-
dimensions (list[str]
|
|
317
|
+
dimensions (list[str], optional): The dimensions to add to the validation set accross which users will be tracked. Defaults to [] which is the default dimension.
|
|
318
318
|
print_confirmation (bool, optional): Whether to print a confirmation message that validation set has been created. Defaults to True.
|
|
319
319
|
|
|
320
320
|
Example:
|
|
@@ -359,7 +359,7 @@ class ValidationSetManager:
|
|
|
359
359
|
datapoints: list[str],
|
|
360
360
|
contexts: list[str] | None = None,
|
|
361
361
|
explanation: list[str | None] | None = None,
|
|
362
|
-
dimensions: list[str]
|
|
362
|
+
dimensions: list[str] = [],
|
|
363
363
|
print_confirmation: bool = True,
|
|
364
364
|
) -> RapidataValidationSet:
|
|
365
365
|
"""Create a timestamp validation set.
|
|
@@ -375,7 +375,7 @@ class ValidationSetManager:
|
|
|
375
375
|
datapoints (list[str]): The datapoints that will be used for validation.
|
|
376
376
|
contexts (list[str], optional): The contexts for each datapoint. Defaults to None.
|
|
377
377
|
explanation (list[str | None], optional): The explanations for each datapoint. Will be given to the annotators in case the answer is wrong. Defaults to None.
|
|
378
|
-
dimensions (list[str]
|
|
378
|
+
dimensions (list[str], optional): The dimensions to add to the validation set accross which users will be tracked. Defaults to [] which is the default dimension.
|
|
379
379
|
print_confirmation (bool, optional): Whether to print a confirmation message that validation set has been created. Defaults to True.
|
|
380
380
|
|
|
381
381
|
Example:
|
|
@@ -416,7 +416,7 @@ class ValidationSetManager:
|
|
|
416
416
|
def create_mixed_set(self,
|
|
417
417
|
name: str,
|
|
418
418
|
rapids: list[Rapid],
|
|
419
|
-
dimensions: list[str]
|
|
419
|
+
dimensions: list[str] = [],
|
|
420
420
|
print_confirmation: bool = True
|
|
421
421
|
) -> RapidataValidationSet:
|
|
422
422
|
"""Create a validation set with a list of rapids.
|
|
@@ -424,7 +424,7 @@ class ValidationSetManager:
|
|
|
424
424
|
Args:
|
|
425
425
|
name (str): The name of the validation set. (will not be shown to the labeler)
|
|
426
426
|
rapids (list[Rapid]): The list of rapids to add to the validation set.
|
|
427
|
-
dimensions (list[str]
|
|
427
|
+
dimensions (list[str], optional): The dimensions to add to the validation set accross which users will be tracked. Defaults to [] which is the default dimension.
|
|
428
428
|
print_confirmation (bool, optional): Whether to print a confirmation message that validation set has been created. Defaults to True.
|
|
429
429
|
"""
|
|
430
430
|
|
|
@@ -468,10 +468,12 @@ class ValidationSetManager:
|
|
|
468
468
|
if print_confirmation:
|
|
469
469
|
print()
|
|
470
470
|
print(f"Validation set '{name}' created with ID {validation_set_id}\n",
|
|
471
|
-
f"Now viewable under: https://app.{self.__openapi_service.
|
|
471
|
+
f"Now viewable under: https://app.{self.__openapi_service.environment}/validation-set/detail/{validation_set_id}",
|
|
472
472
|
sep="")
|
|
473
473
|
|
|
474
|
-
|
|
474
|
+
if dimensions:
|
|
475
|
+
validation_set.update_dimensions(dimensions)
|
|
476
|
+
|
|
475
477
|
return validation_set
|
|
476
478
|
|
|
477
479
|
|
|
@@ -133,9 +133,9 @@ class CredentialManager:
|
|
|
133
133
|
return credential
|
|
134
134
|
|
|
135
135
|
return self._create_new_credentials()
|
|
136
|
-
|
|
136
|
+
|
|
137
137
|
def reset_credentials(self) -> None:
|
|
138
|
-
"""Reset the stored credentials for current
|
|
138
|
+
"""Reset the stored credentials for current environment."""
|
|
139
139
|
credentials = self._read_credentials()
|
|
140
140
|
if self.endpoint in credentials:
|
|
141
141
|
del credentials[self.endpoint]
|
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
import subprocess
|
|
2
|
+
from importlib.metadata import version, PackageNotFoundError
|
|
3
|
+
|
|
1
4
|
from rapidata.api_client.api.campaign_api import CampaignApi
|
|
2
5
|
from rapidata.api_client.api.dataset_api import DatasetApi
|
|
3
6
|
from rapidata.api_client.api.order_api import OrderApi
|
|
@@ -7,31 +10,31 @@ from rapidata.api_client.api.validation_api import ValidationApi
|
|
|
7
10
|
from rapidata.api_client.api.workflow_api import WorkflowApi
|
|
8
11
|
from rapidata.api_client.api_client import ApiClient
|
|
9
12
|
from rapidata.api_client.configuration import Configuration
|
|
10
|
-
from rapidata.service.token_manager import TokenManager, TokenInfo
|
|
11
13
|
from rapidata.service.credential_manager import CredentialManager
|
|
12
14
|
|
|
13
|
-
from importlib.metadata import version, PackageNotFoundError
|
|
14
|
-
|
|
15
15
|
|
|
16
16
|
class OpenAPIService:
|
|
17
17
|
def __init__(
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
18
|
+
self,
|
|
19
|
+
client_id: str | None,
|
|
20
|
+
client_secret: str | None,
|
|
21
|
+
environment: str,
|
|
22
|
+
oauth_scope: str,
|
|
23
|
+
cert_path: str | None = None,
|
|
24
|
+
token: dict | None = None,
|
|
25
|
+
leeway: int = 60,
|
|
24
26
|
):
|
|
25
|
-
self.
|
|
26
|
-
endpoint = f"https://api.{
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
cert_path=cert_path
|
|
27
|
+
self.environment = environment
|
|
28
|
+
endpoint = f"https://api.{environment}"
|
|
29
|
+
auth_endpoint = f"https://auth.{environment}"
|
|
30
|
+
|
|
31
|
+
if environment == "rapidata.dev" and not cert_path:
|
|
32
|
+
cert_path = _get_local_certificate()
|
|
33
|
+
|
|
34
|
+
self.credential_manager = CredentialManager(
|
|
35
|
+
endpoint=auth_endpoint, cert_path=cert_path
|
|
34
36
|
)
|
|
37
|
+
|
|
35
38
|
client_configuration = Configuration(host=endpoint, ssl_ca_cert=cert_path)
|
|
36
39
|
self.api_client = ApiClient(
|
|
37
40
|
configuration=client_configuration,
|
|
@@ -39,16 +42,32 @@ class OpenAPIService:
|
|
|
39
42
|
header_value=f"RapidataPythonSDK/{self._get_rapidata_package_version()}",
|
|
40
43
|
)
|
|
41
44
|
|
|
42
|
-
|
|
43
|
-
|
|
45
|
+
if token:
|
|
46
|
+
self.api_client.rest_client.setup_oauth_with_token(
|
|
47
|
+
token=token,
|
|
48
|
+
token_endpoint=f"{auth_endpoint}/connect/token",
|
|
49
|
+
client_id=client_id,
|
|
50
|
+
client_secret=client_secret,
|
|
51
|
+
leeway=leeway,
|
|
52
|
+
)
|
|
53
|
+
return
|
|
54
|
+
|
|
55
|
+
if not client_id or not client_secret:
|
|
56
|
+
credentials = self.credential_manager.get_client_credentials()
|
|
57
|
+
if not credentials:
|
|
58
|
+
raise ValueError("Failed to fetch client credentials")
|
|
59
|
+
client_id = credentials.client_id
|
|
60
|
+
client_secret = credentials.client_secret
|
|
61
|
+
|
|
62
|
+
self.api_client.rest_client.setup_oauth_client_credentials(
|
|
63
|
+
client_id=client_id,
|
|
64
|
+
client_secret=client_secret,
|
|
65
|
+
token_endpoint=f"{auth_endpoint}/connect/token",
|
|
66
|
+
scope=oauth_scope,
|
|
44
67
|
)
|
|
45
68
|
|
|
46
|
-
self._cert_path = cert_path
|
|
47
|
-
|
|
48
|
-
token_manager.start_token_refresh(token_callback=self._set_token)
|
|
49
|
-
|
|
50
69
|
def reset_credentials(self):
|
|
51
|
-
|
|
70
|
+
self.credential_manager.reset_credentials()
|
|
52
71
|
|
|
53
72
|
@property
|
|
54
73
|
def order_api(self) -> OrderApi:
|
|
@@ -78,9 +97,6 @@ class OpenAPIService:
|
|
|
78
97
|
def workflow_api(self) -> WorkflowApi:
|
|
79
98
|
return WorkflowApi(self.api_client)
|
|
80
99
|
|
|
81
|
-
def _set_token(self, token: TokenInfo):
|
|
82
|
-
self.api_client.configuration.api_key["bearer"] = f"Bearer {token.access_token}"
|
|
83
|
-
|
|
84
100
|
def _get_rapidata_package_version(self):
|
|
85
101
|
"""
|
|
86
102
|
Returns the version of the currently installed rapidata package.
|
|
@@ -93,3 +109,15 @@ class OpenAPIService:
|
|
|
93
109
|
return version("rapidata")
|
|
94
110
|
except PackageNotFoundError:
|
|
95
111
|
return None
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _get_local_certificate() -> str | None:
|
|
115
|
+
result = subprocess.run(["mkcert", "-CAROOT"], capture_output=True)
|
|
116
|
+
if result.returncode != 0:
|
|
117
|
+
return None
|
|
118
|
+
|
|
119
|
+
output = result.stdout.decode("utf-8").strip()
|
|
120
|
+
if not output:
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
return f"{output}/rootCA.pem"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: rapidata
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.18.0
|
|
4
4
|
Summary: Rapidata package containing the Rapidata Python Client to interact with the Rapidata Web API in an easy way.
|
|
5
5
|
License: Apache-2.0
|
|
6
6
|
Author: Rapidata AG
|
|
@@ -12,6 +12,7 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.12
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.13
|
|
15
|
+
Requires-Dist: authlib (>=1.5.1,<2.0.0)
|
|
15
16
|
Requires-Dist: colorama (==0.4.6)
|
|
16
17
|
Requires-Dist: deprecated (>=1.2.14,<2.0.0)
|
|
17
18
|
Requires-Dist: pandas (>=2.2.3,<3.0.0)
|
|
@@ -403,7 +403,7 @@ rapidata/api_client/models/workflow_labeling_step_model.py,sha256=iXeIb78bdMhGFj
|
|
|
403
403
|
rapidata/api_client/models/workflow_split_model.py,sha256=zthOSaUl8dbLhLymLK_lrPTBpeV1a4cODLxnHmNCAZw,4474
|
|
404
404
|
rapidata/api_client/models/workflow_split_model_filter_configs_inner.py,sha256=1Fx9uZtztiiAdMXkj7YeCqt7o6VkG9lKf7D7UP_h088,7447
|
|
405
405
|
rapidata/api_client/models/workflow_state.py,sha256=5LAK1se76RCoozeVB6oxMPb8p_5bhLZJqn7q5fFQWis,850
|
|
406
|
-
rapidata/api_client/rest.py,sha256=
|
|
406
|
+
rapidata/api_client/rest.py,sha256=Nnn1XE9sVUprPm_6AsUmetb_bd9dMjynDOob6y8NJNE,8775
|
|
407
407
|
rapidata/api_client_README.md,sha256=97mR2UeWNIhqNxgUAz87zH8kM9RsbrColX_FvIy_rYo,53784
|
|
408
408
|
rapidata/rapidata_client/__init__.py,sha256=yU0cRoX-RmOHQv0Qj3yJpHaDET4DHZWSO6w2cAApQhQ,920
|
|
409
409
|
rapidata/rapidata_client/assets/__init__.py,sha256=hKgrOSn8gJcBSULaf4auYhH1S1N5AfcwIhBSq1BOKwQ,323
|
|
@@ -438,23 +438,24 @@ rapidata/rapidata_client/metadata/_select_words_metadata.py,sha256=-MK5yQDi_G3BK
|
|
|
438
438
|
rapidata/rapidata_client/order/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
439
439
|
rapidata/rapidata_client/order/_rapidata_dataset.py,sha256=KyEDb4NvlY4uzN3GjiX331VSGOq_I6FhvTnyKrpo4Sg,19400
|
|
440
440
|
rapidata/rapidata_client/order/_rapidata_order_builder.py,sha256=De-gNvOnuXHz6QaRdNpr0_3zOkIEKu8hgtZkatzxZQ4,13155
|
|
441
|
-
rapidata/rapidata_client/order/rapidata_order.py,sha256=
|
|
441
|
+
rapidata/rapidata_client/order/rapidata_order.py,sha256=cz1bXkoS7y_MuJDmjdXao7nbCFd8UF4bb5SHZtOYDjI,10483
|
|
442
442
|
rapidata/rapidata_client/order/rapidata_order_manager.py,sha256=HQzZAnwKnwWl2wLcZPqxQRnth1xWhVoGNkMkzzYbjqA,30615
|
|
443
443
|
rapidata/rapidata_client/order/rapidata_results.py,sha256=0y8EOiqUV7XuwpRJyV53mfo-ddRV1cUPdWexbPEVOHM,8044
|
|
444
|
-
rapidata/rapidata_client/rapidata_client.py,sha256=
|
|
444
|
+
rapidata/rapidata_client/rapidata_client.py,sha256=fcWlmmNCZahK40Ox4aY153tEihIpwkUxYTuiypKF2SY,2857
|
|
445
445
|
rapidata/rapidata_client/referee/__init__.py,sha256=q0Hv9nmfEpyChejtyMLT8hWKL0vTTf_UgUXPYNJ-H6M,153
|
|
446
446
|
rapidata/rapidata_client/referee/_base_referee.py,sha256=MdFOhdxt3sRnWXLDKLJZKFdVpjBGn9jypPnWWQ6msQA,496
|
|
447
447
|
rapidata/rapidata_client/referee/_early_stopping_referee.py,sha256=ULbokQZ91wc9D_20qHUhe55D28D9eTY1J1cMp_-oIDc,2088
|
|
448
448
|
rapidata/rapidata_client/referee/_naive_referee.py,sha256=PVR8uy8hfRjr2DBzdOFyvou6S3swNc-4UvgjhO-09TU,1209
|
|
449
|
-
rapidata/rapidata_client/selection/__init__.py,sha256=
|
|
449
|
+
rapidata/rapidata_client/selection/__init__.py,sha256=OimK44ig39A3kHCR_JGNO4FiUYJ6JUY0ZT0J8dz32Rs,475
|
|
450
450
|
rapidata/rapidata_client/selection/_base_selection.py,sha256=tInbWOgxT_4CHkr5QHoG55ZcUi1ZmfcEGIwLKKCnN20,147
|
|
451
|
-
rapidata/rapidata_client/selection/ab_test_selection.py,sha256=
|
|
451
|
+
rapidata/rapidata_client/selection/ab_test_selection.py,sha256=fymubkVMawqJmYp9FKzWXTki9tgBgoj3cOP8rG9oOd0,1284
|
|
452
452
|
rapidata/rapidata_client/selection/capped_selection.py,sha256=iWhbM1LcayhgFm7oKADXCaKHGdiQIupI0jbYuuEVM2A,1184
|
|
453
453
|
rapidata/rapidata_client/selection/conditional_validation_selection.py,sha256=4etkO5p-wBoI8Wh8vBhNrXm7a_ioFvVmCANJmP8kIwI,2561
|
|
454
454
|
rapidata/rapidata_client/selection/demographic_selection.py,sha256=l4vnNbzlf9ED6BKqN4k5cZXShkXu9L1C5DtO78Vwr5M,1454
|
|
455
455
|
rapidata/rapidata_client/selection/labeling_selection.py,sha256=v26QogjmraFfRoSIgWZl6NMIW_TqbGeuCI2p4HxCeOM,657
|
|
456
456
|
rapidata/rapidata_client/selection/rapidata_selections.py,sha256=Azh0ntBZp9EQNL19imIItotQ8QW3B1gEs5YmuTvUn6U,1526
|
|
457
457
|
rapidata/rapidata_client/selection/shuffling_selection.py,sha256=FzOp7mnBLxNzM5at_-935wd77IHyWnFR1f8uqokiMOg,1201
|
|
458
|
+
rapidata/rapidata_client/selection/static_selection.py,sha256=POhVLjzHcUIuU_GCvRxMuCb27m7CkLxaQPwgf20Xo9o,681
|
|
458
459
|
rapidata/rapidata_client/selection/validation_selection.py,sha256=sedeIa8lpXVXKtFJA9IDeRvo9A1Ne4ZGcepaWDUGhCU,851
|
|
459
460
|
rapidata/rapidata_client/settings/__init__.py,sha256=DTEAT_YykwodZJXqKYOtWRwimLCA-Jxo0F0d-H6A3vM,458
|
|
460
461
|
rapidata/rapidata_client/settings/_rapidata_setting.py,sha256=MD5JhhogSLLrjFKjvL3JhMszOMCygyqLF-st0EwMSkw,352
|
|
@@ -468,12 +469,12 @@ rapidata/rapidata_client/settings/play_video_until_the_end.py,sha256=LLHx2_72k5Z
|
|
|
468
469
|
rapidata/rapidata_client/settings/rapidata_settings.py,sha256=r6eDGo5YHMekOtWqPHD50uI8vEE9VoBJfaWEDFZ78RU,1430
|
|
469
470
|
rapidata/rapidata_client/settings/translation_behaviour.py,sha256=i9n_H0eKJyKW6m3MKH_Cm1XEKWVEWsAV_79xGmGIC-4,742
|
|
470
471
|
rapidata/rapidata_client/validation/__init__.py,sha256=s5wHVtcJkncXSFuL9I0zNwccNOKpWAqxqUjkeohzi2E,24
|
|
471
|
-
rapidata/rapidata_client/validation/rapidata_validation_set.py,sha256=
|
|
472
|
+
rapidata/rapidata_client/validation/rapidata_validation_set.py,sha256=GaatGGuJCHRvdPbjzI0NRQhxzRftKzf-FkaFv25iB78,2649
|
|
472
473
|
rapidata/rapidata_client/validation/rapids/__init__.py,sha256=WU5PPwtTJlte6U90MDakzx4I8Y0laj7siw9teeXj5R0,21
|
|
473
474
|
rapidata/rapidata_client/validation/rapids/box.py,sha256=t3_Kn6doKXdnJdtbwefXnYKPiTKHneJl9E2inkDSqL8,589
|
|
474
475
|
rapidata/rapidata_client/validation/rapids/rapids.py,sha256=uCKnoSn1RykNHgTFbrvCFlfzU8lF42cff-2I-Pd48w0,4620
|
|
475
476
|
rapidata/rapidata_client/validation/rapids/rapids_manager.py,sha256=F00lPYBUx5fTPRw50iZuobtdbjFo6ZHevPMk101JdaY,14271
|
|
476
|
-
rapidata/rapidata_client/validation/validation_set_manager.py,sha256=
|
|
477
|
+
rapidata/rapidata_client/validation/validation_set_manager.py,sha256=Ecm17xf4SFsKdWAGNPdjecuAqsEJ27cHPdns7f8GXhs,26699
|
|
477
478
|
rapidata/rapidata_client/workflow/__init__.py,sha256=7nXcY91xkxjHudBc9H0fP35eBBtgwHGWTQKbb-M4h7Y,477
|
|
478
479
|
rapidata/rapidata_client/workflow/_base_workflow.py,sha256=XyIZFKS_RxAuwIHS848S3AyLEHqd07oTD_5jm2oUbsw,762
|
|
479
480
|
rapidata/rapidata_client/workflow/_classify_workflow.py,sha256=9bT54wxVJgxC-zLk6MVNbseFpzYrvFPjt7DHvxqYfnk,1736
|
|
@@ -486,11 +487,10 @@ rapidata/rapidata_client/workflow/_ranking_workflow.py,sha256=XBIifokhu9Kaqo0qSw
|
|
|
486
487
|
rapidata/rapidata_client/workflow/_select_words_workflow.py,sha256=juTW4TPnnSeBWP3K2QfO092t0u1W8I1ksY6aAtPZOi0,1225
|
|
487
488
|
rapidata/rapidata_client/workflow/_timestamp_workflow.py,sha256=tPi2zu1-SlE_ppbGbMz6MM_2LUSWxM-GA0CZRlB0qFo,1176
|
|
488
489
|
rapidata/service/__init__.py,sha256=s9bS1AJZaWIhLtJX_ZA40_CK39rAAkwdAmymTMbeWl4,68
|
|
489
|
-
rapidata/service/credential_manager.py,sha256=
|
|
490
|
+
rapidata/service/credential_manager.py,sha256=3x-Fb6tyqmgtpjI1MSOtXWW_SkzTK8Lo7I0SSL2YD7E,8602
|
|
490
491
|
rapidata/service/local_file_service.py,sha256=pgorvlWcx52Uh3cEG6VrdMK_t__7dacQ_5AnfY14BW8,877
|
|
491
|
-
rapidata/service/openapi_service.py,sha256=
|
|
492
|
-
rapidata/
|
|
493
|
-
rapidata-2.
|
|
494
|
-
rapidata-2.
|
|
495
|
-
rapidata-2.
|
|
496
|
-
rapidata-2.17.0.dist-info/RECORD,,
|
|
492
|
+
rapidata/service/openapi_service.py,sha256=ORFPfHlb41zOUP5nDjYWZwO-ZcqNF_Mw2r71RitFtS0,4042
|
|
493
|
+
rapidata-2.18.0.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
494
|
+
rapidata-2.18.0.dist-info/METADATA,sha256=RmwN7x5_3aS68UGwC-BTB3FZg8S879NhC32ymMZ8S0A,1187
|
|
495
|
+
rapidata-2.18.0.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
|
|
496
|
+
rapidata-2.18.0.dist-info/RECORD,,
|
|
@@ -1,176 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
import logging
|
|
3
|
-
import threading
|
|
4
|
-
from datetime import datetime, timedelta
|
|
5
|
-
from typing import Optional, Callable
|
|
6
|
-
|
|
7
|
-
import requests
|
|
8
|
-
from pydantic import BaseModel
|
|
9
|
-
|
|
10
|
-
from rapidata.service.credential_manager import CredentialManager
|
|
11
|
-
|
|
12
|
-
logger = logging.getLogger(__name__)
|
|
13
|
-
|
|
14
|
-
class TokenInfo(BaseModel):
|
|
15
|
-
access_token: str
|
|
16
|
-
expires_in: int
|
|
17
|
-
issued_at: datetime
|
|
18
|
-
token_type: str = "Bearer"
|
|
19
|
-
|
|
20
|
-
@property
|
|
21
|
-
def auth_header(self):
|
|
22
|
-
return f"{self.token_type} {self.access_token}"
|
|
23
|
-
|
|
24
|
-
@property
|
|
25
|
-
def time_remaining(self):
|
|
26
|
-
remaining = (
|
|
27
|
-
(self.issued_at + timedelta(seconds=self.expires_in)) - datetime.now()
|
|
28
|
-
).total_seconds()
|
|
29
|
-
return max(0.0, remaining)
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
class TokenManager:
|
|
33
|
-
def __init__(
|
|
34
|
-
self,
|
|
35
|
-
client_id: str | None = None,
|
|
36
|
-
client_secret: str | None = None,
|
|
37
|
-
endpoint: str = "https://auth.rapidata.ai",
|
|
38
|
-
oauth_scope: str = "openid profile email",
|
|
39
|
-
cert_path: str | None = None,
|
|
40
|
-
refresh_threshold: float = 0.8,
|
|
41
|
-
max_sleep_time: float = 30,
|
|
42
|
-
max_token_lifetime: int = 60,
|
|
43
|
-
):
|
|
44
|
-
self._client_id = client_id
|
|
45
|
-
self._client_secret = client_secret
|
|
46
|
-
|
|
47
|
-
if not client_id or not client_secret:
|
|
48
|
-
credential_manager = CredentialManager(
|
|
49
|
-
endpoint=endpoint, cert_path=cert_path
|
|
50
|
-
)
|
|
51
|
-
credentials = credential_manager.get_client_credentials()
|
|
52
|
-
if not credentials:
|
|
53
|
-
raise ValueError("Failed to fetch client credentials")
|
|
54
|
-
self._client_id = credentials.client_id
|
|
55
|
-
self._client_secret = credentials.client_secret
|
|
56
|
-
|
|
57
|
-
self._endpoint = endpoint
|
|
58
|
-
self._oauth_scope = oauth_scope
|
|
59
|
-
self._cert_path = cert_path
|
|
60
|
-
self._refresh_threshold = refresh_threshold
|
|
61
|
-
self._max_sleep_time = max_sleep_time
|
|
62
|
-
self._max_token_lifetime = max_token_lifetime
|
|
63
|
-
|
|
64
|
-
self._token_lock = threading.Lock()
|
|
65
|
-
self._current_token: Optional[TokenInfo] = None
|
|
66
|
-
self._refresh_thread: Optional[threading.Thread] = None
|
|
67
|
-
self._should_stop = threading.Event()
|
|
68
|
-
|
|
69
|
-
def fetch_token(self):
|
|
70
|
-
try:
|
|
71
|
-
response = requests.post(
|
|
72
|
-
f"{self._endpoint}/connect/token",
|
|
73
|
-
data={
|
|
74
|
-
"grant_type": "client_credentials",
|
|
75
|
-
"client_id": self._client_id,
|
|
76
|
-
"client_secret": self._client_secret,
|
|
77
|
-
"scope": self._oauth_scope,
|
|
78
|
-
},
|
|
79
|
-
verify=self._cert_path,
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
if response.ok:
|
|
83
|
-
data = response.json()
|
|
84
|
-
return TokenInfo(
|
|
85
|
-
access_token=data["access_token"],
|
|
86
|
-
token_type=data["token_type"],
|
|
87
|
-
expires_in=min(self._max_token_lifetime, data["expires_in"]),
|
|
88
|
-
issued_at=datetime.now(),
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
else:
|
|
92
|
-
data = response.text
|
|
93
|
-
error_description = "An unknown error occurred"
|
|
94
|
-
if "error_description" in data:
|
|
95
|
-
error_description = (
|
|
96
|
-
data.split("error_description")[1].split("\n")[0].strip()
|
|
97
|
-
)
|
|
98
|
-
raise ValueError(f"Failed to fetch token: {error_description}")
|
|
99
|
-
except requests.RequestException as e:
|
|
100
|
-
raise ValueError(f"Failed to fetch token: {e}")
|
|
101
|
-
except json.JSONDecodeError as e:
|
|
102
|
-
raise ValueError(f"Failed to parse token response: {e}")
|
|
103
|
-
except KeyError as e:
|
|
104
|
-
raise ValueError(f"Failed to extract token from response: {e}")
|
|
105
|
-
|
|
106
|
-
def start_token_refresh(self, token_callback: Callable[[TokenInfo], None]) -> None:
|
|
107
|
-
if self._refresh_thread and self._refresh_thread.is_alive():
|
|
108
|
-
logger.error("Token refresh thread is already running")
|
|
109
|
-
return
|
|
110
|
-
|
|
111
|
-
def refresh_loop():
|
|
112
|
-
while not self._should_stop.is_set():
|
|
113
|
-
try:
|
|
114
|
-
with self._token_lock:
|
|
115
|
-
if self._should_refresh_token(self._current_token):
|
|
116
|
-
logger.debug("Refreshing token")
|
|
117
|
-
self._current_token = self.fetch_token()
|
|
118
|
-
token_callback(self._current_token)
|
|
119
|
-
|
|
120
|
-
if self._current_token:
|
|
121
|
-
time_until_refresh_threshold = (
|
|
122
|
-
self._current_token.time_remaining
|
|
123
|
-
- (
|
|
124
|
-
self._current_token.expires_in
|
|
125
|
-
* (1 - self._refresh_threshold)
|
|
126
|
-
)
|
|
127
|
-
)
|
|
128
|
-
logger.debug("Time until refresh threshold: %s", time_until_refresh_threshold)
|
|
129
|
-
sleep_time = min(
|
|
130
|
-
self._max_sleep_time, time_until_refresh_threshold
|
|
131
|
-
)
|
|
132
|
-
logger.debug(
|
|
133
|
-
f"Sleeping for {sleep_time} until checking the token again"
|
|
134
|
-
)
|
|
135
|
-
self._should_stop.wait(timeout=max(1.0, sleep_time))
|
|
136
|
-
else:
|
|
137
|
-
self._should_stop.wait(timeout=self._max_sleep_time)
|
|
138
|
-
except Exception as e:
|
|
139
|
-
logger.error("Failed to refresh token: %s", e)
|
|
140
|
-
self._should_stop.wait(timeout=5)
|
|
141
|
-
|
|
142
|
-
self._should_stop.clear()
|
|
143
|
-
self._refresh_thread = threading.Thread(target=refresh_loop, daemon=True)
|
|
144
|
-
self._refresh_thread.start()
|
|
145
|
-
|
|
146
|
-
def stop_token_refresh(self):
|
|
147
|
-
self._should_stop.set()
|
|
148
|
-
if self._refresh_thread:
|
|
149
|
-
self._refresh_thread.join(timeout=1)
|
|
150
|
-
self._refresh_thread = None
|
|
151
|
-
|
|
152
|
-
def get_current_token(self) -> Optional[TokenInfo]:
|
|
153
|
-
with self._token_lock:
|
|
154
|
-
return self._current_token
|
|
155
|
-
|
|
156
|
-
def _should_refresh_token(self, token: TokenInfo | None) -> bool:
|
|
157
|
-
if not token:
|
|
158
|
-
return True
|
|
159
|
-
|
|
160
|
-
limit = token.expires_in * (1 - self._refresh_threshold)
|
|
161
|
-
|
|
162
|
-
logger.debug(
|
|
163
|
-
"The token was issued at %s, it expires in %s. It has %s seconds remaining and we refresh the token when it has %s seconds remaining",
|
|
164
|
-
token.issued_at,
|
|
165
|
-
token.expires_in,
|
|
166
|
-
token.time_remaining,
|
|
167
|
-
limit,
|
|
168
|
-
)
|
|
169
|
-
return token.time_remaining < limit
|
|
170
|
-
|
|
171
|
-
def __enter__(self):
|
|
172
|
-
return self
|
|
173
|
-
|
|
174
|
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
175
|
-
self.stop_token_refresh()
|
|
176
|
-
return False
|
|
File without changes
|
|
File without changes
|