label-studio-sdk 1.0.10__py3-none-any.whl → 1.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of label-studio-sdk might be problematic. Click here for more details.
- label_studio_sdk/__init__.py +17 -1
- label_studio_sdk/_extensions/label_studio_tools/core/utils/json_schema.py +5 -0
- label_studio_sdk/base_client.py +8 -0
- label_studio_sdk/core/client_wrapper.py +34 -15
- label_studio_sdk/errors/__init__.py +3 -1
- label_studio_sdk/errors/not_found_error.py +9 -0
- label_studio_sdk/errors/unauthorized_error.py +9 -0
- label_studio_sdk/jwt_settings/__init__.py +2 -0
- label_studio_sdk/jwt_settings/client.py +259 -0
- label_studio_sdk/label_interface/control_tags.py +15 -2
- label_studio_sdk/label_interface/interface.py +80 -1
- label_studio_sdk/label_interface/object_tags.py +2 -2
- label_studio_sdk/projects/__init__.py +2 -1
- label_studio_sdk/projects/client.py +4 -0
- label_studio_sdk/projects/exports/client_ext.py +106 -40
- label_studio_sdk/projects/pauses/__init__.py +2 -0
- label_studio_sdk/projects/pauses/client.py +704 -0
- label_studio_sdk/projects/types/projects_update_response.py +10 -0
- label_studio_sdk/tokens/__init__.py +2 -0
- label_studio_sdk/tokens/client.py +470 -0
- label_studio_sdk/tokens/client_ext.py +94 -0
- label_studio_sdk/types/__init__.py +10 -0
- label_studio_sdk/types/access_token_response.py +22 -0
- label_studio_sdk/types/api_token_response.py +32 -0
- label_studio_sdk/types/jwt_settings_response.py +32 -0
- label_studio_sdk/types/model_provider_connection_provider.py +1 -1
- label_studio_sdk/types/pause.py +34 -0
- label_studio_sdk/types/pause_paused_by.py +5 -0
- label_studio_sdk/types/project.py +10 -0
- label_studio_sdk/types/prompt_version_provider.py +1 -1
- {label_studio_sdk-1.0.10.dist-info → label_studio_sdk-1.0.11.dist-info}/METADATA +2 -1
- {label_studio_sdk-1.0.10.dist-info → label_studio_sdk-1.0.11.dist-info}/RECORD +34 -20
- {label_studio_sdk-1.0.10.dist-info → label_studio_sdk-1.0.11.dist-info}/WHEEL +1 -1
- {label_studio_sdk-1.0.10.dist-info → label_studio_sdk-1.0.11.dist-info}/LICENSE +0 -0
|
@@ -0,0 +1,470 @@
|
|
|
1
|
+
# This file was auto-generated by Fern from our API Definition.
|
|
2
|
+
|
|
3
|
+
import typing
|
|
4
|
+
from ..core.client_wrapper import SyncClientWrapper
|
|
5
|
+
from ..core.request_options import RequestOptions
|
|
6
|
+
from ..errors.not_found_error import NotFoundError
|
|
7
|
+
from ..core.pydantic_utilities import parse_obj_as
|
|
8
|
+
from json.decoder import JSONDecodeError
|
|
9
|
+
from ..core.api_error import ApiError
|
|
10
|
+
from ..types.api_token_response import ApiTokenResponse
|
|
11
|
+
from ..types.access_token_response import AccessTokenResponse
|
|
12
|
+
from ..errors.unauthorized_error import UnauthorizedError
|
|
13
|
+
from ..core.client_wrapper import AsyncClientWrapper
|
|
14
|
+
|
|
15
|
+
# this is used as the default value for optional parameters
|
|
16
|
+
OMIT = typing.cast(typing.Any, ...)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TokensClient:
|
|
20
|
+
def __init__(self, *, client_wrapper: SyncClientWrapper):
|
|
21
|
+
self._client_wrapper = client_wrapper
|
|
22
|
+
|
|
23
|
+
def blacklist(self, *, refresh: str, request_options: typing.Optional[RequestOptions] = None) -> None:
|
|
24
|
+
"""
|
|
25
|
+
Blacklist a refresh token to prevent its future use.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
refresh : str
|
|
30
|
+
JWT refresh token
|
|
31
|
+
|
|
32
|
+
request_options : typing.Optional[RequestOptions]
|
|
33
|
+
Request-specific configuration.
|
|
34
|
+
|
|
35
|
+
Returns
|
|
36
|
+
-------
|
|
37
|
+
None
|
|
38
|
+
|
|
39
|
+
Examples
|
|
40
|
+
--------
|
|
41
|
+
from label_studio_sdk import LabelStudio
|
|
42
|
+
|
|
43
|
+
client = LabelStudio(
|
|
44
|
+
api_key="YOUR_API_KEY",
|
|
45
|
+
)
|
|
46
|
+
client.tokens.blacklist(
|
|
47
|
+
refresh="refresh",
|
|
48
|
+
)
|
|
49
|
+
"""
|
|
50
|
+
_response = self._client_wrapper.httpx_client.request(
|
|
51
|
+
"api/token/blacklist",
|
|
52
|
+
method="POST",
|
|
53
|
+
json={
|
|
54
|
+
"refresh": refresh,
|
|
55
|
+
},
|
|
56
|
+
headers={
|
|
57
|
+
"content-type": "application/json",
|
|
58
|
+
},
|
|
59
|
+
request_options=request_options,
|
|
60
|
+
omit=OMIT,
|
|
61
|
+
)
|
|
62
|
+
try:
|
|
63
|
+
if 200 <= _response.status_code < 300:
|
|
64
|
+
return
|
|
65
|
+
if _response.status_code == 404:
|
|
66
|
+
raise NotFoundError(
|
|
67
|
+
typing.cast(
|
|
68
|
+
typing.Optional[typing.Any],
|
|
69
|
+
parse_obj_as(
|
|
70
|
+
type_=typing.Optional[typing.Any], # type: ignore
|
|
71
|
+
object_=_response.json(),
|
|
72
|
+
),
|
|
73
|
+
)
|
|
74
|
+
)
|
|
75
|
+
_response_json = _response.json()
|
|
76
|
+
except JSONDecodeError:
|
|
77
|
+
raise ApiError(status_code=_response.status_code, body=_response.text)
|
|
78
|
+
raise ApiError(status_code=_response.status_code, body=_response_json)
|
|
79
|
+
|
|
80
|
+
def get(self, *, request_options: typing.Optional[RequestOptions] = None) -> typing.List[ApiTokenResponse]:
|
|
81
|
+
"""
|
|
82
|
+
List all API tokens for the current user.
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
request_options : typing.Optional[RequestOptions]
|
|
87
|
+
Request-specific configuration.
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
typing.List[ApiTokenResponse]
|
|
92
|
+
List of API tokens retrieved successfully
|
|
93
|
+
|
|
94
|
+
Examples
|
|
95
|
+
--------
|
|
96
|
+
from label_studio_sdk import LabelStudio
|
|
97
|
+
|
|
98
|
+
client = LabelStudio(
|
|
99
|
+
api_key="YOUR_API_KEY",
|
|
100
|
+
)
|
|
101
|
+
client.tokens.get()
|
|
102
|
+
"""
|
|
103
|
+
_response = self._client_wrapper.httpx_client.request(
|
|
104
|
+
"api/token",
|
|
105
|
+
method="GET",
|
|
106
|
+
request_options=request_options,
|
|
107
|
+
)
|
|
108
|
+
try:
|
|
109
|
+
if 200 <= _response.status_code < 300:
|
|
110
|
+
return typing.cast(
|
|
111
|
+
typing.List[ApiTokenResponse],
|
|
112
|
+
parse_obj_as(
|
|
113
|
+
type_=typing.List[ApiTokenResponse], # type: ignore
|
|
114
|
+
object_=_response.json(),
|
|
115
|
+
),
|
|
116
|
+
)
|
|
117
|
+
_response_json = _response.json()
|
|
118
|
+
except JSONDecodeError:
|
|
119
|
+
raise ApiError(status_code=_response.status_code, body=_response.text)
|
|
120
|
+
raise ApiError(status_code=_response.status_code, body=_response_json)
|
|
121
|
+
|
|
122
|
+
def create(self, *, request_options: typing.Optional[RequestOptions] = None) -> ApiTokenResponse:
|
|
123
|
+
"""
|
|
124
|
+
Create a new API token for the current user.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
request_options : typing.Optional[RequestOptions]
|
|
129
|
+
Request-specific configuration.
|
|
130
|
+
|
|
131
|
+
Returns
|
|
132
|
+
-------
|
|
133
|
+
ApiTokenResponse
|
|
134
|
+
Token created successfully
|
|
135
|
+
|
|
136
|
+
Examples
|
|
137
|
+
--------
|
|
138
|
+
from label_studio_sdk import LabelStudio
|
|
139
|
+
|
|
140
|
+
client = LabelStudio(
|
|
141
|
+
api_key="YOUR_API_KEY",
|
|
142
|
+
)
|
|
143
|
+
client.tokens.create()
|
|
144
|
+
"""
|
|
145
|
+
_response = self._client_wrapper.httpx_client.request(
|
|
146
|
+
"api/token",
|
|
147
|
+
method="POST",
|
|
148
|
+
request_options=request_options,
|
|
149
|
+
)
|
|
150
|
+
try:
|
|
151
|
+
if 200 <= _response.status_code < 300:
|
|
152
|
+
return typing.cast(
|
|
153
|
+
ApiTokenResponse,
|
|
154
|
+
parse_obj_as(
|
|
155
|
+
type_=ApiTokenResponse, # type: ignore
|
|
156
|
+
object_=_response.json(),
|
|
157
|
+
),
|
|
158
|
+
)
|
|
159
|
+
_response_json = _response.json()
|
|
160
|
+
except JSONDecodeError:
|
|
161
|
+
raise ApiError(status_code=_response.status_code, body=_response.text)
|
|
162
|
+
raise ApiError(status_code=_response.status_code, body=_response_json)
|
|
163
|
+
|
|
164
|
+
def refresh(self, *, refresh: str, request_options: typing.Optional[RequestOptions] = None) -> AccessTokenResponse:
|
|
165
|
+
"""
|
|
166
|
+
Get a new access token, using a refresh token.
|
|
167
|
+
|
|
168
|
+
Parameters
|
|
169
|
+
----------
|
|
170
|
+
refresh : str
|
|
171
|
+
JWT refresh token
|
|
172
|
+
|
|
173
|
+
request_options : typing.Optional[RequestOptions]
|
|
174
|
+
Request-specific configuration.
|
|
175
|
+
|
|
176
|
+
Returns
|
|
177
|
+
-------
|
|
178
|
+
AccessTokenResponse
|
|
179
|
+
New access token created successfully
|
|
180
|
+
|
|
181
|
+
Examples
|
|
182
|
+
--------
|
|
183
|
+
from label_studio_sdk import LabelStudio
|
|
184
|
+
|
|
185
|
+
client = LabelStudio(
|
|
186
|
+
api_key="YOUR_API_KEY",
|
|
187
|
+
)
|
|
188
|
+
client.tokens.refresh(
|
|
189
|
+
refresh="refresh",
|
|
190
|
+
)
|
|
191
|
+
"""
|
|
192
|
+
_response = self._client_wrapper.httpx_client.request(
|
|
193
|
+
"api/token/refresh",
|
|
194
|
+
method="POST",
|
|
195
|
+
json={
|
|
196
|
+
"refresh": refresh,
|
|
197
|
+
},
|
|
198
|
+
headers={
|
|
199
|
+
"content-type": "application/json",
|
|
200
|
+
},
|
|
201
|
+
request_options=request_options,
|
|
202
|
+
omit=OMIT,
|
|
203
|
+
)
|
|
204
|
+
try:
|
|
205
|
+
if 200 <= _response.status_code < 300:
|
|
206
|
+
return typing.cast(
|
|
207
|
+
AccessTokenResponse,
|
|
208
|
+
parse_obj_as(
|
|
209
|
+
type_=AccessTokenResponse, # type: ignore
|
|
210
|
+
object_=_response.json(),
|
|
211
|
+
),
|
|
212
|
+
)
|
|
213
|
+
if _response.status_code == 401:
|
|
214
|
+
raise UnauthorizedError(
|
|
215
|
+
typing.cast(
|
|
216
|
+
typing.Optional[typing.Any],
|
|
217
|
+
parse_obj_as(
|
|
218
|
+
type_=typing.Optional[typing.Any], # type: ignore
|
|
219
|
+
object_=_response.json(),
|
|
220
|
+
),
|
|
221
|
+
)
|
|
222
|
+
)
|
|
223
|
+
_response_json = _response.json()
|
|
224
|
+
except JSONDecodeError:
|
|
225
|
+
raise ApiError(status_code=_response.status_code, body=_response.text)
|
|
226
|
+
raise ApiError(status_code=_response.status_code, body=_response_json)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class AsyncTokensClient:
|
|
230
|
+
def __init__(self, *, client_wrapper: AsyncClientWrapper):
|
|
231
|
+
self._client_wrapper = client_wrapper
|
|
232
|
+
|
|
233
|
+
async def blacklist(self, *, refresh: str, request_options: typing.Optional[RequestOptions] = None) -> None:
|
|
234
|
+
"""
|
|
235
|
+
Blacklist a refresh token to prevent its future use.
|
|
236
|
+
|
|
237
|
+
Parameters
|
|
238
|
+
----------
|
|
239
|
+
refresh : str
|
|
240
|
+
JWT refresh token
|
|
241
|
+
|
|
242
|
+
request_options : typing.Optional[RequestOptions]
|
|
243
|
+
Request-specific configuration.
|
|
244
|
+
|
|
245
|
+
Returns
|
|
246
|
+
-------
|
|
247
|
+
None
|
|
248
|
+
|
|
249
|
+
Examples
|
|
250
|
+
--------
|
|
251
|
+
import asyncio
|
|
252
|
+
|
|
253
|
+
from label_studio_sdk import AsyncLabelStudio
|
|
254
|
+
|
|
255
|
+
client = AsyncLabelStudio(
|
|
256
|
+
api_key="YOUR_API_KEY",
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
async def main() -> None:
|
|
261
|
+
await client.tokens.blacklist(
|
|
262
|
+
refresh="refresh",
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
asyncio.run(main())
|
|
267
|
+
"""
|
|
268
|
+
_response = await self._client_wrapper.httpx_client.request(
|
|
269
|
+
"api/token/blacklist",
|
|
270
|
+
method="POST",
|
|
271
|
+
json={
|
|
272
|
+
"refresh": refresh,
|
|
273
|
+
},
|
|
274
|
+
headers={
|
|
275
|
+
"content-type": "application/json",
|
|
276
|
+
},
|
|
277
|
+
request_options=request_options,
|
|
278
|
+
omit=OMIT,
|
|
279
|
+
)
|
|
280
|
+
try:
|
|
281
|
+
if 200 <= _response.status_code < 300:
|
|
282
|
+
return
|
|
283
|
+
if _response.status_code == 404:
|
|
284
|
+
raise NotFoundError(
|
|
285
|
+
typing.cast(
|
|
286
|
+
typing.Optional[typing.Any],
|
|
287
|
+
parse_obj_as(
|
|
288
|
+
type_=typing.Optional[typing.Any], # type: ignore
|
|
289
|
+
object_=_response.json(),
|
|
290
|
+
),
|
|
291
|
+
)
|
|
292
|
+
)
|
|
293
|
+
_response_json = _response.json()
|
|
294
|
+
except JSONDecodeError:
|
|
295
|
+
raise ApiError(status_code=_response.status_code, body=_response.text)
|
|
296
|
+
raise ApiError(status_code=_response.status_code, body=_response_json)
|
|
297
|
+
|
|
298
|
+
async def get(self, *, request_options: typing.Optional[RequestOptions] = None) -> typing.List[ApiTokenResponse]:
|
|
299
|
+
"""
|
|
300
|
+
List all API tokens for the current user.
|
|
301
|
+
|
|
302
|
+
Parameters
|
|
303
|
+
----------
|
|
304
|
+
request_options : typing.Optional[RequestOptions]
|
|
305
|
+
Request-specific configuration.
|
|
306
|
+
|
|
307
|
+
Returns
|
|
308
|
+
-------
|
|
309
|
+
typing.List[ApiTokenResponse]
|
|
310
|
+
List of API tokens retrieved successfully
|
|
311
|
+
|
|
312
|
+
Examples
|
|
313
|
+
--------
|
|
314
|
+
import asyncio
|
|
315
|
+
|
|
316
|
+
from label_studio_sdk import AsyncLabelStudio
|
|
317
|
+
|
|
318
|
+
client = AsyncLabelStudio(
|
|
319
|
+
api_key="YOUR_API_KEY",
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
async def main() -> None:
|
|
324
|
+
await client.tokens.get()
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
asyncio.run(main())
|
|
328
|
+
"""
|
|
329
|
+
_response = await self._client_wrapper.httpx_client.request(
|
|
330
|
+
"api/token",
|
|
331
|
+
method="GET",
|
|
332
|
+
request_options=request_options,
|
|
333
|
+
)
|
|
334
|
+
try:
|
|
335
|
+
if 200 <= _response.status_code < 300:
|
|
336
|
+
return typing.cast(
|
|
337
|
+
typing.List[ApiTokenResponse],
|
|
338
|
+
parse_obj_as(
|
|
339
|
+
type_=typing.List[ApiTokenResponse], # type: ignore
|
|
340
|
+
object_=_response.json(),
|
|
341
|
+
),
|
|
342
|
+
)
|
|
343
|
+
_response_json = _response.json()
|
|
344
|
+
except JSONDecodeError:
|
|
345
|
+
raise ApiError(status_code=_response.status_code, body=_response.text)
|
|
346
|
+
raise ApiError(status_code=_response.status_code, body=_response_json)
|
|
347
|
+
|
|
348
|
+
async def create(self, *, request_options: typing.Optional[RequestOptions] = None) -> ApiTokenResponse:
|
|
349
|
+
"""
|
|
350
|
+
Create a new API token for the current user.
|
|
351
|
+
|
|
352
|
+
Parameters
|
|
353
|
+
----------
|
|
354
|
+
request_options : typing.Optional[RequestOptions]
|
|
355
|
+
Request-specific configuration.
|
|
356
|
+
|
|
357
|
+
Returns
|
|
358
|
+
-------
|
|
359
|
+
ApiTokenResponse
|
|
360
|
+
Token created successfully
|
|
361
|
+
|
|
362
|
+
Examples
|
|
363
|
+
--------
|
|
364
|
+
import asyncio
|
|
365
|
+
|
|
366
|
+
from label_studio_sdk import AsyncLabelStudio
|
|
367
|
+
|
|
368
|
+
client = AsyncLabelStudio(
|
|
369
|
+
api_key="YOUR_API_KEY",
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
async def main() -> None:
|
|
374
|
+
await client.tokens.create()
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
asyncio.run(main())
|
|
378
|
+
"""
|
|
379
|
+
_response = await self._client_wrapper.httpx_client.request(
|
|
380
|
+
"api/token",
|
|
381
|
+
method="POST",
|
|
382
|
+
request_options=request_options,
|
|
383
|
+
)
|
|
384
|
+
try:
|
|
385
|
+
if 200 <= _response.status_code < 300:
|
|
386
|
+
return typing.cast(
|
|
387
|
+
ApiTokenResponse,
|
|
388
|
+
parse_obj_as(
|
|
389
|
+
type_=ApiTokenResponse, # type: ignore
|
|
390
|
+
object_=_response.json(),
|
|
391
|
+
),
|
|
392
|
+
)
|
|
393
|
+
_response_json = _response.json()
|
|
394
|
+
except JSONDecodeError:
|
|
395
|
+
raise ApiError(status_code=_response.status_code, body=_response.text)
|
|
396
|
+
raise ApiError(status_code=_response.status_code, body=_response_json)
|
|
397
|
+
|
|
398
|
+
async def refresh(
|
|
399
|
+
self, *, refresh: str, request_options: typing.Optional[RequestOptions] = None
|
|
400
|
+
) -> AccessTokenResponse:
|
|
401
|
+
"""
|
|
402
|
+
Get a new access token, using a refresh token.
|
|
403
|
+
|
|
404
|
+
Parameters
|
|
405
|
+
----------
|
|
406
|
+
refresh : str
|
|
407
|
+
JWT refresh token
|
|
408
|
+
|
|
409
|
+
request_options : typing.Optional[RequestOptions]
|
|
410
|
+
Request-specific configuration.
|
|
411
|
+
|
|
412
|
+
Returns
|
|
413
|
+
-------
|
|
414
|
+
AccessTokenResponse
|
|
415
|
+
New access token created successfully
|
|
416
|
+
|
|
417
|
+
Examples
|
|
418
|
+
--------
|
|
419
|
+
import asyncio
|
|
420
|
+
|
|
421
|
+
from label_studio_sdk import AsyncLabelStudio
|
|
422
|
+
|
|
423
|
+
client = AsyncLabelStudio(
|
|
424
|
+
api_key="YOUR_API_KEY",
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
async def main() -> None:
|
|
429
|
+
await client.tokens.refresh(
|
|
430
|
+
refresh="refresh",
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
asyncio.run(main())
|
|
435
|
+
"""
|
|
436
|
+
_response = await self._client_wrapper.httpx_client.request(
|
|
437
|
+
"api/token/refresh",
|
|
438
|
+
method="POST",
|
|
439
|
+
json={
|
|
440
|
+
"refresh": refresh,
|
|
441
|
+
},
|
|
442
|
+
headers={
|
|
443
|
+
"content-type": "application/json",
|
|
444
|
+
},
|
|
445
|
+
request_options=request_options,
|
|
446
|
+
omit=OMIT,
|
|
447
|
+
)
|
|
448
|
+
try:
|
|
449
|
+
if 200 <= _response.status_code < 300:
|
|
450
|
+
return typing.cast(
|
|
451
|
+
AccessTokenResponse,
|
|
452
|
+
parse_obj_as(
|
|
453
|
+
type_=AccessTokenResponse, # type: ignore
|
|
454
|
+
object_=_response.json(),
|
|
455
|
+
),
|
|
456
|
+
)
|
|
457
|
+
if _response.status_code == 401:
|
|
458
|
+
raise UnauthorizedError(
|
|
459
|
+
typing.cast(
|
|
460
|
+
typing.Optional[typing.Any],
|
|
461
|
+
parse_obj_as(
|
|
462
|
+
type_=typing.Optional[typing.Any], # type: ignore
|
|
463
|
+
object_=_response.json(),
|
|
464
|
+
),
|
|
465
|
+
)
|
|
466
|
+
)
|
|
467
|
+
_response_json = _response.json()
|
|
468
|
+
except JSONDecodeError:
|
|
469
|
+
raise ApiError(status_code=_response.status_code, body=_response.text)
|
|
470
|
+
raise ApiError(status_code=_response.status_code, body=_response_json)
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
import typing
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
import jwt
|
|
7
|
+
|
|
8
|
+
from ..core.api_error import ApiError
|
|
9
|
+
from ..types.access_token_response import AccessTokenResponse
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TokensClientExt:
|
|
13
|
+
"""Client for managing authentication tokens."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, base_url: str, api_key: str):
|
|
16
|
+
self._base_url = base_url
|
|
17
|
+
self._api_key = api_key
|
|
18
|
+
self._use_legacy_token = not self._is_valid_jwt_token(api_key, raise_if_expired=True)
|
|
19
|
+
|
|
20
|
+
# cache state for access token when using jwt-based api_key
|
|
21
|
+
self._access_token: typing.Optional[str] = None
|
|
22
|
+
self._access_token_expiration: typing.Optional[datetime] = None
|
|
23
|
+
# Used to keep simultaneous refresh requests from spamming refresh endpoint
|
|
24
|
+
self._token_refresh_lock = threading.Lock()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _is_valid_jwt_token(self, token: str, raise_if_expired: bool = False) -> bool:
|
|
28
|
+
"""Check if a token is a valid JWT token by attempting to decode its header and check expiration."""
|
|
29
|
+
try:
|
|
30
|
+
decoded = jwt.decode(token, options={"verify_signature": False})
|
|
31
|
+
except jwt.InvalidTokenError:
|
|
32
|
+
# presumably a lagacy token
|
|
33
|
+
return False
|
|
34
|
+
expiration = decoded.get("exp")
|
|
35
|
+
if expiration is None:
|
|
36
|
+
raise ApiError(
|
|
37
|
+
status_code=401,
|
|
38
|
+
body={"detail": "API key does not have an expiration set, and is not valid. Please obtain a new refresh token."}
|
|
39
|
+
)
|
|
40
|
+
expiration_time = datetime.fromtimestamp(expiration, timezone.utc)
|
|
41
|
+
if expiration_time < datetime.now(timezone.utc):
|
|
42
|
+
if raise_if_expired:
|
|
43
|
+
raise ApiError(
|
|
44
|
+
status_code=401,
|
|
45
|
+
body={"detail": "API key has expired. Please obtain a new refresh token."}
|
|
46
|
+
)
|
|
47
|
+
else:
|
|
48
|
+
return False
|
|
49
|
+
return True
|
|
50
|
+
|
|
51
|
+
def _set_access_token(self, token: str) -> None:
|
|
52
|
+
"""Set the access token and cache its expiration time."""
|
|
53
|
+
try:
|
|
54
|
+
decoded = jwt.decode(token, options={"verify_signature": False})
|
|
55
|
+
expiration = decoded.get("exp")
|
|
56
|
+
if expiration is not None:
|
|
57
|
+
self._access_token_expiration = datetime.fromtimestamp(expiration, timezone.utc)
|
|
58
|
+
except jwt.InvalidTokenError:
|
|
59
|
+
pass
|
|
60
|
+
self._access_token = token
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def api_key(self) -> str:
|
|
64
|
+
"""Get the current access token, refreshing if necessary."""
|
|
65
|
+
# Legacy tokens: just return the API key directly
|
|
66
|
+
if self._use_legacy_token:
|
|
67
|
+
return self._api_key
|
|
68
|
+
|
|
69
|
+
# JWT tokens: handle refresh if needed
|
|
70
|
+
if (not self._access_token) or (not self._is_valid_jwt_token(self._access_token)):
|
|
71
|
+
with self._token_refresh_lock:
|
|
72
|
+
# Check again after acquiring lock, in case another invocation already refreshed
|
|
73
|
+
if (not self._access_token) or (not self._is_valid_jwt_token(self._access_token)):
|
|
74
|
+
token_response = self.refresh()
|
|
75
|
+
self._set_access_token(token_response.access)
|
|
76
|
+
|
|
77
|
+
return self._access_token
|
|
78
|
+
|
|
79
|
+
def refresh(self) -> AccessTokenResponse:
|
|
80
|
+
"""Refresh the access token and return the token response."""
|
|
81
|
+
# We don't do this often, just use a separate httpx client for simplicity here
|
|
82
|
+
# (avoids complicated state management and sync vs async handling)
|
|
83
|
+
with httpx.Client() as sync_client:
|
|
84
|
+
response = sync_client.request(
|
|
85
|
+
method="POST",
|
|
86
|
+
url=f"{self._base_url}/api/token/refresh/",
|
|
87
|
+
json={"refresh": self._api_key},
|
|
88
|
+
headers={"Content-Type": "application/json"},
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
if response.status_code == 200:
|
|
92
|
+
return AccessTokenResponse.parse_obj(response.json())
|
|
93
|
+
else:
|
|
94
|
+
raise ApiError(status_code=response.status_code, body=response.json())
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
# This file was auto-generated by Fern from our API Definition.
|
|
2
2
|
|
|
3
|
+
from .access_token_response import AccessTokenResponse
|
|
3
4
|
from .annotation import Annotation
|
|
4
5
|
from .annotation_filter_options import AnnotationFilterOptions
|
|
5
6
|
from .annotation_last_action import AnnotationLastAction
|
|
6
7
|
from .annotations_dm_field import AnnotationsDmField
|
|
7
8
|
from .annotations_dm_field_last_action import AnnotationsDmFieldLastAction
|
|
9
|
+
from .api_token_response import ApiTokenResponse
|
|
8
10
|
from .azure_blob_export_storage import AzureBlobExportStorage
|
|
9
11
|
from .azure_blob_export_storage_status import AzureBlobExportStorageStatus
|
|
10
12
|
from .azure_blob_import_storage import AzureBlobImportStorage
|
|
@@ -39,6 +41,7 @@ from .inference_run_created_by import InferenceRunCreatedBy
|
|
|
39
41
|
from .inference_run_organization import InferenceRunOrganization
|
|
40
42
|
from .inference_run_project_subset import InferenceRunProjectSubset
|
|
41
43
|
from .inference_run_status import InferenceRunStatus
|
|
44
|
+
from .jwt_settings_response import JwtSettingsResponse
|
|
42
45
|
from .key_indicator_value import KeyIndicatorValue
|
|
43
46
|
from .key_indicators import KeyIndicators
|
|
44
47
|
from .key_indicators_item import KeyIndicatorsItem
|
|
@@ -57,6 +60,8 @@ from .model_provider_connection_created_by import ModelProviderConnectionCreated
|
|
|
57
60
|
from .model_provider_connection_organization import ModelProviderConnectionOrganization
|
|
58
61
|
from .model_provider_connection_provider import ModelProviderConnectionProvider
|
|
59
62
|
from .model_provider_connection_scope import ModelProviderConnectionScope
|
|
63
|
+
from .pause import Pause
|
|
64
|
+
from .pause_paused_by import PausePausedBy
|
|
60
65
|
from .prediction import Prediction
|
|
61
66
|
from .project import Project
|
|
62
67
|
from .project_import import ProjectImport
|
|
@@ -101,11 +106,13 @@ from .webhook_serializer_for_update_actions_item import WebhookSerializerForUpda
|
|
|
101
106
|
from .workspace import Workspace
|
|
102
107
|
|
|
103
108
|
__all__ = [
|
|
109
|
+
"AccessTokenResponse",
|
|
104
110
|
"Annotation",
|
|
105
111
|
"AnnotationFilterOptions",
|
|
106
112
|
"AnnotationLastAction",
|
|
107
113
|
"AnnotationsDmField",
|
|
108
114
|
"AnnotationsDmFieldLastAction",
|
|
115
|
+
"ApiTokenResponse",
|
|
109
116
|
"AzureBlobExportStorage",
|
|
110
117
|
"AzureBlobExportStorageStatus",
|
|
111
118
|
"AzureBlobImportStorage",
|
|
@@ -140,6 +147,7 @@ __all__ = [
|
|
|
140
147
|
"InferenceRunOrganization",
|
|
141
148
|
"InferenceRunProjectSubset",
|
|
142
149
|
"InferenceRunStatus",
|
|
150
|
+
"JwtSettingsResponse",
|
|
143
151
|
"KeyIndicatorValue",
|
|
144
152
|
"KeyIndicators",
|
|
145
153
|
"KeyIndicatorsItem",
|
|
@@ -158,6 +166,8 @@ __all__ = [
|
|
|
158
166
|
"ModelProviderConnectionOrganization",
|
|
159
167
|
"ModelProviderConnectionProvider",
|
|
160
168
|
"ModelProviderConnectionScope",
|
|
169
|
+
"Pause",
|
|
170
|
+
"PausePausedBy",
|
|
161
171
|
"Prediction",
|
|
162
172
|
"Project",
|
|
163
173
|
"ProjectImport",
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# This file was auto-generated by Fern from our API Definition.
|
|
2
|
+
|
|
3
|
+
from ..core.pydantic_utilities import UniversalBaseModel
|
|
4
|
+
import pydantic
|
|
5
|
+
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
|
6
|
+
import typing
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AccessTokenResponse(UniversalBaseModel):
|
|
10
|
+
access: str = pydantic.Field()
|
|
11
|
+
"""
|
|
12
|
+
New JWT access token
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
if IS_PYDANTIC_V2:
|
|
16
|
+
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
|
17
|
+
else:
|
|
18
|
+
|
|
19
|
+
class Config:
|
|
20
|
+
frozen = True
|
|
21
|
+
smart_union = True
|
|
22
|
+
extra = pydantic.Extra.allow
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# This file was auto-generated by Fern from our API Definition.
|
|
2
|
+
|
|
3
|
+
from ..core.pydantic_utilities import UniversalBaseModel
|
|
4
|
+
import pydantic
|
|
5
|
+
from ..core.pydantic_utilities import IS_PYDANTIC_V2
|
|
6
|
+
import typing
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ApiTokenResponse(UniversalBaseModel):
|
|
10
|
+
token: str = pydantic.Field()
|
|
11
|
+
"""
|
|
12
|
+
JWT token
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
created_at: str = pydantic.Field()
|
|
16
|
+
"""
|
|
17
|
+
Token creation timestamp
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
expires_at: str = pydantic.Field()
|
|
21
|
+
"""
|
|
22
|
+
Token expiration timestamp
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
if IS_PYDANTIC_V2:
|
|
26
|
+
model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
|
|
27
|
+
else:
|
|
28
|
+
|
|
29
|
+
class Config:
|
|
30
|
+
frozen = True
|
|
31
|
+
smart_union = True
|
|
32
|
+
extra = pydantic.Extra.allow
|