runnable 0.33.0__py3-none-any.whl → 0.34.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- runnable/sdk.py +37 -36
- runnable/tasks.py +87 -0
- {runnable-0.33.0.dist-info → runnable-0.34.0.dist-info}/METADATA +2 -2
- {runnable-0.33.0.dist-info → runnable-0.34.0.dist-info}/RECORD +7 -7
- {runnable-0.33.0.dist-info → runnable-0.34.0.dist-info}/entry_points.txt +1 -1
- {runnable-0.33.0.dist-info → runnable-0.34.0.dist-info}/WHEEL +0 -0
- {runnable-0.33.0.dist-info → runnable-0.34.0.dist-info}/licenses/LICENSE +0 -0
runnable/sdk.py
CHANGED
@@ -34,7 +34,6 @@ from extensions.nodes.nodes import (
|
|
34
34
|
SuccessNode,
|
35
35
|
TaskNode,
|
36
36
|
)
|
37
|
-
from extensions.tasks.torch_config import TorchConfig
|
38
37
|
from runnable import console, defaults, entrypoints, exceptions, graph, utils
|
39
38
|
from runnable.executor import BaseJobExecutor, BasePipelineExecutor
|
40
39
|
from runnable.nodes import TraversalNode
|
@@ -44,7 +43,13 @@ from runnable.tasks import TaskReturns
|
|
44
43
|
logger = logging.getLogger(defaults.LOGGER_NAME)
|
45
44
|
|
46
45
|
StepType = Union[
|
47
|
-
"Stub",
|
46
|
+
"Stub",
|
47
|
+
"PythonTask",
|
48
|
+
"NotebookTask",
|
49
|
+
"ShellTask",
|
50
|
+
"Parallel",
|
51
|
+
"Map",
|
52
|
+
"TorchTask",
|
48
53
|
]
|
49
54
|
|
50
55
|
|
@@ -189,36 +194,6 @@ class BaseTask(BaseTraversal):
|
|
189
194
|
)
|
190
195
|
|
191
196
|
|
192
|
-
class TorchTask(BaseTask, TorchConfig):
|
193
|
-
# The user will not know the rnnz variables for multi node
|
194
|
-
# They should be overridden in the environment
|
195
|
-
function: Callable = Field(exclude=True)
|
196
|
-
|
197
|
-
@field_validator("returns", mode="before")
|
198
|
-
@classmethod
|
199
|
-
def serialize_returns(
|
200
|
-
cls, returns: List[Union[str, TaskReturns]]
|
201
|
-
) -> List[TaskReturns]:
|
202
|
-
assert len(returns) == 0, "Torch tasks cannot return any variables"
|
203
|
-
return []
|
204
|
-
|
205
|
-
@computed_field
|
206
|
-
def command_type(self) -> str:
|
207
|
-
return "torch"
|
208
|
-
|
209
|
-
@computed_field
|
210
|
-
def command(self) -> str:
|
211
|
-
module = self.function.__module__
|
212
|
-
name = self.function.__name__
|
213
|
-
|
214
|
-
return f"{module}.{name}"
|
215
|
-
|
216
|
-
def create_job(self) -> RunnableTask:
|
217
|
-
self.terminate_with_success = True
|
218
|
-
node = self.create_node()
|
219
|
-
return node.executable
|
220
|
-
|
221
|
-
|
222
197
|
class PythonTask(BaseTask):
|
223
198
|
"""
|
224
199
|
An execution node of the pipeline of python functions.
|
@@ -307,6 +282,27 @@ class PythonTask(BaseTask):
|
|
307
282
|
return node.executable
|
308
283
|
|
309
284
|
|
285
|
+
class TorchTask(BaseTask):
|
286
|
+
# entrypoint: str = Field(
|
287
|
+
# alias="entrypoint", default="torch.distributed.run", frozen=True
|
288
|
+
# )
|
289
|
+
# args_to_torchrun: Dict[str, Any] = Field(
|
290
|
+
# default_factory=dict, alias="args_to_torchrun"
|
291
|
+
# )
|
292
|
+
|
293
|
+
script_to_call: str
|
294
|
+
accelerate_config_file: str
|
295
|
+
|
296
|
+
@computed_field
|
297
|
+
def command_type(self) -> str:
|
298
|
+
return "torch"
|
299
|
+
|
300
|
+
def create_job(self) -> RunnableTask:
|
301
|
+
self.terminate_with_success = True
|
302
|
+
node = self.create_node()
|
303
|
+
return node.executable
|
304
|
+
|
305
|
+
|
310
306
|
class NotebookTask(BaseTask):
|
311
307
|
"""
|
312
308
|
An execution node of the pipeline of notebook.
|
@@ -978,7 +974,6 @@ class PythonJob(BaseJob):
|
|
978
974
|
|
979
975
|
return f"{module}.{name}"
|
980
976
|
|
981
|
-
# TODO: can this be simplified to just self.model_dump(exclude_none=True)?
|
982
977
|
def get_task(self) -> RunnableTask:
|
983
978
|
# Piggy bank on existing tasks as a hack
|
984
979
|
task = PythonTask(
|
@@ -989,9 +984,15 @@ class PythonJob(BaseJob):
|
|
989
984
|
return task.create_node().executable
|
990
985
|
|
991
986
|
|
992
|
-
class TorchJob(BaseJob
|
993
|
-
|
994
|
-
#
|
987
|
+
class TorchJob(BaseJob):
|
988
|
+
# entrypoint: str = Field(default="torch.distributed.run", frozen=True)
|
989
|
+
# args_to_torchrun: dict[str, str | bool | int | float] = Field(
|
990
|
+
# default_factory=dict
|
991
|
+
# ) # For example
|
992
|
+
# {"nproc_per_node": 2, "nnodes": 1,}
|
993
|
+
|
994
|
+
script_to_call: str # For example train/script.py
|
995
|
+
accelerate_config_file: str
|
995
996
|
|
996
997
|
def get_task(self) -> RunnableTask:
|
997
998
|
# Piggy bank on existing tasks as a hack
|
runnable/tasks.py
CHANGED
@@ -354,6 +354,66 @@ class PythonTaskType(BaseTaskType): # pylint: disable=too-few-public-methods
|
|
354
354
|
return attempt_log
|
355
355
|
|
356
356
|
|
357
|
+
class TorchTaskType(BaseTaskType):
|
358
|
+
task_type: str = Field(default="torch", serialization_alias="command_type")
|
359
|
+
accelerate_config_file: str
|
360
|
+
|
361
|
+
script_to_call: str # For example train/script.py
|
362
|
+
|
363
|
+
def execute_command(
|
364
|
+
self, map_variable: Dict[str, str | int | float] | None = None
|
365
|
+
) -> StepAttempt:
|
366
|
+
from accelerate.commands import launch
|
367
|
+
|
368
|
+
attempt_log = StepAttempt(status=defaults.FAIL, start_time=str(datetime.now()))
|
369
|
+
|
370
|
+
with (
|
371
|
+
self.execution_context(
|
372
|
+
map_variable=map_variable, allow_complex=False
|
373
|
+
) as params,
|
374
|
+
self.expose_secrets() as _,
|
375
|
+
):
|
376
|
+
try:
|
377
|
+
script_args = []
|
378
|
+
for key, value in params.items():
|
379
|
+
script_args.append(f"--{key}")
|
380
|
+
if type(value.value) is not bool:
|
381
|
+
script_args.append(str(value.value))
|
382
|
+
|
383
|
+
# TODO: Check the typing here
|
384
|
+
|
385
|
+
logger.info("Calling the user script with the following parameters:")
|
386
|
+
logger.info(script_args)
|
387
|
+
out_file = TeeIO()
|
388
|
+
try:
|
389
|
+
with contextlib.redirect_stdout(out_file):
|
390
|
+
parser = launch.launch_command_parser()
|
391
|
+
args = parser.parse_args(self.script_to_call)
|
392
|
+
args.training_script = self.script_to_call
|
393
|
+
args.config_file = self.accelerate_config_file
|
394
|
+
args.training_script_args = script_args
|
395
|
+
|
396
|
+
launch.launch_command(args)
|
397
|
+
task_console.print(out_file.getvalue())
|
398
|
+
except Exception as e:
|
399
|
+
raise exceptions.CommandCallError(
|
400
|
+
f"Call to script{self.script_to_call} did not succeed."
|
401
|
+
) from e
|
402
|
+
finally:
|
403
|
+
sys.argv = sys.argv[:1]
|
404
|
+
|
405
|
+
attempt_log.status = defaults.SUCCESS
|
406
|
+
except Exception as _e:
|
407
|
+
msg = f"Call to script: {self.script_to_call} did not succeed."
|
408
|
+
attempt_log.message = msg
|
409
|
+
task_console.print_exception(show_locals=False)
|
410
|
+
task_console.log(_e, style=defaults.error_style)
|
411
|
+
|
412
|
+
attempt_log.end_time = str(datetime.now())
|
413
|
+
|
414
|
+
return attempt_log
|
415
|
+
|
416
|
+
|
357
417
|
class NotebookTaskType(BaseTaskType):
|
358
418
|
"""
|
359
419
|
--8<-- [start:notebook_reference]
|
@@ -747,6 +807,31 @@ class ShellTaskType(BaseTaskType):
|
|
747
807
|
return attempt_log
|
748
808
|
|
749
809
|
|
810
|
+
def convert_binary_to_string(data):
|
811
|
+
"""
|
812
|
+
Recursively converts 1 and 0 values in a nested dictionary to "1" and "0".
|
813
|
+
|
814
|
+
Args:
|
815
|
+
data (dict or any): The input data (dictionary, list, or other).
|
816
|
+
|
817
|
+
Returns:
|
818
|
+
dict or any: The modified data with binary values converted to strings.
|
819
|
+
"""
|
820
|
+
|
821
|
+
if isinstance(data, dict):
|
822
|
+
for key, value in data.items():
|
823
|
+
data[key] = convert_binary_to_string(value)
|
824
|
+
return data
|
825
|
+
elif isinstance(data, list):
|
826
|
+
return [convert_binary_to_string(item) for item in data]
|
827
|
+
elif data == 1:
|
828
|
+
return "1"
|
829
|
+
elif data == 0:
|
830
|
+
return "0"
|
831
|
+
else:
|
832
|
+
return data # Return other values unchanged
|
833
|
+
|
834
|
+
|
750
835
|
def create_task(kwargs_for_init) -> BaseTaskType:
|
751
836
|
"""
|
752
837
|
Creates a task object from the command configuration.
|
@@ -763,6 +848,8 @@ def create_task(kwargs_for_init) -> BaseTaskType:
|
|
763
848
|
kwargs = kwargs_for_init.copy()
|
764
849
|
command_type = kwargs.pop("command_type", defaults.COMMAND_TYPE)
|
765
850
|
|
851
|
+
kwargs = convert_binary_to_string(kwargs)
|
852
|
+
|
766
853
|
try:
|
767
854
|
task_mgr = driver.DriverManager(
|
768
855
|
namespace="tasks",
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: runnable
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.34.0
|
4
4
|
Summary: Add your description here
|
5
5
|
Author-email: "Vammi, Vijay" <vijay.vammi@astrazeneca.com>
|
6
6
|
License-File: LICENSE
|
@@ -27,8 +27,8 @@ Requires-Dist: ploomber-engine>=0.0.33; extra == 'notebook'
|
|
27
27
|
Provides-Extra: s3
|
28
28
|
Requires-Dist: cloudpathlib[s3]; extra == 's3'
|
29
29
|
Provides-Extra: torch
|
30
|
+
Requires-Dist: accelerate>=1.5.2; extra == 'torch'
|
30
31
|
Requires-Dist: torch>=2.6.0; extra == 'torch'
|
31
|
-
Requires-Dist: torchvision>=0.21.0; extra == 'torch'
|
32
32
|
Description-Content-Type: text/markdown
|
33
33
|
|
34
34
|
|
@@ -56,12 +56,12 @@ runnable/names.py,sha256=vn92Kv9ANROYSZX6Z4z1v_WA3WiEdIYmG6KEStBFZug,8134
|
|
56
56
|
runnable/nodes.py,sha256=QGHMznriEz4AcmntHICBZKrDT6zbc7WD1sV0MgwK10c,16691
|
57
57
|
runnable/parameters.py,sha256=u77CdqqDAbVdzNeBFPNUfGnWPy9-SpBVmwEJ56xmDm8,5289
|
58
58
|
runnable/pickler.py,sha256=ydJ_eti_U1F4l-YacFp7BWm6g5vTn04UXye25S1HVok,2684
|
59
|
-
runnable/sdk.py,sha256
|
59
|
+
runnable/sdk.py,sha256=-hsoZctbGKsrfOQW3Z7RqWVGJI4GhbsOjqjMRb2OAUo,35181
|
60
60
|
runnable/secrets.py,sha256=4L_dBFxTgr8r_hHUD6RlZEtqaOHDRsFG5PXO5wlvMI0,2324
|
61
|
-
runnable/tasks.py,sha256=
|
61
|
+
runnable/tasks.py,sha256=lOtCninvosGI2bNIzblrzNa-lN7TMwel1KQ1g23M85A,32088
|
62
62
|
runnable/utils.py,sha256=hBr7oGwGL2VgfITlQCTz-a1iwvvf7Mfl-HY8UdENZac,19929
|
63
|
-
runnable-0.
|
64
|
-
runnable-0.
|
65
|
-
runnable-0.
|
66
|
-
runnable-0.
|
67
|
-
runnable-0.
|
63
|
+
runnable-0.34.0.dist-info/METADATA,sha256=E_O8YUEotnppM6EG006CSRPs3XAuwspYAwrgNe5UgRc,10166
|
64
|
+
runnable-0.34.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
65
|
+
runnable-0.34.0.dist-info/entry_points.txt,sha256=wKfW6aIWMQFlwrwpPBVWlMQDcxQmOupDKNkKyXoPFV4,1917
|
66
|
+
runnable-0.34.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
67
|
+
runnable-0.34.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|