dataspace-sdk 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,989 @@
1
+ """AI Model resource client for DataSpace SDK."""
2
+
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ from dataspace_sdk.base import BaseAPIClient
6
+
7
+
8
+ class AIModelClient(BaseAPIClient):
9
+ """Client for interacting with AI Model resources."""
10
+
11
+ def search(
12
+ self,
13
+ query: Optional[str] = None,
14
+ tags: Optional[List[str]] = None,
15
+ sectors: Optional[List[str]] = None,
16
+ geographies: Optional[List[str]] = None,
17
+ status: Optional[str] = None,
18
+ model_type: Optional[str] = None,
19
+ provider: Optional[str] = None,
20
+ sort: Optional[str] = None,
21
+ page: int = 1,
22
+ page_size: int = 10,
23
+ ) -> Dict[str, Any]:
24
+ """
25
+ Search for AI models using Elasticsearch.
26
+
27
+ Args:
28
+ query: Search query string
29
+ tags: Filter by tags
30
+ sectors: Filter by sectors
31
+ geographies: Filter by geographies
32
+ status: Filter by status (ACTIVE, INACTIVE, etc.)
33
+ model_type: Filter by model type (LLM, VISION, etc.)
34
+ provider: Filter by provider (OPENAI, ANTHROPIC, etc.)
35
+ sort: Sort order (recent, alphabetical)
36
+ page: Page number (1-indexed)
37
+ page_size: Number of results per page
38
+
39
+ Returns:
40
+ Dictionary containing search results and metadata
41
+ """
42
+ params: Dict[str, Any] = {
43
+ "page": page,
44
+ "page_size": page_size,
45
+ }
46
+
47
+ if query:
48
+ params["q"] = query
49
+ if tags:
50
+ params["tags"] = ",".join(tags)
51
+ if sectors:
52
+ params["sectors"] = ",".join(sectors)
53
+ if geographies:
54
+ params["geographies"] = ",".join(geographies)
55
+ if status:
56
+ params["status"] = status
57
+ if model_type:
58
+ params["model_type"] = model_type
59
+ if provider:
60
+ params["provider"] = provider
61
+ if sort:
62
+ params["sort"] = sort
63
+
64
+ return super().get("/api/search/aimodel/", params=params)
65
+
66
+ def get_by_id(self, model_id: str) -> Dict[str, Any]:
67
+ """
68
+ Get an AI model by ID.
69
+
70
+ Args:
71
+ model_id: UUID of the AI model
72
+
73
+ Returns:
74
+ Dictionary containing AI model information
75
+ """
76
+ # Use parent class get method with full endpoint path
77
+ return super().get(f"/api/aimodels/{model_id}/")
78
+
79
+ def get_by_id_graphql(self, model_id: str) -> Dict[str, Any]:
80
+ """
81
+ Get an AI model by ID using GraphQL.
82
+
83
+ Args:
84
+ model_id: UUID of the AI model
85
+
86
+ Returns:
87
+ Dictionary containing AI model information
88
+ """
89
+ query = """
90
+ query GetAIModel($id: UUID!) {
91
+ aiModel(id: $id) {
92
+ id
93
+ name
94
+ displayName
95
+ description
96
+ modelType
97
+ status
98
+ isPublic
99
+ createdAt
100
+ updatedAt
101
+ organization {
102
+ id
103
+ name
104
+ }
105
+ tags {
106
+ id
107
+ value
108
+ }
109
+ sectors {
110
+ id
111
+ name
112
+ }
113
+ geographies {
114
+ id
115
+ name
116
+ }
117
+ versions {
118
+ id
119
+ version
120
+ versionNotes
121
+ lifecycleStage
122
+ isLatest
123
+ supportsStreaming
124
+ maxTokens
125
+ supportedLanguages
126
+ inputSchema
127
+ outputSchema
128
+ status
129
+ createdAt
130
+ updatedAt
131
+ publishedAt
132
+ providers {
133
+ id
134
+ provider
135
+ providerModelId
136
+ isPrimary
137
+ isActive
138
+ # API Configuration
139
+ apiEndpointUrl
140
+ apiHttpMethod
141
+ apiTimeoutSeconds
142
+ apiAuthType
143
+ apiAuthHeaderName
144
+ apiKey
145
+ apiKeyPrefix
146
+ apiHeaders
147
+ apiRequestTemplate
148
+ apiResponsePath
149
+ # HuggingFace Configuration
150
+ hfUsePipeline
151
+ hfAuthToken
152
+ hfModelClass
153
+ hfAttnImplementation
154
+ hfTrustRemoteCode
155
+ hfTorchDtype
156
+ hfDeviceMap
157
+ framework
158
+ config
159
+ }
160
+ }
161
+ }
162
+ }
163
+ """
164
+
165
+ response = self.post(
166
+ "/api/graphql",
167
+ json_data={
168
+ "query": query,
169
+ "variables": {"id": model_id},
170
+ },
171
+ )
172
+
173
+ if "errors" in response:
174
+ from dataspace_sdk.exceptions import DataSpaceAPIError
175
+
176
+ raise DataSpaceAPIError(f"GraphQL error: {response['errors']}")
177
+
178
+ result: Dict[str, Any] = response.get("data", {}).get("aiModel", {})
179
+ return result
180
+
181
+ def list_all(
182
+ self,
183
+ status: Optional[str] = None,
184
+ organization_id: Optional[str] = None,
185
+ model_type: Optional[str] = None,
186
+ limit: int = 10,
187
+ offset: int = 0,
188
+ ) -> Any:
189
+ """
190
+ List all AI models with pagination using GraphQL.
191
+
192
+ Args:
193
+ status: Filter by status
194
+ organization_id: Filter by organization
195
+ model_type: Filter by model type
196
+ limit: Number of results to return
197
+ offset: Number of results to skip
198
+
199
+ Returns:
200
+ Dictionary containing list of AI models
201
+ """
202
+ query = """
203
+ query ListAIModels($filters: AIModelFilter, $pagination: OffsetPaginationInput) {
204
+ aiModels(filters: $filters, pagination: $pagination) {
205
+ id
206
+ name
207
+ displayName
208
+ description
209
+ modelType
210
+ status
211
+ isPublic
212
+ createdAt
213
+ updatedAt
214
+ organization {
215
+ id
216
+ name
217
+ }
218
+ tags {
219
+ id
220
+ value
221
+ }
222
+ versions {
223
+ id
224
+ version
225
+ lifecycleStage
226
+ isLatest
227
+ status
228
+ providers {
229
+ id
230
+ provider
231
+ providerModelId
232
+ isPrimary
233
+ }
234
+ }
235
+ }
236
+ }
237
+ """
238
+
239
+ filters: Dict[str, Any] = {}
240
+ if status:
241
+ filters["status"] = status
242
+ if organization_id:
243
+ filters["organization"] = {"id": {"exact": organization_id}}
244
+ if model_type:
245
+ filters["modelType"] = model_type
246
+
247
+ variables: Dict[str, Any] = {
248
+ "pagination": {"limit": limit, "offset": offset},
249
+ }
250
+ if filters:
251
+ variables["filters"] = filters
252
+
253
+ response = self.post(
254
+ "/api/graphql",
255
+ json_data={
256
+ "query": query,
257
+ "variables": variables,
258
+ },
259
+ )
260
+
261
+ if "errors" in response:
262
+ from dataspace_sdk.exceptions import DataSpaceAPIError
263
+
264
+ raise DataSpaceAPIError(f"GraphQL error: {response['errors']}")
265
+
266
+ data = response.get("data", {})
267
+ models_result: Any = data.get("aiModels", []) if isinstance(data, dict) else []
268
+ return models_result
269
+
270
+ def get_organization_models(
271
+ self,
272
+ organization_id: str,
273
+ limit: int = 10,
274
+ offset: int = 0,
275
+ ) -> Any:
276
+ """
277
+ Get AI models for a specific organization.
278
+
279
+ Args:
280
+ organization_id: UUID of the organization
281
+ limit: Number of results to return
282
+ offset: Number of results to skip
283
+
284
+ Returns:
285
+ Dictionary containing organization's AI models
286
+ """
287
+ return self.list_all(
288
+ organization_id=organization_id,
289
+ limit=limit,
290
+ offset=offset,
291
+ )
292
+
293
+ def create(self, data: Dict[str, Any]) -> Dict[str, Any]:
294
+ """
295
+ Create a new AI model.
296
+
297
+ Args:
298
+ data: Dictionary containing AI model data
299
+
300
+ Returns:
301
+ Dictionary containing created AI model information
302
+ """
303
+ return self.post("/api/aimodels/", json_data=data)
304
+
305
+ def update(self, model_id: str, data: Dict[str, Any]) -> Dict[str, Any]:
306
+ """
307
+ Update an existing AI model.
308
+
309
+ Args:
310
+ model_id: UUID of the AI model
311
+ data: Dictionary containing updated AI model data
312
+
313
+ Returns:
314
+ Dictionary containing updated AI model information
315
+ """
316
+ return self.patch(f"/api/aimodels/{model_id}/", json_data=data)
317
+
318
+ def delete_model(self, model_id: str) -> Dict[str, Any]:
319
+ """
320
+ Delete an AI model.
321
+
322
+ Args:
323
+ model_id: UUID of the AI model
324
+
325
+ Returns:
326
+ Dictionary containing deletion response
327
+ """
328
+ return self.delete(f"/api/aimodels/{model_id}/")
329
+
330
+ def call_model(
331
+ self, model_id: str, input_text: str, parameters: Optional[Dict[str, Any]] = None
332
+ ) -> Dict[str, Any]:
333
+ """
334
+ Call an AI model with input text using the appropriate client (API or HuggingFace).
335
+
336
+ Args:
337
+ model_id: UUID of the AI model
338
+ input_text: Input text to process
339
+ parameters: Optional parameters for the model call (temperature, max_tokens, etc.)
340
+
341
+ Returns:
342
+ Dictionary containing model response:
343
+ {
344
+ "success": bool,
345
+ "output": str (if successful),
346
+ "error": str (if failed),
347
+ "latency_ms": float,
348
+ "provider": str,
349
+ ...
350
+ }
351
+ """
352
+ return self.post(
353
+ f"/api/aimodels/{model_id}/call/",
354
+ json_data={"input_text": input_text, "parameters": parameters or {}},
355
+ )
356
+
357
+ def call_model_async(
358
+ self, model_id: str, input_text: str, parameters: Optional[Dict[str, Any]] = None
359
+ ) -> Dict[str, Any]:
360
+ """
361
+ Call an AI model asynchronously (returns task ID for long-running operations).
362
+
363
+ Args:
364
+ model_id: UUID of the AI model
365
+ input_text: Input text to process
366
+ parameters: Optional parameters for the model call
367
+
368
+ Returns:
369
+ Dictionary containing task information:
370
+ {
371
+ "task_id": str,
372
+ "status": str,
373
+ "model_id": str
374
+ }
375
+ """
376
+ return self.post(
377
+ f"/api/aimodels/{model_id}/call-async/",
378
+ json_data={"input_text": input_text, "parameters": parameters or {}},
379
+ )
380
+
381
+ # ==================== Version Management ====================
382
+
383
+ def get_versions(self, model_id: int) -> List[Dict[str, Any]]:
384
+ """
385
+ Get all versions for an AI model.
386
+
387
+ Args:
388
+ model_id: ID of the AI model
389
+
390
+ Returns:
391
+ List of version dictionaries
392
+ """
393
+ query = """
394
+ query GetModelVersions($filters: AIModelFilter) {
395
+ aiModels(filters: $filters) {
396
+ versions {
397
+ id
398
+ version
399
+ versionNotes
400
+ lifecycleStage
401
+ isLatest
402
+ supportsStreaming
403
+ maxTokens
404
+ supportedLanguages
405
+ status
406
+ createdAt
407
+ updatedAt
408
+ publishedAt
409
+ providers {
410
+ id
411
+ provider
412
+ providerModelId
413
+ isPrimary
414
+ isActive
415
+ }
416
+ }
417
+ }
418
+ }
419
+ """
420
+
421
+ response = self.post(
422
+ "/api/graphql",
423
+ json_data={
424
+ "query": query,
425
+ "variables": {"filters": {"id": model_id}},
426
+ },
427
+ )
428
+
429
+ if "errors" in response:
430
+ from dataspace_sdk.exceptions import DataSpaceAPIError
431
+
432
+ raise DataSpaceAPIError(f"GraphQL error: {response['errors']}")
433
+
434
+ models = response.get("data", {}).get("aiModels", [])
435
+ if models:
436
+ result: List[Dict[str, Any]] = models[0].get("versions", [])
437
+ return result
438
+ return []
439
+
440
+ def create_version(
441
+ self,
442
+ model_id: int,
443
+ version: str,
444
+ lifecycle_stage: str = "DEVELOPMENT",
445
+ is_latest: bool = False,
446
+ copy_from_version_id: Optional[int] = None,
447
+ version_notes: Optional[str] = None,
448
+ supports_streaming: bool = False,
449
+ max_tokens: Optional[int] = None,
450
+ supported_languages: Optional[List[str]] = None,
451
+ ) -> Dict[str, Any]:
452
+ """
453
+ Create a new version for an AI model.
454
+
455
+ Args:
456
+ model_id: ID of the AI model
457
+ version: Version string (e.g., "1.0", "2.1")
458
+ lifecycle_stage: One of DEVELOPMENT, TESTING, BETA, STAGING, PRODUCTION, DEPRECATED, RETIRED
459
+ is_latest: Whether this should be the primary version
460
+ copy_from_version_id: Optional version ID to copy providers from
461
+ version_notes: Optional notes about this version
462
+ supports_streaming: Whether this version supports streaming
463
+ max_tokens: Maximum tokens supported
464
+ supported_languages: List of supported language codes
465
+
466
+ Returns:
467
+ Dictionary containing created version information
468
+ """
469
+ mutation = """
470
+ mutation CreateAIModelVersion($input: CreateAIModelVersionInput!) {
471
+ createAiModelVersion(input: $input) {
472
+ success
473
+ data {
474
+ id
475
+ version
476
+ lifecycleStage
477
+ isLatest
478
+ status
479
+ }
480
+ errors
481
+ }
482
+ }
483
+ """
484
+
485
+ input_data: Dict[str, Any] = {
486
+ "modelId": model_id,
487
+ "version": version,
488
+ "lifecycleStage": lifecycle_stage,
489
+ "isLatest": is_latest,
490
+ "supportsStreaming": supports_streaming,
491
+ }
492
+
493
+ if copy_from_version_id:
494
+ input_data["copyFromVersionId"] = copy_from_version_id
495
+ if version_notes:
496
+ input_data["versionNotes"] = version_notes
497
+ if max_tokens:
498
+ input_data["maxTokens"] = max_tokens
499
+ if supported_languages:
500
+ input_data["supportedLanguages"] = supported_languages
501
+
502
+ response = self.post(
503
+ "/api/graphql",
504
+ json_data={
505
+ "query": mutation,
506
+ "variables": {"input": input_data},
507
+ },
508
+ )
509
+
510
+ if "errors" in response:
511
+ from dataspace_sdk.exceptions import DataSpaceAPIError
512
+
513
+ raise DataSpaceAPIError(f"GraphQL error: {response['errors']}")
514
+
515
+ result: Dict[str, Any] = response.get("data", {}).get("createAiModelVersion", {})
516
+ return result
517
+
518
+ def update_version(
519
+ self,
520
+ version_id: int,
521
+ version: Optional[str] = None,
522
+ lifecycle_stage: Optional[str] = None,
523
+ is_latest: Optional[bool] = None,
524
+ version_notes: Optional[str] = None,
525
+ status: Optional[str] = None,
526
+ ) -> Dict[str, Any]:
527
+ """
528
+ Update an AI model version.
529
+
530
+ Args:
531
+ version_id: ID of the version to update
532
+ version: New version string
533
+ lifecycle_stage: New lifecycle stage
534
+ is_latest: Whether this should be the primary version
535
+ version_notes: New version notes
536
+ status: New status
537
+
538
+ Returns:
539
+ Dictionary containing updated version information
540
+ """
541
+ mutation = """
542
+ mutation UpdateAIModelVersion($input: UpdateAIModelVersionInput!) {
543
+ updateAiModelVersion(input: $input) {
544
+ success
545
+ data {
546
+ id
547
+ version
548
+ lifecycleStage
549
+ isLatest
550
+ status
551
+ }
552
+ errors
553
+ }
554
+ }
555
+ """
556
+
557
+ input_data: Dict[str, Any] = {"id": version_id}
558
+
559
+ if version is not None:
560
+ input_data["version"] = version
561
+ if lifecycle_stage is not None:
562
+ input_data["lifecycleStage"] = lifecycle_stage
563
+ if is_latest is not None:
564
+ input_data["isLatest"] = is_latest
565
+ if version_notes is not None:
566
+ input_data["versionNotes"] = version_notes
567
+ if status is not None:
568
+ input_data["status"] = status
569
+
570
+ response = self.post(
571
+ "/api/graphql",
572
+ json_data={
573
+ "query": mutation,
574
+ "variables": {"input": input_data},
575
+ },
576
+ )
577
+
578
+ if "errors" in response:
579
+ from dataspace_sdk.exceptions import DataSpaceAPIError
580
+
581
+ raise DataSpaceAPIError(f"GraphQL error: {response['errors']}")
582
+
583
+ result: Dict[str, Any] = response.get("data", {}).get("updateAiModelVersion", {})
584
+ return result
585
+
586
+ # ==================== Provider Management ====================
587
+
588
+ def get_version_providers(self, version_id: int) -> List[Dict[str, Any]]:
589
+ """
590
+ Get all providers for a specific version.
591
+
592
+ Args:
593
+ version_id: ID of the version
594
+
595
+ Returns:
596
+ List of provider dictionaries
597
+ """
598
+ query = """
599
+ query GetVersionProviders($versionId: Int!) {
600
+ aiModelVersion(id: $versionId) {
601
+ providers {
602
+ id
603
+ provider
604
+ providerModelId
605
+ isPrimary
606
+ isActive
607
+ # API Configuration
608
+ apiEndpointUrl
609
+ apiHttpMethod
610
+ apiTimeoutSeconds
611
+ apiAuthType
612
+ apiAuthHeaderName
613
+ apiKey
614
+ apiKeyPrefix
615
+ apiHeaders
616
+ apiRequestTemplate
617
+ apiResponsePath
618
+ # HuggingFace Configuration
619
+ hfUsePipeline
620
+ hfAuthToken
621
+ hfModelClass
622
+ hfAttnImplementation
623
+ hfTrustRemoteCode
624
+ hfTorchDtype
625
+ hfDeviceMap
626
+ framework
627
+ config
628
+ }
629
+ }
630
+ }
631
+ """
632
+
633
+ response = self.post(
634
+ "/api/graphql",
635
+ json_data={
636
+ "query": query,
637
+ "variables": {"versionId": version_id},
638
+ },
639
+ )
640
+
641
+ if "errors" in response:
642
+ from dataspace_sdk.exceptions import DataSpaceAPIError
643
+
644
+ raise DataSpaceAPIError(f"GraphQL error: {response['errors']}")
645
+
646
+ version_data = response.get("data", {}).get("aiModelVersion", {})
647
+ result: List[Dict[str, Any]] = version_data.get("providers", []) if version_data else []
648
+ return result
649
+
650
+ def create_provider(
651
+ self,
652
+ version_id: int,
653
+ provider: str,
654
+ provider_model_id: str,
655
+ is_primary: bool = False,
656
+ # API Configuration
657
+ api_endpoint_url: Optional[str] = None,
658
+ api_http_method: str = "POST",
659
+ api_timeout_seconds: int = 60,
660
+ api_auth_type: str = "BEARER",
661
+ api_auth_header_name: str = "Authorization",
662
+ api_key: Optional[str] = None,
663
+ api_key_prefix: str = "Bearer",
664
+ api_headers: Optional[Dict[str, str]] = None,
665
+ api_request_template: Optional[Dict[str, Any]] = None,
666
+ api_response_path: Optional[str] = None,
667
+ # HuggingFace Configuration
668
+ hf_use_pipeline: bool = False,
669
+ hf_model_class: Optional[str] = None,
670
+ hf_auth_token: Optional[str] = None,
671
+ hf_attn_implementation: Optional[str] = None,
672
+ hf_trust_remote_code: bool = True,
673
+ hf_torch_dtype: Optional[str] = "auto",
674
+ hf_device_map: Optional[str] = "auto",
675
+ framework: Optional[str] = None,
676
+ config: Optional[Dict[str, Any]] = None,
677
+ ) -> Dict[str, Any]:
678
+ """
679
+ Create a new provider for a version.
680
+
681
+ Args:
682
+ version_id: ID of the version
683
+ provider: Provider type (OPENAI, LLAMA_OLLAMA, LLAMA_TOGETHER, LLAMA_REPLICATE,
684
+ LLAMA_CUSTOM, CUSTOM, HUGGINGFACE)
685
+ provider_model_id: Model ID at the provider (e.g., "gpt-4", "meta-llama/Llama-2-7b")
686
+ is_primary: Whether this is the primary provider
687
+ api_endpoint_url: Full URL for the API endpoint
688
+ api_http_method: HTTP method (POST, GET)
689
+ api_timeout_seconds: Request timeout in seconds
690
+ api_auth_type: Authentication type (BEARER, API_KEY, BASIC, OAUTH2, CUSTOM, NONE)
691
+ api_auth_header_name: Header name for authentication
692
+ api_key: API key or token
693
+ api_key_prefix: Prefix for the API key (e.g., "Bearer")
694
+ api_headers: Additional headers as dict
695
+ api_request_template: Request body template as dict
696
+ api_response_path: JSON path to extract response text
697
+ hf_use_pipeline: For HuggingFace - whether to use pipeline API
698
+ hf_model_class: For HuggingFace - model class (e.g., "AutoModelForCausalLM")
699
+ hf_auth_token: For HuggingFace - auth token for gated models
700
+ hf_attn_implementation: For HuggingFace - attention implementation
701
+ hf_trust_remote_code: For HuggingFace - trust remote code
702
+ hf_torch_dtype: For HuggingFace - torch dtype (auto, float16, bfloat16)
703
+ hf_device_map: For HuggingFace - device map (auto, cuda, cpu)
704
+ framework: Framework (pt, tf)
705
+ config: Additional configuration
706
+
707
+ Returns:
708
+ Dictionary containing created provider information
709
+ """
710
+ mutation = """
711
+ mutation CreateVersionProvider($input: CreateVersionProviderInput!) {
712
+ createVersionProvider(input: $input) {
713
+ success
714
+ data {
715
+ id
716
+ provider
717
+ providerModelId
718
+ isPrimary
719
+ isActive
720
+ }
721
+ errors
722
+ }
723
+ }
724
+ """
725
+
726
+ input_data: Dict[str, Any] = {
727
+ "versionId": version_id,
728
+ "provider": provider,
729
+ "providerModelId": provider_model_id,
730
+ "isPrimary": is_primary,
731
+ # API Configuration
732
+ "apiHttpMethod": api_http_method,
733
+ "apiTimeoutSeconds": api_timeout_seconds,
734
+ "apiAuthType": api_auth_type,
735
+ "apiAuthHeaderName": api_auth_header_name,
736
+ "apiKeyPrefix": api_key_prefix,
737
+ # HuggingFace Configuration
738
+ "hfUsePipeline": hf_use_pipeline,
739
+ "hfTrustRemoteCode": hf_trust_remote_code,
740
+ }
741
+
742
+ # Optional API fields
743
+ if api_endpoint_url:
744
+ input_data["apiEndpointUrl"] = api_endpoint_url
745
+ if api_key:
746
+ input_data["apiKey"] = api_key
747
+ if api_headers:
748
+ input_data["apiHeaders"] = api_headers
749
+ if api_request_template:
750
+ input_data["apiRequestTemplate"] = api_request_template
751
+ if api_response_path:
752
+ input_data["apiResponsePath"] = api_response_path
753
+
754
+ # Optional HuggingFace fields
755
+ if hf_model_class:
756
+ input_data["hfModelClass"] = hf_model_class
757
+ if hf_auth_token:
758
+ input_data["hfAuthToken"] = hf_auth_token
759
+ if hf_attn_implementation:
760
+ input_data["hfAttnImplementation"] = hf_attn_implementation
761
+ if hf_torch_dtype:
762
+ input_data["hfTorchDtype"] = hf_torch_dtype
763
+ if hf_device_map:
764
+ input_data["hfDeviceMap"] = hf_device_map
765
+ if framework:
766
+ input_data["framework"] = framework
767
+ if config:
768
+ input_data["config"] = config
769
+
770
+ response = self.post(
771
+ "/api/graphql",
772
+ json_data={
773
+ "query": mutation,
774
+ "variables": {"input": input_data},
775
+ },
776
+ )
777
+
778
+ if "errors" in response:
779
+ from dataspace_sdk.exceptions import DataSpaceAPIError
780
+
781
+ raise DataSpaceAPIError(f"GraphQL error: {response['errors']}")
782
+
783
+ result: Dict[str, Any] = response.get("data", {}).get("createVersionProvider", {})
784
+ return result
785
+
786
+ def update_provider(
787
+ self,
788
+ provider_id: int,
789
+ provider_model_id: Optional[str] = None,
790
+ is_primary: Optional[bool] = None,
791
+ # API Configuration
792
+ api_endpoint_url: Optional[str] = None,
793
+ api_http_method: Optional[str] = None,
794
+ api_timeout_seconds: Optional[int] = None,
795
+ api_auth_type: Optional[str] = None,
796
+ api_auth_header_name: Optional[str] = None,
797
+ api_key: Optional[str] = None,
798
+ api_key_prefix: Optional[str] = None,
799
+ api_headers: Optional[Dict[str, str]] = None,
800
+ api_request_template: Optional[Dict[str, Any]] = None,
801
+ api_response_path: Optional[str] = None,
802
+ # HuggingFace Configuration
803
+ hf_use_pipeline: Optional[bool] = None,
804
+ hf_model_class: Optional[str] = None,
805
+ hf_auth_token: Optional[str] = None,
806
+ hf_attn_implementation: Optional[str] = None,
807
+ hf_trust_remote_code: Optional[bool] = None,
808
+ hf_torch_dtype: Optional[str] = None,
809
+ hf_device_map: Optional[str] = None,
810
+ framework: Optional[str] = None,
811
+ config: Optional[Dict[str, Any]] = None,
812
+ ) -> Dict[str, Any]:
813
+ """
814
+ Update a provider.
815
+
816
+ Args:
817
+ provider_id: ID of the provider to update
818
+ provider_model_id: New model ID at the provider
819
+ is_primary: Whether this is the primary provider
820
+ api_endpoint_url: Full URL for the API endpoint
821
+ api_http_method: HTTP method (POST, GET)
822
+ api_timeout_seconds: Request timeout in seconds
823
+ api_auth_type: Authentication type (BEARER, API_KEY, BASIC, OAUTH2, CUSTOM, NONE)
824
+ api_auth_header_name: Header name for authentication
825
+ api_key: API key or token
826
+ api_key_prefix: Prefix for the API key (e.g., "Bearer")
827
+ api_headers: Additional headers as dict
828
+ api_request_template: Request body template as dict
829
+ api_response_path: JSON path to extract response text
830
+ hf_use_pipeline: For HuggingFace - whether to use pipeline API
831
+ hf_model_class: For HuggingFace - model class
832
+ hf_auth_token: For HuggingFace - auth token
833
+ hf_attn_implementation: For HuggingFace - attention implementation
834
+ hf_trust_remote_code: For HuggingFace - trust remote code
835
+ hf_torch_dtype: For HuggingFace - torch dtype
836
+ hf_device_map: For HuggingFace - device map
837
+ framework: Framework (pt, tf)
838
+ config: Additional configuration
839
+
840
+ Returns:
841
+ Dictionary containing updated provider information
842
+ """
843
+ mutation = """
844
+ mutation UpdateVersionProvider($input: UpdateVersionProviderInput!) {
845
+ updateVersionProvider(input: $input) {
846
+ success
847
+ data {
848
+ id
849
+ provider
850
+ providerModelId
851
+ isPrimary
852
+ isActive
853
+ }
854
+ errors
855
+ }
856
+ }
857
+ """
858
+
859
+ input_data: Dict[str, Any] = {"id": provider_id}
860
+
861
+ if provider_model_id is not None:
862
+ input_data["providerModelId"] = provider_model_id
863
+ if is_primary is not None:
864
+ input_data["isPrimary"] = is_primary
865
+ # API Configuration
866
+ if api_endpoint_url is not None:
867
+ input_data["apiEndpointUrl"] = api_endpoint_url
868
+ if api_http_method is not None:
869
+ input_data["apiHttpMethod"] = api_http_method
870
+ if api_timeout_seconds is not None:
871
+ input_data["apiTimeoutSeconds"] = api_timeout_seconds
872
+ if api_auth_type is not None:
873
+ input_data["apiAuthType"] = api_auth_type
874
+ if api_auth_header_name is not None:
875
+ input_data["apiAuthHeaderName"] = api_auth_header_name
876
+ if api_key is not None:
877
+ input_data["apiKey"] = api_key
878
+ if api_key_prefix is not None:
879
+ input_data["apiKeyPrefix"] = api_key_prefix
880
+ if api_headers is not None:
881
+ input_data["apiHeaders"] = api_headers
882
+ if api_request_template is not None:
883
+ input_data["apiRequestTemplate"] = api_request_template
884
+ if api_response_path is not None:
885
+ input_data["apiResponsePath"] = api_response_path
886
+ # HuggingFace Configuration
887
+ if hf_use_pipeline is not None:
888
+ input_data["hfUsePipeline"] = hf_use_pipeline
889
+ if hf_model_class is not None:
890
+ input_data["hfModelClass"] = hf_model_class
891
+ if hf_auth_token is not None:
892
+ input_data["hfAuthToken"] = hf_auth_token
893
+ if hf_attn_implementation is not None:
894
+ input_data["hfAttnImplementation"] = hf_attn_implementation
895
+ if hf_trust_remote_code is not None:
896
+ input_data["hfTrustRemoteCode"] = hf_trust_remote_code
897
+ if hf_torch_dtype is not None:
898
+ input_data["hfTorchDtype"] = hf_torch_dtype
899
+ if hf_device_map is not None:
900
+ input_data["hfDeviceMap"] = hf_device_map
901
+ if framework is not None:
902
+ input_data["framework"] = framework
903
+ if config is not None:
904
+ input_data["config"] = config
905
+
906
+ response = self.post(
907
+ "/api/graphql",
908
+ json_data={
909
+ "query": mutation,
910
+ "variables": {"input": input_data},
911
+ },
912
+ )
913
+
914
+ if "errors" in response:
915
+ from dataspace_sdk.exceptions import DataSpaceAPIError
916
+
917
+ raise DataSpaceAPIError(f"GraphQL error: {response['errors']}")
918
+
919
+ result: Dict[str, Any] = response.get("data", {}).get("updateVersionProvider", {})
920
+ return result
921
+
922
+ def delete_provider(self, provider_id: int) -> Dict[str, Any]:
923
+ """
924
+ Delete a provider.
925
+
926
+ Args:
927
+ provider_id: ID of the provider to delete
928
+
929
+ Returns:
930
+ Dictionary containing deletion response
931
+ """
932
+ mutation = """
933
+ mutation DeleteVersionProvider($providerId: Int!) {
934
+ deleteVersionProvider(providerId: $providerId) {
935
+ success
936
+ errors
937
+ }
938
+ }
939
+ """
940
+
941
+ response = self.post(
942
+ "/api/graphql",
943
+ json_data={
944
+ "query": mutation,
945
+ "variables": {"providerId": provider_id},
946
+ },
947
+ )
948
+
949
+ if "errors" in response:
950
+ from dataspace_sdk.exceptions import DataSpaceAPIError
951
+
952
+ raise DataSpaceAPIError(f"GraphQL error: {response['errors']}")
953
+
954
+ result: Dict[str, Any] = response.get("data", {}).get("deleteVersionProvider", {})
955
+ return result
956
+
957
+ # ==================== Helper Methods ====================
958
+
959
+ def get_primary_version(self, model_id: int) -> Optional[Dict[str, Any]]:
960
+ """
961
+ Get the primary (latest) version for an AI model.
962
+
963
+ Args:
964
+ model_id: ID of the AI model
965
+
966
+ Returns:
967
+ Dictionary containing the primary version, or None if no versions exist
968
+ """
969
+ versions = self.get_versions(model_id)
970
+ for version in versions:
971
+ if version.get("isLatest"):
972
+ return version
973
+ return versions[0] if versions else None
974
+
975
+ def get_primary_provider(self, version_id: int) -> Optional[Dict[str, Any]]:
976
+ """
977
+ Get the primary provider for a version.
978
+
979
+ Args:
980
+ version_id: ID of the version
981
+
982
+ Returns:
983
+ Dictionary containing the primary provider, or None if no providers exist
984
+ """
985
+ providers = self.get_version_providers(version_id)
986
+ for provider in providers:
987
+ if provider.get("isPrimary"):
988
+ return provider
989
+ return providers[0] if providers else None