runnable 0.34.0a2__py3-none-any.whl → 0.34.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
runnable/sdk.py CHANGED
@@ -283,14 +283,15 @@ class PythonTask(BaseTask):
283
283
 
284
284
 
285
285
  class TorchTask(BaseTask):
286
- entrypoint: str = Field(
287
- alias="entrypoint", default="torch.distributed.run", frozen=True
288
- )
289
- args_to_torchrun: Dict[str, Any] = Field(
290
- default_factory=dict, alias="args_to_torchrun"
291
- )
286
+ # entrypoint: str = Field(
287
+ # alias="entrypoint", default="torch.distributed.run", frozen=True
288
+ # )
289
+ # args_to_torchrun: Dict[str, Any] = Field(
290
+ # default_factory=dict, alias="args_to_torchrun"
291
+ # )
292
292
 
293
293
  script_to_call: str
294
+ accelerate_config_file: str
294
295
 
295
296
  @computed_field
296
297
  def command_type(self) -> str:
@@ -984,13 +985,14 @@ class PythonJob(BaseJob):
984
985
 
985
986
 
986
987
  class TorchJob(BaseJob):
987
- entrypoint: str = Field(default="torch.distributed.run", frozen=True)
988
- args_to_torchrun: dict[str, str | bool | int | float] = Field(
989
- default_factory=dict
990
- ) # For example
988
+ # entrypoint: str = Field(default="torch.distributed.run", frozen=True)
989
+ # args_to_torchrun: dict[str, str | bool | int | float] = Field(
990
+ # default_factory=dict
991
+ # ) # For example
991
992
  # {"nproc_per_node": 2, "nnodes": 1,}
992
993
 
993
994
  script_to_call: str # For example train/script.py
995
+ accelerate_config_file: str
994
996
 
995
997
  def get_task(self) -> RunnableTask:
996
998
  # Piggy bank on existing tasks as a hack
runnable/tasks.py CHANGED
@@ -5,7 +5,6 @@ import io
5
5
  import json
6
6
  import logging
7
7
  import os
8
- import runpy
9
8
  import subprocess
10
9
  import sys
11
10
  from datetime import datetime
@@ -357,16 +356,15 @@ class PythonTaskType(BaseTaskType): # pylint: disable=too-few-public-methods
357
356
 
358
357
  class TorchTaskType(BaseTaskType):
359
358
  task_type: str = Field(default="torch", serialization_alias="command_type")
360
-
361
- entrypoint: str = Field(default="torch.distributed.run", frozen=True)
362
- args_to_torchrun: dict[str, str | bool] = Field(default_factory=dict) # For example
363
- # {"nproc_per_node": 2, "nnodes": 1,}
359
+ accelerate_config_file: str
364
360
 
365
361
  script_to_call: str # For example train/script.py
366
362
 
367
363
  def execute_command(
368
364
  self, map_variable: Dict[str, str | int | float] | None = None
369
365
  ) -> StepAttempt:
366
+ from accelerate.commands import launch
367
+
370
368
  attempt_log = StepAttempt(status=defaults.FAIL, start_time=str(datetime.now()))
371
369
 
372
370
  with (
@@ -376,39 +374,37 @@ class TorchTaskType(BaseTaskType):
376
374
  self.expose_secrets() as _,
377
375
  ):
378
376
  try:
379
- entry_point_args = [self.entrypoint]
380
-
381
- for key, value in self.args_to_torchrun.items():
382
- entry_point_args.append(f"--{key}")
383
- if type(value) is not bool:
384
- entry_point_args.append(str(value))
385
-
386
- entry_point_args.append(self.script_to_call)
377
+ script_args = []
387
378
  for key, value in params.items():
388
- entry_point_args.append(f"--{key}")
389
- if type(value.value) is not bool: # type: ignore
390
- entry_point_args.append(str(value.value)) # type: ignore
379
+ script_args.append(f"--{key}")
380
+ if type(value.value) is not bool:
381
+ script_args.append(str(value.value))
391
382
 
392
383
  # TODO: Check the typing here
393
384
 
394
385
  logger.info("Calling the user script with the following parameters:")
395
- logger.info(entry_point_args)
386
+ logger.info(script_args)
396
387
  out_file = TeeIO()
397
388
  try:
398
389
  with contextlib.redirect_stdout(out_file):
399
- sys.argv = entry_point_args
400
- runpy.run_module(self.entrypoint, run_name="__main__")
390
+ parser = launch.launch_command_parser()
391
+ args = parser.parse_args(self.script_to_call)
392
+ args.training_script = self.script_to_call
393
+ args.config_file = self.accelerate_config_file
394
+ args.training_script_args = script_args
395
+
396
+ launch.launch_command(args)
401
397
  task_console.print(out_file.getvalue())
402
398
  except Exception as e:
403
399
  raise exceptions.CommandCallError(
404
- f"Call to entrypoint {self.entrypoint} with {self.script_to_call} did not succeed."
400
+ f"Call to script{self.script_to_call} did not succeed."
405
401
  ) from e
406
402
  finally:
407
403
  sys.argv = sys.argv[:1]
408
404
 
409
405
  attempt_log.status = defaults.SUCCESS
410
406
  except Exception as _e:
411
- msg = f"Call to entrypoint {self.entrypoint} with {self.script_to_call} did not succeed."
407
+ msg = f"Call to script: {self.script_to_call} did not succeed."
412
408
  attempt_log.message = msg
413
409
  task_console.print_exception(show_locals=False)
414
410
  task_console.log(_e, style=defaults.error_style)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: runnable
3
- Version: 0.34.0a2
3
+ Version: 0.34.0a3
4
4
  Summary: Add your description here
5
5
  Author-email: "Vammi, Vijay" <vijay.vammi@astrazeneca.com>
6
6
  License-File: LICENSE
@@ -15,6 +15,7 @@ Requires-Dist: rich>=13.9.4
15
15
  Requires-Dist: ruamel-yaml>=0.18.6
16
16
  Requires-Dist: setuptools>=75.6.0
17
17
  Requires-Dist: stevedore>=5.4.0
18
+ Requires-Dist: torchvision>=0.21.0
18
19
  Requires-Dist: typer>=0.15.1
19
20
  Provides-Extra: docker
20
21
  Requires-Dist: docker>=7.1.0; extra == 'docker'
@@ -27,8 +28,8 @@ Requires-Dist: ploomber-engine>=0.0.33; extra == 'notebook'
27
28
  Provides-Extra: s3
28
29
  Requires-Dist: cloudpathlib[s3]; extra == 's3'
29
30
  Provides-Extra: torch
31
+ Requires-Dist: accelerate>=1.5.2; extra == 'torch'
30
32
  Requires-Dist: torch>=2.6.0; extra == 'torch'
31
- Requires-Dist: torchvision>=0.21.0; extra == 'torch'
32
33
  Description-Content-Type: text/markdown
33
34
 
34
35
 
@@ -56,12 +56,12 @@ runnable/names.py,sha256=vn92Kv9ANROYSZX6Z4z1v_WA3WiEdIYmG6KEStBFZug,8134
56
56
  runnable/nodes.py,sha256=QGHMznriEz4AcmntHICBZKrDT6zbc7WD1sV0MgwK10c,16691
57
57
  runnable/parameters.py,sha256=u77CdqqDAbVdzNeBFPNUfGnWPy9-SpBVmwEJ56xmDm8,5289
58
58
  runnable/pickler.py,sha256=ydJ_eti_U1F4l-YacFp7BWm6g5vTn04UXye25S1HVok,2684
59
- runnable/sdk.py,sha256=Cl6wVJj_pBnHmcszf-kh4nVqbiQaIruGJn06cm9epm4,35097
59
+ runnable/sdk.py,sha256=-hsoZctbGKsrfOQW3Z7RqWVGJI4GhbsOjqjMRb2OAUo,35181
60
60
  runnable/secrets.py,sha256=4L_dBFxTgr8r_hHUD6RlZEtqaOHDRsFG5PXO5wlvMI0,2324
61
- runnable/tasks.py,sha256=OW9pzjEKMRFpB256KJm__jWwsF37gs-tkIUcfnOTJwA,32382
61
+ runnable/tasks.py,sha256=lOtCninvosGI2bNIzblrzNa-lN7TMwel1KQ1g23M85A,32088
62
62
  runnable/utils.py,sha256=hBr7oGwGL2VgfITlQCTz-a1iwvvf7Mfl-HY8UdENZac,19929
63
- runnable-0.34.0a2.dist-info/METADATA,sha256=DzGQTVqxRAN95MoyRc5TQXG_OC85uf6PH5NGtru3qSg,10170
64
- runnable-0.34.0a2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
65
- runnable-0.34.0a2.dist-info/entry_points.txt,sha256=wKfW6aIWMQFlwrwpPBVWlMQDcxQmOupDKNkKyXoPFV4,1917
66
- runnable-0.34.0a2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
67
- runnable-0.34.0a2.dist-info/RECORD,,
63
+ runnable-0.34.0a3.dist-info/METADATA,sha256=AYMw1jtTzhBN_Y2dMJiguAnYwc82LLxa-WHYApUYpCs,10203
64
+ runnable-0.34.0a3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
65
+ runnable-0.34.0a3.dist-info/entry_points.txt,sha256=wKfW6aIWMQFlwrwpPBVWlMQDcxQmOupDKNkKyXoPFV4,1917
66
+ runnable-0.34.0a3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
67
+ runnable-0.34.0a3.dist-info/RECORD,,