adaptive-sdk 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,6 +12,7 @@ from adaptive_sdk.graphql_client import (
12
12
  AttachModel,
13
13
  UpdateModelService,
14
14
  ModelData,
15
+ JobData,
15
16
  ModelServiceData,
16
17
  ListModelsModels,
17
18
  AddHFModelInput,
@@ -29,7 +30,9 @@ if TYPE_CHECKING:
29
30
  provider_config = {
30
31
  "open_ai": {
31
32
  "provider_data": lambda api_key, model_id: ModelProviderDataInput(
32
- openAI=OpenAIProviderDataInput(apiKey=api_key, externalModelId=OpenAIModel(model_id))
33
+ openAI=OpenAIProviderDataInput(
34
+ apiKey=api_key, externalModelId=OpenAIModel(model_id)
35
+ )
33
36
  ),
34
37
  },
35
38
  "google": {
@@ -39,40 +42,53 @@ provider_config = {
39
42
  },
40
43
  "azure": {
41
44
  "provider_data": lambda api_key, model_id, endpoint: ModelProviderDataInput(
42
- azure=AzureProviderDataInput(apiKey=api_key, externalModelId=model_id, endpoint=endpoint)
45
+ azure=AzureProviderDataInput(
46
+ apiKey=api_key, externalModelId=model_id, endpoint=endpoint
47
+ )
43
48
  )
44
49
  },
45
50
  }
46
51
 
47
52
  SupportedHFModels = Literal[
48
- "google/gemma-3-4b-it",
49
- "google/gemma-3-12b-it",
50
- "google/gemma-3-27b-it",
53
+ "deepseek-ai/deepseek-coder-1.3b-base",
54
+ "deepseek-ai/deepseek-coder-6.7b-base",
55
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
56
+ "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
57
+ "tiiuae/falcon-7b",
58
+ "tiiuae/falcon-7b-instruct",
59
+ "tiiuae/falcon-40b",
60
+ "tiiuae/falcon-180B",
61
+ "BAAI/bge-multilingual-gemma2",
62
+ "Locutusque/TinyMistral-248M",
63
+ "mistralai/Mistral-Small-24B-Instruct-2501",
64
+ "baffo32/decapoda-research-llama-7B-hf",
65
+ "princeton-nlp/Sheared-LLaMA-1.3B",
66
+ "meta-llama/Llama-3.1-8B",
51
67
  "meta-llama/Llama-3.1-8B-Instruct",
52
68
  "meta-llama/Llama-3.1-70B-Instruct",
53
- "meta-llama/Llama-3.2-1B-Instruct",
54
- "meta-llama/Llama-3.2-3B-Instruct",
55
69
  "meta-llama/Llama-3.3-70B-Instruct",
56
- "mistralai/Mistral-Small-24B-Instruct-2501",
70
+ "nvidia/Llama3-ChatQA-1.5-70B",
71
+ "Qwen/Qwen2.5-0.5B",
57
72
  "Qwen/Qwen2.5-0.5B-Instruct",
58
- "Qwen/Qwen2.5-1.5B-Instruct",
59
- "Qwen/Qwen2.5-3B-Instruct",
60
- "Qwen/Qwen2.5-7B-Instruct",
61
- "Qwen/Qwen2.5-14B-Instruct",
62
- "Qwen/Qwen2.5-32B-Instruct",
63
- "Qwen/Qwen2.5-72B-Instruct",
64
- "Qwen/Qwen2.5-Coder-0.5B-Instruct",
65
- "Qwen/Qwen2.5-Coder-1.5B-Instruct",
66
- "Qwen/Qwen2.5-Coder-3B-Instruct",
73
+ "Qwen/Qwen2.5-Coder-7B",
67
74
  "Qwen/Qwen2.5-Coder-7B-Instruct",
75
+ "Qwen/Qwen2.5-Math-7B",
76
+ "Qwen/Qwen2.5-Math-7B-Instruct",
68
77
  "Qwen/Qwen2.5-Coder-14B-Instruct",
69
78
  "Qwen/Qwen2.5-Coder-32B-Instruct",
79
+ "Qwen/QwQ-32B",
80
+ "google/gemma-3-1b-it",
81
+ "google/gemma-3-4b-it",
82
+ "google/gemma-3-12b-it",
83
+ "google/gemma-3-27b-it",
70
84
  "Qwen/Qwen3-0.6B",
71
85
  "Qwen/Qwen3-1.7B",
72
86
  "Qwen/Qwen3-4B",
73
87
  "Qwen/Qwen3-8B",
74
88
  "Qwen/Qwen3-14B",
75
89
  "Qwen/Qwen3-32B",
90
+ "01-ai/Yi-34B",
91
+ "HuggingFaceH4/zephyr-7b-beta",
76
92
  ]
77
93
 
78
94
 
@@ -80,7 +96,9 @@ def is_supported_model(model_id: str):
80
96
  supported_models = get_args(SupportedHFModels)
81
97
  if model_id not in supported_models:
82
98
  supported_models_str = "\n".join(supported_models)
83
- raise ValueError(f"Model {model_id} is not supported.\n\nChoose from:\n{supported_models_str}")
99
+ raise ValueError(
100
+ f"Model {model_id} is not supported.\n\nChoose from:\n{supported_models_str}"
101
+ )
84
102
 
85
103
 
86
104
  class Models(SyncAPIResource, UseCaseResource): # type: ignore[misc]
@@ -99,7 +117,7 @@ class Models(SyncAPIResource, UseCaseResource): # type: ignore[misc]
99
117
  output_model_key: str,
100
118
  hf_token: str,
101
119
  compute_pool: str | None = None,
102
- ) -> str:
120
+ ) -> JobData:
103
121
  """
104
122
  Add model from the HuggingFace Model hub to Adaptive model registry.
105
123
  It will take several minutes for the model to be downloaded and converted to Adaptive format.
@@ -145,16 +163,22 @@ class Models(SyncAPIResource, UseCaseResource): # type: ignore[misc]
145
163
  provider_data = provider_data_fn(api_key, external_model_id)
146
164
  case "azure":
147
165
  if not endpoint:
148
- raise ValueError("`endpoint` is required to connect Azure external model.")
166
+ raise ValueError(
167
+ "`endpoint` is required to connect Azure external model."
168
+ )
149
169
  provider_data = provider_data_fn(api_key, external_model_id, endpoint)
150
170
  case _:
151
171
  raise ValueError(f"Provider {provider} is not supported")
152
172
 
153
173
  provider_enum = ExternalModelProviderName(provider.upper())
154
- input = AddExternalModelInput(name=name, provider=provider_enum, providerData=provider_data)
174
+ input = AddExternalModelInput(
175
+ name=name, provider=provider_enum, providerData=provider_data
176
+ )
155
177
  return self._gql_client.add_external_model(input).add_external_model
156
178
 
157
- def list(self, filter: input_types.ModelFilter | None = None) -> Sequence[ListModelsModels]:
179
+ def list(
180
+ self, filter: input_types.ModelFilter | None = None
181
+ ) -> Sequence[ListModelsModels]:
158
182
  """
159
183
  List all models in Adaptive model registry.
160
184
  """
@@ -192,7 +216,9 @@ class Models(SyncAPIResource, UseCaseResource): # type: ignore[misc]
192
216
  useCase=self.use_case_key(use_case),
193
217
  attached=True,
194
218
  wait=wait,
195
- placement=(ModelPlacementInput.model_validate(placement) if placement else None),
219
+ placement=(
220
+ ModelPlacementInput.model_validate(placement) if placement else None
221
+ ),
196
222
  )
197
223
  result = self._gql_client.attach_model_to_use_case(input).attach_model
198
224
  if make_default:
@@ -253,7 +279,9 @@ class Models(SyncAPIResource, UseCaseResource): # type: ignore[misc]
253
279
  isDefault=is_default,
254
280
  attached=attached,
255
281
  desiredOnline=desired_online,
256
- placement=(ModelPlacementInput.model_validate(placement) if placement else None),
282
+ placement=(
283
+ ModelPlacementInput.model_validate(placement) if placement else None
284
+ ),
257
285
  )
258
286
  return self._gql_client.update_model(input).update_model_service
259
287
 
@@ -276,7 +304,9 @@ class Models(SyncAPIResource, UseCaseResource): # type: ignore[misc]
276
304
  force: If model is attached to several use cases, `force` must equal `True` in order
277
305
  for the model to be terminated.
278
306
  """
279
- return self._gql_client.terminate_model(id_or_key=model, force=force).terminate_model
307
+ return self._gql_client.terminate_model(
308
+ id_or_key=model, force=force
309
+ ).terminate_model
280
310
 
281
311
 
282
312
  class AsyncModels(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
@@ -295,7 +325,7 @@ class AsyncModels(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
295
325
  output_model_key: str,
296
326
  hf_token: str,
297
327
  compute_pool: str | None = None,
298
- ):
328
+ ) -> JobData:
299
329
  """
300
330
  Add model from the HuggingFace Model hub to Adaptive model registry.
301
331
  It will take several minutes for the model to be downloaded and converted to Adaptive format.
@@ -341,17 +371,23 @@ class AsyncModels(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
341
371
  provider_data = provider_data_fn(api_key, external_model_id)
342
372
  case "azure":
343
373
  if not endpoint:
344
- raise ValueError("`endpoint` is required to connect Azure external model.")
374
+ raise ValueError(
375
+ "`endpoint` is required to connect Azure external model."
376
+ )
345
377
  provider_data = provider_data_fn(api_key, external_model_id, endpoint)
346
378
  case _:
347
379
  raise ValueError(f"Provider {provider} is not supported")
348
380
 
349
381
  provider_enum = ExternalModelProviderName(provider.upper())
350
- input = AddExternalModelInput(name=name, provider=provider_enum, providerData=provider_data)
382
+ input = AddExternalModelInput(
383
+ name=name, provider=provider_enum, providerData=provider_data
384
+ )
351
385
  result = await self._gql_client.add_external_model(input)
352
386
  return result.add_external_model
353
387
 
354
- async def list(self, filter: input_types.ModelFilter | None = None) -> Sequence[ListModelsModels]:
388
+ async def list(
389
+ self, filter: input_types.ModelFilter | None = None
390
+ ) -> Sequence[ListModelsModels]:
355
391
  """
356
392
  List all models in Adaptive model registry.
357
393
  """
@@ -389,7 +425,9 @@ class AsyncModels(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
389
425
  useCase=self.use_case_key(use_case),
390
426
  attached=True,
391
427
  wait=wait,
392
- placement=(ModelPlacementInput.model_validate(placement) if placement else None),
428
+ placement=(
429
+ ModelPlacementInput.model_validate(placement) if placement else None
430
+ ),
393
431
  )
394
432
  result = await self._gql_client.attach_model_to_use_case(input)
395
433
  result = result.attach_model
@@ -453,7 +491,9 @@ class AsyncModels(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
453
491
  isDefault=is_default,
454
492
  attached=attached,
455
493
  desiredOnline=desired_online,
456
- placement=(ModelPlacementInput.model_validate(placement) if placement else None),
494
+ placement=(
495
+ ModelPlacementInput.model_validate(placement) if placement else None
496
+ ),
457
497
  )
458
498
  result = await self._gql_client.update_model(input)
459
499
  return result.update_model_service
@@ -466,7 +506,9 @@ class AsyncModels(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
466
506
  model: Model key.
467
507
  wait: If `True`, call block until model is in `Online` state.
468
508
  """
469
- return (await self._gql_client.deploy_model(id_or_key=model, wait=wait)).deploy_model
509
+ return (
510
+ await self._gql_client.deploy_model(id_or_key=model, wait=wait)
511
+ ).deploy_model
470
512
 
471
513
  async def terminate(self, model: str, force: bool = False) -> str:
472
514
  """
@@ -477,4 +519,6 @@ class AsyncModels(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
477
519
  force: If model is attached to several use cases, `force` must equal `True` in order
478
520
  for the model to be terminated.
479
521
  """
480
- return (await self._gql_client.terminate_model(id_or_key=model, force=force)).terminate_model
522
+ return (
523
+ await self._gql_client.terminate_model(id_or_key=model, force=force)
524
+ ).terminate_model
@@ -1,4 +1,10 @@
1
1
  from __future__ import annotations
2
+ import os
3
+ import io
4
+ import zipfile
5
+ import mimetypes
6
+ from contextlib import contextmanager
7
+ from loguru import logger
2
8
  from hypothesis_jsonschema import from_schema
3
9
  from typing import TYPE_CHECKING, Sequence, Any
4
10
  from pathlib import Path
@@ -18,20 +24,6 @@ from adaptive_sdk.graphql_client import (
18
24
 
19
25
  if TYPE_CHECKING:
20
26
  from adaptive_sdk.client import Adaptive, AsyncAdaptive
21
- import mimetypes
22
-
23
-
24
- def _count_keys_recursively(data: Any) -> int:
25
- """Recursively counts the total number of keys in dictionaries within the data."""
26
- count = 0
27
- if isinstance(data, dict):
28
- count += len(data)
29
- for value in data.values():
30
- count += _count_keys_recursively(value)
31
- elif isinstance(data, list):
32
- for item in data:
33
- count += _count_keys_recursively(item)
34
- return count
35
27
 
36
28
 
37
29
  class Recipes(SyncAPIResource, UseCaseResource): # type: ignore[misc]
@@ -49,52 +41,35 @@ class Recipes(SyncAPIResource, UseCaseResource): # type: ignore[misc]
49
41
 
50
42
  def upload(
51
43
  self,
52
- file_path: str,
44
+ path: str,
53
45
  recipe_key: str,
54
46
  name: str | None = None,
55
47
  description: str | None = None,
56
48
  labels: dict[str, str] | None = None,
57
49
  use_case: str | None = None,
58
50
  ) -> CustomRecipeData:
59
- filename = Path(file_path).stem
51
+ """
52
+ Upload a recipe from either a single Python file or a directory (path).
53
+ If a directory is provided, it must contain a 'main.py' and will be zipped in-memory before upload.
54
+ """
55
+ inferred_name = name or recipe_key
60
56
  label_inputs = [LabelInput(key=k, value=v) for k, v in labels.items()] if labels else None
61
57
  input = CreateRecipeInput(
62
58
  key=recipe_key,
63
- name=name or filename,
59
+ name=inferred_name,
64
60
  description=description,
65
61
  labels=label_inputs,
66
62
  )
67
- content_type = mimetypes.guess_type(file_path)[0] or "application/octet-stream"
68
- with open(file_path, "rb") as f:
69
- file_upload = Upload(filename=filename, content=f, content_type=content_type)
63
+ with _upload_from_path(path) as file_upload:
70
64
  return self._gql_client.create_custom_recipe(
71
65
  use_case=self.use_case_key(use_case), input=input, file=file_upload
72
66
  ).create_custom_recipe
73
67
 
74
- def run(
75
- self,
76
- recipe_key: str,
77
- num_gpus: int,
78
- input_args: dict | None = None,
79
- name: str | None = None,
80
- use_case: str | None = None,
81
- compute_pool: str | None = None,
82
- ) -> JobData:
83
- input = JobInput(
84
- recipe=recipe_key,
85
- useCase=self.use_case_key(use_case),
86
- args=input_args or {},
87
- name=name,
88
- computePool=compute_pool,
89
- numGpus=num_gpus,
90
- )
91
- return self._gql_client.create_job(input).create_job
92
-
93
68
  def get(
94
69
  self,
95
70
  recipe_key: str,
96
71
  use_case: str | None = None,
97
- ) -> CustomRecipeData:
72
+ ) -> CustomRecipeData | None:
98
73
  return self._gql_client.get_custom_recipe(
99
74
  id_or_key=recipe_key, use_case=self.use_case_key(use_case)
100
75
  ).custom_recipe
@@ -102,7 +77,7 @@ class Recipes(SyncAPIResource, UseCaseResource): # type: ignore[misc]
102
77
  def update(
103
78
  self,
104
79
  recipe_key: str,
105
- file_path: str | None = None,
80
+ path: str | None = None,
106
81
  name: str | None = None,
107
82
  description: str | None = None,
108
83
  labels: Sequence[tuple[str, str]] | None = None,
@@ -115,19 +90,21 @@ class Recipes(SyncAPIResource, UseCaseResource): # type: ignore[misc]
115
90
  labels=label_inputs,
116
91
  )
117
92
 
118
- file_upload = None
119
- if file_path:
120
- filename = Path(file_path).stem
121
- content_type = mimetypes.guess_type(file_path)[0] or "application/octet-stream"
122
- with open(file_path, "rb") as f:
123
- file_upload = Upload(filename=filename, content=f, content_type=content_type)
124
-
125
- return self._gql_client.update_custom_recipe(
126
- use_case=self.use_case_key(use_case),
127
- id=recipe_key,
128
- input=input,
129
- file=file_upload,
130
- ).update_custom_recipe
93
+ if path:
94
+ with _upload_from_path(path) as file_upload:
95
+ return self._gql_client.update_custom_recipe(
96
+ use_case=self.use_case_key(use_case),
97
+ id=recipe_key,
98
+ input=input,
99
+ file=file_upload,
100
+ ).update_custom_recipe
101
+ else:
102
+ return self._gql_client.update_custom_recipe(
103
+ use_case=self.use_case_key(use_case),
104
+ id=recipe_key,
105
+ input=input,
106
+ file=None,
107
+ ).update_custom_recipe
131
108
 
132
109
  def delete(
133
110
  self,
@@ -140,6 +117,8 @@ class Recipes(SyncAPIResource, UseCaseResource): # type: ignore[misc]
140
117
 
141
118
  def generate_sample_input(self, recipe_key: str, use_case: str | None = None) -> dict:
142
119
  recipe_details = self.get(recipe_key=recipe_key, use_case=self.use_case_key(use_case))
120
+ if recipe_details is None:
121
+ raise ValueError(f"Recipe {recipe_key} was not found")
143
122
  strategy = from_schema(recipe_details.json_schema)
144
123
 
145
124
  best_example = None
@@ -180,54 +159,33 @@ class AsyncRecipes(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
180
159
 
181
160
  async def upload(
182
161
  self,
183
- file_path: str,
162
+ path: str,
184
163
  recipe_key: str,
185
164
  name: str | None = None,
186
165
  description: str | None = None,
187
166
  labels: Sequence[tuple[str, str]] | None = None,
188
167
  use_case: str | None = None,
189
168
  ) -> CustomRecipeData:
190
- filename = Path(file_path).stem
169
+ inferred_name = name or recipe_key
191
170
  label_inputs = [LabelInput(key=k, value=v) for k, v in labels] if labels else None
192
171
  input = CreateRecipeInput(
193
172
  key=recipe_key,
194
- name=name or filename,
173
+ name=inferred_name,
195
174
  description=description,
196
175
  labels=label_inputs,
197
176
  )
198
- content_type = mimetypes.guess_type(file_path)[0] or "application/octet-stream"
199
- with open(file_path, "rb") as f:
200
- file_upload = Upload(filename=filename, content=f, content_type=content_type)
177
+ with _upload_from_path(path) as file_upload:
201
178
  return (
202
179
  await self._gql_client.create_custom_recipe(
203
180
  use_case=self.use_case_key(use_case), input=input, file=file_upload
204
181
  )
205
182
  ).create_custom_recipe
206
183
 
207
- async def run(
208
- self,
209
- recipe_key: str,
210
- num_gpus: int,
211
- input_args: dict | None = None,
212
- name: str | None = None,
213
- use_case: str | None = None,
214
- compute_pool: str | None = None,
215
- ) -> JobData:
216
- input = JobInput(
217
- recipe=recipe_key,
218
- useCase=self.use_case_key(use_case),
219
- args=input_args,
220
- name=name,
221
- computePool=compute_pool,
222
- numGpus=num_gpus,
223
- )
224
- return (await self._gql_client.create_job(input)).create_job
225
-
226
184
  async def get(
227
185
  self,
228
186
  recipe_key: str,
229
187
  use_case: str | None = None,
230
- ) -> CustomRecipeData:
188
+ ) -> CustomRecipeData | None:
231
189
  return (
232
190
  await self._gql_client.get_custom_recipe(id_or_key=recipe_key, use_case=self.use_case_key(use_case))
233
191
  ).custom_recipe
@@ -235,7 +193,7 @@ class AsyncRecipes(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
235
193
  async def update(
236
194
  self,
237
195
  recipe_key: str,
238
- file_path: str | None = None,
196
+ path: str | None = None,
239
197
  name: str | None = None,
240
198
  description: str | None = None,
241
199
  labels: Sequence[tuple[str, str]] | None = None,
@@ -248,21 +206,25 @@ class AsyncRecipes(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
248
206
  labels=label_inputs,
249
207
  )
250
208
 
251
- file_upload = None
252
- if file_path:
253
- filename = Path(file_path).stem
254
- content_type = mimetypes.guess_type(file_path)[0] or "application/octet-stream"
255
- with open(file_path, "rb") as f:
256
- file_upload = Upload(filename=filename, content=f, content_type=content_type)
257
-
258
- return (
259
- await self._gql_client.update_custom_recipe(
260
- use_case=self.use_case_key(use_case),
261
- id=recipe_key,
262
- input=input,
263
- file=file_upload,
264
- )
265
- ).update_custom_recipe
209
+ if path:
210
+ with _upload_from_path(path) as file_upload:
211
+ return (
212
+ await self._gql_client.update_custom_recipe(
213
+ use_case=self.use_case_key(use_case),
214
+ id=recipe_key,
215
+ input=input,
216
+ file=file_upload,
217
+ )
218
+ ).update_custom_recipe
219
+ else:
220
+ return (
221
+ await self._gql_client.update_custom_recipe(
222
+ use_case=self.use_case_key(use_case),
223
+ id=recipe_key,
224
+ input=input,
225
+ file=None,
226
+ )
227
+ ).update_custom_recipe
266
228
 
267
229
  async def delete(
268
230
  self,
@@ -275,6 +237,8 @@ class AsyncRecipes(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
275
237
 
276
238
  async def generate_sample_input(self, recipe_key: str, use_case: str | None = None) -> dict:
277
239
  recipe_details = await self.get(recipe_key=recipe_key, use_case=self.use_case_key(use_case))
240
+ if recipe_details is None:
241
+ raise ValueError(f"Recipe {recipe_key} was not found")
278
242
  strategy = from_schema(recipe_details.json_schema)
279
243
 
280
244
  best_example = None
@@ -296,3 +260,102 @@ class AsyncRecipes(AsyncAPIResource, UseCaseResource): # type: ignore[misc]
296
260
  print("A valid sample could not be generated. Returning an empty dict.")
297
261
  best_example = {}
298
262
  return dict(best_example) # type: ignore
263
+
264
+
265
+ def _count_keys_recursively(data: Any) -> int:
266
+ """Recursively counts the total number of keys in dictionaries within the data."""
267
+ count = 0
268
+ if isinstance(data, dict):
269
+ count += len(data)
270
+ for value in data.values():
271
+ count += _count_keys_recursively(value)
272
+ elif isinstance(data, list):
273
+ for item in data:
274
+ count += _count_keys_recursively(item)
275
+ return count
276
+
277
+
278
+ def _validate_python_file(path: Path) -> None:
279
+ """Validate that the path exists, is a file and has a .py extension."""
280
+ if not path.exists():
281
+ raise FileNotFoundError(f"Python file not found: {path}")
282
+ if not path.is_file():
283
+ raise ValueError(f"Expected a file path, got a directory or non-file: {path}")
284
+ if path.suffix.lower() != ".py":
285
+ raise ValueError(f"Expected a Python file with .py extension, got: {path}")
286
+
287
+
288
+ def _validate_recipe_directory(dir_path: Path) -> None:
289
+ """Validate that the directory exists and contains a main.py file."""
290
+ if not dir_path.exists():
291
+ raise FileNotFoundError(f"Directory not found: {dir_path}")
292
+ if not dir_path.is_dir():
293
+ raise ValueError(f"Expected a directory path, got a file: {dir_path}")
294
+ main_py = dir_path / "main.py"
295
+ if not main_py.exists() or not main_py.is_file():
296
+ raise FileNotFoundError(f"Directory must contain a 'main.py' file: {dir_path}")
297
+
298
+
299
+ def _zip_directory_to_bytes_io(dir_path: Path) -> io.BytesIO:
300
+ """Zip the contents of a directory into an in-memory BytesIO buffer."""
301
+ buffer = io.BytesIO()
302
+ with zipfile.ZipFile(buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
303
+ for root, _, files in os.walk(dir_path):
304
+ for file_name in files:
305
+ file_path = Path(root) / file_name
306
+ arcname = file_path.relative_to(dir_path)
307
+ zf.write(file_path, arcname.as_posix())
308
+ buffer.seek(0)
309
+ return buffer
310
+
311
+
312
+ @contextmanager
313
+ def _upload_from_path(path: str):
314
+ """
315
+ Context manager yielding an Upload object for a Python file or a directory.
316
+
317
+ - If path is a .py file, validates and opens it for upload.
318
+ - If path is a directory, validates it contains main.py, zips contents in-memory.
319
+ """
320
+ p = Path(path)
321
+ if p.is_file():
322
+ _validate_python_file(p)
323
+ filename = p.name
324
+ content_type = mimetypes.guess_type(str(p))[0] or "application/octet-stream"
325
+ f = open(p, "rb")
326
+ try:
327
+ yield Upload(filename=filename, content=f, content_type=content_type)
328
+ finally:
329
+ f.close()
330
+ elif p.is_dir():
331
+ _validate_recipe_directory(p)
332
+ # Ensure __init__.py exists at the root of the directory before zipping
333
+ created_init = False
334
+ root_init = p / "__init__.py"
335
+ zip_buffer = None
336
+ try:
337
+ if not root_init.exists():
338
+ root_init.touch()
339
+ created_init = True
340
+ logger.info(f"Added __init__.py to your directory, as it is required for proper execution of recipe")
341
+ zip_buffer = _zip_directory_to_bytes_io(p)
342
+ finally:
343
+ if created_init:
344
+ try:
345
+ root_init.unlink()
346
+ logger.info(f"Cleaned up __init__.py from your directory")
347
+ except Exception:
348
+ logger.error(f"Failed to remove __init__.py from your directory")
349
+ pass
350
+ if zip_buffer is None:
351
+ raise RuntimeError("Failed to create in-memory zip for directory upload")
352
+
353
+ filename = f"{p.name}.zip"
354
+ try:
355
+ yield Upload(filename=filename, content=zip_buffer, content_type="application/zip")
356
+ finally:
357
+ zip_buffer.close()
358
+ else:
359
+ if not p.exists():
360
+ raise FileNotFoundError(f"Path not found: {path}")
361
+ raise ValueError(f"Path must be a Python file or a directory: {path}")