dataspace-sdk 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataspace_sdk/__init__.py +18 -0
- dataspace_sdk/__version__.py +3 -0
- dataspace_sdk/auth.py +470 -0
- dataspace_sdk/base.py +160 -0
- dataspace_sdk/client.py +206 -0
- dataspace_sdk/exceptions.py +36 -0
- dataspace_sdk/resources/__init__.py +8 -0
- dataspace_sdk/resources/aimodels.py +989 -0
- dataspace_sdk/resources/datasets.py +233 -0
- dataspace_sdk/resources/sectors.py +128 -0
- dataspace_sdk/resources/usecases.py +248 -0
- dataspace_sdk-0.4.2.dist-info/METADATA +551 -0
- dataspace_sdk-0.4.2.dist-info/RECORD +15 -0
- dataspace_sdk-0.4.2.dist-info/WHEEL +5 -0
- dataspace_sdk-0.4.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,989 @@
|
|
|
1
|
+
"""AI Model resource client for DataSpace SDK."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
from dataspace_sdk.base import BaseAPIClient
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class AIModelClient(BaseAPIClient):
|
|
9
|
+
"""Client for interacting with AI Model resources."""
|
|
10
|
+
|
|
11
|
+
def search(
|
|
12
|
+
self,
|
|
13
|
+
query: Optional[str] = None,
|
|
14
|
+
tags: Optional[List[str]] = None,
|
|
15
|
+
sectors: Optional[List[str]] = None,
|
|
16
|
+
geographies: Optional[List[str]] = None,
|
|
17
|
+
status: Optional[str] = None,
|
|
18
|
+
model_type: Optional[str] = None,
|
|
19
|
+
provider: Optional[str] = None,
|
|
20
|
+
sort: Optional[str] = None,
|
|
21
|
+
page: int = 1,
|
|
22
|
+
page_size: int = 10,
|
|
23
|
+
) -> Dict[str, Any]:
|
|
24
|
+
"""
|
|
25
|
+
Search for AI models using Elasticsearch.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
query: Search query string
|
|
29
|
+
tags: Filter by tags
|
|
30
|
+
sectors: Filter by sectors
|
|
31
|
+
geographies: Filter by geographies
|
|
32
|
+
status: Filter by status (ACTIVE, INACTIVE, etc.)
|
|
33
|
+
model_type: Filter by model type (LLM, VISION, etc.)
|
|
34
|
+
provider: Filter by provider (OPENAI, ANTHROPIC, etc.)
|
|
35
|
+
sort: Sort order (recent, alphabetical)
|
|
36
|
+
page: Page number (1-indexed)
|
|
37
|
+
page_size: Number of results per page
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Dictionary containing search results and metadata
|
|
41
|
+
"""
|
|
42
|
+
params: Dict[str, Any] = {
|
|
43
|
+
"page": page,
|
|
44
|
+
"page_size": page_size,
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
if query:
|
|
48
|
+
params["q"] = query
|
|
49
|
+
if tags:
|
|
50
|
+
params["tags"] = ",".join(tags)
|
|
51
|
+
if sectors:
|
|
52
|
+
params["sectors"] = ",".join(sectors)
|
|
53
|
+
if geographies:
|
|
54
|
+
params["geographies"] = ",".join(geographies)
|
|
55
|
+
if status:
|
|
56
|
+
params["status"] = status
|
|
57
|
+
if model_type:
|
|
58
|
+
params["model_type"] = model_type
|
|
59
|
+
if provider:
|
|
60
|
+
params["provider"] = provider
|
|
61
|
+
if sort:
|
|
62
|
+
params["sort"] = sort
|
|
63
|
+
|
|
64
|
+
return super().get("/api/search/aimodel/", params=params)
|
|
65
|
+
|
|
66
|
+
def get_by_id(self, model_id: str) -> Dict[str, Any]:
|
|
67
|
+
"""
|
|
68
|
+
Get an AI model by ID.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
model_id: UUID of the AI model
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Dictionary containing AI model information
|
|
75
|
+
"""
|
|
76
|
+
# Use parent class get method with full endpoint path
|
|
77
|
+
return super().get(f"/api/aimodels/{model_id}/")
|
|
78
|
+
|
|
79
|
+
def get_by_id_graphql(self, model_id: str) -> Dict[str, Any]:
|
|
80
|
+
"""
|
|
81
|
+
Get an AI model by ID using GraphQL.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
model_id: UUID of the AI model
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Dictionary containing AI model information
|
|
88
|
+
"""
|
|
89
|
+
query = """
|
|
90
|
+
query GetAIModel($id: UUID!) {
|
|
91
|
+
aiModel(id: $id) {
|
|
92
|
+
id
|
|
93
|
+
name
|
|
94
|
+
displayName
|
|
95
|
+
description
|
|
96
|
+
modelType
|
|
97
|
+
status
|
|
98
|
+
isPublic
|
|
99
|
+
createdAt
|
|
100
|
+
updatedAt
|
|
101
|
+
organization {
|
|
102
|
+
id
|
|
103
|
+
name
|
|
104
|
+
}
|
|
105
|
+
tags {
|
|
106
|
+
id
|
|
107
|
+
value
|
|
108
|
+
}
|
|
109
|
+
sectors {
|
|
110
|
+
id
|
|
111
|
+
name
|
|
112
|
+
}
|
|
113
|
+
geographies {
|
|
114
|
+
id
|
|
115
|
+
name
|
|
116
|
+
}
|
|
117
|
+
versions {
|
|
118
|
+
id
|
|
119
|
+
version
|
|
120
|
+
versionNotes
|
|
121
|
+
lifecycleStage
|
|
122
|
+
isLatest
|
|
123
|
+
supportsStreaming
|
|
124
|
+
maxTokens
|
|
125
|
+
supportedLanguages
|
|
126
|
+
inputSchema
|
|
127
|
+
outputSchema
|
|
128
|
+
status
|
|
129
|
+
createdAt
|
|
130
|
+
updatedAt
|
|
131
|
+
publishedAt
|
|
132
|
+
providers {
|
|
133
|
+
id
|
|
134
|
+
provider
|
|
135
|
+
providerModelId
|
|
136
|
+
isPrimary
|
|
137
|
+
isActive
|
|
138
|
+
# API Configuration
|
|
139
|
+
apiEndpointUrl
|
|
140
|
+
apiHttpMethod
|
|
141
|
+
apiTimeoutSeconds
|
|
142
|
+
apiAuthType
|
|
143
|
+
apiAuthHeaderName
|
|
144
|
+
apiKey
|
|
145
|
+
apiKeyPrefix
|
|
146
|
+
apiHeaders
|
|
147
|
+
apiRequestTemplate
|
|
148
|
+
apiResponsePath
|
|
149
|
+
# HuggingFace Configuration
|
|
150
|
+
hfUsePipeline
|
|
151
|
+
hfAuthToken
|
|
152
|
+
hfModelClass
|
|
153
|
+
hfAttnImplementation
|
|
154
|
+
hfTrustRemoteCode
|
|
155
|
+
hfTorchDtype
|
|
156
|
+
hfDeviceMap
|
|
157
|
+
framework
|
|
158
|
+
config
|
|
159
|
+
}
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
response = self.post(
|
|
166
|
+
"/api/graphql",
|
|
167
|
+
json_data={
|
|
168
|
+
"query": query,
|
|
169
|
+
"variables": {"id": model_id},
|
|
170
|
+
},
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
if "errors" in response:
|
|
174
|
+
from dataspace_sdk.exceptions import DataSpaceAPIError
|
|
175
|
+
|
|
176
|
+
raise DataSpaceAPIError(f"GraphQL error: {response['errors']}")
|
|
177
|
+
|
|
178
|
+
result: Dict[str, Any] = response.get("data", {}).get("aiModel", {})
|
|
179
|
+
return result
|
|
180
|
+
|
|
181
|
+
def list_all(
|
|
182
|
+
self,
|
|
183
|
+
status: Optional[str] = None,
|
|
184
|
+
organization_id: Optional[str] = None,
|
|
185
|
+
model_type: Optional[str] = None,
|
|
186
|
+
limit: int = 10,
|
|
187
|
+
offset: int = 0,
|
|
188
|
+
) -> Any:
|
|
189
|
+
"""
|
|
190
|
+
List all AI models with pagination using GraphQL.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
status: Filter by status
|
|
194
|
+
organization_id: Filter by organization
|
|
195
|
+
model_type: Filter by model type
|
|
196
|
+
limit: Number of results to return
|
|
197
|
+
offset: Number of results to skip
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Dictionary containing list of AI models
|
|
201
|
+
"""
|
|
202
|
+
query = """
|
|
203
|
+
query ListAIModels($filters: AIModelFilter, $pagination: OffsetPaginationInput) {
|
|
204
|
+
aiModels(filters: $filters, pagination: $pagination) {
|
|
205
|
+
id
|
|
206
|
+
name
|
|
207
|
+
displayName
|
|
208
|
+
description
|
|
209
|
+
modelType
|
|
210
|
+
status
|
|
211
|
+
isPublic
|
|
212
|
+
createdAt
|
|
213
|
+
updatedAt
|
|
214
|
+
organization {
|
|
215
|
+
id
|
|
216
|
+
name
|
|
217
|
+
}
|
|
218
|
+
tags {
|
|
219
|
+
id
|
|
220
|
+
value
|
|
221
|
+
}
|
|
222
|
+
versions {
|
|
223
|
+
id
|
|
224
|
+
version
|
|
225
|
+
lifecycleStage
|
|
226
|
+
isLatest
|
|
227
|
+
status
|
|
228
|
+
providers {
|
|
229
|
+
id
|
|
230
|
+
provider
|
|
231
|
+
providerModelId
|
|
232
|
+
isPrimary
|
|
233
|
+
}
|
|
234
|
+
}
|
|
235
|
+
}
|
|
236
|
+
}
|
|
237
|
+
"""
|
|
238
|
+
|
|
239
|
+
filters: Dict[str, Any] = {}
|
|
240
|
+
if status:
|
|
241
|
+
filters["status"] = status
|
|
242
|
+
if organization_id:
|
|
243
|
+
filters["organization"] = {"id": {"exact": organization_id}}
|
|
244
|
+
if model_type:
|
|
245
|
+
filters["modelType"] = model_type
|
|
246
|
+
|
|
247
|
+
variables: Dict[str, Any] = {
|
|
248
|
+
"pagination": {"limit": limit, "offset": offset},
|
|
249
|
+
}
|
|
250
|
+
if filters:
|
|
251
|
+
variables["filters"] = filters
|
|
252
|
+
|
|
253
|
+
response = self.post(
|
|
254
|
+
"/api/graphql",
|
|
255
|
+
json_data={
|
|
256
|
+
"query": query,
|
|
257
|
+
"variables": variables,
|
|
258
|
+
},
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
if "errors" in response:
|
|
262
|
+
from dataspace_sdk.exceptions import DataSpaceAPIError
|
|
263
|
+
|
|
264
|
+
raise DataSpaceAPIError(f"GraphQL error: {response['errors']}")
|
|
265
|
+
|
|
266
|
+
data = response.get("data", {})
|
|
267
|
+
models_result: Any = data.get("aiModels", []) if isinstance(data, dict) else []
|
|
268
|
+
return models_result
|
|
269
|
+
|
|
270
|
+
def get_organization_models(
|
|
271
|
+
self,
|
|
272
|
+
organization_id: str,
|
|
273
|
+
limit: int = 10,
|
|
274
|
+
offset: int = 0,
|
|
275
|
+
) -> Any:
|
|
276
|
+
"""
|
|
277
|
+
Get AI models for a specific organization.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
organization_id: UUID of the organization
|
|
281
|
+
limit: Number of results to return
|
|
282
|
+
offset: Number of results to skip
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
Dictionary containing organization's AI models
|
|
286
|
+
"""
|
|
287
|
+
return self.list_all(
|
|
288
|
+
organization_id=organization_id,
|
|
289
|
+
limit=limit,
|
|
290
|
+
offset=offset,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
def create(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
294
|
+
"""
|
|
295
|
+
Create a new AI model.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
data: Dictionary containing AI model data
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
Dictionary containing created AI model information
|
|
302
|
+
"""
|
|
303
|
+
return self.post("/api/aimodels/", json_data=data)
|
|
304
|
+
|
|
305
|
+
def update(self, model_id: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
306
|
+
"""
|
|
307
|
+
Update an existing AI model.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
model_id: UUID of the AI model
|
|
311
|
+
data: Dictionary containing updated AI model data
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
Dictionary containing updated AI model information
|
|
315
|
+
"""
|
|
316
|
+
return self.patch(f"/api/aimodels/{model_id}/", json_data=data)
|
|
317
|
+
|
|
318
|
+
def delete_model(self, model_id: str) -> Dict[str, Any]:
|
|
319
|
+
"""
|
|
320
|
+
Delete an AI model.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
model_id: UUID of the AI model
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
Dictionary containing deletion response
|
|
327
|
+
"""
|
|
328
|
+
return self.delete(f"/api/aimodels/{model_id}/")
|
|
329
|
+
|
|
330
|
+
def call_model(
|
|
331
|
+
self, model_id: str, input_text: str, parameters: Optional[Dict[str, Any]] = None
|
|
332
|
+
) -> Dict[str, Any]:
|
|
333
|
+
"""
|
|
334
|
+
Call an AI model with input text using the appropriate client (API or HuggingFace).
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
model_id: UUID of the AI model
|
|
338
|
+
input_text: Input text to process
|
|
339
|
+
parameters: Optional parameters for the model call (temperature, max_tokens, etc.)
|
|
340
|
+
|
|
341
|
+
Returns:
|
|
342
|
+
Dictionary containing model response:
|
|
343
|
+
{
|
|
344
|
+
"success": bool,
|
|
345
|
+
"output": str (if successful),
|
|
346
|
+
"error": str (if failed),
|
|
347
|
+
"latency_ms": float,
|
|
348
|
+
"provider": str,
|
|
349
|
+
...
|
|
350
|
+
}
|
|
351
|
+
"""
|
|
352
|
+
return self.post(
|
|
353
|
+
f"/api/aimodels/{model_id}/call/",
|
|
354
|
+
json_data={"input_text": input_text, "parameters": parameters or {}},
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
def call_model_async(
|
|
358
|
+
self, model_id: str, input_text: str, parameters: Optional[Dict[str, Any]] = None
|
|
359
|
+
) -> Dict[str, Any]:
|
|
360
|
+
"""
|
|
361
|
+
Call an AI model asynchronously (returns task ID for long-running operations).
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
model_id: UUID of the AI model
|
|
365
|
+
input_text: Input text to process
|
|
366
|
+
parameters: Optional parameters for the model call
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
Dictionary containing task information:
|
|
370
|
+
{
|
|
371
|
+
"task_id": str,
|
|
372
|
+
"status": str,
|
|
373
|
+
"model_id": str
|
|
374
|
+
}
|
|
375
|
+
"""
|
|
376
|
+
return self.post(
|
|
377
|
+
f"/api/aimodels/{model_id}/call-async/",
|
|
378
|
+
json_data={"input_text": input_text, "parameters": parameters or {}},
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
# ==================== Version Management ====================
|
|
382
|
+
|
|
383
|
+
def get_versions(self, model_id: int) -> List[Dict[str, Any]]:
|
|
384
|
+
"""
|
|
385
|
+
Get all versions for an AI model.
|
|
386
|
+
|
|
387
|
+
Args:
|
|
388
|
+
model_id: ID of the AI model
|
|
389
|
+
|
|
390
|
+
Returns:
|
|
391
|
+
List of version dictionaries
|
|
392
|
+
"""
|
|
393
|
+
query = """
|
|
394
|
+
query GetModelVersions($filters: AIModelFilter) {
|
|
395
|
+
aiModels(filters: $filters) {
|
|
396
|
+
versions {
|
|
397
|
+
id
|
|
398
|
+
version
|
|
399
|
+
versionNotes
|
|
400
|
+
lifecycleStage
|
|
401
|
+
isLatest
|
|
402
|
+
supportsStreaming
|
|
403
|
+
maxTokens
|
|
404
|
+
supportedLanguages
|
|
405
|
+
status
|
|
406
|
+
createdAt
|
|
407
|
+
updatedAt
|
|
408
|
+
publishedAt
|
|
409
|
+
providers {
|
|
410
|
+
id
|
|
411
|
+
provider
|
|
412
|
+
providerModelId
|
|
413
|
+
isPrimary
|
|
414
|
+
isActive
|
|
415
|
+
}
|
|
416
|
+
}
|
|
417
|
+
}
|
|
418
|
+
}
|
|
419
|
+
"""
|
|
420
|
+
|
|
421
|
+
response = self.post(
|
|
422
|
+
"/api/graphql",
|
|
423
|
+
json_data={
|
|
424
|
+
"query": query,
|
|
425
|
+
"variables": {"filters": {"id": model_id}},
|
|
426
|
+
},
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
if "errors" in response:
|
|
430
|
+
from dataspace_sdk.exceptions import DataSpaceAPIError
|
|
431
|
+
|
|
432
|
+
raise DataSpaceAPIError(f"GraphQL error: {response['errors']}")
|
|
433
|
+
|
|
434
|
+
models = response.get("data", {}).get("aiModels", [])
|
|
435
|
+
if models:
|
|
436
|
+
result: List[Dict[str, Any]] = models[0].get("versions", [])
|
|
437
|
+
return result
|
|
438
|
+
return []
|
|
439
|
+
|
|
440
|
+
def create_version(
|
|
441
|
+
self,
|
|
442
|
+
model_id: int,
|
|
443
|
+
version: str,
|
|
444
|
+
lifecycle_stage: str = "DEVELOPMENT",
|
|
445
|
+
is_latest: bool = False,
|
|
446
|
+
copy_from_version_id: Optional[int] = None,
|
|
447
|
+
version_notes: Optional[str] = None,
|
|
448
|
+
supports_streaming: bool = False,
|
|
449
|
+
max_tokens: Optional[int] = None,
|
|
450
|
+
supported_languages: Optional[List[str]] = None,
|
|
451
|
+
) -> Dict[str, Any]:
|
|
452
|
+
"""
|
|
453
|
+
Create a new version for an AI model.
|
|
454
|
+
|
|
455
|
+
Args:
|
|
456
|
+
model_id: ID of the AI model
|
|
457
|
+
version: Version string (e.g., "1.0", "2.1")
|
|
458
|
+
lifecycle_stage: One of DEVELOPMENT, TESTING, BETA, STAGING, PRODUCTION, DEPRECATED, RETIRED
|
|
459
|
+
is_latest: Whether this should be the primary version
|
|
460
|
+
copy_from_version_id: Optional version ID to copy providers from
|
|
461
|
+
version_notes: Optional notes about this version
|
|
462
|
+
supports_streaming: Whether this version supports streaming
|
|
463
|
+
max_tokens: Maximum tokens supported
|
|
464
|
+
supported_languages: List of supported language codes
|
|
465
|
+
|
|
466
|
+
Returns:
|
|
467
|
+
Dictionary containing created version information
|
|
468
|
+
"""
|
|
469
|
+
mutation = """
|
|
470
|
+
mutation CreateAIModelVersion($input: CreateAIModelVersionInput!) {
|
|
471
|
+
createAiModelVersion(input: $input) {
|
|
472
|
+
success
|
|
473
|
+
data {
|
|
474
|
+
id
|
|
475
|
+
version
|
|
476
|
+
lifecycleStage
|
|
477
|
+
isLatest
|
|
478
|
+
status
|
|
479
|
+
}
|
|
480
|
+
errors
|
|
481
|
+
}
|
|
482
|
+
}
|
|
483
|
+
"""
|
|
484
|
+
|
|
485
|
+
input_data: Dict[str, Any] = {
|
|
486
|
+
"modelId": model_id,
|
|
487
|
+
"version": version,
|
|
488
|
+
"lifecycleStage": lifecycle_stage,
|
|
489
|
+
"isLatest": is_latest,
|
|
490
|
+
"supportsStreaming": supports_streaming,
|
|
491
|
+
}
|
|
492
|
+
|
|
493
|
+
if copy_from_version_id:
|
|
494
|
+
input_data["copyFromVersionId"] = copy_from_version_id
|
|
495
|
+
if version_notes:
|
|
496
|
+
input_data["versionNotes"] = version_notes
|
|
497
|
+
if max_tokens:
|
|
498
|
+
input_data["maxTokens"] = max_tokens
|
|
499
|
+
if supported_languages:
|
|
500
|
+
input_data["supportedLanguages"] = supported_languages
|
|
501
|
+
|
|
502
|
+
response = self.post(
|
|
503
|
+
"/api/graphql",
|
|
504
|
+
json_data={
|
|
505
|
+
"query": mutation,
|
|
506
|
+
"variables": {"input": input_data},
|
|
507
|
+
},
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
if "errors" in response:
|
|
511
|
+
from dataspace_sdk.exceptions import DataSpaceAPIError
|
|
512
|
+
|
|
513
|
+
raise DataSpaceAPIError(f"GraphQL error: {response['errors']}")
|
|
514
|
+
|
|
515
|
+
result: Dict[str, Any] = response.get("data", {}).get("createAiModelVersion", {})
|
|
516
|
+
return result
|
|
517
|
+
|
|
518
|
+
def update_version(
|
|
519
|
+
self,
|
|
520
|
+
version_id: int,
|
|
521
|
+
version: Optional[str] = None,
|
|
522
|
+
lifecycle_stage: Optional[str] = None,
|
|
523
|
+
is_latest: Optional[bool] = None,
|
|
524
|
+
version_notes: Optional[str] = None,
|
|
525
|
+
status: Optional[str] = None,
|
|
526
|
+
) -> Dict[str, Any]:
|
|
527
|
+
"""
|
|
528
|
+
Update an AI model version.
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
version_id: ID of the version to update
|
|
532
|
+
version: New version string
|
|
533
|
+
lifecycle_stage: New lifecycle stage
|
|
534
|
+
is_latest: Whether this should be the primary version
|
|
535
|
+
version_notes: New version notes
|
|
536
|
+
status: New status
|
|
537
|
+
|
|
538
|
+
Returns:
|
|
539
|
+
Dictionary containing updated version information
|
|
540
|
+
"""
|
|
541
|
+
mutation = """
|
|
542
|
+
mutation UpdateAIModelVersion($input: UpdateAIModelVersionInput!) {
|
|
543
|
+
updateAiModelVersion(input: $input) {
|
|
544
|
+
success
|
|
545
|
+
data {
|
|
546
|
+
id
|
|
547
|
+
version
|
|
548
|
+
lifecycleStage
|
|
549
|
+
isLatest
|
|
550
|
+
status
|
|
551
|
+
}
|
|
552
|
+
errors
|
|
553
|
+
}
|
|
554
|
+
}
|
|
555
|
+
"""
|
|
556
|
+
|
|
557
|
+
input_data: Dict[str, Any] = {"id": version_id}
|
|
558
|
+
|
|
559
|
+
if version is not None:
|
|
560
|
+
input_data["version"] = version
|
|
561
|
+
if lifecycle_stage is not None:
|
|
562
|
+
input_data["lifecycleStage"] = lifecycle_stage
|
|
563
|
+
if is_latest is not None:
|
|
564
|
+
input_data["isLatest"] = is_latest
|
|
565
|
+
if version_notes is not None:
|
|
566
|
+
input_data["versionNotes"] = version_notes
|
|
567
|
+
if status is not None:
|
|
568
|
+
input_data["status"] = status
|
|
569
|
+
|
|
570
|
+
response = self.post(
|
|
571
|
+
"/api/graphql",
|
|
572
|
+
json_data={
|
|
573
|
+
"query": mutation,
|
|
574
|
+
"variables": {"input": input_data},
|
|
575
|
+
},
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
if "errors" in response:
|
|
579
|
+
from dataspace_sdk.exceptions import DataSpaceAPIError
|
|
580
|
+
|
|
581
|
+
raise DataSpaceAPIError(f"GraphQL error: {response['errors']}")
|
|
582
|
+
|
|
583
|
+
result: Dict[str, Any] = response.get("data", {}).get("updateAiModelVersion", {})
|
|
584
|
+
return result
|
|
585
|
+
|
|
586
|
+
# ==================== Provider Management ====================
|
|
587
|
+
|
|
588
|
+
def get_version_providers(self, version_id: int) -> List[Dict[str, Any]]:
|
|
589
|
+
"""
|
|
590
|
+
Get all providers for a specific version.
|
|
591
|
+
|
|
592
|
+
Args:
|
|
593
|
+
version_id: ID of the version
|
|
594
|
+
|
|
595
|
+
Returns:
|
|
596
|
+
List of provider dictionaries
|
|
597
|
+
"""
|
|
598
|
+
query = """
|
|
599
|
+
query GetVersionProviders($versionId: Int!) {
|
|
600
|
+
aiModelVersion(id: $versionId) {
|
|
601
|
+
providers {
|
|
602
|
+
id
|
|
603
|
+
provider
|
|
604
|
+
providerModelId
|
|
605
|
+
isPrimary
|
|
606
|
+
isActive
|
|
607
|
+
# API Configuration
|
|
608
|
+
apiEndpointUrl
|
|
609
|
+
apiHttpMethod
|
|
610
|
+
apiTimeoutSeconds
|
|
611
|
+
apiAuthType
|
|
612
|
+
apiAuthHeaderName
|
|
613
|
+
apiKey
|
|
614
|
+
apiKeyPrefix
|
|
615
|
+
apiHeaders
|
|
616
|
+
apiRequestTemplate
|
|
617
|
+
apiResponsePath
|
|
618
|
+
# HuggingFace Configuration
|
|
619
|
+
hfUsePipeline
|
|
620
|
+
hfAuthToken
|
|
621
|
+
hfModelClass
|
|
622
|
+
hfAttnImplementation
|
|
623
|
+
hfTrustRemoteCode
|
|
624
|
+
hfTorchDtype
|
|
625
|
+
hfDeviceMap
|
|
626
|
+
framework
|
|
627
|
+
config
|
|
628
|
+
}
|
|
629
|
+
}
|
|
630
|
+
}
|
|
631
|
+
"""
|
|
632
|
+
|
|
633
|
+
response = self.post(
|
|
634
|
+
"/api/graphql",
|
|
635
|
+
json_data={
|
|
636
|
+
"query": query,
|
|
637
|
+
"variables": {"versionId": version_id},
|
|
638
|
+
},
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
if "errors" in response:
|
|
642
|
+
from dataspace_sdk.exceptions import DataSpaceAPIError
|
|
643
|
+
|
|
644
|
+
raise DataSpaceAPIError(f"GraphQL error: {response['errors']}")
|
|
645
|
+
|
|
646
|
+
version_data = response.get("data", {}).get("aiModelVersion", {})
|
|
647
|
+
result: List[Dict[str, Any]] = version_data.get("providers", []) if version_data else []
|
|
648
|
+
return result
|
|
649
|
+
|
|
650
|
+
def create_provider(
|
|
651
|
+
self,
|
|
652
|
+
version_id: int,
|
|
653
|
+
provider: str,
|
|
654
|
+
provider_model_id: str,
|
|
655
|
+
is_primary: bool = False,
|
|
656
|
+
# API Configuration
|
|
657
|
+
api_endpoint_url: Optional[str] = None,
|
|
658
|
+
api_http_method: str = "POST",
|
|
659
|
+
api_timeout_seconds: int = 60,
|
|
660
|
+
api_auth_type: str = "BEARER",
|
|
661
|
+
api_auth_header_name: str = "Authorization",
|
|
662
|
+
api_key: Optional[str] = None,
|
|
663
|
+
api_key_prefix: str = "Bearer",
|
|
664
|
+
api_headers: Optional[Dict[str, str]] = None,
|
|
665
|
+
api_request_template: Optional[Dict[str, Any]] = None,
|
|
666
|
+
api_response_path: Optional[str] = None,
|
|
667
|
+
# HuggingFace Configuration
|
|
668
|
+
hf_use_pipeline: bool = False,
|
|
669
|
+
hf_model_class: Optional[str] = None,
|
|
670
|
+
hf_auth_token: Optional[str] = None,
|
|
671
|
+
hf_attn_implementation: Optional[str] = None,
|
|
672
|
+
hf_trust_remote_code: bool = True,
|
|
673
|
+
hf_torch_dtype: Optional[str] = "auto",
|
|
674
|
+
hf_device_map: Optional[str] = "auto",
|
|
675
|
+
framework: Optional[str] = None,
|
|
676
|
+
config: Optional[Dict[str, Any]] = None,
|
|
677
|
+
) -> Dict[str, Any]:
|
|
678
|
+
"""
|
|
679
|
+
Create a new provider for a version.
|
|
680
|
+
|
|
681
|
+
Args:
|
|
682
|
+
version_id: ID of the version
|
|
683
|
+
provider: Provider type (OPENAI, LLAMA_OLLAMA, LLAMA_TOGETHER, LLAMA_REPLICATE,
|
|
684
|
+
LLAMA_CUSTOM, CUSTOM, HUGGINGFACE)
|
|
685
|
+
provider_model_id: Model ID at the provider (e.g., "gpt-4", "meta-llama/Llama-2-7b")
|
|
686
|
+
is_primary: Whether this is the primary provider
|
|
687
|
+
api_endpoint_url: Full URL for the API endpoint
|
|
688
|
+
api_http_method: HTTP method (POST, GET)
|
|
689
|
+
api_timeout_seconds: Request timeout in seconds
|
|
690
|
+
api_auth_type: Authentication type (BEARER, API_KEY, BASIC, OAUTH2, CUSTOM, NONE)
|
|
691
|
+
api_auth_header_name: Header name for authentication
|
|
692
|
+
api_key: API key or token
|
|
693
|
+
api_key_prefix: Prefix for the API key (e.g., "Bearer")
|
|
694
|
+
api_headers: Additional headers as dict
|
|
695
|
+
api_request_template: Request body template as dict
|
|
696
|
+
api_response_path: JSON path to extract response text
|
|
697
|
+
hf_use_pipeline: For HuggingFace - whether to use pipeline API
|
|
698
|
+
hf_model_class: For HuggingFace - model class (e.g., "AutoModelForCausalLM")
|
|
699
|
+
hf_auth_token: For HuggingFace - auth token for gated models
|
|
700
|
+
hf_attn_implementation: For HuggingFace - attention implementation
|
|
701
|
+
hf_trust_remote_code: For HuggingFace - trust remote code
|
|
702
|
+
hf_torch_dtype: For HuggingFace - torch dtype (auto, float16, bfloat16)
|
|
703
|
+
hf_device_map: For HuggingFace - device map (auto, cuda, cpu)
|
|
704
|
+
framework: Framework (pt, tf)
|
|
705
|
+
config: Additional configuration
|
|
706
|
+
|
|
707
|
+
Returns:
|
|
708
|
+
Dictionary containing created provider information
|
|
709
|
+
"""
|
|
710
|
+
mutation = """
|
|
711
|
+
mutation CreateVersionProvider($input: CreateVersionProviderInput!) {
|
|
712
|
+
createVersionProvider(input: $input) {
|
|
713
|
+
success
|
|
714
|
+
data {
|
|
715
|
+
id
|
|
716
|
+
provider
|
|
717
|
+
providerModelId
|
|
718
|
+
isPrimary
|
|
719
|
+
isActive
|
|
720
|
+
}
|
|
721
|
+
errors
|
|
722
|
+
}
|
|
723
|
+
}
|
|
724
|
+
"""
|
|
725
|
+
|
|
726
|
+
input_data: Dict[str, Any] = {
|
|
727
|
+
"versionId": version_id,
|
|
728
|
+
"provider": provider,
|
|
729
|
+
"providerModelId": provider_model_id,
|
|
730
|
+
"isPrimary": is_primary,
|
|
731
|
+
# API Configuration
|
|
732
|
+
"apiHttpMethod": api_http_method,
|
|
733
|
+
"apiTimeoutSeconds": api_timeout_seconds,
|
|
734
|
+
"apiAuthType": api_auth_type,
|
|
735
|
+
"apiAuthHeaderName": api_auth_header_name,
|
|
736
|
+
"apiKeyPrefix": api_key_prefix,
|
|
737
|
+
# HuggingFace Configuration
|
|
738
|
+
"hfUsePipeline": hf_use_pipeline,
|
|
739
|
+
"hfTrustRemoteCode": hf_trust_remote_code,
|
|
740
|
+
}
|
|
741
|
+
|
|
742
|
+
# Optional API fields
|
|
743
|
+
if api_endpoint_url:
|
|
744
|
+
input_data["apiEndpointUrl"] = api_endpoint_url
|
|
745
|
+
if api_key:
|
|
746
|
+
input_data["apiKey"] = api_key
|
|
747
|
+
if api_headers:
|
|
748
|
+
input_data["apiHeaders"] = api_headers
|
|
749
|
+
if api_request_template:
|
|
750
|
+
input_data["apiRequestTemplate"] = api_request_template
|
|
751
|
+
if api_response_path:
|
|
752
|
+
input_data["apiResponsePath"] = api_response_path
|
|
753
|
+
|
|
754
|
+
# Optional HuggingFace fields
|
|
755
|
+
if hf_model_class:
|
|
756
|
+
input_data["hfModelClass"] = hf_model_class
|
|
757
|
+
if hf_auth_token:
|
|
758
|
+
input_data["hfAuthToken"] = hf_auth_token
|
|
759
|
+
if hf_attn_implementation:
|
|
760
|
+
input_data["hfAttnImplementation"] = hf_attn_implementation
|
|
761
|
+
if hf_torch_dtype:
|
|
762
|
+
input_data["hfTorchDtype"] = hf_torch_dtype
|
|
763
|
+
if hf_device_map:
|
|
764
|
+
input_data["hfDeviceMap"] = hf_device_map
|
|
765
|
+
if framework:
|
|
766
|
+
input_data["framework"] = framework
|
|
767
|
+
if config:
|
|
768
|
+
input_data["config"] = config
|
|
769
|
+
|
|
770
|
+
response = self.post(
|
|
771
|
+
"/api/graphql",
|
|
772
|
+
json_data={
|
|
773
|
+
"query": mutation,
|
|
774
|
+
"variables": {"input": input_data},
|
|
775
|
+
},
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
if "errors" in response:
|
|
779
|
+
from dataspace_sdk.exceptions import DataSpaceAPIError
|
|
780
|
+
|
|
781
|
+
raise DataSpaceAPIError(f"GraphQL error: {response['errors']}")
|
|
782
|
+
|
|
783
|
+
result: Dict[str, Any] = response.get("data", {}).get("createVersionProvider", {})
|
|
784
|
+
return result
|
|
785
|
+
|
|
786
|
+
def update_provider(
|
|
787
|
+
self,
|
|
788
|
+
provider_id: int,
|
|
789
|
+
provider_model_id: Optional[str] = None,
|
|
790
|
+
is_primary: Optional[bool] = None,
|
|
791
|
+
# API Configuration
|
|
792
|
+
api_endpoint_url: Optional[str] = None,
|
|
793
|
+
api_http_method: Optional[str] = None,
|
|
794
|
+
api_timeout_seconds: Optional[int] = None,
|
|
795
|
+
api_auth_type: Optional[str] = None,
|
|
796
|
+
api_auth_header_name: Optional[str] = None,
|
|
797
|
+
api_key: Optional[str] = None,
|
|
798
|
+
api_key_prefix: Optional[str] = None,
|
|
799
|
+
api_headers: Optional[Dict[str, str]] = None,
|
|
800
|
+
api_request_template: Optional[Dict[str, Any]] = None,
|
|
801
|
+
api_response_path: Optional[str] = None,
|
|
802
|
+
# HuggingFace Configuration
|
|
803
|
+
hf_use_pipeline: Optional[bool] = None,
|
|
804
|
+
hf_model_class: Optional[str] = None,
|
|
805
|
+
hf_auth_token: Optional[str] = None,
|
|
806
|
+
hf_attn_implementation: Optional[str] = None,
|
|
807
|
+
hf_trust_remote_code: Optional[bool] = None,
|
|
808
|
+
hf_torch_dtype: Optional[str] = None,
|
|
809
|
+
hf_device_map: Optional[str] = None,
|
|
810
|
+
framework: Optional[str] = None,
|
|
811
|
+
config: Optional[Dict[str, Any]] = None,
|
|
812
|
+
) -> Dict[str, Any]:
|
|
813
|
+
"""
|
|
814
|
+
Update a provider.
|
|
815
|
+
|
|
816
|
+
Args:
|
|
817
|
+
provider_id: ID of the provider to update
|
|
818
|
+
provider_model_id: New model ID at the provider
|
|
819
|
+
is_primary: Whether this is the primary provider
|
|
820
|
+
api_endpoint_url: Full URL for the API endpoint
|
|
821
|
+
api_http_method: HTTP method (POST, GET)
|
|
822
|
+
api_timeout_seconds: Request timeout in seconds
|
|
823
|
+
api_auth_type: Authentication type (BEARER, API_KEY, BASIC, OAUTH2, CUSTOM, NONE)
|
|
824
|
+
api_auth_header_name: Header name for authentication
|
|
825
|
+
api_key: API key or token
|
|
826
|
+
api_key_prefix: Prefix for the API key (e.g., "Bearer")
|
|
827
|
+
api_headers: Additional headers as dict
|
|
828
|
+
api_request_template: Request body template as dict
|
|
829
|
+
api_response_path: JSON path to extract response text
|
|
830
|
+
hf_use_pipeline: For HuggingFace - whether to use pipeline API
|
|
831
|
+
hf_model_class: For HuggingFace - model class
|
|
832
|
+
hf_auth_token: For HuggingFace - auth token
|
|
833
|
+
hf_attn_implementation: For HuggingFace - attention implementation
|
|
834
|
+
hf_trust_remote_code: For HuggingFace - trust remote code
|
|
835
|
+
hf_torch_dtype: For HuggingFace - torch dtype
|
|
836
|
+
hf_device_map: For HuggingFace - device map
|
|
837
|
+
framework: Framework (pt, tf)
|
|
838
|
+
config: Additional configuration
|
|
839
|
+
|
|
840
|
+
Returns:
|
|
841
|
+
Dictionary containing updated provider information
|
|
842
|
+
"""
|
|
843
|
+
mutation = """
|
|
844
|
+
mutation UpdateVersionProvider($input: UpdateVersionProviderInput!) {
|
|
845
|
+
updateVersionProvider(input: $input) {
|
|
846
|
+
success
|
|
847
|
+
data {
|
|
848
|
+
id
|
|
849
|
+
provider
|
|
850
|
+
providerModelId
|
|
851
|
+
isPrimary
|
|
852
|
+
isActive
|
|
853
|
+
}
|
|
854
|
+
errors
|
|
855
|
+
}
|
|
856
|
+
}
|
|
857
|
+
"""
|
|
858
|
+
|
|
859
|
+
input_data: Dict[str, Any] = {"id": provider_id}
|
|
860
|
+
|
|
861
|
+
if provider_model_id is not None:
|
|
862
|
+
input_data["providerModelId"] = provider_model_id
|
|
863
|
+
if is_primary is not None:
|
|
864
|
+
input_data["isPrimary"] = is_primary
|
|
865
|
+
# API Configuration
|
|
866
|
+
if api_endpoint_url is not None:
|
|
867
|
+
input_data["apiEndpointUrl"] = api_endpoint_url
|
|
868
|
+
if api_http_method is not None:
|
|
869
|
+
input_data["apiHttpMethod"] = api_http_method
|
|
870
|
+
if api_timeout_seconds is not None:
|
|
871
|
+
input_data["apiTimeoutSeconds"] = api_timeout_seconds
|
|
872
|
+
if api_auth_type is not None:
|
|
873
|
+
input_data["apiAuthType"] = api_auth_type
|
|
874
|
+
if api_auth_header_name is not None:
|
|
875
|
+
input_data["apiAuthHeaderName"] = api_auth_header_name
|
|
876
|
+
if api_key is not None:
|
|
877
|
+
input_data["apiKey"] = api_key
|
|
878
|
+
if api_key_prefix is not None:
|
|
879
|
+
input_data["apiKeyPrefix"] = api_key_prefix
|
|
880
|
+
if api_headers is not None:
|
|
881
|
+
input_data["apiHeaders"] = api_headers
|
|
882
|
+
if api_request_template is not None:
|
|
883
|
+
input_data["apiRequestTemplate"] = api_request_template
|
|
884
|
+
if api_response_path is not None:
|
|
885
|
+
input_data["apiResponsePath"] = api_response_path
|
|
886
|
+
# HuggingFace Configuration
|
|
887
|
+
if hf_use_pipeline is not None:
|
|
888
|
+
input_data["hfUsePipeline"] = hf_use_pipeline
|
|
889
|
+
if hf_model_class is not None:
|
|
890
|
+
input_data["hfModelClass"] = hf_model_class
|
|
891
|
+
if hf_auth_token is not None:
|
|
892
|
+
input_data["hfAuthToken"] = hf_auth_token
|
|
893
|
+
if hf_attn_implementation is not None:
|
|
894
|
+
input_data["hfAttnImplementation"] = hf_attn_implementation
|
|
895
|
+
if hf_trust_remote_code is not None:
|
|
896
|
+
input_data["hfTrustRemoteCode"] = hf_trust_remote_code
|
|
897
|
+
if hf_torch_dtype is not None:
|
|
898
|
+
input_data["hfTorchDtype"] = hf_torch_dtype
|
|
899
|
+
if hf_device_map is not None:
|
|
900
|
+
input_data["hfDeviceMap"] = hf_device_map
|
|
901
|
+
if framework is not None:
|
|
902
|
+
input_data["framework"] = framework
|
|
903
|
+
if config is not None:
|
|
904
|
+
input_data["config"] = config
|
|
905
|
+
|
|
906
|
+
response = self.post(
|
|
907
|
+
"/api/graphql",
|
|
908
|
+
json_data={
|
|
909
|
+
"query": mutation,
|
|
910
|
+
"variables": {"input": input_data},
|
|
911
|
+
},
|
|
912
|
+
)
|
|
913
|
+
|
|
914
|
+
if "errors" in response:
|
|
915
|
+
from dataspace_sdk.exceptions import DataSpaceAPIError
|
|
916
|
+
|
|
917
|
+
raise DataSpaceAPIError(f"GraphQL error: {response['errors']}")
|
|
918
|
+
|
|
919
|
+
result: Dict[str, Any] = response.get("data", {}).get("updateVersionProvider", {})
|
|
920
|
+
return result
|
|
921
|
+
|
|
922
|
+
def delete_provider(self, provider_id: int) -> Dict[str, Any]:
|
|
923
|
+
"""
|
|
924
|
+
Delete a provider.
|
|
925
|
+
|
|
926
|
+
Args:
|
|
927
|
+
provider_id: ID of the provider to delete
|
|
928
|
+
|
|
929
|
+
Returns:
|
|
930
|
+
Dictionary containing deletion response
|
|
931
|
+
"""
|
|
932
|
+
mutation = """
|
|
933
|
+
mutation DeleteVersionProvider($providerId: Int!) {
|
|
934
|
+
deleteVersionProvider(providerId: $providerId) {
|
|
935
|
+
success
|
|
936
|
+
errors
|
|
937
|
+
}
|
|
938
|
+
}
|
|
939
|
+
"""
|
|
940
|
+
|
|
941
|
+
response = self.post(
|
|
942
|
+
"/api/graphql",
|
|
943
|
+
json_data={
|
|
944
|
+
"query": mutation,
|
|
945
|
+
"variables": {"providerId": provider_id},
|
|
946
|
+
},
|
|
947
|
+
)
|
|
948
|
+
|
|
949
|
+
if "errors" in response:
|
|
950
|
+
from dataspace_sdk.exceptions import DataSpaceAPIError
|
|
951
|
+
|
|
952
|
+
raise DataSpaceAPIError(f"GraphQL error: {response['errors']}")
|
|
953
|
+
|
|
954
|
+
result: Dict[str, Any] = response.get("data", {}).get("deleteVersionProvider", {})
|
|
955
|
+
return result
|
|
956
|
+
|
|
957
|
+
# ==================== Helper Methods ====================
|
|
958
|
+
|
|
959
|
+
def get_primary_version(self, model_id: int) -> Optional[Dict[str, Any]]:
|
|
960
|
+
"""
|
|
961
|
+
Get the primary (latest) version for an AI model.
|
|
962
|
+
|
|
963
|
+
Args:
|
|
964
|
+
model_id: ID of the AI model
|
|
965
|
+
|
|
966
|
+
Returns:
|
|
967
|
+
Dictionary containing the primary version, or None if no versions exist
|
|
968
|
+
"""
|
|
969
|
+
versions = self.get_versions(model_id)
|
|
970
|
+
for version in versions:
|
|
971
|
+
if version.get("isLatest"):
|
|
972
|
+
return version
|
|
973
|
+
return versions[0] if versions else None
|
|
974
|
+
|
|
975
|
+
def get_primary_provider(self, version_id: int) -> Optional[Dict[str, Any]]:
|
|
976
|
+
"""
|
|
977
|
+
Get the primary provider for a version.
|
|
978
|
+
|
|
979
|
+
Args:
|
|
980
|
+
version_id: ID of the version
|
|
981
|
+
|
|
982
|
+
Returns:
|
|
983
|
+
Dictionary containing the primary provider, or None if no providers exist
|
|
984
|
+
"""
|
|
985
|
+
providers = self.get_version_providers(version_id)
|
|
986
|
+
for provider in providers:
|
|
987
|
+
if provider.get("isPrimary"):
|
|
988
|
+
return provider
|
|
989
|
+
return providers[0] if providers else None
|