adaptive-sdk 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adaptive_sdk/graphql_client/__init__.py +8 -7
- adaptive_sdk/graphql_client/add_hf_model.py +7 -1
- adaptive_sdk/graphql_client/async_client.py +15 -7
- adaptive_sdk/graphql_client/client.py +15 -7
- adaptive_sdk/graphql_client/custom_fields.py +154 -39
- adaptive_sdk/graphql_client/custom_mutations.py +2 -2
- adaptive_sdk/graphql_client/custom_queries.py +16 -6
- adaptive_sdk/graphql_client/custom_typing_fields.py +42 -7
- adaptive_sdk/graphql_client/describe_job.py +12 -0
- adaptive_sdk/graphql_client/enums.py +11 -5
- adaptive_sdk/graphql_client/fragments.py +23 -23
- adaptive_sdk/graphql_client/get_custom_recipe.py +2 -1
- adaptive_sdk/graphql_client/input_types.py +10 -2
- adaptive_sdk/graphql_client/list_compute_pools.py +3 -3
- adaptive_sdk/graphql_client/list_harmony_groups.py +13 -0
- adaptive_sdk/resources/jobs.py +16 -14
- adaptive_sdk/resources/models.py +77 -33
- adaptive_sdk/resources/recipes.py +159 -96
- adaptive_sdk/rest/rest_types.py +61 -37
- {adaptive_sdk-0.1.2.dist-info → adaptive_sdk-0.1.3.dist-info}/METADATA +1 -1
- {adaptive_sdk-0.1.2.dist-info → adaptive_sdk-0.1.3.dist-info}/RECORD +22 -21
- adaptive_sdk/graphql_client/list_partitions.py +0 -12
- {adaptive_sdk-0.1.2.dist-info → adaptive_sdk-0.1.3.dist-info}/WHEEL +0 -0
adaptive_sdk/resources/models.py
CHANGED
|
@@ -12,6 +12,7 @@ from adaptive_sdk.graphql_client import (
|
|
|
12
12
|
AttachModel,
|
|
13
13
|
UpdateModelService,
|
|
14
14
|
ModelData,
|
|
15
|
+
JobData,
|
|
15
16
|
ModelServiceData,
|
|
16
17
|
ListModelsModels,
|
|
17
18
|
AddHFModelInput,
|
|
@@ -29,7 +30,9 @@ if TYPE_CHECKING:
|
|
|
29
30
|
provider_config = {
|
|
30
31
|
"open_ai": {
|
|
31
32
|
"provider_data": lambda api_key, model_id: ModelProviderDataInput(
|
|
32
|
-
openAI=OpenAIProviderDataInput(
|
|
33
|
+
openAI=OpenAIProviderDataInput(
|
|
34
|
+
apiKey=api_key, externalModelId=OpenAIModel(model_id)
|
|
35
|
+
)
|
|
33
36
|
),
|
|
34
37
|
},
|
|
35
38
|
"google": {
|
|
@@ -39,40 +42,53 @@ provider_config = {
|
|
|
39
42
|
},
|
|
40
43
|
"azure": {
|
|
41
44
|
"provider_data": lambda api_key, model_id, endpoint: ModelProviderDataInput(
|
|
42
|
-
azure=AzureProviderDataInput(
|
|
45
|
+
azure=AzureProviderDataInput(
|
|
46
|
+
apiKey=api_key, externalModelId=model_id, endpoint=endpoint
|
|
47
|
+
)
|
|
43
48
|
)
|
|
44
49
|
},
|
|
45
50
|
}
|
|
46
51
|
|
|
47
52
|
SupportedHFModels = Literal[
|
|
48
|
-
"
|
|
49
|
-
"
|
|
50
|
-
"
|
|
53
|
+
"deepseek-ai/deepseek-coder-1.3b-base",
|
|
54
|
+
"deepseek-ai/deepseek-coder-6.7b-base",
|
|
55
|
+
"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
|
|
56
|
+
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
|
|
57
|
+
"tiiuae/falcon-7b",
|
|
58
|
+
"tiiuae/falcon-7b-instruct",
|
|
59
|
+
"tiiuae/falcon-40b",
|
|
60
|
+
"tiiuae/falcon-180B",
|
|
61
|
+
"BAAI/bge-multilingual-gemma2",
|
|
62
|
+
"Locutusque/TinyMistral-248M",
|
|
63
|
+
"mistralai/Mistral-Small-24B-Instruct-2501",
|
|
64
|
+
"baffo32/decapoda-research-llama-7B-hf",
|
|
65
|
+
"princeton-nlp/Sheared-LLaMA-1.3B",
|
|
66
|
+
"meta-llama/Llama-3.1-8B",
|
|
51
67
|
"meta-llama/Llama-3.1-8B-Instruct",
|
|
52
68
|
"meta-llama/Llama-3.1-70B-Instruct",
|
|
53
|
-
"meta-llama/Llama-3.2-1B-Instruct",
|
|
54
|
-
"meta-llama/Llama-3.2-3B-Instruct",
|
|
55
69
|
"meta-llama/Llama-3.3-70B-Instruct",
|
|
56
|
-
"
|
|
70
|
+
"nvidia/Llama3-ChatQA-1.5-70B",
|
|
71
|
+
"Qwen/Qwen2.5-0.5B",
|
|
57
72
|
"Qwen/Qwen2.5-0.5B-Instruct",
|
|
58
|
-
"Qwen/Qwen2.5-
|
|
59
|
-
"Qwen/Qwen2.5-3B-Instruct",
|
|
60
|
-
"Qwen/Qwen2.5-7B-Instruct",
|
|
61
|
-
"Qwen/Qwen2.5-14B-Instruct",
|
|
62
|
-
"Qwen/Qwen2.5-32B-Instruct",
|
|
63
|
-
"Qwen/Qwen2.5-72B-Instruct",
|
|
64
|
-
"Qwen/Qwen2.5-Coder-0.5B-Instruct",
|
|
65
|
-
"Qwen/Qwen2.5-Coder-1.5B-Instruct",
|
|
66
|
-
"Qwen/Qwen2.5-Coder-3B-Instruct",
|
|
73
|
+
"Qwen/Qwen2.5-Coder-7B",
|
|
67
74
|
"Qwen/Qwen2.5-Coder-7B-Instruct",
|
|
75
|
+
"Qwen/Qwen2.5-Math-7B",
|
|
76
|
+
"Qwen/Qwen2.5-Math-7B-Instruct",
|
|
68
77
|
"Qwen/Qwen2.5-Coder-14B-Instruct",
|
|
69
78
|
"Qwen/Qwen2.5-Coder-32B-Instruct",
|
|
79
|
+
"Qwen/QwQ-32B",
|
|
80
|
+
"google/gemma-3-1b-it",
|
|
81
|
+
"google/gemma-3-4b-it",
|
|
82
|
+
"google/gemma-3-12b-it",
|
|
83
|
+
"google/gemma-3-27b-it",
|
|
70
84
|
"Qwen/Qwen3-0.6B",
|
|
71
85
|
"Qwen/Qwen3-1.7B",
|
|
72
86
|
"Qwen/Qwen3-4B",
|
|
73
87
|
"Qwen/Qwen3-8B",
|
|
74
88
|
"Qwen/Qwen3-14B",
|
|
75
89
|
"Qwen/Qwen3-32B",
|
|
90
|
+
"01-ai/Yi-34B",
|
|
91
|
+
"HuggingFaceH4/zephyr-7b-beta",
|
|
76
92
|
]
|
|
77
93
|
|
|
78
94
|
|
|
@@ -80,7 +96,9 @@ def is_supported_model(model_id: str):
|
|
|
80
96
|
supported_models = get_args(SupportedHFModels)
|
|
81
97
|
if model_id not in supported_models:
|
|
82
98
|
supported_models_str = "\n".join(supported_models)
|
|
83
|
-
raise ValueError(
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f"Model {model_id} is not supported.\n\nChoose from:\n{supported_models_str}"
|
|
101
|
+
)
|
|
84
102
|
|
|
85
103
|
|
|
86
104
|
class Models(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
@@ -99,7 +117,7 @@ class Models(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
99
117
|
output_model_key: str,
|
|
100
118
|
hf_token: str,
|
|
101
119
|
compute_pool: str | None = None,
|
|
102
|
-
) ->
|
|
120
|
+
) -> JobData:
|
|
103
121
|
"""
|
|
104
122
|
Add model from the HuggingFace Model hub to Adaptive model registry.
|
|
105
123
|
It will take several minutes for the model to be downloaded and converted to Adaptive format.
|
|
@@ -145,16 +163,22 @@ class Models(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
145
163
|
provider_data = provider_data_fn(api_key, external_model_id)
|
|
146
164
|
case "azure":
|
|
147
165
|
if not endpoint:
|
|
148
|
-
raise ValueError(
|
|
166
|
+
raise ValueError(
|
|
167
|
+
"`endpoint` is required to connect Azure external model."
|
|
168
|
+
)
|
|
149
169
|
provider_data = provider_data_fn(api_key, external_model_id, endpoint)
|
|
150
170
|
case _:
|
|
151
171
|
raise ValueError(f"Provider {provider} is not supported")
|
|
152
172
|
|
|
153
173
|
provider_enum = ExternalModelProviderName(provider.upper())
|
|
154
|
-
input = AddExternalModelInput(
|
|
174
|
+
input = AddExternalModelInput(
|
|
175
|
+
name=name, provider=provider_enum, providerData=provider_data
|
|
176
|
+
)
|
|
155
177
|
return self._gql_client.add_external_model(input).add_external_model
|
|
156
178
|
|
|
157
|
-
def list(
|
|
179
|
+
def list(
|
|
180
|
+
self, filter: input_types.ModelFilter | None = None
|
|
181
|
+
) -> Sequence[ListModelsModels]:
|
|
158
182
|
"""
|
|
159
183
|
List all models in Adaptive model registry.
|
|
160
184
|
"""
|
|
@@ -192,7 +216,9 @@ class Models(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
192
216
|
useCase=self.use_case_key(use_case),
|
|
193
217
|
attached=True,
|
|
194
218
|
wait=wait,
|
|
195
|
-
placement=(
|
|
219
|
+
placement=(
|
|
220
|
+
ModelPlacementInput.model_validate(placement) if placement else None
|
|
221
|
+
),
|
|
196
222
|
)
|
|
197
223
|
result = self._gql_client.attach_model_to_use_case(input).attach_model
|
|
198
224
|
if make_default:
|
|
@@ -253,7 +279,9 @@ class Models(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
253
279
|
isDefault=is_default,
|
|
254
280
|
attached=attached,
|
|
255
281
|
desiredOnline=desired_online,
|
|
256
|
-
placement=(
|
|
282
|
+
placement=(
|
|
283
|
+
ModelPlacementInput.model_validate(placement) if placement else None
|
|
284
|
+
),
|
|
257
285
|
)
|
|
258
286
|
return self._gql_client.update_model(input).update_model_service
|
|
259
287
|
|
|
@@ -276,7 +304,9 @@ class Models(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
276
304
|
force: If model is attached to several use cases, `force` must equal `True` in order
|
|
277
305
|
for the model to be terminated.
|
|
278
306
|
"""
|
|
279
|
-
return self._gql_client.terminate_model(
|
|
307
|
+
return self._gql_client.terminate_model(
|
|
308
|
+
id_or_key=model, force=force
|
|
309
|
+
).terminate_model
|
|
280
310
|
|
|
281
311
|
|
|
282
312
|
class AsyncModels(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
@@ -295,7 +325,7 @@ class AsyncModels(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
295
325
|
output_model_key: str,
|
|
296
326
|
hf_token: str,
|
|
297
327
|
compute_pool: str | None = None,
|
|
298
|
-
):
|
|
328
|
+
) -> JobData:
|
|
299
329
|
"""
|
|
300
330
|
Add model from the HuggingFace Model hub to Adaptive model registry.
|
|
301
331
|
It will take several minutes for the model to be downloaded and converted to Adaptive format.
|
|
@@ -341,17 +371,23 @@ class AsyncModels(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
341
371
|
provider_data = provider_data_fn(api_key, external_model_id)
|
|
342
372
|
case "azure":
|
|
343
373
|
if not endpoint:
|
|
344
|
-
raise ValueError(
|
|
374
|
+
raise ValueError(
|
|
375
|
+
"`endpoint` is required to connect Azure external model."
|
|
376
|
+
)
|
|
345
377
|
provider_data = provider_data_fn(api_key, external_model_id, endpoint)
|
|
346
378
|
case _:
|
|
347
379
|
raise ValueError(f"Provider {provider} is not supported")
|
|
348
380
|
|
|
349
381
|
provider_enum = ExternalModelProviderName(provider.upper())
|
|
350
|
-
input = AddExternalModelInput(
|
|
382
|
+
input = AddExternalModelInput(
|
|
383
|
+
name=name, provider=provider_enum, providerData=provider_data
|
|
384
|
+
)
|
|
351
385
|
result = await self._gql_client.add_external_model(input)
|
|
352
386
|
return result.add_external_model
|
|
353
387
|
|
|
354
|
-
async def list(
|
|
388
|
+
async def list(
|
|
389
|
+
self, filter: input_types.ModelFilter | None = None
|
|
390
|
+
) -> Sequence[ListModelsModels]:
|
|
355
391
|
"""
|
|
356
392
|
List all models in Adaptive model registry.
|
|
357
393
|
"""
|
|
@@ -389,7 +425,9 @@ class AsyncModels(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
389
425
|
useCase=self.use_case_key(use_case),
|
|
390
426
|
attached=True,
|
|
391
427
|
wait=wait,
|
|
392
|
-
placement=(
|
|
428
|
+
placement=(
|
|
429
|
+
ModelPlacementInput.model_validate(placement) if placement else None
|
|
430
|
+
),
|
|
393
431
|
)
|
|
394
432
|
result = await self._gql_client.attach_model_to_use_case(input)
|
|
395
433
|
result = result.attach_model
|
|
@@ -453,7 +491,9 @@ class AsyncModels(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
453
491
|
isDefault=is_default,
|
|
454
492
|
attached=attached,
|
|
455
493
|
desiredOnline=desired_online,
|
|
456
|
-
placement=(
|
|
494
|
+
placement=(
|
|
495
|
+
ModelPlacementInput.model_validate(placement) if placement else None
|
|
496
|
+
),
|
|
457
497
|
)
|
|
458
498
|
result = await self._gql_client.update_model(input)
|
|
459
499
|
return result.update_model_service
|
|
@@ -466,7 +506,9 @@ class AsyncModels(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
466
506
|
model: Model key.
|
|
467
507
|
wait: If `True`, call block until model is in `Online` state.
|
|
468
508
|
"""
|
|
469
|
-
return (
|
|
509
|
+
return (
|
|
510
|
+
await self._gql_client.deploy_model(id_or_key=model, wait=wait)
|
|
511
|
+
).deploy_model
|
|
470
512
|
|
|
471
513
|
async def terminate(self, model: str, force: bool = False) -> str:
|
|
472
514
|
"""
|
|
@@ -477,4 +519,6 @@ class AsyncModels(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
477
519
|
force: If model is attached to several use cases, `force` must equal `True` in order
|
|
478
520
|
for the model to be terminated.
|
|
479
521
|
"""
|
|
480
|
-
return (
|
|
522
|
+
return (
|
|
523
|
+
await self._gql_client.terminate_model(id_or_key=model, force=force)
|
|
524
|
+
).terminate_model
|
|
@@ -1,4 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
import os
|
|
3
|
+
import io
|
|
4
|
+
import zipfile
|
|
5
|
+
import mimetypes
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
from loguru import logger
|
|
2
8
|
from hypothesis_jsonschema import from_schema
|
|
3
9
|
from typing import TYPE_CHECKING, Sequence, Any
|
|
4
10
|
from pathlib import Path
|
|
@@ -18,20 +24,6 @@ from adaptive_sdk.graphql_client import (
|
|
|
18
24
|
|
|
19
25
|
if TYPE_CHECKING:
|
|
20
26
|
from adaptive_sdk.client import Adaptive, AsyncAdaptive
|
|
21
|
-
import mimetypes
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def _count_keys_recursively(data: Any) -> int:
|
|
25
|
-
"""Recursively counts the total number of keys in dictionaries within the data."""
|
|
26
|
-
count = 0
|
|
27
|
-
if isinstance(data, dict):
|
|
28
|
-
count += len(data)
|
|
29
|
-
for value in data.values():
|
|
30
|
-
count += _count_keys_recursively(value)
|
|
31
|
-
elif isinstance(data, list):
|
|
32
|
-
for item in data:
|
|
33
|
-
count += _count_keys_recursively(item)
|
|
34
|
-
return count
|
|
35
27
|
|
|
36
28
|
|
|
37
29
|
class Recipes(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
@@ -49,52 +41,35 @@ class Recipes(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
49
41
|
|
|
50
42
|
def upload(
|
|
51
43
|
self,
|
|
52
|
-
|
|
44
|
+
path: str,
|
|
53
45
|
recipe_key: str,
|
|
54
46
|
name: str | None = None,
|
|
55
47
|
description: str | None = None,
|
|
56
48
|
labels: dict[str, str] | None = None,
|
|
57
49
|
use_case: str | None = None,
|
|
58
50
|
) -> CustomRecipeData:
|
|
59
|
-
|
|
51
|
+
"""
|
|
52
|
+
Upload a recipe from either a single Python file or a directory (path).
|
|
53
|
+
If a directory is provided, it must contain a 'main.py' and will be zipped in-memory before upload.
|
|
54
|
+
"""
|
|
55
|
+
inferred_name = name or recipe_key
|
|
60
56
|
label_inputs = [LabelInput(key=k, value=v) for k, v in labels.items()] if labels else None
|
|
61
57
|
input = CreateRecipeInput(
|
|
62
58
|
key=recipe_key,
|
|
63
|
-
name=
|
|
59
|
+
name=inferred_name,
|
|
64
60
|
description=description,
|
|
65
61
|
labels=label_inputs,
|
|
66
62
|
)
|
|
67
|
-
|
|
68
|
-
with open(file_path, "rb") as f:
|
|
69
|
-
file_upload = Upload(filename=filename, content=f, content_type=content_type)
|
|
63
|
+
with _upload_from_path(path) as file_upload:
|
|
70
64
|
return self._gql_client.create_custom_recipe(
|
|
71
65
|
use_case=self.use_case_key(use_case), input=input, file=file_upload
|
|
72
66
|
).create_custom_recipe
|
|
73
67
|
|
|
74
|
-
def run(
|
|
75
|
-
self,
|
|
76
|
-
recipe_key: str,
|
|
77
|
-
num_gpus: int,
|
|
78
|
-
input_args: dict | None = None,
|
|
79
|
-
name: str | None = None,
|
|
80
|
-
use_case: str | None = None,
|
|
81
|
-
compute_pool: str | None = None,
|
|
82
|
-
) -> JobData:
|
|
83
|
-
input = JobInput(
|
|
84
|
-
recipe=recipe_key,
|
|
85
|
-
useCase=self.use_case_key(use_case),
|
|
86
|
-
args=input_args or {},
|
|
87
|
-
name=name,
|
|
88
|
-
computePool=compute_pool,
|
|
89
|
-
numGpus=num_gpus,
|
|
90
|
-
)
|
|
91
|
-
return self._gql_client.create_job(input).create_job
|
|
92
|
-
|
|
93
68
|
def get(
|
|
94
69
|
self,
|
|
95
70
|
recipe_key: str,
|
|
96
71
|
use_case: str | None = None,
|
|
97
|
-
) -> CustomRecipeData:
|
|
72
|
+
) -> CustomRecipeData | None:
|
|
98
73
|
return self._gql_client.get_custom_recipe(
|
|
99
74
|
id_or_key=recipe_key, use_case=self.use_case_key(use_case)
|
|
100
75
|
).custom_recipe
|
|
@@ -102,7 +77,7 @@ class Recipes(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
102
77
|
def update(
|
|
103
78
|
self,
|
|
104
79
|
recipe_key: str,
|
|
105
|
-
|
|
80
|
+
path: str | None = None,
|
|
106
81
|
name: str | None = None,
|
|
107
82
|
description: str | None = None,
|
|
108
83
|
labels: Sequence[tuple[str, str]] | None = None,
|
|
@@ -115,19 +90,21 @@ class Recipes(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
115
90
|
labels=label_inputs,
|
|
116
91
|
)
|
|
117
92
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
93
|
+
if path:
|
|
94
|
+
with _upload_from_path(path) as file_upload:
|
|
95
|
+
return self._gql_client.update_custom_recipe(
|
|
96
|
+
use_case=self.use_case_key(use_case),
|
|
97
|
+
id=recipe_key,
|
|
98
|
+
input=input,
|
|
99
|
+
file=file_upload,
|
|
100
|
+
).update_custom_recipe
|
|
101
|
+
else:
|
|
102
|
+
return self._gql_client.update_custom_recipe(
|
|
103
|
+
use_case=self.use_case_key(use_case),
|
|
104
|
+
id=recipe_key,
|
|
105
|
+
input=input,
|
|
106
|
+
file=None,
|
|
107
|
+
).update_custom_recipe
|
|
131
108
|
|
|
132
109
|
def delete(
|
|
133
110
|
self,
|
|
@@ -140,6 +117,8 @@ class Recipes(SyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
140
117
|
|
|
141
118
|
def generate_sample_input(self, recipe_key: str, use_case: str | None = None) -> dict:
|
|
142
119
|
recipe_details = self.get(recipe_key=recipe_key, use_case=self.use_case_key(use_case))
|
|
120
|
+
if recipe_details is None:
|
|
121
|
+
raise ValueError(f"Recipe {recipe_key} was not found")
|
|
143
122
|
strategy = from_schema(recipe_details.json_schema)
|
|
144
123
|
|
|
145
124
|
best_example = None
|
|
@@ -180,54 +159,33 @@ class AsyncRecipes(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
180
159
|
|
|
181
160
|
async def upload(
|
|
182
161
|
self,
|
|
183
|
-
|
|
162
|
+
path: str,
|
|
184
163
|
recipe_key: str,
|
|
185
164
|
name: str | None = None,
|
|
186
165
|
description: str | None = None,
|
|
187
166
|
labels: Sequence[tuple[str, str]] | None = None,
|
|
188
167
|
use_case: str | None = None,
|
|
189
168
|
) -> CustomRecipeData:
|
|
190
|
-
|
|
169
|
+
inferred_name = name or recipe_key
|
|
191
170
|
label_inputs = [LabelInput(key=k, value=v) for k, v in labels] if labels else None
|
|
192
171
|
input = CreateRecipeInput(
|
|
193
172
|
key=recipe_key,
|
|
194
|
-
name=
|
|
173
|
+
name=inferred_name,
|
|
195
174
|
description=description,
|
|
196
175
|
labels=label_inputs,
|
|
197
176
|
)
|
|
198
|
-
|
|
199
|
-
with open(file_path, "rb") as f:
|
|
200
|
-
file_upload = Upload(filename=filename, content=f, content_type=content_type)
|
|
177
|
+
with _upload_from_path(path) as file_upload:
|
|
201
178
|
return (
|
|
202
179
|
await self._gql_client.create_custom_recipe(
|
|
203
180
|
use_case=self.use_case_key(use_case), input=input, file=file_upload
|
|
204
181
|
)
|
|
205
182
|
).create_custom_recipe
|
|
206
183
|
|
|
207
|
-
async def run(
|
|
208
|
-
self,
|
|
209
|
-
recipe_key: str,
|
|
210
|
-
num_gpus: int,
|
|
211
|
-
input_args: dict | None = None,
|
|
212
|
-
name: str | None = None,
|
|
213
|
-
use_case: str | None = None,
|
|
214
|
-
compute_pool: str | None = None,
|
|
215
|
-
) -> JobData:
|
|
216
|
-
input = JobInput(
|
|
217
|
-
recipe=recipe_key,
|
|
218
|
-
useCase=self.use_case_key(use_case),
|
|
219
|
-
args=input_args,
|
|
220
|
-
name=name,
|
|
221
|
-
computePool=compute_pool,
|
|
222
|
-
numGpus=num_gpus,
|
|
223
|
-
)
|
|
224
|
-
return (await self._gql_client.create_job(input)).create_job
|
|
225
|
-
|
|
226
184
|
async def get(
|
|
227
185
|
self,
|
|
228
186
|
recipe_key: str,
|
|
229
187
|
use_case: str | None = None,
|
|
230
|
-
) -> CustomRecipeData:
|
|
188
|
+
) -> CustomRecipeData | None:
|
|
231
189
|
return (
|
|
232
190
|
await self._gql_client.get_custom_recipe(id_or_key=recipe_key, use_case=self.use_case_key(use_case))
|
|
233
191
|
).custom_recipe
|
|
@@ -235,7 +193,7 @@ class AsyncRecipes(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
235
193
|
async def update(
|
|
236
194
|
self,
|
|
237
195
|
recipe_key: str,
|
|
238
|
-
|
|
196
|
+
path: str | None = None,
|
|
239
197
|
name: str | None = None,
|
|
240
198
|
description: str | None = None,
|
|
241
199
|
labels: Sequence[tuple[str, str]] | None = None,
|
|
@@ -248,21 +206,25 @@ class AsyncRecipes(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
248
206
|
labels=label_inputs,
|
|
249
207
|
)
|
|
250
208
|
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
209
|
+
if path:
|
|
210
|
+
with _upload_from_path(path) as file_upload:
|
|
211
|
+
return (
|
|
212
|
+
await self._gql_client.update_custom_recipe(
|
|
213
|
+
use_case=self.use_case_key(use_case),
|
|
214
|
+
id=recipe_key,
|
|
215
|
+
input=input,
|
|
216
|
+
file=file_upload,
|
|
217
|
+
)
|
|
218
|
+
).update_custom_recipe
|
|
219
|
+
else:
|
|
220
|
+
return (
|
|
221
|
+
await self._gql_client.update_custom_recipe(
|
|
222
|
+
use_case=self.use_case_key(use_case),
|
|
223
|
+
id=recipe_key,
|
|
224
|
+
input=input,
|
|
225
|
+
file=None,
|
|
226
|
+
)
|
|
227
|
+
).update_custom_recipe
|
|
266
228
|
|
|
267
229
|
async def delete(
|
|
268
230
|
self,
|
|
@@ -275,6 +237,8 @@ class AsyncRecipes(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
275
237
|
|
|
276
238
|
async def generate_sample_input(self, recipe_key: str, use_case: str | None = None) -> dict:
|
|
277
239
|
recipe_details = await self.get(recipe_key=recipe_key, use_case=self.use_case_key(use_case))
|
|
240
|
+
if recipe_details is None:
|
|
241
|
+
raise ValueError(f"Recipe {recipe_key} was not found")
|
|
278
242
|
strategy = from_schema(recipe_details.json_schema)
|
|
279
243
|
|
|
280
244
|
best_example = None
|
|
@@ -296,3 +260,102 @@ class AsyncRecipes(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
|
|
|
296
260
|
print("A valid sample could not be generated. Returning an empty dict.")
|
|
297
261
|
best_example = {}
|
|
298
262
|
return dict(best_example) # type: ignore
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def _count_keys_recursively(data: Any) -> int:
|
|
266
|
+
"""Recursively counts the total number of keys in dictionaries within the data."""
|
|
267
|
+
count = 0
|
|
268
|
+
if isinstance(data, dict):
|
|
269
|
+
count += len(data)
|
|
270
|
+
for value in data.values():
|
|
271
|
+
count += _count_keys_recursively(value)
|
|
272
|
+
elif isinstance(data, list):
|
|
273
|
+
for item in data:
|
|
274
|
+
count += _count_keys_recursively(item)
|
|
275
|
+
return count
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def _validate_python_file(path: Path) -> None:
|
|
279
|
+
"""Validate that the path exists, is a file and has a .py extension."""
|
|
280
|
+
if not path.exists():
|
|
281
|
+
raise FileNotFoundError(f"Python file not found: {path}")
|
|
282
|
+
if not path.is_file():
|
|
283
|
+
raise ValueError(f"Expected a file path, got a directory or non-file: {path}")
|
|
284
|
+
if path.suffix.lower() != ".py":
|
|
285
|
+
raise ValueError(f"Expected a Python file with .py extension, got: {path}")
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def _validate_recipe_directory(dir_path: Path) -> None:
|
|
289
|
+
"""Validate that the directory exists and contains a main.py file."""
|
|
290
|
+
if not dir_path.exists():
|
|
291
|
+
raise FileNotFoundError(f"Directory not found: {dir_path}")
|
|
292
|
+
if not dir_path.is_dir():
|
|
293
|
+
raise ValueError(f"Expected a directory path, got a file: {dir_path}")
|
|
294
|
+
main_py = dir_path / "main.py"
|
|
295
|
+
if not main_py.exists() or not main_py.is_file():
|
|
296
|
+
raise FileNotFoundError(f"Directory must contain a 'main.py' file: {dir_path}")
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def _zip_directory_to_bytes_io(dir_path: Path) -> io.BytesIO:
|
|
300
|
+
"""Zip the contents of a directory into an in-memory BytesIO buffer."""
|
|
301
|
+
buffer = io.BytesIO()
|
|
302
|
+
with zipfile.ZipFile(buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
|
|
303
|
+
for root, _, files in os.walk(dir_path):
|
|
304
|
+
for file_name in files:
|
|
305
|
+
file_path = Path(root) / file_name
|
|
306
|
+
arcname = file_path.relative_to(dir_path)
|
|
307
|
+
zf.write(file_path, arcname.as_posix())
|
|
308
|
+
buffer.seek(0)
|
|
309
|
+
return buffer
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
@contextmanager
|
|
313
|
+
def _upload_from_path(path: str):
|
|
314
|
+
"""
|
|
315
|
+
Context manager yielding an Upload object for a Python file or a directory.
|
|
316
|
+
|
|
317
|
+
- If path is a .py file, validates and opens it for upload.
|
|
318
|
+
- If path is a directory, validates it contains main.py, zips contents in-memory.
|
|
319
|
+
"""
|
|
320
|
+
p = Path(path)
|
|
321
|
+
if p.is_file():
|
|
322
|
+
_validate_python_file(p)
|
|
323
|
+
filename = p.name
|
|
324
|
+
content_type = mimetypes.guess_type(str(p))[0] or "application/octet-stream"
|
|
325
|
+
f = open(p, "rb")
|
|
326
|
+
try:
|
|
327
|
+
yield Upload(filename=filename, content=f, content_type=content_type)
|
|
328
|
+
finally:
|
|
329
|
+
f.close()
|
|
330
|
+
elif p.is_dir():
|
|
331
|
+
_validate_recipe_directory(p)
|
|
332
|
+
# Ensure __init__.py exists at the root of the directory before zipping
|
|
333
|
+
created_init = False
|
|
334
|
+
root_init = p / "__init__.py"
|
|
335
|
+
zip_buffer = None
|
|
336
|
+
try:
|
|
337
|
+
if not root_init.exists():
|
|
338
|
+
root_init.touch()
|
|
339
|
+
created_init = True
|
|
340
|
+
logger.info(f"Added __init__.py to your directory, as it is required for proper execution of recipe")
|
|
341
|
+
zip_buffer = _zip_directory_to_bytes_io(p)
|
|
342
|
+
finally:
|
|
343
|
+
if created_init:
|
|
344
|
+
try:
|
|
345
|
+
root_init.unlink()
|
|
346
|
+
logger.info(f"Cleaned up __init__.py from your directory")
|
|
347
|
+
except Exception:
|
|
348
|
+
logger.error(f"Failed to remove __init__.py from your directory")
|
|
349
|
+
pass
|
|
350
|
+
if zip_buffer is None:
|
|
351
|
+
raise RuntimeError("Failed to create in-memory zip for directory upload")
|
|
352
|
+
|
|
353
|
+
filename = f"{p.name}.zip"
|
|
354
|
+
try:
|
|
355
|
+
yield Upload(filename=filename, content=zip_buffer, content_type="application/zip")
|
|
356
|
+
finally:
|
|
357
|
+
zip_buffer.close()
|
|
358
|
+
else:
|
|
359
|
+
if not p.exists():
|
|
360
|
+
raise FileNotFoundError(f"Path not found: {path}")
|
|
361
|
+
raise ValueError(f"Path must be a Python file or a directory: {path}")
|