runnable 0.33.1__py3-none-any.whl → 0.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
runnable/sdk.py CHANGED
@@ -34,7 +34,6 @@ from extensions.nodes.nodes import (
34
34
  SuccessNode,
35
35
  TaskNode,
36
36
  )
37
- from extensions.tasks.torch_config import TorchConfig
38
37
  from runnable import console, defaults, entrypoints, exceptions, graph, utils
39
38
  from runnable.executor import BaseJobExecutor, BasePipelineExecutor
40
39
  from runnable.nodes import TraversalNode
@@ -44,7 +43,13 @@ from runnable.tasks import TaskReturns
44
43
  logger = logging.getLogger(defaults.LOGGER_NAME)
45
44
 
46
45
  StepType = Union[
47
- "Stub", "PythonTask", "NotebookTask", "ShellTask", "Parallel", "Map", "TorchTask"
46
+ "Stub",
47
+ "PythonTask",
48
+ "NotebookTask",
49
+ "ShellTask",
50
+ "Parallel",
51
+ "Map",
52
+ "TorchTask",
48
53
  ]
49
54
 
50
55
 
@@ -189,36 +194,6 @@ class BaseTask(BaseTraversal):
189
194
  )
190
195
 
191
196
 
192
- class TorchTask(BaseTask, TorchConfig):
193
- # The user will not know the rnnz variables for multi node
194
- # They should be overridden in the environment
195
- function: Callable = Field(exclude=True)
196
-
197
- @field_validator("returns", mode="before")
198
- @classmethod
199
- def serialize_returns(
200
- cls, returns: List[Union[str, TaskReturns]]
201
- ) -> List[TaskReturns]:
202
- assert len(returns) == 0, "Torch tasks cannot return any variables"
203
- return []
204
-
205
- @computed_field
206
- def command_type(self) -> str:
207
- return "torch"
208
-
209
- @computed_field
210
- def command(self) -> str:
211
- module = self.function.__module__
212
- name = self.function.__name__
213
-
214
- return f"{module}.{name}"
215
-
216
- def create_job(self) -> RunnableTask:
217
- self.terminate_with_success = True
218
- node = self.create_node()
219
- return node.executable
220
-
221
-
222
197
  class PythonTask(BaseTask):
223
198
  """
224
199
  An execution node of the pipeline of python functions.
@@ -307,6 +282,27 @@ class PythonTask(BaseTask):
307
282
  return node.executable
308
283
 
309
284
 
285
+ class TorchTask(BaseTask):
286
+ # entrypoint: str = Field(
287
+ # alias="entrypoint", default="torch.distributed.run", frozen=True
288
+ # )
289
+ # args_to_torchrun: Dict[str, Any] = Field(
290
+ # default_factory=dict, alias="args_to_torchrun"
291
+ # )
292
+
293
+ script_to_call: str
294
+ accelerate_config_file: str
295
+
296
+ @computed_field
297
+ def command_type(self) -> str:
298
+ return "torch"
299
+
300
+ def create_job(self) -> RunnableTask:
301
+ self.terminate_with_success = True
302
+ node = self.create_node()
303
+ return node.executable
304
+
305
+
310
306
  class NotebookTask(BaseTask):
311
307
  """
312
308
  An execution node of the pipeline of notebook.
@@ -978,7 +974,6 @@ class PythonJob(BaseJob):
978
974
 
979
975
  return f"{module}.{name}"
980
976
 
981
- # TODO: can this be simplified to just self.model_dump(exclude_none=True)?
982
977
  def get_task(self) -> RunnableTask:
983
978
  # Piggy bank on existing tasks as a hack
984
979
  task = PythonTask(
@@ -989,9 +984,15 @@ class PythonJob(BaseJob):
989
984
  return task.create_node().executable
990
985
 
991
986
 
992
- class TorchJob(BaseJob, TorchConfig):
993
- function: Callable = Field()
994
- # min and max should always be 1
987
+ class TorchJob(BaseJob):
988
+ # entrypoint: str = Field(default="torch.distributed.run", frozen=True)
989
+ # args_to_torchrun: dict[str, str | bool | int | float] = Field(
990
+ # default_factory=dict
991
+ # ) # For example
992
+ # {"nproc_per_node": 2, "nnodes": 1,}
993
+
994
+ script_to_call: str # For example train/script.py
995
+ accelerate_config_file: str
995
996
 
996
997
  def get_task(self) -> RunnableTask:
997
998
  # Piggy bank on existing tasks as a hack
runnable/tasks.py CHANGED
@@ -354,6 +354,66 @@ class PythonTaskType(BaseTaskType): # pylint: disable=too-few-public-methods
354
354
  return attempt_log
355
355
 
356
356
 
357
+ class TorchTaskType(BaseTaskType):
358
+ task_type: str = Field(default="torch", serialization_alias="command_type")
359
+ accelerate_config_file: str
360
+
361
+ script_to_call: str # For example train/script.py
362
+
363
+ def execute_command(
364
+ self, map_variable: Dict[str, str | int | float] | None = None
365
+ ) -> StepAttempt:
366
+ from accelerate.commands import launch
367
+
368
+ attempt_log = StepAttempt(status=defaults.FAIL, start_time=str(datetime.now()))
369
+
370
+ with (
371
+ self.execution_context(
372
+ map_variable=map_variable, allow_complex=False
373
+ ) as params,
374
+ self.expose_secrets() as _,
375
+ ):
376
+ try:
377
+ script_args = []
378
+ for key, value in params.items():
379
+ script_args.append(f"--{key}")
380
+ if type(value.value) is not bool:
381
+ script_args.append(str(value.value))
382
+
383
+ # TODO: Check the typing here
384
+
385
+ logger.info("Calling the user script with the following parameters:")
386
+ logger.info(script_args)
387
+ out_file = TeeIO()
388
+ try:
389
+ with contextlib.redirect_stdout(out_file):
390
+ parser = launch.launch_command_parser()
391
+ args = parser.parse_args(self.script_to_call)
392
+ args.training_script = self.script_to_call
393
+ args.config_file = self.accelerate_config_file
394
+ args.training_script_args = script_args
395
+
396
+ launch.launch_command(args)
397
+ task_console.print(out_file.getvalue())
398
+ except Exception as e:
399
+ raise exceptions.CommandCallError(
400
+ f"Call to script{self.script_to_call} did not succeed."
401
+ ) from e
402
+ finally:
403
+ sys.argv = sys.argv[:1]
404
+
405
+ attempt_log.status = defaults.SUCCESS
406
+ except Exception as _e:
407
+ msg = f"Call to script: {self.script_to_call} did not succeed."
408
+ attempt_log.message = msg
409
+ task_console.print_exception(show_locals=False)
410
+ task_console.log(_e, style=defaults.error_style)
411
+
412
+ attempt_log.end_time = str(datetime.now())
413
+
414
+ return attempt_log
415
+
416
+
357
417
  class NotebookTaskType(BaseTaskType):
358
418
  """
359
419
  --8<-- [start:notebook_reference]
@@ -747,6 +807,31 @@ class ShellTaskType(BaseTaskType):
747
807
  return attempt_log
748
808
 
749
809
 
810
+ def convert_binary_to_string(data):
811
+ """
812
+ Recursively converts 1 and 0 values in a nested dictionary to "1" and "0".
813
+
814
+ Args:
815
+ data (dict or any): The input data (dictionary, list, or other).
816
+
817
+ Returns:
818
+ dict or any: The modified data with binary values converted to strings.
819
+ """
820
+
821
+ if isinstance(data, dict):
822
+ for key, value in data.items():
823
+ data[key] = convert_binary_to_string(value)
824
+ return data
825
+ elif isinstance(data, list):
826
+ return [convert_binary_to_string(item) for item in data]
827
+ elif data == 1:
828
+ return "1"
829
+ elif data == 0:
830
+ return "0"
831
+ else:
832
+ return data # Return other values unchanged
833
+
834
+
750
835
  def create_task(kwargs_for_init) -> BaseTaskType:
751
836
  """
752
837
  Creates a task object from the command configuration.
@@ -763,6 +848,8 @@ def create_task(kwargs_for_init) -> BaseTaskType:
763
848
  kwargs = kwargs_for_init.copy()
764
849
  command_type = kwargs.pop("command_type", defaults.COMMAND_TYPE)
765
850
 
851
+ kwargs = convert_binary_to_string(kwargs)
852
+
766
853
  try:
767
854
  task_mgr = driver.DriverManager(
768
855
  namespace="tasks",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: runnable
3
- Version: 0.33.1
3
+ Version: 0.34.0
4
4
  Summary: Add your description here
5
5
  Author-email: "Vammi, Vijay" <vijay.vammi@astrazeneca.com>
6
6
  License-File: LICENSE
@@ -27,8 +27,8 @@ Requires-Dist: ploomber-engine>=0.0.33; extra == 'notebook'
27
27
  Provides-Extra: s3
28
28
  Requires-Dist: cloudpathlib[s3]; extra == 's3'
29
29
  Provides-Extra: torch
30
+ Requires-Dist: accelerate>=1.5.2; extra == 'torch'
30
31
  Requires-Dist: torch>=2.6.0; extra == 'torch'
31
- Requires-Dist: torchvision>=0.21.0; extra == 'torch'
32
32
  Description-Content-Type: text/markdown
33
33
 
34
34
 
@@ -56,12 +56,12 @@ runnable/names.py,sha256=vn92Kv9ANROYSZX6Z4z1v_WA3WiEdIYmG6KEStBFZug,8134
56
56
  runnable/nodes.py,sha256=QGHMznriEz4AcmntHICBZKrDT6zbc7WD1sV0MgwK10c,16691
57
57
  runnable/parameters.py,sha256=u77CdqqDAbVdzNeBFPNUfGnWPy9-SpBVmwEJ56xmDm8,5289
58
58
  runnable/pickler.py,sha256=ydJ_eti_U1F4l-YacFp7BWm6g5vTn04UXye25S1HVok,2684
59
- runnable/sdk.py,sha256=xjMEDg1Uovrp0Kw7RR4GAQIJgG9zcIgueqOzQAaO7Bs,35363
59
+ runnable/sdk.py,sha256=-hsoZctbGKsrfOQW3Z7RqWVGJI4GhbsOjqjMRb2OAUo,35181
60
60
  runnable/secrets.py,sha256=4L_dBFxTgr8r_hHUD6RlZEtqaOHDRsFG5PXO5wlvMI0,2324
61
- runnable/tasks.py,sha256=ABRhgiTY8F62pNlqJmVTDjwJwuzp8DqciUEOq1fpt1U,28989
61
+ runnable/tasks.py,sha256=lOtCninvosGI2bNIzblrzNa-lN7TMwel1KQ1g23M85A,32088
62
62
  runnable/utils.py,sha256=hBr7oGwGL2VgfITlQCTz-a1iwvvf7Mfl-HY8UdENZac,19929
63
- runnable-0.33.1.dist-info/METADATA,sha256=Cgy39IH3KEP3b2eEVF8u934b_AQ4BWzzQqjb2BV5aKw,10168
64
- runnable-0.33.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
65
- runnable-0.33.1.dist-info/entry_points.txt,sha256=uWHbbOSj0jlG54tFHw377xKkfVbjWvb_1Y9L_LgjJ0Q,1925
66
- runnable-0.33.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
67
- runnable-0.33.1.dist-info/RECORD,,
63
+ runnable-0.34.0.dist-info/METADATA,sha256=E_O8YUEotnppM6EG006CSRPs3XAuwspYAwrgNe5UgRc,10166
64
+ runnable-0.34.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
65
+ runnable-0.34.0.dist-info/entry_points.txt,sha256=wKfW6aIWMQFlwrwpPBVWlMQDcxQmOupDKNkKyXoPFV4,1917
66
+ runnable-0.34.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
67
+ runnable-0.34.0.dist-info/RECORD,,
@@ -49,4 +49,4 @@ env-secrets = runnable.secrets:EnvSecretsManager
49
49
  notebook = runnable.tasks:NotebookTaskType
50
50
  python = runnable.tasks:PythonTaskType
51
51
  shell = runnable.tasks:ShellTaskType
52
- torch = extensions.tasks.torch:TorchTaskType
52
+ torch = runnable.tasks:TorchTaskType