runnable 0.33.0__py3-none-any.whl → 0.34.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- runnable/sdk.py +35 -36
- runnable/tasks.py +91 -0
- {runnable-0.33.0.dist-info → runnable-0.34.0a1.dist-info}/METADATA +1 -1
- {runnable-0.33.0.dist-info → runnable-0.34.0a1.dist-info}/RECORD +7 -7
- {runnable-0.33.0.dist-info → runnable-0.34.0a1.dist-info}/entry_points.txt +1 -1
- {runnable-0.33.0.dist-info → runnable-0.34.0a1.dist-info}/WHEEL +0 -0
- {runnable-0.33.0.dist-info → runnable-0.34.0a1.dist-info}/licenses/LICENSE +0 -0
runnable/sdk.py
CHANGED
@@ -34,7 +34,6 @@ from extensions.nodes.nodes import (
|
|
34
34
|
SuccessNode,
|
35
35
|
TaskNode,
|
36
36
|
)
|
37
|
-
from extensions.tasks.torch_config import TorchConfig
|
38
37
|
from runnable import console, defaults, entrypoints, exceptions, graph, utils
|
39
38
|
from runnable.executor import BaseJobExecutor, BasePipelineExecutor
|
40
39
|
from runnable.nodes import TraversalNode
|
@@ -44,7 +43,13 @@ from runnable.tasks import TaskReturns
|
|
44
43
|
logger = logging.getLogger(defaults.LOGGER_NAME)
|
45
44
|
|
46
45
|
StepType = Union[
|
47
|
-
"Stub",
|
46
|
+
"Stub",
|
47
|
+
"PythonTask",
|
48
|
+
"NotebookTask",
|
49
|
+
"ShellTask",
|
50
|
+
"Parallel",
|
51
|
+
"Map",
|
52
|
+
"TorchTask",
|
48
53
|
]
|
49
54
|
|
50
55
|
|
@@ -189,36 +194,6 @@ class BaseTask(BaseTraversal):
|
|
189
194
|
)
|
190
195
|
|
191
196
|
|
192
|
-
class TorchTask(BaseTask, TorchConfig):
|
193
|
-
# The user will not know the rnnz variables for multi node
|
194
|
-
# They should be overridden in the environment
|
195
|
-
function: Callable = Field(exclude=True)
|
196
|
-
|
197
|
-
@field_validator("returns", mode="before")
|
198
|
-
@classmethod
|
199
|
-
def serialize_returns(
|
200
|
-
cls, returns: List[Union[str, TaskReturns]]
|
201
|
-
) -> List[TaskReturns]:
|
202
|
-
assert len(returns) == 0, "Torch tasks cannot return any variables"
|
203
|
-
return []
|
204
|
-
|
205
|
-
@computed_field
|
206
|
-
def command_type(self) -> str:
|
207
|
-
return "torch"
|
208
|
-
|
209
|
-
@computed_field
|
210
|
-
def command(self) -> str:
|
211
|
-
module = self.function.__module__
|
212
|
-
name = self.function.__name__
|
213
|
-
|
214
|
-
return f"{module}.{name}"
|
215
|
-
|
216
|
-
def create_job(self) -> RunnableTask:
|
217
|
-
self.terminate_with_success = True
|
218
|
-
node = self.create_node()
|
219
|
-
return node.executable
|
220
|
-
|
221
|
-
|
222
197
|
class PythonTask(BaseTask):
|
223
198
|
"""
|
224
199
|
An execution node of the pipeline of python functions.
|
@@ -307,6 +282,26 @@ class PythonTask(BaseTask):
|
|
307
282
|
return node.executable
|
308
283
|
|
309
284
|
|
285
|
+
class TorchTask(BaseTask):
|
286
|
+
entrypoint: str = Field(
|
287
|
+
alias="entrypoint", default="torch.distributed.run", frozen=True
|
288
|
+
)
|
289
|
+
args_to_torchrun: Dict[str, Any] = Field(
|
290
|
+
default_factory=dict, alias="args_to_torchrun"
|
291
|
+
)
|
292
|
+
|
293
|
+
script_to_call: str
|
294
|
+
|
295
|
+
@computed_field
|
296
|
+
def command_type(self) -> str:
|
297
|
+
return "torch"
|
298
|
+
|
299
|
+
def create_job(self) -> RunnableTask:
|
300
|
+
self.terminate_with_success = True
|
301
|
+
node = self.create_node()
|
302
|
+
return node.executable
|
303
|
+
|
304
|
+
|
310
305
|
class NotebookTask(BaseTask):
|
311
306
|
"""
|
312
307
|
An execution node of the pipeline of notebook.
|
@@ -978,7 +973,6 @@ class PythonJob(BaseJob):
|
|
978
973
|
|
979
974
|
return f"{module}.{name}"
|
980
975
|
|
981
|
-
# TODO: can this be simplified to just self.model_dump(exclude_none=True)?
|
982
976
|
def get_task(self) -> RunnableTask:
|
983
977
|
# Piggy bank on existing tasks as a hack
|
984
978
|
task = PythonTask(
|
@@ -989,9 +983,14 @@ class PythonJob(BaseJob):
|
|
989
983
|
return task.create_node().executable
|
990
984
|
|
991
985
|
|
992
|
-
class TorchJob(BaseJob
|
993
|
-
|
994
|
-
|
986
|
+
class TorchJob(BaseJob):
|
987
|
+
entrypoint: str = Field(default="torch.distributed.run", frozen=True)
|
988
|
+
args_to_torchrun: dict[str, str | bool | int | float] = Field(
|
989
|
+
default_factory=dict
|
990
|
+
) # For example
|
991
|
+
# {"nproc_per_node": 2, "nnodes": 1,}
|
992
|
+
|
993
|
+
script_to_call: str # For example train/script.py
|
995
994
|
|
996
995
|
def get_task(self) -> RunnableTask:
|
997
996
|
# Piggy bank on existing tasks as a hack
|
runnable/tasks.py
CHANGED
@@ -5,6 +5,7 @@ import io
|
|
5
5
|
import json
|
6
6
|
import logging
|
7
7
|
import os
|
8
|
+
import runpy
|
8
9
|
import subprocess
|
9
10
|
import sys
|
10
11
|
from datetime import datetime
|
@@ -354,6 +355,69 @@ class PythonTaskType(BaseTaskType): # pylint: disable=too-few-public-methods
|
|
354
355
|
return attempt_log
|
355
356
|
|
356
357
|
|
358
|
+
class TorchTaskType(BaseTaskType):
|
359
|
+
task_type: str = Field(default="torch", serialization_alias="command_type")
|
360
|
+
|
361
|
+
entrypoint: str = Field(default="torch.distributed.run", frozen=True)
|
362
|
+
args_to_torchrun: dict[str, str | bool] = Field(default_factory=dict) # For example
|
363
|
+
# {"nproc_per_node": 2, "nnodes": 1,}
|
364
|
+
|
365
|
+
script_to_call: str # For example train/script.py
|
366
|
+
|
367
|
+
def execute_command(
|
368
|
+
self, map_variable: Dict[str, str | int | float] | None = None
|
369
|
+
) -> StepAttempt:
|
370
|
+
attempt_log = StepAttempt(status=defaults.FAIL, start_time=str(datetime.now()))
|
371
|
+
|
372
|
+
with (
|
373
|
+
self.execution_context(
|
374
|
+
map_variable=map_variable, allow_complex=False
|
375
|
+
) as params,
|
376
|
+
self.expose_secrets() as _,
|
377
|
+
):
|
378
|
+
try:
|
379
|
+
entry_point_args = [self.entrypoint]
|
380
|
+
|
381
|
+
for key, value in self.args_to_torchrun.items():
|
382
|
+
entry_point_args.append(f"--{key}")
|
383
|
+
if type(value) is not bool:
|
384
|
+
entry_point_args.append(str(value))
|
385
|
+
|
386
|
+
entry_point_args.append(self.script_to_call)
|
387
|
+
for key, value in params.items():
|
388
|
+
entry_point_args.append(f"--{key}")
|
389
|
+
if type(value.value) is not bool: # type: ignore
|
390
|
+
entry_point_args.append(str(value.value)) # type: ignore
|
391
|
+
|
392
|
+
# TODO: Check the typing here
|
393
|
+
|
394
|
+
logger.info("Calling the user script with the following parameters:")
|
395
|
+
logger.info(entry_point_args)
|
396
|
+
out_file = TeeIO()
|
397
|
+
try:
|
398
|
+
with contextlib.redirect_stdout(out_file):
|
399
|
+
sys.argv = entry_point_args
|
400
|
+
runpy.run_module(self.entrypoint, run_name="__main__")
|
401
|
+
task_console.print(out_file.getvalue())
|
402
|
+
except Exception as e:
|
403
|
+
raise exceptions.CommandCallError(
|
404
|
+
f"Call to entrypoint {self.entrypoint} with {self.script_to_call} did not succeed."
|
405
|
+
) from e
|
406
|
+
finally:
|
407
|
+
sys.argv = sys.argv[:1]
|
408
|
+
|
409
|
+
attempt_log.status = defaults.SUCCESS
|
410
|
+
except Exception as _e:
|
411
|
+
msg = f"Call to entrypoint {self.entrypoint} with {self.script_to_call} did not succeed."
|
412
|
+
attempt_log.message = msg
|
413
|
+
task_console.print_exception(show_locals=False)
|
414
|
+
task_console.log(_e, style=defaults.error_style)
|
415
|
+
|
416
|
+
attempt_log.end_time = str(datetime.now())
|
417
|
+
|
418
|
+
return attempt_log
|
419
|
+
|
420
|
+
|
357
421
|
class NotebookTaskType(BaseTaskType):
|
358
422
|
"""
|
359
423
|
--8<-- [start:notebook_reference]
|
@@ -747,6 +811,31 @@ class ShellTaskType(BaseTaskType):
|
|
747
811
|
return attempt_log
|
748
812
|
|
749
813
|
|
814
|
+
def convert_binary_to_string(data):
|
815
|
+
"""
|
816
|
+
Recursively converts 1 and 0 values in a nested dictionary to "1" and "0".
|
817
|
+
|
818
|
+
Args:
|
819
|
+
data (dict or any): The input data (dictionary, list, or other).
|
820
|
+
|
821
|
+
Returns:
|
822
|
+
dict or any: The modified data with binary values converted to strings.
|
823
|
+
"""
|
824
|
+
|
825
|
+
if isinstance(data, dict):
|
826
|
+
for key, value in data.items():
|
827
|
+
data[key] = convert_binary_to_string(value)
|
828
|
+
return data
|
829
|
+
elif isinstance(data, list):
|
830
|
+
return [convert_binary_to_string(item) for item in data]
|
831
|
+
elif data == 1:
|
832
|
+
return "1"
|
833
|
+
elif data == 0:
|
834
|
+
return "0"
|
835
|
+
else:
|
836
|
+
return data # Return other values unchanged
|
837
|
+
|
838
|
+
|
750
839
|
def create_task(kwargs_for_init) -> BaseTaskType:
|
751
840
|
"""
|
752
841
|
Creates a task object from the command configuration.
|
@@ -763,6 +852,8 @@ def create_task(kwargs_for_init) -> BaseTaskType:
|
|
763
852
|
kwargs = kwargs_for_init.copy()
|
764
853
|
command_type = kwargs.pop("command_type", defaults.COMMAND_TYPE)
|
765
854
|
|
855
|
+
kwargs = convert_binary_to_string(kwargs)
|
856
|
+
|
766
857
|
try:
|
767
858
|
task_mgr = driver.DriverManager(
|
768
859
|
namespace="tasks",
|
@@ -56,12 +56,12 @@ runnable/names.py,sha256=vn92Kv9ANROYSZX6Z4z1v_WA3WiEdIYmG6KEStBFZug,8134
|
|
56
56
|
runnable/nodes.py,sha256=QGHMznriEz4AcmntHICBZKrDT6zbc7WD1sV0MgwK10c,16691
|
57
57
|
runnable/parameters.py,sha256=u77CdqqDAbVdzNeBFPNUfGnWPy9-SpBVmwEJ56xmDm8,5289
|
58
58
|
runnable/pickler.py,sha256=ydJ_eti_U1F4l-YacFp7BWm6g5vTn04UXye25S1HVok,2684
|
59
|
-
runnable/sdk.py,sha256=
|
59
|
+
runnable/sdk.py,sha256=Cl6wVJj_pBnHmcszf-kh4nVqbiQaIruGJn06cm9epm4,35097
|
60
60
|
runnable/secrets.py,sha256=4L_dBFxTgr8r_hHUD6RlZEtqaOHDRsFG5PXO5wlvMI0,2324
|
61
|
-
runnable/tasks.py,sha256=
|
61
|
+
runnable/tasks.py,sha256=OW9pzjEKMRFpB256KJm__jWwsF37gs-tkIUcfnOTJwA,32382
|
62
62
|
runnable/utils.py,sha256=hBr7oGwGL2VgfITlQCTz-a1iwvvf7Mfl-HY8UdENZac,19929
|
63
|
-
runnable-0.
|
64
|
-
runnable-0.
|
65
|
-
runnable-0.
|
66
|
-
runnable-0.
|
67
|
-
runnable-0.
|
63
|
+
runnable-0.34.0a1.dist-info/METADATA,sha256=LphmhidZHZotusfCyam7DBBHze8nbJ85aVxUhnbCGyc,10170
|
64
|
+
runnable-0.34.0a1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
65
|
+
runnable-0.34.0a1.dist-info/entry_points.txt,sha256=wKfW6aIWMQFlwrwpPBVWlMQDcxQmOupDKNkKyXoPFV4,1917
|
66
|
+
runnable-0.34.0a1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
67
|
+
runnable-0.34.0a1.dist-info/RECORD,,
|
File without changes
|
File without changes
|