rapidata 2.17.0__py3-none-any.whl → 2.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rapidata might be problematic. Click here for more details.

@@ -11,114 +11,86 @@
11
11
  Do not edit the class manually.
12
12
  """ # noqa: E501
13
13
 
14
-
15
14
  import io
16
15
  import json
17
16
  import re
18
- import ssl
17
+ from typing import Dict, Optional
19
18
 
20
- import urllib3
19
+ import requests
20
+ from authlib.integrations.requests_client import OAuth2Session
21
+ from requests.adapters import HTTPAdapter
22
+ from urllib3 import Retry
21
23
 
22
24
  from rapidata.api_client.exceptions import ApiException, ApiValueError
23
25
 
24
- SUPPORTED_SOCKS_PROXIES = {"socks5", "socks5h", "socks4", "socks4a"}
25
- RESTResponseType = urllib3.HTTPResponse
26
-
27
-
28
- def is_socks_proxy_url(url):
29
- if url is None:
30
- return False
31
- split_section = url.split("://")
32
- if len(split_section) < 2:
33
- return False
34
- else:
35
- return split_section[0].lower() in SUPPORTED_SOCKS_PROXIES
36
-
37
26
 
38
27
  class RESTResponse(io.IOBase):
39
28
 
40
- def __init__(self, resp) -> None:
29
+ def __init__(self, resp: requests.Response) -> None:
41
30
  self.response = resp
42
- self.status = resp.status
31
+ self.status = resp.status_code
43
32
  self.reason = resp.reason
44
33
  self.data = None
45
34
 
46
35
  def read(self):
47
36
  if self.data is None:
48
- self.data = self.response.data
37
+ self.data = self.response.content
49
38
  return self.data
50
39
 
51
- def getheaders(self):
40
+ def getheaders(self) -> Dict[str, str]:
52
41
  """Returns a dictionary of the response headers."""
53
- return self.response.headers
42
+ return dict(self.response.headers)
54
43
 
55
- def getheader(self, name, default=None):
44
+ def getheader(self, name, default=None) -> Optional[str]:
56
45
  """Returns a given response header."""
57
46
  return self.response.headers.get(name, default)
58
47
 
59
48
 
60
- class RESTClientObject:
61
-
62
- def __init__(self, configuration) -> None:
63
- # urllib3.PoolManager will pass all kw parameters to connectionpool
64
- # https://github.com/shazow/urllib3/blob/f9409436f83aeb79fbaf090181cd81b784f1b8ce/urllib3/poolmanager.py#L75 # noqa: E501
65
- # https://github.com/shazow/urllib3/blob/f9409436f83aeb79fbaf090181cd81b784f1b8ce/urllib3/connectionpool.py#L680 # noqa: E501
66
- # Custom SSL certificates and client certificates: http://urllib3.readthedocs.io/en/latest/advanced-usage.html # noqa: E501
67
-
68
- # cert_reqs
69
- if configuration.verify_ssl:
70
- cert_reqs = ssl.CERT_REQUIRED
71
- else:
72
- cert_reqs = ssl.CERT_NONE
73
-
74
- pool_args = {
75
- "cert_reqs": cert_reqs,
76
- "ca_certs": configuration.ssl_ca_cert,
77
- "cert_file": configuration.cert_file,
78
- "key_file": configuration.key_file,
79
- }
80
- if configuration.assert_hostname is not None:
81
- pool_args['assert_hostname'] = (
82
- configuration.assert_hostname
83
- )
49
+ RESTResponseType = RESTResponse
84
50
 
85
- if configuration.retries is not None:
86
- pool_args['retries'] = configuration.retries
87
51
 
88
- if configuration.tls_server_name:
89
- pool_args['server_hostname'] = configuration.tls_server_name
90
-
91
-
92
- if configuration.socket_options is not None:
93
- pool_args['socket_options'] = configuration.socket_options
52
+ class RESTClientObject:
94
53
 
95
- if configuration.connection_pool_maxsize is not None:
96
- pool_args['maxsize'] = configuration.connection_pool_maxsize
54
+ def __init__(self, configuration) -> None:
55
+ self.configuration = configuration
97
56
 
98
- # https pool manager
99
- self.pool_manager: urllib3.PoolManager
57
+ self.session: Optional[OAuth2Session] = None
100
58
 
101
- if configuration.proxy:
102
- if is_socks_proxy_url(configuration.proxy):
103
- from urllib3.contrib.socks import SOCKSProxyManager
104
- pool_args["proxy_url"] = configuration.proxy
105
- pool_args["headers"] = configuration.proxy_headers
106
- self.pool_manager = SOCKSProxyManager(**pool_args)
107
- else:
108
- pool_args["proxy_url"] = configuration.proxy
109
- pool_args["proxy_headers"] = configuration.proxy_headers
110
- self.pool_manager = urllib3.ProxyManager(**pool_args)
111
- else:
112
- self.pool_manager = urllib3.PoolManager(**pool_args)
59
+ def setup_oauth_client_credentials(
60
+ self, client_id: str, client_secret: str, token_endpoint: str, scope: str
61
+ ):
62
+ self.session = OAuth2Session(
63
+ client_id=client_id,
64
+ client_secret=client_secret,
65
+ token_endpoint=token_endpoint,
66
+ scope=scope,
67
+ )
68
+ self._configure_session_defaults()
69
+ self.session.fetch_token()
70
+
71
+ def setup_oauth_with_token(self,
72
+ client_id: str | None,
73
+ client_secret: str | None,
74
+ token: dict,
75
+ token_endpoint: str,
76
+ leeway: int = 60):
77
+ self.session = OAuth2Session(
78
+ token=token,
79
+ token_endpoint=token_endpoint,
80
+ client_id=client_id,
81
+ client_secret=client_secret,
82
+ leeway=leeway,
83
+ )
84
+ self._configure_session_defaults()
113
85
 
114
86
  def request(
115
- self,
116
- method,
117
- url,
118
- headers=None,
119
- body=None,
120
- post_params=None,
121
- _request_timeout=None
87
+ self,
88
+ method,
89
+ url,
90
+ headers=None,
91
+ body=None,
92
+ post_params=None,
93
+ _request_timeout=None,
122
94
  ):
123
95
  """Perform requests.
124
96
 
@@ -135,15 +107,7 @@ class RESTClientObject:
135
107
  (connection, read) timeouts.
136
108
  """
137
109
  method = method.upper()
138
- assert method in [
139
- 'GET',
140
- 'HEAD',
141
- 'DELETE',
142
- 'POST',
143
- 'PUT',
144
- 'PATCH',
145
- 'OPTIONS'
146
- ]
110
+ assert method in ["GET", "HEAD", "DELETE", "POST", "PUT", "PATCH", "OPTIONS"]
147
111
 
148
112
  if post_params and body:
149
113
  raise ApiValueError(
@@ -153,105 +117,115 @@ class RESTClientObject:
153
117
  post_params = post_params or {}
154
118
  headers = headers or {}
155
119
 
120
+ if not self.session:
121
+ raise ApiValueError(
122
+ "OAuth2 session is not initialized. Please initialize it before making requests."
123
+ )
124
+
125
+ session = self.session
126
+
156
127
  timeout = None
157
128
  if _request_timeout:
158
129
  if isinstance(_request_timeout, (int, float)):
159
- timeout = urllib3.Timeout(total=_request_timeout)
160
- elif (
161
- isinstance(_request_timeout, tuple)
162
- and len(_request_timeout) == 2
163
- ):
164
- timeout = urllib3.Timeout(
165
- connect=_request_timeout[0],
166
- read=_request_timeout[1]
167
- )
130
+ timeout = _request_timeout
131
+ elif isinstance(_request_timeout, tuple) and len(_request_timeout) == 2:
132
+ timeout = _request_timeout
168
133
 
169
134
  try:
170
- # For `POST`, `PUT`, `PATCH`, `OPTIONS`, `DELETE`
171
- if method in ['POST', 'PUT', 'PATCH', 'OPTIONS', 'DELETE']:
172
-
173
- # no content type provided or payload is json
174
- content_type = headers.get('Content-Type')
175
- if (
176
- not content_type
177
- or re.search('json', content_type, re.IGNORECASE)
178
- ):
135
+ if method in ["POST", "PUT", "PATCH", "OPTIONS", "DELETE"]:
136
+ content_type = headers.get("Content-Type")
137
+
138
+ if not content_type or re.search("json", content_type, re.IGNORECASE):
179
139
  request_body = None
180
140
  if body is not None:
181
141
  request_body = json.dumps(body)
182
- r = self.pool_manager.request(
183
- method,
184
- url,
185
- body=request_body,
186
- timeout=timeout,
187
- headers=headers,
188
- preload_content=False
189
- )
142
+ r = session.request(method, url, data=request_body, timeout=timeout, headers=headers)
143
+
190
144
  elif content_type == 'application/x-www-form-urlencoded':
191
- r = self.pool_manager.request(
192
- method,
193
- url,
194
- fields=post_params,
195
- encode_multipart=False,
196
- timeout=timeout,
197
- headers=headers,
198
- preload_content=False
199
- )
145
+ r = session.request(method, url, data=post_params, timeout=timeout, headers=headers)
146
+
200
147
  elif content_type == 'multipart/form-data':
201
- # must del headers['Content-Type'], or the correct
202
- # Content-Type which generated by urllib3 will be
203
- # overwritten.
204
148
  del headers['Content-Type']
205
- # Ensures that dict objects are serialized
206
- post_params = [(a, json.dumps(b)) if isinstance(b, dict) else (a,b) for a, b in post_params]
207
- r = self.pool_manager.request(
208
- method,
209
- url,
210
- fields=post_params,
211
- encode_multipart=True,
212
- timeout=timeout,
213
- headers=headers,
214
- preload_content=False
215
- )
216
- # Pass a `string` parameter directly in the body to support
217
- # other content types than JSON when `body` argument is
218
- # provided in serialized form.
149
+ files = []
150
+ data = {}
151
+
152
+ for key, value in post_params:
153
+ if isinstance(value, tuple) and len(value) >= 2:
154
+ # This is a file tuple (filename, file_data, [content_type])
155
+ filename, file_data = value[0], value[1]
156
+ content_type = value[2] if len(value) > 2 else None
157
+ files.append((key, (filename, file_data, content_type)))
158
+ elif isinstance(value, dict):
159
+ # JSON-serialize dictionary values
160
+ if key in data:
161
+ # If we already have this key, handle as needed
162
+ # (convert to list or append to existing list)
163
+ if not isinstance(data[key], list):
164
+ data[key] = [data[key]]
165
+ data[key].append(json.dumps(value))
166
+ else:
167
+ data[key] = json.dumps(value)
168
+ else:
169
+ # Regular form data
170
+ if key in data:
171
+ if not isinstance(data[key], list):
172
+ data[key] = [data[key]]
173
+ data[key].append(value)
174
+ else:
175
+ data[key] = value
176
+ r = session.request(method, url, files=files, data=data, timeout=timeout, headers=headers)
177
+
219
178
  elif isinstance(body, str) or isinstance(body, bytes):
220
- r = self.pool_manager.request(
221
- method,
222
- url,
223
- body=body,
224
- timeout=timeout,
225
- headers=headers,
226
- preload_content=False
227
- )
179
+ r = session.request(method, url, data=body, timeout=timeout, headers=headers)
180
+
228
181
  elif headers['Content-Type'].startswith('text/') and isinstance(body, bool):
229
- request_body = "true" if body else "false"
230
- r = self.pool_manager.request(
231
- method,
232
- url,
233
- body=request_body,
234
- preload_content=False,
235
- timeout=timeout,
236
- headers=headers)
182
+ request_body = 'true' if body else 'false'
183
+ r = session.request(method, url, data=request_body, timeout=timeout, headers=headers)
184
+
237
185
  else:
238
- # Cannot generate the request from given parameters
239
- msg = """Cannot prepare a request message for provided
240
- arguments. Please check that your arguments match
241
- declared content type."""
186
+ msg = '''Cannot prepare a request message for provided arguments.
187
+ Please check that your arguments match declared content type.'''
242
188
  raise ApiException(status=0, reason=msg)
243
- # For `GET`, `HEAD`
189
+
244
190
  else:
245
- r = self.pool_manager.request(
246
- method,
247
- url,
248
- fields={},
249
- timeout=timeout,
250
- headers=headers,
251
- preload_content=False
252
- )
253
- except urllib3.exceptions.SSLError as e:
254
- msg = "\n".join([type(e).__name__, str(e)])
191
+ r = session.request(method, url, params={}, timeout=timeout, headers=headers)
192
+
193
+ except requests.exceptions.SSLError as e:
194
+ msg = '\n'.join([type(e).__name__, str(e)])
255
195
  raise ApiException(status=0, reason=msg)
256
196
 
257
197
  return RESTResponse(r)
198
+
199
+ def _configure_session_defaults(self):
200
+ self.session.verify = (
201
+ self.configuration.ssl_ca_cert
202
+ if self.configuration.ssl_ca_cert
203
+ else self.configuration.verify_ssl
204
+ )
205
+
206
+ if self.configuration.cert_file and self.configuration.key_file:
207
+ self.session.cert = (
208
+ self.configuration.cert_file,
209
+ self.configuration.key_file,
210
+ )
211
+
212
+ if self.configuration.retries is not None:
213
+ retry = Retry(
214
+ total=self.configuration.retries,
215
+ backoff_factor=0.3,
216
+ status_forcelist=[429, 500, 502, 503, 504],
217
+ )
218
+ adapter = HTTPAdapter(max_retries=retry)
219
+ self.session.mount("https://", adapter)
220
+ # noinspection HttpUrlsUsage
221
+ self.session.mount("http://", adapter)
222
+
223
+ if self.configuration.proxy:
224
+ self.session.proxies = {
225
+ "http": self.configuration.proxy,
226
+ "https": self.configuration.proxy,
227
+ }
228
+
229
+ if self.configuration.proxy_headers:
230
+ for key, value in self.configuration.proxy_headers.items():
231
+ self.session.headers[key] = value
@@ -51,7 +51,7 @@ class RapidataOrder:
51
51
  """Runs the order to start collecting responses."""
52
52
  self.__openapi_service.order_api.order_submit_post(self.order_id)
53
53
  if print_link:
54
- print(f"Order '{self.name}' is now viewable under: https://app.{self.__openapi_service.enviroment}/order/detail/{self.order_id}")
54
+ print(f"Order '{self.name}' is now viewable under: https://app.{self.__openapi_service.environment}/order/detail/{self.order_id}")
55
55
  return self
56
56
 
57
57
  def pause(self) -> None:
@@ -167,7 +167,7 @@ class RapidataOrder:
167
167
  Exception: If the order is not in processing state.
168
168
  """
169
169
  campaign_id = self.__get_campaign_id()
170
- auth_url = f"https://app.{self.__openapi_service.enviroment}/order/detail/{self.order_id}/preview?campaignId={campaign_id}"
170
+ auth_url = f"https://app.{self.__openapi_service.environment}/order/detail/{self.order_id}/preview?campaignId={campaign_id}"
171
171
  could_open_browser = webbrowser.open(auth_url)
172
172
  if not could_open_browser:
173
173
  encoded_url = urllib.parse.quote(auth_url, safe="%/:=&?~#+!$,;'@()*[]")
@@ -2,28 +2,37 @@ from rapidata.service.openapi_service import OpenAPIService
2
2
 
3
3
  from rapidata.rapidata_client.order.rapidata_order_manager import RapidataOrderManager
4
4
 
5
- from rapidata.rapidata_client.validation.validation_set_manager import ValidationSetManager
5
+ from rapidata.rapidata_client.validation.validation_set_manager import (
6
+ ValidationSetManager,
7
+ )
6
8
 
7
9
  from rapidata.rapidata_client.demographic.demographic_manager import DemographicManager
8
10
 
11
+
9
12
  class RapidataClient:
10
13
  """The Rapidata client is the main entry point for interacting with the Rapidata API. It allows you to create orders and validation sets."""
11
-
14
+
12
15
  def __init__(
13
- self,
14
- client_id: str | None = None,
15
- client_secret: str | None = None,
16
- enviroment: str = "rapidata.ai",
17
- oauth_scope: str = "openid",
18
- cert_path: str | None = None,
16
+ self,
17
+ client_id: str | None = None,
18
+ client_secret: str | None = None,
19
+ environment: str = "rapidata.ai",
20
+ oauth_scope: str = "openid",
21
+ cert_path: str | None = None,
22
+ token: dict | None = None,
23
+ leeway: int = 60
19
24
  ):
20
- """Initialize the RapidataClient. If both the client_id and client_secret are None, it will try using your credentials under "~/.config/rapidata/credentials.json".
21
- If this is not successful, it will open a browser windown and ask you to log in, then save your new credentials in said json file.
25
+ """Initialize the RapidataClient. If both the client_id and client_secret are None, it will try using your credentials under "~/.config/rapidata/credentials.json".
26
+ If this is not successful, it will open a browser window and ask you to log in, then save your new credentials in said json file.
22
27
 
23
28
  Args:
24
29
  client_id (str): The client ID for authentication.
25
30
  client_secret (str): The client secret for authentication.
26
- enviroment (str, optional): The API endpoint.
31
+ environment (str, optional): The API endpoint.
32
+ oauth_scope (str, optional): The scopes to use for authentication. In general this does not need to be changed.
33
+ cert_path (str, optional): An optional path to a certificate file useful for development.
34
+ token (dict, optional): If you already have a token that the client should use for authentication. Important, if set, this needs to be the complete token object containing the access token, token type and expiration time.
35
+ leeway (int, optional): An optional leeway to use to determine if a token is expired. Defaults to 60 seconds.
27
36
 
28
37
  Attributes:
29
38
  order (RapidataOrderManager): The RapidataOrderManager instance.
@@ -32,17 +41,19 @@ class RapidataClient:
32
41
  self._openapi_service = OpenAPIService(
33
42
  client_id=client_id,
34
43
  client_secret=client_secret,
35
- enviroment=enviroment,
44
+ environment=environment,
36
45
  oauth_scope=oauth_scope,
37
- cert_path=cert_path
46
+ cert_path=cert_path,
47
+ token=token,
48
+ leeway=leeway,
38
49
  )
39
-
50
+
40
51
  self.order = RapidataOrderManager(openapi_service=self._openapi_service)
41
-
52
+
42
53
  self.validation = ValidationSetManager(openapi_service=self._openapi_service)
43
54
 
44
55
  self._demographic = DemographicManager(openapi_service=self._openapi_service)
45
56
 
46
57
  def reset_credentials(self):
47
- """Reset the credentials saved in the configuration file for the current enviroment."""
58
+ """Reset the credentials saved in the configuration file for the current environment."""
48
59
  self._openapi_service.reset_credentials()
@@ -5,3 +5,5 @@ from .validation_selection import ValidationSelection
5
5
  from .conditional_validation_selection import ConditionalValidationSelection
6
6
  from .capped_selection import CappedSelection
7
7
  from .shuffling_selection import ShufflingSelection
8
+ from .ab_test_selection import AbTestSelection
9
+ from .static_selection import StaticSelection
@@ -20,7 +20,7 @@ class AbTestSelection(RapidataSelection):
20
20
  b_selections (Sequence[RapidataSelection]): List of selections for group B.
21
21
  """
22
22
 
23
- def __init__(self, a_selections: Sequence[RapidataSelection], b_selections: Sequence[RapidataSelection], max_rapids: int):
23
+ def __init__(self, a_selections: Sequence[RapidataSelection], b_selections: Sequence[RapidataSelection]):
24
24
  self.a_selections = a_selections
25
25
  self.b_selections = b_selections
26
26
 
@@ -0,0 +1,22 @@
1
+
2
+ from rapidata.api_client.models.static_selection import StaticSelection as StaticSelectionModel
3
+ from rapidata.rapidata_client.selection._base_selection import RapidataSelection
4
+
5
+ class StaticSelection(RapidataSelection):
6
+ """StaticSelection Class
7
+
8
+ Given a list of RapidIds, theses specific rapids will be shown in order for every session.
9
+
10
+ Args:
11
+ selections (list[str]): List of rapid ids to show.
12
+ """
13
+
14
+ def __init__(self, rapid_ids: list[str]):
15
+ self.rapid_ids = rapid_ids
16
+
17
+ def _to_model(self) -> StaticSelectionModel:
18
+ return StaticSelectionModel(
19
+ _t="StaticSelection",
20
+ rapidIds=self.rapid_ids
21
+ )
22
+
@@ -31,13 +31,13 @@ class RapidataValidationSet:
31
31
  rapid._add_to_validation_set(self.id, self.__openapi_service, self.__session)
32
32
  return self
33
33
 
34
- def update_dimensions(self, dimensions: list[str] | None):
34
+ def update_dimensions(self, dimensions: list[str]):
35
35
  """Update the dimensions of the validation set.
36
36
 
37
37
  Args:
38
38
  dimensions (list[str]): The new dimensions of the validation set.
39
39
  """
40
- self.__openapi_service.validation_api.validation_validation_set_id_dimensions_patch(self.id, UpdateDimensionsModel(dimensions=dimensions) if dimensions else None)
40
+ self.__openapi_service.validation_api.validation_validation_set_id_dimensions_patch(self.id, UpdateDimensionsModel(dimensions=dimensions))
41
41
  return self
42
42
 
43
43
  def _get_session(self, max_retries: int = 5, max_workers: int = 10) -> requests.Session:
@@ -39,7 +39,7 @@ class ValidationSetManager:
39
39
  data_type: str = RapidataDataTypes.MEDIA,
40
40
  contexts: list[str] | None = None,
41
41
  explanations: list[str | None] | None = None,
42
- dimensions: list[str] | None = None,
42
+ dimensions: list[str] = [],
43
43
  print_confirmation: bool = True,
44
44
  ) -> RapidataValidationSet:
45
45
  """Create a classification validation set.
@@ -59,7 +59,7 @@ class ValidationSetManager:
59
59
  If provided has to be the same length as datapoints and will be shown in addition to the instruction and answer options. (Therefore will be different for each datapoint)
60
60
  Will be match up with the datapoints using the list index.
61
61
  explanations (list[str | None], optional): The explanations for each datapoint. Will be given to the annotators in case the answer is wrong. Defaults to None.
62
- dimensions (list[str] | None, optional): The dimensions of the validation set. If not provided will be set to the default dimensions.
62
+ dimensions (list[str], optional): The dimensions to add to the validation set accross which users will be tracked. Defaults to [] which is the default dimension.
63
63
  print_confirmation (bool, optional): Whether to print a confirmation message that validation set has been created. Defaults to True.
64
64
 
65
65
  Example:
@@ -107,7 +107,7 @@ class ValidationSetManager:
107
107
  data_type: str = RapidataDataTypes.MEDIA,
108
108
  contexts: list[str] | None = None,
109
109
  explanation: list[str | None] | None = None,
110
- dimensions: list[str] | None = None,
110
+ dimensions: list[str] = [],
111
111
  print_confirmation: bool = True,
112
112
  ) -> RapidataValidationSet:
113
113
  """Create a comparison validation set.
@@ -127,7 +127,7 @@ class ValidationSetManager:
127
127
  If provided has to be the same length as datapoints and will be shown in addition to the instruction and truth. (Therefore will be different for each datapoint)
128
128
  Will be match up with the datapoints using the list index.
129
129
  explanation (list[str | None], optional): The explanations for each datapoint. Will be given to the annotators in case the answer is wrong. Defaults to None.
130
- dimensions (list[str] | None, optional): The dimensions of the validation set. If not provided will be set to the default dimensions.
130
+ dimensions (list[str], optional): The dimensions to add to the validation set accross which users will be tracked. Defaults to [] which is the default dimension.
131
131
  print_confirmation (bool, optional): Whether to print a confirmation message that validation set has been created. Defaults to True.
132
132
 
133
133
  Example:
@@ -175,7 +175,7 @@ class ValidationSetManager:
175
175
  required_precision: float = 1.0,
176
176
  required_completeness: float = 1.0,
177
177
  explanation: list[str | None] | None = None,
178
- dimensions: list[str] | None = None,
178
+ dimensions: list[str] = [],
179
179
  print_confirmation: bool = True,
180
180
  ) -> RapidataValidationSet:
181
181
  """Create a select words validation set.
@@ -194,7 +194,7 @@ class ValidationSetManager:
194
194
  required_precision (float, optional): The required precision for the labeler to get the rapid correct (minimum ratio of the words selected that need to be correct). Defaults to 1.0 (no wrong word can be selected).
195
195
  required_completeness (float, optional): The required completeness for the labeler to get the rapid correct (miminum ratio of total correct words selected). Defaults to 1.0 (all correct words need to be selected).
196
196
  explanation (list[str | None], optional): The explanations for each datapoint. Will be given to the annotators in case the answer is wrong. Defaults to None.
197
- dimensions (list[str] | None, optional): The dimensions of the validation set. If not provided will be set to the default dimensions.
197
+ dimensions (list[str], optional): The dimensions to add to the validation set accross which users will be tracked. Defaults to [] which is the default dimension.
198
198
  print_confirmation (bool, optional): Whether to print a confirmation message that validation set has been created. Defaults to True.
199
199
 
200
200
  Example:
@@ -238,7 +238,7 @@ class ValidationSetManager:
238
238
  datapoints: list[str],
239
239
  contexts: list[str] | None = None,
240
240
  explanation: list[str | None] | None = None,
241
- dimensions: list[str] | None = None,
241
+ dimensions: list[str] = [],
242
242
  print_confirmation: bool = True,
243
243
  ) -> RapidataValidationSet:
244
244
  """Create a locate validation set.
@@ -253,7 +253,7 @@ class ValidationSetManager:
253
253
  datapoints (list[str]): The datapoints that will be used for validation.
254
254
  contexts (list[str], optional): The contexts for each datapoint. Defaults to None.
255
255
  explanation (list[str | None], optional): The explanations for each datapoint. Will be given to the annotators in case the answer is wrong. Defaults to None.
256
- dimensions (list[str] | None, optional): The dimensions of the validation set. If not provided will be set to the default dimensions.
256
+ dimensions (list[str], optional): The dimensions to add to the validation set accross which users will be tracked. Defaults to [] which is the default dimension.
257
257
  print_confirmation (bool, optional): Whether to print a confirmation message that validation set has been created. Defaults to True.
258
258
 
259
259
  Example:
@@ -299,7 +299,7 @@ class ValidationSetManager:
299
299
  datapoints: list[str],
300
300
  contexts: list[str] | None = None,
301
301
  explanation: list[str | None] | None = None,
302
- dimensions: list[str] | None = None,
302
+ dimensions: list[str] = [],
303
303
  print_confirmation: bool = True,
304
304
  ) -> RapidataValidationSet:
305
305
  """Create a draw validation set.
@@ -314,7 +314,7 @@ class ValidationSetManager:
314
314
  datapoints (list[str]): The datapoints that will be used for validation.
315
315
  contexts (list[str], optional): The contexts for each datapoint. Defaults to None.
316
316
  explanation (list[str | None], optional): The explanations for each datapoint. Will be given to the annotators in case the answer is wrong. Defaults to None.
317
- dimensions (list[str] | None, optional): The dimensions of the validation set. If not provided will be set to the default dimensions.
317
+ dimensions (list[str], optional): The dimensions to add to the validation set accross which users will be tracked. Defaults to [] which is the default dimension.
318
318
  print_confirmation (bool, optional): Whether to print a confirmation message that validation set has been created. Defaults to True.
319
319
 
320
320
  Example:
@@ -359,7 +359,7 @@ class ValidationSetManager:
359
359
  datapoints: list[str],
360
360
  contexts: list[str] | None = None,
361
361
  explanation: list[str | None] | None = None,
362
- dimensions: list[str] | None = None,
362
+ dimensions: list[str] = [],
363
363
  print_confirmation: bool = True,
364
364
  ) -> RapidataValidationSet:
365
365
  """Create a timestamp validation set.
@@ -375,7 +375,7 @@ class ValidationSetManager:
375
375
  datapoints (list[str]): The datapoints that will be used for validation.
376
376
  contexts (list[str], optional): The contexts for each datapoint. Defaults to None.
377
377
  explanation (list[str | None], optional): The explanations for each datapoint. Will be given to the annotators in case the answer is wrong. Defaults to None.
378
- dimensions (list[str] | None, optional): The dimensions of the validation set. If not provided will be set to the default dimensions.
378
+ dimensions (list[str], optional): The dimensions to add to the validation set accross which users will be tracked. Defaults to [] which is the default dimension.
379
379
  print_confirmation (bool, optional): Whether to print a confirmation message that validation set has been created. Defaults to True.
380
380
 
381
381
  Example:
@@ -416,7 +416,7 @@ class ValidationSetManager:
416
416
  def create_mixed_set(self,
417
417
  name: str,
418
418
  rapids: list[Rapid],
419
- dimensions: list[str] | None = None,
419
+ dimensions: list[str] = [],
420
420
  print_confirmation: bool = True
421
421
  ) -> RapidataValidationSet:
422
422
  """Create a validation set with a list of rapids.
@@ -424,7 +424,7 @@ class ValidationSetManager:
424
424
  Args:
425
425
  name (str): The name of the validation set. (will not be shown to the labeler)
426
426
  rapids (list[Rapid]): The list of rapids to add to the validation set.
427
- dimensions (list[str] | None, optional): The dimensions of the validation set. If not provided will be set to the default dimensions.
427
+ dimensions (list[str], optional): The dimensions to add to the validation set accross which users will be tracked. Defaults to [] which is the default dimension.
428
428
  print_confirmation (bool, optional): Whether to print a confirmation message that validation set has been created. Defaults to True.
429
429
  """
430
430
 
@@ -468,10 +468,12 @@ class ValidationSetManager:
468
468
  if print_confirmation:
469
469
  print()
470
470
  print(f"Validation set '{name}' created with ID {validation_set_id}\n",
471
- f"Now viewable under: https://app.{self.__openapi_service.enviroment}/validation-set/detail/{validation_set_id}",
471
+ f"Now viewable under: https://app.{self.__openapi_service.environment}/validation-set/detail/{validation_set_id}",
472
472
  sep="")
473
473
 
474
- validation_set.update_dimensions(dimensions)
474
+ if dimensions:
475
+ validation_set.update_dimensions(dimensions)
476
+
475
477
  return validation_set
476
478
 
477
479
 
@@ -133,9 +133,9 @@ class CredentialManager:
133
133
  return credential
134
134
 
135
135
  return self._create_new_credentials()
136
-
136
+
137
137
  def reset_credentials(self) -> None:
138
- """Reset the stored credentials for current enviroment."""
138
+ """Reset the stored credentials for current environment."""
139
139
  credentials = self._read_credentials()
140
140
  if self.endpoint in credentials:
141
141
  del credentials[self.endpoint]
@@ -1,3 +1,6 @@
1
+ import subprocess
2
+ from importlib.metadata import version, PackageNotFoundError
3
+
1
4
  from rapidata.api_client.api.campaign_api import CampaignApi
2
5
  from rapidata.api_client.api.dataset_api import DatasetApi
3
6
  from rapidata.api_client.api.order_api import OrderApi
@@ -7,31 +10,31 @@ from rapidata.api_client.api.validation_api import ValidationApi
7
10
  from rapidata.api_client.api.workflow_api import WorkflowApi
8
11
  from rapidata.api_client.api_client import ApiClient
9
12
  from rapidata.api_client.configuration import Configuration
10
- from rapidata.service.token_manager import TokenManager, TokenInfo
11
13
  from rapidata.service.credential_manager import CredentialManager
12
14
 
13
- from importlib.metadata import version, PackageNotFoundError
14
-
15
15
 
16
16
  class OpenAPIService:
17
17
  def __init__(
18
- self,
19
- client_id: str | None,
20
- client_secret: str | None,
21
- enviroment: str,
22
- oauth_scope: str,
23
- cert_path: str | None = None,
18
+ self,
19
+ client_id: str | None,
20
+ client_secret: str | None,
21
+ environment: str,
22
+ oauth_scope: str,
23
+ cert_path: str | None = None,
24
+ token: dict | None = None,
25
+ leeway: int = 60,
24
26
  ):
25
- self.enviroment = enviroment
26
- endpoint = f"https://api.{enviroment}"
27
- self._token_url = f"https://auth.{enviroment}"
28
- token_manager = TokenManager(
29
- client_id=client_id,
30
- client_secret=client_secret,
31
- endpoint=self._token_url,
32
- oauth_scope=oauth_scope,
33
- cert_path=cert_path,
27
+ self.environment = environment
28
+ endpoint = f"https://api.{environment}"
29
+ auth_endpoint = f"https://auth.{environment}"
30
+
31
+ if environment == "rapidata.dev" and not cert_path:
32
+ cert_path = _get_local_certificate()
33
+
34
+ self.credential_manager = CredentialManager(
35
+ endpoint=auth_endpoint, cert_path=cert_path
34
36
  )
37
+
35
38
  client_configuration = Configuration(host=endpoint, ssl_ca_cert=cert_path)
36
39
  self.api_client = ApiClient(
37
40
  configuration=client_configuration,
@@ -39,16 +42,32 @@ class OpenAPIService:
39
42
  header_value=f"RapidataPythonSDK/{self._get_rapidata_package_version()}",
40
43
  )
41
44
 
42
- self.api_client.configuration.api_key["bearer"] = (
43
- f"Bearer {token_manager.fetch_token().access_token}"
45
+ if token:
46
+ self.api_client.rest_client.setup_oauth_with_token(
47
+ token=token,
48
+ token_endpoint=f"{auth_endpoint}/connect/token",
49
+ client_id=client_id,
50
+ client_secret=client_secret,
51
+ leeway=leeway,
52
+ )
53
+ return
54
+
55
+ if not client_id or not client_secret:
56
+ credentials = self.credential_manager.get_client_credentials()
57
+ if not credentials:
58
+ raise ValueError("Failed to fetch client credentials")
59
+ client_id = credentials.client_id
60
+ client_secret = credentials.client_secret
61
+
62
+ self.api_client.rest_client.setup_oauth_client_credentials(
63
+ client_id=client_id,
64
+ client_secret=client_secret,
65
+ token_endpoint=f"{auth_endpoint}/connect/token",
66
+ scope=oauth_scope,
44
67
  )
45
68
 
46
- self._cert_path = cert_path
47
-
48
- token_manager.start_token_refresh(token_callback=self._set_token)
49
-
50
69
  def reset_credentials(self):
51
- CredentialManager(endpoint=self._token_url, cert_path=self._cert_path).reset_credentials()
70
+ self.credential_manager.reset_credentials()
52
71
 
53
72
  @property
54
73
  def order_api(self) -> OrderApi:
@@ -78,9 +97,6 @@ class OpenAPIService:
78
97
  def workflow_api(self) -> WorkflowApi:
79
98
  return WorkflowApi(self.api_client)
80
99
 
81
- def _set_token(self, token: TokenInfo):
82
- self.api_client.configuration.api_key["bearer"] = f"Bearer {token.access_token}"
83
-
84
100
  def _get_rapidata_package_version(self):
85
101
  """
86
102
  Returns the version of the currently installed rapidata package.
@@ -93,3 +109,15 @@ class OpenAPIService:
93
109
  return version("rapidata")
94
110
  except PackageNotFoundError:
95
111
  return None
112
+
113
+
114
+ def _get_local_certificate() -> str | None:
115
+ result = subprocess.run(["mkcert", "-CAROOT"], capture_output=True)
116
+ if result.returncode != 0:
117
+ return None
118
+
119
+ output = result.stdout.decode("utf-8").strip()
120
+ if not output:
121
+ return None
122
+
123
+ return f"{output}/rootCA.pem"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rapidata
3
- Version: 2.17.0
3
+ Version: 2.18.0
4
4
  Summary: Rapidata package containing the Rapidata Python Client to interact with the Rapidata Web API in an easy way.
5
5
  License: Apache-2.0
6
6
  Author: Rapidata AG
@@ -12,6 +12,7 @@ Classifier: Programming Language :: Python :: 3.10
12
12
  Classifier: Programming Language :: Python :: 3.11
13
13
  Classifier: Programming Language :: Python :: 3.12
14
14
  Classifier: Programming Language :: Python :: 3.13
15
+ Requires-Dist: authlib (>=1.5.1,<2.0.0)
15
16
  Requires-Dist: colorama (==0.4.6)
16
17
  Requires-Dist: deprecated (>=1.2.14,<2.0.0)
17
18
  Requires-Dist: pandas (>=2.2.3,<3.0.0)
@@ -403,7 +403,7 @@ rapidata/api_client/models/workflow_labeling_step_model.py,sha256=iXeIb78bdMhGFj
403
403
  rapidata/api_client/models/workflow_split_model.py,sha256=zthOSaUl8dbLhLymLK_lrPTBpeV1a4cODLxnHmNCAZw,4474
404
404
  rapidata/api_client/models/workflow_split_model_filter_configs_inner.py,sha256=1Fx9uZtztiiAdMXkj7YeCqt7o6VkG9lKf7D7UP_h088,7447
405
405
  rapidata/api_client/models/workflow_state.py,sha256=5LAK1se76RCoozeVB6oxMPb8p_5bhLZJqn7q5fFQWis,850
406
- rapidata/api_client/rest.py,sha256=zmCIFQC2l1t-KZcq-TgEm3vco3y_LK6vRm3Q07K-xRI,9423
406
+ rapidata/api_client/rest.py,sha256=Nnn1XE9sVUprPm_6AsUmetb_bd9dMjynDOob6y8NJNE,8775
407
407
  rapidata/api_client_README.md,sha256=97mR2UeWNIhqNxgUAz87zH8kM9RsbrColX_FvIy_rYo,53784
408
408
  rapidata/rapidata_client/__init__.py,sha256=yU0cRoX-RmOHQv0Qj3yJpHaDET4DHZWSO6w2cAApQhQ,920
409
409
  rapidata/rapidata_client/assets/__init__.py,sha256=hKgrOSn8gJcBSULaf4auYhH1S1N5AfcwIhBSq1BOKwQ,323
@@ -438,23 +438,24 @@ rapidata/rapidata_client/metadata/_select_words_metadata.py,sha256=-MK5yQDi_G3BK
438
438
  rapidata/rapidata_client/order/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
439
439
  rapidata/rapidata_client/order/_rapidata_dataset.py,sha256=KyEDb4NvlY4uzN3GjiX331VSGOq_I6FhvTnyKrpo4Sg,19400
440
440
  rapidata/rapidata_client/order/_rapidata_order_builder.py,sha256=De-gNvOnuXHz6QaRdNpr0_3zOkIEKu8hgtZkatzxZQ4,13155
441
- rapidata/rapidata_client/order/rapidata_order.py,sha256=1uhU19uo2xGypd2aJClLRFLWqon-JQN8RTAkTGIg1iQ,10481
441
+ rapidata/rapidata_client/order/rapidata_order.py,sha256=cz1bXkoS7y_MuJDmjdXao7nbCFd8UF4bb5SHZtOYDjI,10483
442
442
  rapidata/rapidata_client/order/rapidata_order_manager.py,sha256=HQzZAnwKnwWl2wLcZPqxQRnth1xWhVoGNkMkzzYbjqA,30615
443
443
  rapidata/rapidata_client/order/rapidata_results.py,sha256=0y8EOiqUV7XuwpRJyV53mfo-ddRV1cUPdWexbPEVOHM,8044
444
- rapidata/rapidata_client/rapidata_client.py,sha256=oZrO8KpXQxK4FXDS20FkgdVbE_rBeTztyNYDKbf4ihM,2137
444
+ rapidata/rapidata_client/rapidata_client.py,sha256=fcWlmmNCZahK40Ox4aY153tEihIpwkUxYTuiypKF2SY,2857
445
445
  rapidata/rapidata_client/referee/__init__.py,sha256=q0Hv9nmfEpyChejtyMLT8hWKL0vTTf_UgUXPYNJ-H6M,153
446
446
  rapidata/rapidata_client/referee/_base_referee.py,sha256=MdFOhdxt3sRnWXLDKLJZKFdVpjBGn9jypPnWWQ6msQA,496
447
447
  rapidata/rapidata_client/referee/_early_stopping_referee.py,sha256=ULbokQZ91wc9D_20qHUhe55D28D9eTY1J1cMp_-oIDc,2088
448
448
  rapidata/rapidata_client/referee/_naive_referee.py,sha256=PVR8uy8hfRjr2DBzdOFyvou6S3swNc-4UvgjhO-09TU,1209
449
- rapidata/rapidata_client/selection/__init__.py,sha256=LbafUzvKgKbykbvHZJ7S9aYU82HQl71Y7jAbj_HTZ8c,382
449
+ rapidata/rapidata_client/selection/__init__.py,sha256=OimK44ig39A3kHCR_JGNO4FiUYJ6JUY0ZT0J8dz32Rs,475
450
450
  rapidata/rapidata_client/selection/_base_selection.py,sha256=tInbWOgxT_4CHkr5QHoG55ZcUi1ZmfcEGIwLKKCnN20,147
451
- rapidata/rapidata_client/selection/ab_test_selection.py,sha256=yNVRpYbRBBde_vg_Pf_xSKMemmJSrbHiL7XXKZW6kWw,1301
451
+ rapidata/rapidata_client/selection/ab_test_selection.py,sha256=fymubkVMawqJmYp9FKzWXTki9tgBgoj3cOP8rG9oOd0,1284
452
452
  rapidata/rapidata_client/selection/capped_selection.py,sha256=iWhbM1LcayhgFm7oKADXCaKHGdiQIupI0jbYuuEVM2A,1184
453
453
  rapidata/rapidata_client/selection/conditional_validation_selection.py,sha256=4etkO5p-wBoI8Wh8vBhNrXm7a_ioFvVmCANJmP8kIwI,2561
454
454
  rapidata/rapidata_client/selection/demographic_selection.py,sha256=l4vnNbzlf9ED6BKqN4k5cZXShkXu9L1C5DtO78Vwr5M,1454
455
455
  rapidata/rapidata_client/selection/labeling_selection.py,sha256=v26QogjmraFfRoSIgWZl6NMIW_TqbGeuCI2p4HxCeOM,657
456
456
  rapidata/rapidata_client/selection/rapidata_selections.py,sha256=Azh0ntBZp9EQNL19imIItotQ8QW3B1gEs5YmuTvUn6U,1526
457
457
  rapidata/rapidata_client/selection/shuffling_selection.py,sha256=FzOp7mnBLxNzM5at_-935wd77IHyWnFR1f8uqokiMOg,1201
458
+ rapidata/rapidata_client/selection/static_selection.py,sha256=POhVLjzHcUIuU_GCvRxMuCb27m7CkLxaQPwgf20Xo9o,681
458
459
  rapidata/rapidata_client/selection/validation_selection.py,sha256=sedeIa8lpXVXKtFJA9IDeRvo9A1Ne4ZGcepaWDUGhCU,851
459
460
  rapidata/rapidata_client/settings/__init__.py,sha256=DTEAT_YykwodZJXqKYOtWRwimLCA-Jxo0F0d-H6A3vM,458
460
461
  rapidata/rapidata_client/settings/_rapidata_setting.py,sha256=MD5JhhogSLLrjFKjvL3JhMszOMCygyqLF-st0EwMSkw,352
@@ -468,12 +469,12 @@ rapidata/rapidata_client/settings/play_video_until_the_end.py,sha256=LLHx2_72k5Z
468
469
  rapidata/rapidata_client/settings/rapidata_settings.py,sha256=r6eDGo5YHMekOtWqPHD50uI8vEE9VoBJfaWEDFZ78RU,1430
469
470
  rapidata/rapidata_client/settings/translation_behaviour.py,sha256=i9n_H0eKJyKW6m3MKH_Cm1XEKWVEWsAV_79xGmGIC-4,742
470
471
  rapidata/rapidata_client/validation/__init__.py,sha256=s5wHVtcJkncXSFuL9I0zNwccNOKpWAqxqUjkeohzi2E,24
471
- rapidata/rapidata_client/validation/rapidata_validation_set.py,sha256=Px_tpFOc5rSvHaDRtN3prNTuXUqz7Y_ZjmbEFLzDciY,2680
472
+ rapidata/rapidata_client/validation/rapidata_validation_set.py,sha256=GaatGGuJCHRvdPbjzI0NRQhxzRftKzf-FkaFv25iB78,2649
472
473
  rapidata/rapidata_client/validation/rapids/__init__.py,sha256=WU5PPwtTJlte6U90MDakzx4I8Y0laj7siw9teeXj5R0,21
473
474
  rapidata/rapidata_client/validation/rapids/box.py,sha256=t3_Kn6doKXdnJdtbwefXnYKPiTKHneJl9E2inkDSqL8,589
474
475
  rapidata/rapidata_client/validation/rapids/rapids.py,sha256=uCKnoSn1RykNHgTFbrvCFlfzU8lF42cff-2I-Pd48w0,4620
475
476
  rapidata/rapidata_client/validation/rapids/rapids_manager.py,sha256=F00lPYBUx5fTPRw50iZuobtdbjFo6ZHevPMk101JdaY,14271
476
- rapidata/rapidata_client/validation/validation_set_manager.py,sha256=OeiJ_WX8qSGdQAOVejnSYxZJUmttNsW71E0MLWUQHXY,26529
477
+ rapidata/rapidata_client/validation/validation_set_manager.py,sha256=Ecm17xf4SFsKdWAGNPdjecuAqsEJ27cHPdns7f8GXhs,26699
477
478
  rapidata/rapidata_client/workflow/__init__.py,sha256=7nXcY91xkxjHudBc9H0fP35eBBtgwHGWTQKbb-M4h7Y,477
478
479
  rapidata/rapidata_client/workflow/_base_workflow.py,sha256=XyIZFKS_RxAuwIHS848S3AyLEHqd07oTD_5jm2oUbsw,762
479
480
  rapidata/rapidata_client/workflow/_classify_workflow.py,sha256=9bT54wxVJgxC-zLk6MVNbseFpzYrvFPjt7DHvxqYfnk,1736
@@ -486,11 +487,10 @@ rapidata/rapidata_client/workflow/_ranking_workflow.py,sha256=XBIifokhu9Kaqo0qSw
486
487
  rapidata/rapidata_client/workflow/_select_words_workflow.py,sha256=juTW4TPnnSeBWP3K2QfO092t0u1W8I1ksY6aAtPZOi0,1225
487
488
  rapidata/rapidata_client/workflow/_timestamp_workflow.py,sha256=tPi2zu1-SlE_ppbGbMz6MM_2LUSWxM-GA0CZRlB0qFo,1176
488
489
  rapidata/service/__init__.py,sha256=s9bS1AJZaWIhLtJX_ZA40_CK39rAAkwdAmymTMbeWl4,68
489
- rapidata/service/credential_manager.py,sha256=_DIP665fpl4fkqj1l-wjRrBp-8fy2Db7tnXDDx4kWyc,8605
490
+ rapidata/service/credential_manager.py,sha256=3x-Fb6tyqmgtpjI1MSOtXWW_SkzTK8Lo7I0SSL2YD7E,8602
490
491
  rapidata/service/local_file_service.py,sha256=pgorvlWcx52Uh3cEG6VrdMK_t__7dacQ_5AnfY14BW8,877
491
- rapidata/service/openapi_service.py,sha256=fUQGLQzezjJbLqHVq7o6pQZUrK2Y12sfV23OstV4lOk,3234
492
- rapidata/service/token_manager.py,sha256=C-8dN6P5TXCLANZCHWusmwAful5YBpKjKg0StQtajF0,6547
493
- rapidata-2.17.0.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
494
- rapidata-2.17.0.dist-info/METADATA,sha256=5yD0-X5uZr9D4YPNk_KKwhXPOVAkcUpGubS61oz3Yfg,1147
495
- rapidata-2.17.0.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
496
- rapidata-2.17.0.dist-info/RECORD,,
492
+ rapidata/service/openapi_service.py,sha256=ORFPfHlb41zOUP5nDjYWZwO-ZcqNF_Mw2r71RitFtS0,4042
493
+ rapidata-2.18.0.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
494
+ rapidata-2.18.0.dist-info/METADATA,sha256=RmwN7x5_3aS68UGwC-BTB3FZg8S879NhC32ymMZ8S0A,1187
495
+ rapidata-2.18.0.dist-info/WHEEL,sha256=XbeZDeTWKc1w7CSIyre5aMDU_-PohRwTQceYnisIYYY,88
496
+ rapidata-2.18.0.dist-info/RECORD,,
@@ -1,176 +0,0 @@
1
- import json
2
- import logging
3
- import threading
4
- from datetime import datetime, timedelta
5
- from typing import Optional, Callable
6
-
7
- import requests
8
- from pydantic import BaseModel
9
-
10
- from rapidata.service.credential_manager import CredentialManager
11
-
12
- logger = logging.getLogger(__name__)
13
-
14
- class TokenInfo(BaseModel):
15
- access_token: str
16
- expires_in: int
17
- issued_at: datetime
18
- token_type: str = "Bearer"
19
-
20
- @property
21
- def auth_header(self):
22
- return f"{self.token_type} {self.access_token}"
23
-
24
- @property
25
- def time_remaining(self):
26
- remaining = (
27
- (self.issued_at + timedelta(seconds=self.expires_in)) - datetime.now()
28
- ).total_seconds()
29
- return max(0.0, remaining)
30
-
31
-
32
- class TokenManager:
33
- def __init__(
34
- self,
35
- client_id: str | None = None,
36
- client_secret: str | None = None,
37
- endpoint: str = "https://auth.rapidata.ai",
38
- oauth_scope: str = "openid profile email",
39
- cert_path: str | None = None,
40
- refresh_threshold: float = 0.8,
41
- max_sleep_time: float = 30,
42
- max_token_lifetime: int = 60,
43
- ):
44
- self._client_id = client_id
45
- self._client_secret = client_secret
46
-
47
- if not client_id or not client_secret:
48
- credential_manager = CredentialManager(
49
- endpoint=endpoint, cert_path=cert_path
50
- )
51
- credentials = credential_manager.get_client_credentials()
52
- if not credentials:
53
- raise ValueError("Failed to fetch client credentials")
54
- self._client_id = credentials.client_id
55
- self._client_secret = credentials.client_secret
56
-
57
- self._endpoint = endpoint
58
- self._oauth_scope = oauth_scope
59
- self._cert_path = cert_path
60
- self._refresh_threshold = refresh_threshold
61
- self._max_sleep_time = max_sleep_time
62
- self._max_token_lifetime = max_token_lifetime
63
-
64
- self._token_lock = threading.Lock()
65
- self._current_token: Optional[TokenInfo] = None
66
- self._refresh_thread: Optional[threading.Thread] = None
67
- self._should_stop = threading.Event()
68
-
69
- def fetch_token(self):
70
- try:
71
- response = requests.post(
72
- f"{self._endpoint}/connect/token",
73
- data={
74
- "grant_type": "client_credentials",
75
- "client_id": self._client_id,
76
- "client_secret": self._client_secret,
77
- "scope": self._oauth_scope,
78
- },
79
- verify=self._cert_path,
80
- )
81
-
82
- if response.ok:
83
- data = response.json()
84
- return TokenInfo(
85
- access_token=data["access_token"],
86
- token_type=data["token_type"],
87
- expires_in=min(self._max_token_lifetime, data["expires_in"]),
88
- issued_at=datetime.now(),
89
- )
90
-
91
- else:
92
- data = response.text
93
- error_description = "An unknown error occurred"
94
- if "error_description" in data:
95
- error_description = (
96
- data.split("error_description")[1].split("\n")[0].strip()
97
- )
98
- raise ValueError(f"Failed to fetch token: {error_description}")
99
- except requests.RequestException as e:
100
- raise ValueError(f"Failed to fetch token: {e}")
101
- except json.JSONDecodeError as e:
102
- raise ValueError(f"Failed to parse token response: {e}")
103
- except KeyError as e:
104
- raise ValueError(f"Failed to extract token from response: {e}")
105
-
106
- def start_token_refresh(self, token_callback: Callable[[TokenInfo], None]) -> None:
107
- if self._refresh_thread and self._refresh_thread.is_alive():
108
- logger.error("Token refresh thread is already running")
109
- return
110
-
111
- def refresh_loop():
112
- while not self._should_stop.is_set():
113
- try:
114
- with self._token_lock:
115
- if self._should_refresh_token(self._current_token):
116
- logger.debug("Refreshing token")
117
- self._current_token = self.fetch_token()
118
- token_callback(self._current_token)
119
-
120
- if self._current_token:
121
- time_until_refresh_threshold = (
122
- self._current_token.time_remaining
123
- - (
124
- self._current_token.expires_in
125
- * (1 - self._refresh_threshold)
126
- )
127
- )
128
- logger.debug("Time until refresh threshold: %s", time_until_refresh_threshold)
129
- sleep_time = min(
130
- self._max_sleep_time, time_until_refresh_threshold
131
- )
132
- logger.debug(
133
- f"Sleeping for {sleep_time} until checking the token again"
134
- )
135
- self._should_stop.wait(timeout=max(1.0, sleep_time))
136
- else:
137
- self._should_stop.wait(timeout=self._max_sleep_time)
138
- except Exception as e:
139
- logger.error("Failed to refresh token: %s", e)
140
- self._should_stop.wait(timeout=5)
141
-
142
- self._should_stop.clear()
143
- self._refresh_thread = threading.Thread(target=refresh_loop, daemon=True)
144
- self._refresh_thread.start()
145
-
146
- def stop_token_refresh(self):
147
- self._should_stop.set()
148
- if self._refresh_thread:
149
- self._refresh_thread.join(timeout=1)
150
- self._refresh_thread = None
151
-
152
- def get_current_token(self) -> Optional[TokenInfo]:
153
- with self._token_lock:
154
- return self._current_token
155
-
156
- def _should_refresh_token(self, token: TokenInfo | None) -> bool:
157
- if not token:
158
- return True
159
-
160
- limit = token.expires_in * (1 - self._refresh_threshold)
161
-
162
- logger.debug(
163
- "The token was issued at %s, it expires in %s. It has %s seconds remaining and we refresh the token when it has %s seconds remaining",
164
- token.issued_at,
165
- token.expires_in,
166
- token.time_remaining,
167
- limit,
168
- )
169
- return token.time_remaining < limit
170
-
171
- def __enter__(self):
172
- return self
173
-
174
- def __exit__(self, exc_type, exc_val, exc_tb):
175
- self.stop_token_refresh()
176
- return False