runnable 0.34.0a2__py3-none-any.whl → 0.34.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- runnable/sdk.py +12 -10
- runnable/tasks.py +17 -21
- {runnable-0.34.0a2.dist-info → runnable-0.34.0a3.dist-info}/METADATA +3 -2
- {runnable-0.34.0a2.dist-info → runnable-0.34.0a3.dist-info}/RECORD +7 -7
- {runnable-0.34.0a2.dist-info → runnable-0.34.0a3.dist-info}/WHEEL +0 -0
- {runnable-0.34.0a2.dist-info → runnable-0.34.0a3.dist-info}/entry_points.txt +0 -0
- {runnable-0.34.0a2.dist-info → runnable-0.34.0a3.dist-info}/licenses/LICENSE +0 -0
runnable/sdk.py
CHANGED
@@ -283,14 +283,15 @@ class PythonTask(BaseTask):
|
|
283
283
|
|
284
284
|
|
285
285
|
class TorchTask(BaseTask):
|
286
|
-
entrypoint: str = Field(
|
287
|
-
|
288
|
-
)
|
289
|
-
args_to_torchrun: Dict[str, Any] = Field(
|
290
|
-
|
291
|
-
)
|
286
|
+
# entrypoint: str = Field(
|
287
|
+
# alias="entrypoint", default="torch.distributed.run", frozen=True
|
288
|
+
# )
|
289
|
+
# args_to_torchrun: Dict[str, Any] = Field(
|
290
|
+
# default_factory=dict, alias="args_to_torchrun"
|
291
|
+
# )
|
292
292
|
|
293
293
|
script_to_call: str
|
294
|
+
accelerate_config_file: str
|
294
295
|
|
295
296
|
@computed_field
|
296
297
|
def command_type(self) -> str:
|
@@ -984,13 +985,14 @@ class PythonJob(BaseJob):
|
|
984
985
|
|
985
986
|
|
986
987
|
class TorchJob(BaseJob):
|
987
|
-
entrypoint: str = Field(default="torch.distributed.run", frozen=True)
|
988
|
-
args_to_torchrun: dict[str, str | bool | int | float] = Field(
|
989
|
-
|
990
|
-
) # For example
|
988
|
+
# entrypoint: str = Field(default="torch.distributed.run", frozen=True)
|
989
|
+
# args_to_torchrun: dict[str, str | bool | int | float] = Field(
|
990
|
+
# default_factory=dict
|
991
|
+
# ) # For example
|
991
992
|
# {"nproc_per_node": 2, "nnodes": 1,}
|
992
993
|
|
993
994
|
script_to_call: str # For example train/script.py
|
995
|
+
accelerate_config_file: str
|
994
996
|
|
995
997
|
def get_task(self) -> RunnableTask:
|
996
998
|
# Piggy bank on existing tasks as a hack
|
runnable/tasks.py
CHANGED
@@ -5,7 +5,6 @@ import io
|
|
5
5
|
import json
|
6
6
|
import logging
|
7
7
|
import os
|
8
|
-
import runpy
|
9
8
|
import subprocess
|
10
9
|
import sys
|
11
10
|
from datetime import datetime
|
@@ -357,16 +356,15 @@ class PythonTaskType(BaseTaskType): # pylint: disable=too-few-public-methods
|
|
357
356
|
|
358
357
|
class TorchTaskType(BaseTaskType):
|
359
358
|
task_type: str = Field(default="torch", serialization_alias="command_type")
|
360
|
-
|
361
|
-
entrypoint: str = Field(default="torch.distributed.run", frozen=True)
|
362
|
-
args_to_torchrun: dict[str, str | bool] = Field(default_factory=dict) # For example
|
363
|
-
# {"nproc_per_node": 2, "nnodes": 1,}
|
359
|
+
accelerate_config_file: str
|
364
360
|
|
365
361
|
script_to_call: str # For example train/script.py
|
366
362
|
|
367
363
|
def execute_command(
|
368
364
|
self, map_variable: Dict[str, str | int | float] | None = None
|
369
365
|
) -> StepAttempt:
|
366
|
+
from accelerate.commands import launch
|
367
|
+
|
370
368
|
attempt_log = StepAttempt(status=defaults.FAIL, start_time=str(datetime.now()))
|
371
369
|
|
372
370
|
with (
|
@@ -376,39 +374,37 @@ class TorchTaskType(BaseTaskType):
|
|
376
374
|
self.expose_secrets() as _,
|
377
375
|
):
|
378
376
|
try:
|
379
|
-
|
380
|
-
|
381
|
-
for key, value in self.args_to_torchrun.items():
|
382
|
-
entry_point_args.append(f"--{key}")
|
383
|
-
if type(value) is not bool:
|
384
|
-
entry_point_args.append(str(value))
|
385
|
-
|
386
|
-
entry_point_args.append(self.script_to_call)
|
377
|
+
script_args = []
|
387
378
|
for key, value in params.items():
|
388
|
-
|
389
|
-
if type(value.value) is not bool:
|
390
|
-
|
379
|
+
script_args.append(f"--{key}")
|
380
|
+
if type(value.value) is not bool:
|
381
|
+
script_args.append(str(value.value))
|
391
382
|
|
392
383
|
# TODO: Check the typing here
|
393
384
|
|
394
385
|
logger.info("Calling the user script with the following parameters:")
|
395
|
-
logger.info(
|
386
|
+
logger.info(script_args)
|
396
387
|
out_file = TeeIO()
|
397
388
|
try:
|
398
389
|
with contextlib.redirect_stdout(out_file):
|
399
|
-
|
400
|
-
|
390
|
+
parser = launch.launch_command_parser()
|
391
|
+
args = parser.parse_args(self.script_to_call)
|
392
|
+
args.training_script = self.script_to_call
|
393
|
+
args.config_file = self.accelerate_config_file
|
394
|
+
args.training_script_args = script_args
|
395
|
+
|
396
|
+
launch.launch_command(args)
|
401
397
|
task_console.print(out_file.getvalue())
|
402
398
|
except Exception as e:
|
403
399
|
raise exceptions.CommandCallError(
|
404
|
-
f"Call to
|
400
|
+
f"Call to script{self.script_to_call} did not succeed."
|
405
401
|
) from e
|
406
402
|
finally:
|
407
403
|
sys.argv = sys.argv[:1]
|
408
404
|
|
409
405
|
attempt_log.status = defaults.SUCCESS
|
410
406
|
except Exception as _e:
|
411
|
-
msg = f"Call to
|
407
|
+
msg = f"Call to script: {self.script_to_call} did not succeed."
|
412
408
|
attempt_log.message = msg
|
413
409
|
task_console.print_exception(show_locals=False)
|
414
410
|
task_console.log(_e, style=defaults.error_style)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: runnable
|
3
|
-
Version: 0.34.
|
3
|
+
Version: 0.34.0a3
|
4
4
|
Summary: Add your description here
|
5
5
|
Author-email: "Vammi, Vijay" <vijay.vammi@astrazeneca.com>
|
6
6
|
License-File: LICENSE
|
@@ -15,6 +15,7 @@ Requires-Dist: rich>=13.9.4
|
|
15
15
|
Requires-Dist: ruamel-yaml>=0.18.6
|
16
16
|
Requires-Dist: setuptools>=75.6.0
|
17
17
|
Requires-Dist: stevedore>=5.4.0
|
18
|
+
Requires-Dist: torchvision>=0.21.0
|
18
19
|
Requires-Dist: typer>=0.15.1
|
19
20
|
Provides-Extra: docker
|
20
21
|
Requires-Dist: docker>=7.1.0; extra == 'docker'
|
@@ -27,8 +28,8 @@ Requires-Dist: ploomber-engine>=0.0.33; extra == 'notebook'
|
|
27
28
|
Provides-Extra: s3
|
28
29
|
Requires-Dist: cloudpathlib[s3]; extra == 's3'
|
29
30
|
Provides-Extra: torch
|
31
|
+
Requires-Dist: accelerate>=1.5.2; extra == 'torch'
|
30
32
|
Requires-Dist: torch>=2.6.0; extra == 'torch'
|
31
|
-
Requires-Dist: torchvision>=0.21.0; extra == 'torch'
|
32
33
|
Description-Content-Type: text/markdown
|
33
34
|
|
34
35
|
|
@@ -56,12 +56,12 @@ runnable/names.py,sha256=vn92Kv9ANROYSZX6Z4z1v_WA3WiEdIYmG6KEStBFZug,8134
|
|
56
56
|
runnable/nodes.py,sha256=QGHMznriEz4AcmntHICBZKrDT6zbc7WD1sV0MgwK10c,16691
|
57
57
|
runnable/parameters.py,sha256=u77CdqqDAbVdzNeBFPNUfGnWPy9-SpBVmwEJ56xmDm8,5289
|
58
58
|
runnable/pickler.py,sha256=ydJ_eti_U1F4l-YacFp7BWm6g5vTn04UXye25S1HVok,2684
|
59
|
-
runnable/sdk.py,sha256
|
59
|
+
runnable/sdk.py,sha256=-hsoZctbGKsrfOQW3Z7RqWVGJI4GhbsOjqjMRb2OAUo,35181
|
60
60
|
runnable/secrets.py,sha256=4L_dBFxTgr8r_hHUD6RlZEtqaOHDRsFG5PXO5wlvMI0,2324
|
61
|
-
runnable/tasks.py,sha256=
|
61
|
+
runnable/tasks.py,sha256=lOtCninvosGI2bNIzblrzNa-lN7TMwel1KQ1g23M85A,32088
|
62
62
|
runnable/utils.py,sha256=hBr7oGwGL2VgfITlQCTz-a1iwvvf7Mfl-HY8UdENZac,19929
|
63
|
-
runnable-0.34.
|
64
|
-
runnable-0.34.
|
65
|
-
runnable-0.34.
|
66
|
-
runnable-0.34.
|
67
|
-
runnable-0.34.
|
63
|
+
runnable-0.34.0a3.dist-info/METADATA,sha256=AYMw1jtTzhBN_Y2dMJiguAnYwc82LLxa-WHYApUYpCs,10203
|
64
|
+
runnable-0.34.0a3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
65
|
+
runnable-0.34.0a3.dist-info/entry_points.txt,sha256=wKfW6aIWMQFlwrwpPBVWlMQDcxQmOupDKNkKyXoPFV4,1917
|
66
|
+
runnable-0.34.0a3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
67
|
+
runnable-0.34.0a3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|