runnable 0.33.1__py3-none-any.whl → 0.34.0a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
runnable/sdk.py CHANGED
@@ -34,7 +34,6 @@ from extensions.nodes.nodes import (
34
34
  SuccessNode,
35
35
  TaskNode,
36
36
  )
37
- from extensions.tasks.torch_config import TorchConfig
38
37
  from runnable import console, defaults, entrypoints, exceptions, graph, utils
39
38
  from runnable.executor import BaseJobExecutor, BasePipelineExecutor
40
39
  from runnable.nodes import TraversalNode
@@ -44,7 +43,13 @@ from runnable.tasks import TaskReturns
44
43
  logger = logging.getLogger(defaults.LOGGER_NAME)
45
44
 
46
45
  StepType = Union[
47
- "Stub", "PythonTask", "NotebookTask", "ShellTask", "Parallel", "Map", "TorchTask"
46
+ "Stub",
47
+ "PythonTask",
48
+ "NotebookTask",
49
+ "ShellTask",
50
+ "Parallel",
51
+ "Map",
52
+ "TorchTask",
48
53
  ]
49
54
 
50
55
 
@@ -189,36 +194,6 @@ class BaseTask(BaseTraversal):
189
194
  )
190
195
 
191
196
 
192
- class TorchTask(BaseTask, TorchConfig):
193
- # The user will not know the rnnz variables for multi node
194
- # They should be overridden in the environment
195
- function: Callable = Field(exclude=True)
196
-
197
- @field_validator("returns", mode="before")
198
- @classmethod
199
- def serialize_returns(
200
- cls, returns: List[Union[str, TaskReturns]]
201
- ) -> List[TaskReturns]:
202
- assert len(returns) == 0, "Torch tasks cannot return any variables"
203
- return []
204
-
205
- @computed_field
206
- def command_type(self) -> str:
207
- return "torch"
208
-
209
- @computed_field
210
- def command(self) -> str:
211
- module = self.function.__module__
212
- name = self.function.__name__
213
-
214
- return f"{module}.{name}"
215
-
216
- def create_job(self) -> RunnableTask:
217
- self.terminate_with_success = True
218
- node = self.create_node()
219
- return node.executable
220
-
221
-
222
197
  class PythonTask(BaseTask):
223
198
  """
224
199
  An execution node of the pipeline of python functions.
@@ -307,6 +282,26 @@ class PythonTask(BaseTask):
307
282
  return node.executable
308
283
 
309
284
 
285
+ class TorchTask(BaseTask):
286
+ entrypoint: str = Field(
287
+ alias="entrypoint", default="torch.distributed.run", frozen=True
288
+ )
289
+ args_to_torchrun: Dict[str, Any] = Field(
290
+ default_factory=dict, alias="args_to_torchrun"
291
+ )
292
+
293
+ script_to_call: str
294
+
295
+ @computed_field
296
+ def command_type(self) -> str:
297
+ return "torch"
298
+
299
+ def create_job(self) -> RunnableTask:
300
+ self.terminate_with_success = True
301
+ node = self.create_node()
302
+ return node.executable
303
+
304
+
310
305
  class NotebookTask(BaseTask):
311
306
  """
312
307
  An execution node of the pipeline of notebook.
@@ -978,7 +973,6 @@ class PythonJob(BaseJob):
978
973
 
979
974
  return f"{module}.{name}"
980
975
 
981
- # TODO: can this be simplified to just self.model_dump(exclude_none=True)?
982
976
  def get_task(self) -> RunnableTask:
983
977
  # Piggy bank on existing tasks as a hack
984
978
  task = PythonTask(
@@ -989,9 +983,14 @@ class PythonJob(BaseJob):
989
983
  return task.create_node().executable
990
984
 
991
985
 
992
- class TorchJob(BaseJob, TorchConfig):
993
- function: Callable = Field()
994
- # min and max should always be 1
986
+ class TorchJob(BaseJob):
987
+ entrypoint: str = Field(default="torch.distributed.run", frozen=True)
988
+ args_to_torchrun: dict[str, str | bool | int | float] = Field(
989
+ default_factory=dict
990
+ ) # For example
991
+ # {"nproc_per_node": 2, "nnodes": 1,}
992
+
993
+ script_to_call: str # For example train/script.py
995
994
 
996
995
  def get_task(self) -> RunnableTask:
997
996
  # Piggy bank on existing tasks as a hack
runnable/tasks.py CHANGED
@@ -5,6 +5,7 @@ import io
5
5
  import json
6
6
  import logging
7
7
  import os
8
+ import runpy
8
9
  import subprocess
9
10
  import sys
10
11
  from datetime import datetime
@@ -354,6 +355,69 @@ class PythonTaskType(BaseTaskType): # pylint: disable=too-few-public-methods
354
355
  return attempt_log
355
356
 
356
357
 
358
+ class TorchTaskType(BaseTaskType):
359
+ task_type: str = Field(default="torch", serialization_alias="command_type")
360
+
361
+ entrypoint: str = Field(default="torch.distributed.run", frozen=True)
362
+ args_to_torchrun: dict[str, str | bool] = Field(default_factory=dict) # For example
363
+ # {"nproc_per_node": 2, "nnodes": 1,}
364
+
365
+ script_to_call: str # For example train/script.py
366
+
367
+ def execute_command(
368
+ self, map_variable: Dict[str, str | int | float] | None = None
369
+ ) -> StepAttempt:
370
+ attempt_log = StepAttempt(status=defaults.FAIL, start_time=str(datetime.now()))
371
+
372
+ with (
373
+ self.execution_context(
374
+ map_variable=map_variable, allow_complex=False
375
+ ) as params,
376
+ self.expose_secrets() as _,
377
+ ):
378
+ try:
379
+ entry_point_args = [self.entrypoint]
380
+
381
+ for key, value in self.args_to_torchrun.items():
382
+ entry_point_args.append(f"--{key}")
383
+ if type(value) is not bool:
384
+ entry_point_args.append(str(value))
385
+
386
+ entry_point_args.append(self.script_to_call)
387
+ for key, value in params.items():
388
+ entry_point_args.append(f"--{key}")
389
+ if type(value.value) is not bool: # type: ignore
390
+ entry_point_args.append(str(value.value)) # type: ignore
391
+
392
+ # TODO: Check the typing here
393
+
394
+ logger.info("Calling the user script with the following parameters:")
395
+ logger.info(entry_point_args)
396
+ out_file = TeeIO()
397
+ try:
398
+ with contextlib.redirect_stdout(out_file):
399
+ sys.argv = entry_point_args
400
+ runpy.run_module(self.entrypoint, run_name="__main__")
401
+ task_console.print(out_file.getvalue())
402
+ except Exception as e:
403
+ raise exceptions.CommandCallError(
404
+ f"Call to entrypoint {self.entrypoint} with {self.script_to_call} did not succeed."
405
+ ) from e
406
+ finally:
407
+ sys.argv = sys.argv[:1]
408
+
409
+ attempt_log.status = defaults.SUCCESS
410
+ except Exception as _e:
411
+ msg = f"Call to entrypoint {self.entrypoint} with {self.script_to_call} did not succeed."
412
+ attempt_log.message = msg
413
+ task_console.print_exception(show_locals=False)
414
+ task_console.log(_e, style=defaults.error_style)
415
+
416
+ attempt_log.end_time = str(datetime.now())
417
+
418
+ return attempt_log
419
+
420
+
357
421
  class NotebookTaskType(BaseTaskType):
358
422
  """
359
423
  --8<-- [start:notebook_reference]
@@ -747,6 +811,31 @@ class ShellTaskType(BaseTaskType):
747
811
  return attempt_log
748
812
 
749
813
 
814
+ def convert_binary_to_string(data):
815
+ """
816
+ Recursively converts 1 and 0 values in a nested dictionary to "1" and "0".
817
+
818
+ Args:
819
+ data (dict or any): The input data (dictionary, list, or other).
820
+
821
+ Returns:
822
+ dict or any: The modified data with binary values converted to strings.
823
+ """
824
+
825
+ if isinstance(data, dict):
826
+ for key, value in data.items():
827
+ data[key] = convert_binary_to_string(value)
828
+ return data
829
+ elif isinstance(data, list):
830
+ return [convert_binary_to_string(item) for item in data]
831
+ elif data == 1:
832
+ return "1"
833
+ elif data == 0:
834
+ return "0"
835
+ else:
836
+ return data # Return other values unchanged
837
+
838
+
750
839
  def create_task(kwargs_for_init) -> BaseTaskType:
751
840
  """
752
841
  Creates a task object from the command configuration.
@@ -763,6 +852,8 @@ def create_task(kwargs_for_init) -> BaseTaskType:
763
852
  kwargs = kwargs_for_init.copy()
764
853
  command_type = kwargs.pop("command_type", defaults.COMMAND_TYPE)
765
854
 
855
+ kwargs = convert_binary_to_string(kwargs)
856
+
766
857
  try:
767
858
  task_mgr = driver.DriverManager(
768
859
  namespace="tasks",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: runnable
3
- Version: 0.33.1
3
+ Version: 0.34.0a2
4
4
  Summary: Add your description here
5
5
  Author-email: "Vammi, Vijay" <vijay.vammi@astrazeneca.com>
6
6
  License-File: LICENSE
@@ -56,12 +56,12 @@ runnable/names.py,sha256=vn92Kv9ANROYSZX6Z4z1v_WA3WiEdIYmG6KEStBFZug,8134
56
56
  runnable/nodes.py,sha256=QGHMznriEz4AcmntHICBZKrDT6zbc7WD1sV0MgwK10c,16691
57
57
  runnable/parameters.py,sha256=u77CdqqDAbVdzNeBFPNUfGnWPy9-SpBVmwEJ56xmDm8,5289
58
58
  runnable/pickler.py,sha256=ydJ_eti_U1F4l-YacFp7BWm6g5vTn04UXye25S1HVok,2684
59
- runnable/sdk.py,sha256=xjMEDg1Uovrp0Kw7RR4GAQIJgG9zcIgueqOzQAaO7Bs,35363
59
+ runnable/sdk.py,sha256=Cl6wVJj_pBnHmcszf-kh4nVqbiQaIruGJn06cm9epm4,35097
60
60
  runnable/secrets.py,sha256=4L_dBFxTgr8r_hHUD6RlZEtqaOHDRsFG5PXO5wlvMI0,2324
61
- runnable/tasks.py,sha256=ABRhgiTY8F62pNlqJmVTDjwJwuzp8DqciUEOq1fpt1U,28989
61
+ runnable/tasks.py,sha256=OW9pzjEKMRFpB256KJm__jWwsF37gs-tkIUcfnOTJwA,32382
62
62
  runnable/utils.py,sha256=hBr7oGwGL2VgfITlQCTz-a1iwvvf7Mfl-HY8UdENZac,19929
63
- runnable-0.33.1.dist-info/METADATA,sha256=Cgy39IH3KEP3b2eEVF8u934b_AQ4BWzzQqjb2BV5aKw,10168
64
- runnable-0.33.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
65
- runnable-0.33.1.dist-info/entry_points.txt,sha256=uWHbbOSj0jlG54tFHw377xKkfVbjWvb_1Y9L_LgjJ0Q,1925
66
- runnable-0.33.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
67
- runnable-0.33.1.dist-info/RECORD,,
63
+ runnable-0.34.0a2.dist-info/METADATA,sha256=DzGQTVqxRAN95MoyRc5TQXG_OC85uf6PH5NGtru3qSg,10170
64
+ runnable-0.34.0a2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
65
+ runnable-0.34.0a2.dist-info/entry_points.txt,sha256=wKfW6aIWMQFlwrwpPBVWlMQDcxQmOupDKNkKyXoPFV4,1917
66
+ runnable-0.34.0a2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
67
+ runnable-0.34.0a2.dist-info/RECORD,,
@@ -49,4 +49,4 @@ env-secrets = runnable.secrets:EnvSecretsManager
49
49
  notebook = runnable.tasks:NotebookTaskType
50
50
  python = runnable.tasks:PythonTaskType
51
51
  shell = runnable.tasks:ShellTaskType
52
- torch = extensions.tasks.torch:TorchTaskType
52
+ torch = runnable.tasks:TorchTaskType