adaptive-sdk 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adaptive_sdk/graphql_client/__init__.py +6 -3
- adaptive_sdk/graphql_client/async_client.py +50 -26
- adaptive_sdk/graphql_client/client.py +50 -26
- adaptive_sdk/graphql_client/create_dataset_from_multipart_upload.py +18 -0
- adaptive_sdk/graphql_client/custom_fields.py +24 -2
- adaptive_sdk/graphql_client/custom_queries.py +14 -2
- adaptive_sdk/graphql_client/custom_typing_fields.py +7 -0
- adaptive_sdk/graphql_client/dataset_upload_processing_status.py +18 -0
- adaptive_sdk/graphql_client/enums.py +1 -4
- adaptive_sdk/graphql_client/fragments.py +12 -0
- adaptive_sdk/graphql_client/input_types.py +4 -13
- adaptive_sdk/graphql_client/resize_inference_partition.py +6 -0
- adaptive_sdk/resources/compute_pools.py +63 -2
- adaptive_sdk/resources/datasets.py +82 -122
- adaptive_sdk/resources/models.py +0 -8
- {adaptive_sdk-0.1.8.dist-info → adaptive_sdk-0.1.10.dist-info}/METADATA +2 -2
- {adaptive_sdk-0.1.8.dist-info → adaptive_sdk-0.1.10.dist-info}/RECORD +18 -15
- {adaptive_sdk-0.1.8.dist-info → adaptive_sdk-0.1.10.dist-info}/WHEEL +0 -0
|
@@ -240,6 +240,7 @@ class ModelData(BaseModel):
|
|
|
240
240
|
key: str
|
|
241
241
|
name: str
|
|
242
242
|
online: ModelOnline
|
|
243
|
+
error: Optional[str]
|
|
243
244
|
is_external: bool = Field(alias='isExternal')
|
|
244
245
|
provider_name: ProviderName = Field(alias='providerName')
|
|
245
246
|
is_adapter: bool = Field(alias='isAdapter')
|
|
@@ -340,6 +341,7 @@ class HarmonyGroupData(BaseModel):
|
|
|
340
341
|
gpu_types: str = Field(alias='gpuTypes')
|
|
341
342
|
created_at: int = Field(alias='createdAt')
|
|
342
343
|
online_models: List['HarmonyGroupDataOnlineModels'] = Field(alias='onlineModels')
|
|
344
|
+
gpu_allocations: Optional[List['HarmonyGroupDataGpuAllocations']] = Field(alias='gpuAllocations')
|
|
343
345
|
|
|
344
346
|
class HarmonyGroupDataComputePool(BaseModel):
|
|
345
347
|
"""@public"""
|
|
@@ -350,6 +352,15 @@ class HarmonyGroupDataOnlineModels(ModelData):
|
|
|
350
352
|
"""@public"""
|
|
351
353
|
pass
|
|
352
354
|
|
|
355
|
+
class HarmonyGroupDataGpuAllocations(BaseModel):
|
|
356
|
+
"""@public"""
|
|
357
|
+
name: str
|
|
358
|
+
num_gpus: int = Field(alias='numGpus')
|
|
359
|
+
ranks: List[int]
|
|
360
|
+
created_at: int = Field(alias='createdAt')
|
|
361
|
+
user_name: Optional[str] = Field(alias='userName')
|
|
362
|
+
job_id: str = Field(alias='jobId')
|
|
363
|
+
|
|
353
364
|
class JobData(BaseModel):
|
|
354
365
|
"""@public"""
|
|
355
366
|
id: Any
|
|
@@ -566,6 +577,7 @@ class ModelDataAdmin(BaseModel):
|
|
|
566
577
|
key: str
|
|
567
578
|
name: str
|
|
568
579
|
online: ModelOnline
|
|
580
|
+
error: Optional[str]
|
|
569
581
|
use_cases: List['ModelDataAdminUseCases'] = Field(alias='useCases')
|
|
570
582
|
is_external: bool = Field(alias='isExternal')
|
|
571
583
|
provider_name: ProviderName = Field(alias='providerName')
|
|
@@ -65,12 +65,6 @@ class AttachModel(BaseModel):
|
|
|
65
65
|
wait: bool = False
|
|
66
66
|
'Wait for the model to be deployed or not'
|
|
67
67
|
|
|
68
|
-
class AzureProviderDataInput(BaseModel):
|
|
69
|
-
"""@private"""
|
|
70
|
-
api_key: str = Field(alias='apiKey')
|
|
71
|
-
external_model_id: str = Field(alias='externalModelId')
|
|
72
|
-
endpoint: str
|
|
73
|
-
|
|
74
68
|
class CancelAllocationInput(BaseModel):
|
|
75
69
|
"""@private"""
|
|
76
70
|
harmony_group: str = Field(alias='harmonyGroup')
|
|
@@ -381,11 +375,9 @@ class ModelPlacementInput(BaseModel):
|
|
|
381
375
|
|
|
382
376
|
class ModelProviderDataInput(BaseModel):
|
|
383
377
|
"""@private"""
|
|
384
|
-
azure: Optional['AzureProviderDataInput'] = None
|
|
385
378
|
open_ai: Optional['OpenAIProviderDataInput'] = Field(alias='openAI', default=None)
|
|
386
379
|
google: Optional['GoogleProviderDataInput'] = None
|
|
387
380
|
anthropic: Optional['AnthropicProviderDataInput'] = None
|
|
388
|
-
nvidia: Optional['NvidiaProviderDataInput'] = None
|
|
389
381
|
|
|
390
382
|
class ModelServiceDisconnect(BaseModel):
|
|
391
383
|
"""@private"""
|
|
@@ -397,11 +389,6 @@ class ModelServiceFilter(BaseModel):
|
|
|
397
389
|
model: Optional[str] = None
|
|
398
390
|
capabilities: Optional['CapabilityFilter'] = None
|
|
399
391
|
|
|
400
|
-
class NvidiaProviderDataInput(BaseModel):
|
|
401
|
-
"""@private"""
|
|
402
|
-
external_model_id: str = Field(alias='externalModelId')
|
|
403
|
-
endpoint: str
|
|
404
|
-
|
|
405
392
|
class OpenAIProviderDataInput(BaseModel):
|
|
406
393
|
"""@private"""
|
|
407
394
|
api_key: str = Field(alias='apiKey')
|
|
@@ -451,6 +438,10 @@ class SampleConfig(BaseModel):
|
|
|
451
438
|
selection_type: SelectionTypeInput = Field(alias='selectionType')
|
|
452
439
|
sample_size: Optional[int] = Field(alias='sampleSize', default=None)
|
|
453
440
|
|
|
441
|
+
class SearchInput(BaseModel):
|
|
442
|
+
"""@private"""
|
|
443
|
+
query: str
|
|
444
|
+
|
|
454
445
|
class SystemPromptTemplateCreate(BaseModel):
|
|
455
446
|
"""@private"""
|
|
456
447
|
name: str
|
|
@@ -2,6 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
from typing import TYPE_CHECKING
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
from adaptive_sdk.graphql_client import ResizePartitionInput, HarmonyStatus
|
|
5
8
|
|
|
6
9
|
from .base_resource import SyncAPIResource, AsyncAPIResource, UseCaseResource
|
|
7
10
|
|
|
@@ -9,6 +12,12 @@ if TYPE_CHECKING:
|
|
|
9
12
|
from adaptive_sdk.client import Adaptive, AsyncAdaptive
|
|
10
13
|
|
|
11
14
|
|
|
15
|
+
class ResizeResult(BaseModel):
|
|
16
|
+
harmony_group_key: str
|
|
17
|
+
success: bool
|
|
18
|
+
error: str | None = None
|
|
19
|
+
|
|
20
|
+
|
|
12
21
|
class ComputePools(SyncAPIResource, UseCaseResource):
|
|
13
22
|
"""
|
|
14
23
|
Resource to interact with compute pools.
|
|
@@ -19,7 +28,33 @@ class ComputePools(SyncAPIResource, UseCaseResource):
|
|
|
19
28
|
UseCaseResource.__init__(self, client)
|
|
20
29
|
|
|
21
30
|
def list(self):
|
|
22
|
-
return self._gql_client.list_compute_pools()
|
|
31
|
+
return self._gql_client.list_compute_pools().compute_pools
|
|
32
|
+
|
|
33
|
+
def resize_inference_partition(self, compute_pool_key: str, size: int) -> list[ResizeResult]:
|
|
34
|
+
"""
|
|
35
|
+
Resize the inference partitions of all harmony groups in a compute pool.
|
|
36
|
+
"""
|
|
37
|
+
cps = self.list()
|
|
38
|
+
found_cp = False
|
|
39
|
+
for cp in cps.compute_pools:
|
|
40
|
+
if cp.key == compute_pool_key:
|
|
41
|
+
selected_cp = cp
|
|
42
|
+
found_cp = True
|
|
43
|
+
break
|
|
44
|
+
if not found_cp:
|
|
45
|
+
raise ValueError(f"Compute pool with key {compute_pool_key} not found")
|
|
46
|
+
|
|
47
|
+
resize_results: list[ResizeResult] = []
|
|
48
|
+
for hg in selected_cp.harmony_groups:
|
|
49
|
+
if hg.status == HarmonyStatus.ONLINE:
|
|
50
|
+
input = ResizePartitionInput(harmonyGroup=hg.key, size=size)
|
|
51
|
+
try:
|
|
52
|
+
_ = self._gql_client.resize_inference_partition(input)
|
|
53
|
+
resize_results.append(ResizeResult(harmony_group_key=hg.key, success=True))
|
|
54
|
+
except Exception as e:
|
|
55
|
+
resize_results.append(ResizeResult(harmony_group_key=hg.key, success=False, error_message=str(e)))
|
|
56
|
+
|
|
57
|
+
return resize_results
|
|
23
58
|
|
|
24
59
|
|
|
25
60
|
class AsyncComputePools(AsyncAPIResource, UseCaseResource):
|
|
@@ -32,4 +67,30 @@ class AsyncComputePools(AsyncAPIResource, UseCaseResource):
|
|
|
32
67
|
UseCaseResource.__init__(self, client)
|
|
33
68
|
|
|
34
69
|
async def list(self):
|
|
35
|
-
return await self._gql_client.list_compute_pools()
|
|
70
|
+
return (await self._gql_client.list_compute_pools()).compute_pools
|
|
71
|
+
|
|
72
|
+
async def resize_inference_partition(self, compute_pool_key: str, size: int) -> list[ResizeResult]:
|
|
73
|
+
"""
|
|
74
|
+
Resize the inference partitions of all harmony groups in a compute pool.
|
|
75
|
+
"""
|
|
76
|
+
cps = await self.list()
|
|
77
|
+
found_cp = False
|
|
78
|
+
for cp in cps.compute_pools:
|
|
79
|
+
if cp.key == compute_pool_key:
|
|
80
|
+
selected_cp = cp
|
|
81
|
+
found_cp = True
|
|
82
|
+
break
|
|
83
|
+
if not found_cp:
|
|
84
|
+
raise ValueError(f"Compute pool with key {compute_pool_key} not found")
|
|
85
|
+
|
|
86
|
+
resize_results: list[ResizeResult] = []
|
|
87
|
+
for hg in selected_cp.harmony_groups:
|
|
88
|
+
if hg.status == HarmonyStatus.ONLINE:
|
|
89
|
+
input = ResizePartitionInput(harmonyGroup=hg.key, size=size)
|
|
90
|
+
try:
|
|
91
|
+
_ = await self._gql_client.resize_inference_partition(input)
|
|
92
|
+
resize_results.append(ResizeResult(harmony_group_key=hg.key, success=True))
|
|
93
|
+
except Exception as e:
|
|
94
|
+
resize_results.append(ResizeResult(harmony_group_key=hg.key, success=False, error_message=str(e)))
|
|
95
|
+
|
|
96
|
+
return resize_results
|
|
@@ -2,22 +2,20 @@ from __future__ import annotations
|
|
|
2
2
|
import json
|
|
3
3
|
import math
|
|
4
4
|
import os
|
|
5
|
+
import time
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
from typing import List,
|
|
7
|
+
from typing import List, TYPE_CHECKING
|
|
7
8
|
|
|
8
9
|
from adaptive_sdk.graphql_client import (
|
|
9
10
|
DatasetCreate,
|
|
10
11
|
Upload,
|
|
11
|
-
LoadDatasetCreateDataset,
|
|
12
12
|
ListDatasetsDatasets,
|
|
13
13
|
DatasetData,
|
|
14
14
|
DatasetCreateFromMultipartUpload,
|
|
15
|
+
DatasetUploadProcessingStatusInput,
|
|
16
|
+
SessionStatus,
|
|
15
17
|
)
|
|
16
|
-
|
|
17
|
-
from adaptive_sdk.graphql_client.custom_fields import (
|
|
18
|
-
DatasetFields,
|
|
19
|
-
DatasetUploadProcessingStatusFields,
|
|
20
|
-
)
|
|
18
|
+
|
|
21
19
|
from adaptive_sdk.rest import rest_types
|
|
22
20
|
from adaptive_sdk.error_handling import rest_error_handler
|
|
23
21
|
|
|
@@ -29,23 +27,12 @@ if TYPE_CHECKING:
|
|
|
29
27
|
MIN_CHUNK_SIZE_BYTES = 5 * 1024 * 1024 # 5MB
|
|
30
28
|
MAX_CHUNK_SIZE_BYTES = 100 * 1024 * 1024 # 100MB
|
|
31
29
|
MAX_PARTS_COUNT = 10000
|
|
30
|
+
INIT_CHUNKED_UPLOAD_ROUTE = "/upload/init"
|
|
31
|
+
UPLOAD_PART_ROUTE = "/upload/part"
|
|
32
|
+
ABORT_CHUNKED_UPLOAD_ROUTE = "/upload/abort"
|
|
32
33
|
|
|
33
34
|
|
|
34
35
|
def _calculate_upload_parts(file_size: int) -> tuple[int, int]:
|
|
35
|
-
"""
|
|
36
|
-
Calculate optimal number of parts and chunk size for multipart upload.
|
|
37
|
-
|
|
38
|
-
Strategy: Scale chunk size based on file size for optimal performance
|
|
39
|
-
|
|
40
|
-
Args:
|
|
41
|
-
file_size: Size of the file in bytes
|
|
42
|
-
|
|
43
|
-
Returns:
|
|
44
|
-
Tuple of (total_parts, chunk_size_bytes)
|
|
45
|
-
|
|
46
|
-
Raises:
|
|
47
|
-
ValueError: If file is too large to upload with the given constraints
|
|
48
|
-
"""
|
|
49
36
|
if file_size < MIN_CHUNK_SIZE_BYTES:
|
|
50
37
|
raise ValueError(f"File size ({file_size:,} bytes) is too small for chunked upload")
|
|
51
38
|
|
|
@@ -97,19 +84,15 @@ class Datasets(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
97
84
|
dataset_key: str,
|
|
98
85
|
name: str | None = None,
|
|
99
86
|
use_case: str | None = None,
|
|
100
|
-
) ->
|
|
87
|
+
) -> DatasetData:
|
|
101
88
|
"""
|
|
102
|
-
|
|
89
|
+
Upload a dataset from a file. File must be jsonl, where each line should match supported structure.
|
|
103
90
|
|
|
104
91
|
Args:
|
|
105
92
|
file_path: Path to jsonl file.
|
|
106
93
|
dataset_key: New dataset key.
|
|
107
94
|
name: Optional name to render in UI; if `None`, defaults to same as `dataset_key`.
|
|
108
95
|
|
|
109
|
-
Example:
|
|
110
|
-
```
|
|
111
|
-
{"messages": [{"role": "system", "content": "<optional system prompt>"}, {"role": "user", "content": "<user content>"}, {"role": "assistant", "content": "<assistant answer>"}], "completion": "hey"}
|
|
112
|
-
```
|
|
113
96
|
"""
|
|
114
97
|
file_size = os.path.getsize(file_path)
|
|
115
98
|
|
|
@@ -124,12 +107,8 @@ class Datasets(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
124
107
|
)
|
|
125
108
|
filename = Path(file_path).stem
|
|
126
109
|
with open(file_path, "rb") as f:
|
|
127
|
-
file_upload = Upload(
|
|
128
|
-
|
|
129
|
-
)
|
|
130
|
-
return self._gql_client.load_dataset(
|
|
131
|
-
input=input, file=file_upload
|
|
132
|
-
).create_dataset
|
|
110
|
+
file_upload = Upload(filename=filename, content=f, content_type="application/jsonl")
|
|
111
|
+
return self._gql_client.load_dataset(input=input, file=file_upload).create_dataset
|
|
133
112
|
|
|
134
113
|
def _chunked_upload(
|
|
135
114
|
self,
|
|
@@ -137,78 +116,73 @@ class Datasets(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
137
116
|
dataset_key: str,
|
|
138
117
|
name: str | None = None,
|
|
139
118
|
use_case: str | None = None,
|
|
140
|
-
) ->
|
|
119
|
+
) -> DatasetData:
|
|
141
120
|
"""Upload large files using chunked upload via REST API."""
|
|
142
121
|
file_size = os.path.getsize(file_path)
|
|
143
122
|
total_parts, chunk_size = _calculate_upload_parts(file_size)
|
|
144
123
|
|
|
145
|
-
# Step 1: Initialize chunked upload session
|
|
146
124
|
init_request = rest_types.InitChunkedUploadRequest(
|
|
147
125
|
content_type="application/jsonl",
|
|
148
126
|
metadata=None,
|
|
149
127
|
total_parts_count=total_parts,
|
|
150
128
|
)
|
|
151
|
-
response = self._rest_client.post(
|
|
152
|
-
"/upload/init", json=init_request.model_dump()
|
|
153
|
-
)
|
|
129
|
+
response = self._rest_client.post(INIT_CHUNKED_UPLOAD_ROUTE, json=init_request.model_dump())
|
|
154
130
|
rest_error_handler(response)
|
|
155
|
-
init_response = rest_types.InitChunkedUploadResponse.model_validate(
|
|
156
|
-
response.json()
|
|
157
|
-
)
|
|
131
|
+
init_response = rest_types.InitChunkedUploadResponse.model_validate(response.json())
|
|
158
132
|
session_id = init_response.session_id
|
|
159
133
|
|
|
160
134
|
try:
|
|
161
|
-
# Step 2: Upload each part
|
|
162
135
|
with open(file_path, "rb") as f:
|
|
163
136
|
for part_number in range(1, total_parts + 1):
|
|
164
137
|
chunk_data = f.read(chunk_size)
|
|
165
138
|
|
|
166
139
|
response = self._rest_client.post(
|
|
167
|
-
|
|
140
|
+
UPLOAD_PART_ROUTE,
|
|
168
141
|
params={"session_id": session_id, "part_number": part_number},
|
|
169
142
|
content=chunk_data,
|
|
170
143
|
headers={"Content-Type": "application/octet-stream"},
|
|
171
144
|
)
|
|
172
145
|
rest_error_handler(response)
|
|
173
146
|
|
|
174
|
-
# Step 3: Finalize upload by creating dataset from multipart upload
|
|
175
147
|
input = DatasetCreateFromMultipartUpload(
|
|
176
148
|
useCase=self.use_case_key(use_case),
|
|
177
149
|
name=name if name else dataset_key,
|
|
178
150
|
key=dataset_key,
|
|
179
151
|
uploadSessionId=session_id,
|
|
180
152
|
)
|
|
153
|
+
create_dataset_result = self._gql_client.create_dataset_from_multipart_upload(
|
|
154
|
+
input=input
|
|
155
|
+
).create_dataset_from_multipart_upload
|
|
156
|
+
|
|
157
|
+
upload_done = False
|
|
158
|
+
while not upload_done:
|
|
159
|
+
check_progress_result = self._gql_client.dataset_upload_processing_status(
|
|
160
|
+
input=DatasetUploadProcessingStatusInput(
|
|
161
|
+
useCase=self.use_case_key(use_case), datasetId=create_dataset_result.dataset_id
|
|
162
|
+
)
|
|
163
|
+
).dataset_upload_processing_status
|
|
164
|
+
if check_progress_result.status == SessionStatus.DONE:
|
|
165
|
+
upload_done = True
|
|
166
|
+
elif check_progress_result.status == SessionStatus.ERROR:
|
|
167
|
+
raise Exception(f"Upload failed: {check_progress_result.error}")
|
|
168
|
+
else:
|
|
169
|
+
time.sleep(2)
|
|
181
170
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
DatasetUploadProcessingStatusFields.processed_parts,
|
|
187
|
-
DatasetUploadProcessingStatusFields.progress,
|
|
188
|
-
DatasetUploadProcessingStatusFields.error,
|
|
189
|
-
DatasetUploadProcessingStatusFields.status,
|
|
190
|
-
)
|
|
191
|
-
result = self._gql_client.mutation(
|
|
192
|
-
mutation_field, operation_name="CreateDatasetFromMultipartUpload"
|
|
193
|
-
)
|
|
194
|
-
return LoadDatasetCreateDataset.model_validate(
|
|
195
|
-
result["createDatasetFromMultipartUpload"]
|
|
196
|
-
)
|
|
171
|
+
dataset_data = self.get(create_dataset_result.dataset_id, use_case=self.use_case_key(use_case))
|
|
172
|
+
assert dataset_data is not None
|
|
173
|
+
|
|
174
|
+
return dataset_data
|
|
197
175
|
|
|
198
|
-
except Exception
|
|
199
|
-
# Abort the upload session on error
|
|
176
|
+
except Exception:
|
|
200
177
|
try:
|
|
201
|
-
abort_request = rest_types.AbortChunkedUploadRequest(
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
self._rest_client.delete( # type: ignore[call-arg]
|
|
206
|
-
"/upload/abort",
|
|
207
|
-
content=json.dumps(abort_request.model_dump()),
|
|
178
|
+
abort_request = rest_types.AbortChunkedUploadRequest(session_id=session_id)
|
|
179
|
+
self._rest_client.delete(
|
|
180
|
+
ABORT_CHUNKED_UPLOAD_ROUTE,
|
|
181
|
+
content=json.dumps(abort_request.model_dump()), # type: ignore[call-arg]
|
|
208
182
|
headers={"Content-Type": "application/json"},
|
|
209
183
|
)
|
|
210
184
|
except Exception:
|
|
211
|
-
pass
|
|
185
|
+
pass
|
|
212
186
|
raise
|
|
213
187
|
|
|
214
188
|
def list(
|
|
@@ -227,15 +201,11 @@ class Datasets(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
227
201
|
Args:
|
|
228
202
|
key: Dataset key.
|
|
229
203
|
"""
|
|
230
|
-
return self._gql_client.describe_dataset(
|
|
231
|
-
key, self.use_case_key(use_case)
|
|
232
|
-
).dataset
|
|
204
|
+
return self._gql_client.describe_dataset(key, self.use_case_key(use_case)).dataset
|
|
233
205
|
|
|
234
206
|
def delete(self, key: str, use_case: str | None = None) -> bool:
|
|
235
207
|
"""Delete dataset."""
|
|
236
|
-
return self._gql_client.delete_dataset(
|
|
237
|
-
id_or_key=key, use_case=self.use_case_key(use_case)
|
|
238
|
-
).delete_dataset
|
|
208
|
+
return self._gql_client.delete_dataset(id_or_key=key, use_case=self.use_case_key(use_case)).delete_dataset
|
|
239
209
|
|
|
240
210
|
|
|
241
211
|
class AsyncDatasets(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
@@ -249,7 +219,7 @@ class AsyncDatasets(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
249
219
|
dataset_key: str,
|
|
250
220
|
name: str | None = None,
|
|
251
221
|
use_case: str | None = None,
|
|
252
|
-
) ->
|
|
222
|
+
) -> DatasetData:
|
|
253
223
|
"""
|
|
254
224
|
Upload a dataset from a file. File must be jsonl, where each line should match structure in example below.
|
|
255
225
|
|
|
@@ -277,12 +247,8 @@ class AsyncDatasets(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
277
247
|
)
|
|
278
248
|
filename = Path(file_path).stem
|
|
279
249
|
with open(file_path, "rb") as f:
|
|
280
|
-
file_upload = Upload(
|
|
281
|
-
|
|
282
|
-
)
|
|
283
|
-
upload_result = await self._gql_client.load_dataset(
|
|
284
|
-
input=input, file=file_upload
|
|
285
|
-
)
|
|
250
|
+
file_upload = Upload(filename=filename, content=f, content_type="application/jsonl")
|
|
251
|
+
upload_result = await self._gql_client.load_dataset(input=input, file=file_upload)
|
|
286
252
|
return upload_result.create_dataset
|
|
287
253
|
|
|
288
254
|
async def _chunked_upload(
|
|
@@ -291,34 +257,28 @@ class AsyncDatasets(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
291
257
|
dataset_key: str,
|
|
292
258
|
name: str | None = None,
|
|
293
259
|
use_case: str | None = None,
|
|
294
|
-
) ->
|
|
260
|
+
) -> DatasetData:
|
|
295
261
|
"""Upload large files using chunked upload via REST API."""
|
|
296
262
|
file_size = os.path.getsize(file_path)
|
|
297
263
|
total_parts, chunk_size = _calculate_upload_parts(file_size)
|
|
298
264
|
|
|
299
|
-
# Step 1: Initialize chunked upload session
|
|
300
265
|
init_request = rest_types.InitChunkedUploadRequest(
|
|
301
266
|
content_type="application/jsonl",
|
|
302
267
|
metadata=None,
|
|
303
268
|
total_parts_count=total_parts,
|
|
304
269
|
)
|
|
305
|
-
response = await self._rest_client.post(
|
|
306
|
-
"/upload/init", json=init_request.model_dump()
|
|
307
|
-
)
|
|
270
|
+
response = await self._rest_client.post(INIT_CHUNKED_UPLOAD_ROUTE, json=init_request.model_dump())
|
|
308
271
|
rest_error_handler(response)
|
|
309
|
-
init_response = rest_types.InitChunkedUploadResponse.model_validate(
|
|
310
|
-
response.json()
|
|
311
|
-
)
|
|
272
|
+
init_response = rest_types.InitChunkedUploadResponse.model_validate(response.json())
|
|
312
273
|
session_id = init_response.session_id
|
|
313
274
|
|
|
314
275
|
try:
|
|
315
|
-
# Step 2: Upload each part
|
|
316
276
|
with open(file_path, "rb") as f:
|
|
317
277
|
for part_number in range(1, total_parts + 1):
|
|
318
278
|
chunk_data = f.read(chunk_size)
|
|
319
279
|
|
|
320
280
|
response = await self._rest_client.post(
|
|
321
|
-
|
|
281
|
+
UPLOAD_PART_ROUTE,
|
|
322
282
|
params={"session_id": session_id, "part_number": part_number},
|
|
323
283
|
content=chunk_data,
|
|
324
284
|
headers={"Content-Type": "application/octet-stream"},
|
|
@@ -331,33 +291,37 @@ class AsyncDatasets(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
331
291
|
key=dataset_key,
|
|
332
292
|
uploadSessionId=session_id,
|
|
333
293
|
)
|
|
294
|
+
create_dataset_result = (
|
|
295
|
+
await self._gql_client.create_dataset_from_multipart_upload(input=input)
|
|
296
|
+
).create_dataset_from_multipart_upload
|
|
297
|
+
|
|
298
|
+
upload_done = False
|
|
299
|
+
while not upload_done:
|
|
300
|
+
check_progress_result = (
|
|
301
|
+
await self._gql_client.dataset_upload_processing_status(
|
|
302
|
+
input=DatasetUploadProcessingStatusInput(
|
|
303
|
+
useCase=self.use_case_key(use_case), datasetId=create_dataset_result.dataset_id
|
|
304
|
+
)
|
|
305
|
+
)
|
|
306
|
+
).dataset_upload_processing_status
|
|
307
|
+
if check_progress_result.status == SessionStatus.DONE:
|
|
308
|
+
upload_done = True
|
|
309
|
+
elif check_progress_result.status == SessionStatus.ERROR:
|
|
310
|
+
raise Exception(f"Upload failed: {check_progress_result.error}")
|
|
311
|
+
else:
|
|
312
|
+
time.sleep(2)
|
|
334
313
|
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
DatasetUploadProcessingStatusFields.dataset_id,
|
|
338
|
-
DatasetUploadProcessingStatusFields.total_parts,
|
|
339
|
-
DatasetUploadProcessingStatusFields.processed_parts,
|
|
340
|
-
DatasetUploadProcessingStatusFields.progress,
|
|
341
|
-
DatasetUploadProcessingStatusFields.error,
|
|
342
|
-
DatasetUploadProcessingStatusFields.status,
|
|
343
|
-
)
|
|
344
|
-
result = await self._gql_client.mutation(
|
|
345
|
-
mutation_field, operation_name="CreateDatasetFromMultipartUpload"
|
|
346
|
-
)
|
|
347
|
-
return LoadDatasetCreateDataset.model_validate(
|
|
348
|
-
result["createDatasetFromMultipartUpload"]
|
|
349
|
-
)
|
|
314
|
+
dataset_data = await self.get(create_dataset_result.dataset_id, use_case=self.use_case_key(use_case))
|
|
315
|
+
assert dataset_data is not None
|
|
350
316
|
|
|
351
|
-
|
|
352
|
-
|
|
317
|
+
return dataset_data
|
|
318
|
+
|
|
319
|
+
except Exception:
|
|
353
320
|
try:
|
|
354
|
-
abort_request = rest_types.AbortChunkedUploadRequest(
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
await self._rest_client.delete( # type: ignore[call-arg]
|
|
359
|
-
"/upload/abort",
|
|
360
|
-
content=json.dumps(abort_request.model_dump()),
|
|
321
|
+
abort_request = rest_types.AbortChunkedUploadRequest(session_id=session_id)
|
|
322
|
+
_ = await self._rest_client.delete(
|
|
323
|
+
ABORT_CHUNKED_UPLOAD_ROUTE,
|
|
324
|
+
content=json.dumps(abort_request.model_dump()), # type: ignore[call-arg]
|
|
361
325
|
headers={"Content-Type": "application/json"},
|
|
362
326
|
)
|
|
363
327
|
except Exception:
|
|
@@ -381,15 +345,11 @@ class AsyncDatasets(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
381
345
|
Args:
|
|
382
346
|
key: Dataset key.
|
|
383
347
|
"""
|
|
384
|
-
result = await self._gql_client.describe_dataset(
|
|
385
|
-
key, self.use_case_key(use_case)
|
|
386
|
-
)
|
|
348
|
+
result = await self._gql_client.describe_dataset(key, self.use_case_key(use_case))
|
|
387
349
|
return result.dataset
|
|
388
350
|
|
|
389
351
|
async def delete(self, key: str, use_case: str | None = None) -> bool:
|
|
390
352
|
"""Delete dataset."""
|
|
391
353
|
return (
|
|
392
|
-
await self._gql_client.delete_dataset(
|
|
393
|
-
id_or_key=key, use_case=self.use_case_key(use_case)
|
|
394
|
-
)
|
|
354
|
+
await self._gql_client.delete_dataset(id_or_key=key, use_case=self.use_case_key(use_case))
|
|
395
355
|
).delete_dataset
|
adaptive_sdk/resources/models.py
CHANGED
|
@@ -5,7 +5,6 @@ from adaptive_sdk.graphql_client import (
|
|
|
5
5
|
OpenAIModel,
|
|
6
6
|
OpenAIProviderDataInput,
|
|
7
7
|
GoogleProviderDataInput,
|
|
8
|
-
AzureProviderDataInput,
|
|
9
8
|
ModelProviderDataInput,
|
|
10
9
|
AddExternalModelInput,
|
|
11
10
|
ExternalModelProviderName,
|
|
@@ -40,13 +39,6 @@ provider_config = {
|
|
|
40
39
|
google=GoogleProviderDataInput(apiKey=api_key, externalModelId=model_id)
|
|
41
40
|
),
|
|
42
41
|
},
|
|
43
|
-
"azure": {
|
|
44
|
-
"provider_data": lambda api_key, model_id, endpoint: ModelProviderDataInput(
|
|
45
|
-
azure=AzureProviderDataInput(
|
|
46
|
-
apiKey=api_key, externalModelId=model_id, endpoint=endpoint
|
|
47
|
-
)
|
|
48
|
-
)
|
|
49
|
-
},
|
|
50
42
|
}
|
|
51
43
|
|
|
52
44
|
SupportedHFModels = Literal[
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: adaptive-sdk
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.10
|
|
4
4
|
Summary: Python SDK for Adaptive Engine
|
|
5
5
|
Author-email: Vincent Debergue <vincent@adaptive-ml.com>, Joao Moura <joao@adaptive-ml.com>, Yacine Bouraoui <yacine@adaptive-ml.com>
|
|
6
6
|
Requires-Python: >=3.10
|
|
@@ -10,7 +10,7 @@ Requires-Dist: httpx-sse >= 0.4.0
|
|
|
10
10
|
Requires-Dist: gql >= 3.5.0
|
|
11
11
|
Requires-Dist: pydantic[email] >= 2.9.0
|
|
12
12
|
Requires-Dist: pyhumps >= 3.8.0
|
|
13
|
-
Requires-Dist: fastapi >= 0.
|
|
13
|
+
Requires-Dist: fastapi >= 0.121.1
|
|
14
14
|
Requires-Dist: uvicorn >= 0.34.0
|
|
15
15
|
Requires-Dist: jsonschema==4.24.0
|
|
16
16
|
Requires-Dist: websockets==15.0.1
|