runnable 0.29.0__py3-none-any.whl → 0.30.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,6 +12,12 @@ logger = logging.getLogger(defaults.LOGGER_NAME)
12
12
  class LocalJobExecutor(GenericJobExecutor):
13
13
  """
14
14
  The LocalJobExecutor is a job executor that runs the job locally.
15
+
16
+ Configuration:
17
+
18
+ pipeline-executor:
19
+ type: local
20
+
15
21
  """
16
22
 
17
23
  service_name: str = "local"
extensions/nodes/torch.py CHANGED
@@ -115,7 +115,9 @@ class TorchNode(DistributedNode, TorchConfig):
115
115
  map_variable: TypeMapVariable = None,
116
116
  attempt_number: int = 1,
117
117
  ) -> StepLog:
118
- assert map_variable is None, "TorchNode does not support map_variable"
118
+ assert (
119
+ map_variable is None or not map_variable
120
+ ), "TorchNode does not support map_variable"
119
121
 
120
122
  step_log = self._context.run_log_store.get_step_log(
121
123
  self._get_step_log_name(map_variable), self._context.run_id
@@ -130,6 +132,8 @@ class TorchNode(DistributedNode, TorchConfig):
130
132
  self._context.parameters_file or ""
131
133
  )
132
134
  os.environ["RUNNABLE_TORCH_RUN_ID"] = self._context.run_id
135
+ # retrieve the master address and port from the parameters
136
+ # default to localhost and 29500
133
137
  launcher = elastic_launch(
134
138
  launch_config,
135
139
  training_subprocess,
@@ -161,9 +165,17 @@ class TorchNode(DistributedNode, TorchConfig):
161
165
 
162
166
  return step_log
163
167
 
164
- # TODO: Not sure we need these methods
165
168
  def fan_in(self, map_variable: dict[str, str | int | float] | None = None):
166
- assert map_variable is None, "TorchNode does not support map_variable"
169
+ # Destroy the service
170
+ # Destroy the statefulset
171
+ assert (
172
+ map_variable is None or not map_variable
173
+ ), "TorchNode does not support map_variable"
167
174
 
168
175
  def fan_out(self, map_variable: dict[str, str | int | float] | None = None):
169
- assert map_variable is None, "TorchNode does not support map_variable"
176
+ # Create a service
177
+ # Create a statefulset
178
+ # Gather the IPs and set them as parameters downstream
179
+ assert (
180
+ map_variable is None or not map_variable
181
+ ), "TorchNode does not support map_variable"
@@ -27,6 +27,7 @@ from extensions.nodes.nodes import (
27
27
  SuccessNode,
28
28
  TaskNode,
29
29
  )
30
+ from extensions.nodes.torch import TorchNode
30
31
  from extensions.pipeline_executor import GenericPipelineExecutor
31
32
  from runnable import defaults, utils
32
33
  from runnable.defaults import TypeMapVariable
@@ -370,6 +371,89 @@ class CustomVolume(BaseModelWIthConfig):
370
371
 
371
372
 
372
373
  class ArgoExecutor(GenericPipelineExecutor):
374
+ """
375
+ Executes the pipeline using Argo Workflows.
376
+
377
+ The defaults configuration is kept similar to the
378
+ [Argo Workflow spec](https://argo-workflows.readthedocs.io/en/latest/fields/#workflow).
379
+
380
+ Configuration:
381
+
382
+ ```yaml
383
+ pipeline-executor:
384
+ type: argo
385
+ config:
386
+ pvc_for_runnable: "my-pvc"
387
+ custom_volumes:
388
+ - mount_path: "/tmp"
389
+ persistent_volume_claim:
390
+ claim_name: "my-pvc"
391
+ read_only: false/true
392
+ expose_parameters_as_inputs: true/false
393
+ secrets_from_k8s:
394
+ - key1
395
+ - key2
396
+ - ...
397
+ output_file: "argo-pipeline.yaml"
398
+ log_level: "DEBUG"/"INFO"/"WARNING"/"ERROR"/"CRITICAL"
399
+ defaults:
400
+ image: "my-image"
401
+ activeDeadlineSeconds: 86400
402
+ failFast: true
403
+ nodeSelector:
404
+ label: value
405
+ parallelism: 1
406
+ retryStrategy:
407
+ backoff:
408
+ duration: "2m"
409
+ factor: 2
410
+ maxDuration: "1h"
411
+ limit: 0
412
+ retryPolicy: "Always"
413
+ timeout: "1h"
414
+ tolerations:
415
+ imagePullPolicy: "Always"/"IfNotPresent"/"Never"
416
+ resources:
417
+ limits:
418
+ memory: "1Gi"
419
+ cpu: "250m"
420
+ gpu: 0
421
+ requests:
422
+ memory: "1Gi"
423
+ cpu: "250m"
424
+ env:
425
+ - name: "MY_ENV"
426
+ value: "my-value"
427
+ - name: secret_env
428
+ secretName: "my-secret"
429
+ secretKey: "my-key"
430
+ overrides:
431
+ key1:
432
+ ... similar structure to defaults
433
+
434
+ argoWorkflow:
435
+ metadata:
436
+ annotations:
437
+ key1: value1
438
+ key2: value2
439
+ generateName: "my-workflow"
440
+ labels:
441
+ key1: value1
442
+
443
+ ```
444
+
445
+ As of now, ```runnable``` needs a pvc to store the logs and the catalog; provided by ```pvc_for_runnable```.
446
+ - ```custom_volumes``` can be used to mount additional volumes to the container.
447
+
448
+ - ```expose_parameters_as_inputs``` can be used to expose the initial parameters as inputs to the workflow.
449
+ - ```secrets_from_k8s``` can be used to expose the secrets from the k8s secret store.
450
+ - ```output_file``` is the file where the argo pipeline will be dumped.
451
+ - ```log_level``` is the log level for the containers.
452
+ - ```defaults``` is the default configuration for all the containers.
453
+
454
+
455
+ """
456
+
373
457
  service_name: str = "argo"
374
458
  _is_local: bool = False
375
459
  mock: bool = False
@@ -510,6 +594,7 @@ class ArgoExecutor(GenericPipelineExecutor):
510
594
  isinstance(node, TaskNode)
511
595
  or isinstance(node, StubNode)
512
596
  or isinstance(node, SuccessNode)
597
+ or isinstance(node, TorchNode)
513
598
  )
514
599
 
515
600
  node_override = None
@@ -522,7 +607,7 @@ class ArgoExecutor(GenericPipelineExecutor):
522
607
 
523
608
  effective_settings = self.defaults.model_dump()
524
609
  if node_override:
525
- effective_settings.update(node_override.model_dump())
610
+ effective_settings.update(node_override.model_dump(exclude_none=True))
526
611
 
527
612
  inputs = inputs or Inputs(parameters=[])
528
613
 
@@ -573,7 +658,7 @@ class ArgoExecutor(GenericPipelineExecutor):
573
658
  def _set_env_vars_to_task(
574
659
  self, working_on: BaseNode, container_template: CoreContainerTemplate
575
660
  ):
576
- if not isinstance(working_on, TaskNode):
661
+ if not isinstance(working_on, TaskNode) or isinstance(working_on, TorchNode):
577
662
  return
578
663
 
579
664
  global_envs: dict[str, str] = {}
@@ -792,6 +877,26 @@ class ArgoExecutor(GenericPipelineExecutor):
792
877
 
793
878
  self._templates.append(composite_template)
794
879
 
880
+ case "torch":
881
+ assert isinstance(working_on, TorchNode)
882
+ # TODO: Need to add multi-node functionality
883
+ # Check notes on the torch node
884
+
885
+ template_of_container = self._create_container_template(
886
+ working_on,
887
+ task_name=task_name,
888
+ inputs=Inputs(parameters=parameters),
889
+ )
890
+ assert template_of_container.container is not None
891
+
892
+ if working_on.node_type == "task":
893
+ self._expose_secrets_to_task(
894
+ working_on,
895
+ container_template=template_of_container.container,
896
+ )
897
+
898
+ self._templates.append(template_of_container)
899
+
795
900
  self._handle_failures(
796
901
  working_on,
797
902
  dag,
@@ -18,8 +18,11 @@ class LocalExecutor(GenericPipelineExecutor):
18
18
  Also ensure that the local compute is good enough for the compute to happen of all the steps.
19
19
 
20
20
  Example config:
21
- execution:
21
+
22
+ ```yaml
23
+ pipeline-executor:
22
24
  type: local
25
+ ```
23
26
 
24
27
  """
25
28
 
@@ -3,7 +3,6 @@ from pathlib import Path
3
3
  from typing import Dict
4
4
 
5
5
  from pydantic import Field
6
- from rich import print
7
6
 
8
7
  from extensions.pipeline_executor import GenericPipelineExecutor
9
8
  from runnable import console, defaults, task_console, utils
@@ -20,31 +19,50 @@ class LocalContainerExecutor(GenericPipelineExecutor):
20
19
 
21
20
  Ensure that the local compute has enough resources to finish all your jobs.
22
21
 
23
- The image of the run, could either be provided as default in the configuration of the execution engine
24
- i.e.:
25
- execution:
26
- type: 'local-container'
27
- config:
28
- docker_image: the image you want the code to run in.
29
-
30
- or default image could be over-ridden for a single node by providing a docker_image in the step config.
31
- i.e:
32
- dag:
33
- steps:
34
- step:
35
- executor_config:
36
- local-container:
37
- docker_image: The image that you want that single step to run in.
38
- This image would only be used for that step only.
39
-
40
- This mode does not build the docker image with the latest code for you, it is still left for the user to build
41
- and ensure that the docker image provided is the correct one.
22
+ Configuration options:
42
23
 
43
- Example config:
44
- execution:
24
+ ```yaml
25
+ pipeline-executor:
45
26
  type: local-container
46
27
  config:
47
- docker_image: The default docker image to use if the node does not provide one.
28
+ docker_image: <required>
29
+ auto_remove_container: true/false
30
+ environment:
31
+ key: value
32
+ overrides:
33
+ alternate_config:
34
+ docker_image: <required>
35
+ auto_remove_container: true/false
36
+ environment:
37
+ key: value
38
+ ```
39
+
40
+ - ```docker_image```: The default docker image to use for all the steps.
41
+ - ```auto_remove_container```: Remove container after execution
42
+ - ```environment```: Environment variables to pass to the container
43
+
44
+ Overrides give you the ability to override the default docker image for a single step.
45
+ A step can then then refer to the alternate_config in the task definition.
46
+
47
+ Example:
48
+
49
+ ```python
50
+ from runnable import PythonTask
51
+
52
+ task = PythonTask(
53
+ name="alt_task",
54
+ overrides={
55
+ "local-container": "alternate_config"
56
+ }
57
+ )
58
+ ```
59
+
60
+ In the above example, ```alt_task``` will run in the docker image/configuration
61
+ as defined in the alternate_config.
62
+
63
+ ```runnable``` does not build the docker image for you, it is still left for the user to build
64
+ and ensure that the docker image provided is the correct one.
65
+
48
66
  """
49
67
 
50
68
  service_name: str = "local-container"
@@ -221,7 +239,6 @@ class LocalContainerExecutor(GenericPipelineExecutor):
221
239
 
222
240
  try:
223
241
  logger.info(f"Running the command {command}")
224
- print(command)
225
242
  #  Overrides global config with local
226
243
  executor_config = self._resolve_executor_config(node)
227
244
 
runnable/executor.py CHANGED
@@ -156,7 +156,7 @@ class BaseJobExecutor(BaseExecutor):
156
156
  # TODO: Consolidate execute_node, trigger_node_execution, _execute_node
157
157
  class BasePipelineExecutor(BaseExecutor):
158
158
  service_type: str = "pipeline_executor"
159
- overrides: dict = {}
159
+ overrides: dict[str, Any] = {}
160
160
 
161
161
  _context_node: Optional[BaseNode] = PrivateAttr(default=None)
162
162
 
runnable/sdk.py CHANGED
@@ -325,7 +325,7 @@ class NotebookTask(BaseTask):
325
325
  catalog Optional[Catalog]: The files sync data from/to, refer to Catalog.
326
326
 
327
327
  secrets List[str]: List of secrets to pass to the task. They are exposed as environment variables
328
- and removed after execution.
328
+ and removed after execution.
329
329
 
330
330
  overrides (Dict[str, Any]): Any overrides to the command.
331
331
  Individual tasks can override the global configuration config by referring to the
@@ -391,7 +391,7 @@ class ShellTask(BaseTask):
391
391
  catalog Optional[Catalog]: The files sync data from/to, refer to Catalog.
392
392
 
393
393
  secrets List[str]: List of secrets to pass to the task. They are exposed as environment variables
394
- and removed after execution.
394
+ and removed after execution.
395
395
 
396
396
  overrides (Dict[str, Any]): Any overrides to the command.
397
397
  Individual tasks can override the global configuration config by referring to the
@@ -460,8 +460,6 @@ class Stub(BaseTraversal):
460
460
 
461
461
 
462
462
  class Torch(BaseTraversal, TorchConfig):
463
- # Its a wrapper of a python task
464
- # TODO: Is there a way to not sync these with the torch node in extensions?
465
463
  function: Callable = Field(exclude=True)
466
464
  catalog: Optional[Catalog] = Field(default=None, alias="catalog")
467
465
  overrides: Dict[str, Any] = Field(default_factory=dict, alias="overrides")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: runnable
3
- Version: 0.29.0
3
+ Version: 0.30.1
4
4
  Summary: Add your description here
5
5
  Author-email: "Vammi, Vijay" <vijay.vammi@astrazeneca.com>
6
6
  License-File: LICENSE
@@ -10,19 +10,19 @@ extensions/job_executor/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3h
10
10
  extensions/job_executor/__init__.py,sha256=E2R6GV5cZTlZdqA5SVJ6ajZFh4oruM0k8AKHkpOZ3W8,5772
11
11
  extensions/job_executor/k8s.py,sha256=erzw4UOsOf2JSOiQio5stgW_rMryAsIQSBd8wiL6nBY,16214
12
12
  extensions/job_executor/k8s_job_spec.yaml,sha256=7aFpxHdO_p6Hkc3YxusUOuAQTD1Myu0yTPX9DrhxbOg,1158
13
- extensions/job_executor/local.py,sha256=raobGxwoqZN8c-yCsAa0CDuPLWKuyEttB37U5wsqGF4,1968
13
+ extensions/job_executor/local.py,sha256=3v6F8SOaPbCfPVVmU07RFr1wgs8iC8WoSn6Evfi8o3M,2033
14
14
  extensions/job_executor/local_container.py,sha256=8-dLhzY34pOVjJ_x0VmeTwVvYkESXBnp4j-XLsSsgBk,6688
15
15
  extensions/job_executor/pyproject.toml,sha256=UIEgiCYHTXcRWSByNMFuKJFKgxTBpQqTqyUecIsb_Vc,286
16
16
  extensions/nodes/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  extensions/nodes/nodes.py,sha256=s9ub1dqy4qHjRQG6YElCdL7rCOTYNs9RUIrStZ6tEB4,28256
18
18
  extensions/nodes/pyproject.toml,sha256=YTu-ETN3JNFSkMzzWeOwn4m-O2nbRH-PmiPBALDCUw4,278
19
- extensions/nodes/torch.py,sha256=kB4a72YMcrxDDzbR5LffODtrdA7vUo9dRJlaVr8KEEM,5570
19
+ extensions/nodes/torch.py,sha256=oYh4ep9J6CS3r04HURJba5m4v8lzNupWUh4PAXvGgi0,5952
20
20
  extensions/nodes/torch_config.py,sha256=yDvDADpnLhQsNtfH8qIztLHQ2LhYiOJEWljxpH9GZzs,1222
21
21
  extensions/pipeline_executor/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
22
  extensions/pipeline_executor/__init__.py,sha256=9ZMHcieSYdTiYyjSkc8eT8yhOlKEUFnrbrdbqdOgvP0,24195
23
- extensions/pipeline_executor/argo.py,sha256=LlXtzcbJyOossTNd-gdC4xrkPe9qYxRc0czdNzNQzlY,34497
24
- extensions/pipeline_executor/local.py,sha256=orBIhG8QJesA3YqWhsuczjhhlKD_1s62MRNerxv9_Tg,1858
25
- extensions/pipeline_executor/local_container.py,sha256=PvMXy-zFTnT9hj7jjz1VkaPVJ9Dkrhnd7Hs6eqLMXZ8,12398
23
+ extensions/pipeline_executor/argo.py,sha256=eyIVZbpecU1cPAwdvt56UFRZW2AqxALcBM_Yfvbvhqw,37958
24
+ extensions/pipeline_executor/local.py,sha256=6oWUJ6b6NvIkpeQJBoCT1hbfX4_6WCB4HzMgHZ4ik1A,1887
25
+ extensions/pipeline_executor/local_container.py,sha256=3kZ2QCsrq_YjH9dcAz8v05knKShQ_JtbIU-IA_-G538,12724
26
26
  extensions/pipeline_executor/mocked.py,sha256=0sMmypuvstBIv9uQg-WAcPrF3oOFpeEXNi6N8Nzdnl0,5680
27
27
  extensions/pipeline_executor/pyproject.toml,sha256=ykTX7srR10PBYb8LsIwEj8vIPPIEZQ5V_R7VYbZ-ido,291
28
28
  extensions/pipeline_executor/retry.py,sha256=6ClFXJYtr0M6nWIZiI-mbUGshobOtVH_KADN8JCfvH0,6881
@@ -40,7 +40,6 @@ extensions/run_log_store/db/integration_FF.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeR
40
40
  extensions/secrets/README.md,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
41
41
  extensions/secrets/dotenv.py,sha256=nADHXI6KJ_LUYOIe5EbtYH-21OBebSNVr0Pjb1GlZ7w,1573
42
42
  extensions/secrets/pyproject.toml,sha256=mLJNImNcBlbLKHh-0ugVWT9V83R4RibyyYDtBCSqVF4,282
43
- extensions/tasks/torch.py,sha256=uNO4qYMawNH5hPecANCiSUQZnUC8yqw3-rxtM526CeA,1955
44
43
  runnable/__init__.py,sha256=swvqdCjeddn40o4zjsluyahdVcU0r1arSRrxmRsvFEQ,673
45
44
  runnable/catalog.py,sha256=W_erYbLZ-ffuA9RQuWVqz1DUJOuWayf32ne32IDbAbc,4358
46
45
  runnable/cli.py,sha256=3BiKSj95h2Drn__YlchMPZ5rBMafuRb2OGIsVpbsO5Y,8788
@@ -49,18 +48,18 @@ runnable/datastore.py,sha256=ZobM1aVkgeUJ2fZYt63IFDsoNzObwc93hdByegS5YKQ,32396
49
48
  runnable/defaults.py,sha256=3o9IVGryyCE6PoQTOoaIaHHTbJGEzmdXMcwzOhwAYoI,3518
50
49
  runnable/entrypoints.py,sha256=cDbhtmLUWdBh9K6hNusfQpSd5NadcX8V1K2JEDf_YAg,18984
51
50
  runnable/exceptions.py,sha256=LFbp0-Qxg2PAMLEVt7w2whhBxSG-5pzUEv5qN-Rc4_c,3003
52
- runnable/executor.py,sha256=F0gQjJ10VQSlilqaxYm1gRdXpjmx5DqBP0KCPmz85zg,15021
51
+ runnable/executor.py,sha256=J8-Ri9nBZCb-ao6okePb9FUVlhAaPc0ojQ2l48-FUqc,15031
53
52
  runnable/graph.py,sha256=poQz5zcvq89ju_u5sYlunQLPbHnXTaUmjcvstPwvT4U,16536
54
53
  runnable/names.py,sha256=vn92Kv9ANROYSZX6Z4z1v_WA3WiEdIYmG6KEStBFZug,8134
55
54
  runnable/nodes.py,sha256=d1eLttMAcV7CTwTEqOuNwZqItANoLUkXJ73Xp-srlyI,17811
56
55
  runnable/parameters.py,sha256=sT3DNGczivP9z7r4Cp_brbudg1z4J-zjmvrq3ppIrVs,5089
57
56
  runnable/pickler.py,sha256=ydJ_eti_U1F4l-YacFp7BWm6g5vTn04UXye25S1HVok,2684
58
- runnable/sdk.py,sha256=6OO_vsRuGSjVhME2AJEljs0cjobjIQKa2E3mGoILkZA,35237
57
+ runnable/sdk.py,sha256=NZVQGaL4Zm2hwloRmqEgp8UPbBg9hY1abQGYnOgniPI,35128
59
58
  runnable/secrets.py,sha256=4L_dBFxTgr8r_hHUD6RlZEtqaOHDRsFG5PXO5wlvMI0,2324
60
59
  runnable/tasks.py,sha256=Qb1IhVxHv68E7vf3M3YCf7MGRHyjmsEEYBpEpiZ4mRI,29062
61
60
  runnable/utils.py,sha256=hBr7oGwGL2VgfITlQCTz-a1iwvvf7Mfl-HY8UdENZac,19929
62
- runnable-0.29.0.dist-info/METADATA,sha256=1jW9CmQUxqDClGKzNPZGMlq1B7M4y8p9FwQf0Cz1bZg,10115
63
- runnable-0.29.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
64
- runnable-0.29.0.dist-info/entry_points.txt,sha256=PrjKrlfXPZaV_7hz8orGu4FDnatLqnhPOXljyllszdw,1880
65
- runnable-0.29.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
66
- runnable-0.29.0.dist-info/RECORD,,
61
+ runnable-0.30.1.dist-info/METADATA,sha256=4Y4D0jyK46LpYoZE53b761BJe95eBvxo5QU3R-_-t0Y,10115
62
+ runnable-0.30.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
63
+ runnable-0.30.1.dist-info/entry_points.txt,sha256=PrjKrlfXPZaV_7hz8orGu4FDnatLqnhPOXljyllszdw,1880
64
+ runnable-0.30.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
65
+ runnable-0.30.1.dist-info/RECORD,,
extensions/tasks/torch.py DELETED
@@ -1,52 +0,0 @@
1
- from typing import List, Optional
2
-
3
- from pydantic import Field, field_validator
4
-
5
- from runnable import defaults
6
- from runnable.datastore import StepAttempt
7
- from runnable.defaults import TypeMapVariable
8
- from runnable.tasks import BaseTaskType
9
-
10
-
11
- def run_torch_task(
12
- rank: int = 1,
13
- world_size: int = 1,
14
- entrypoint: str = "some function",
15
- catalog: Optional[dict[str, List[str]]] = None,
16
- task_returns: Optional[List[str]] = None,
17
- secrets: Optional[list[str]] = None,
18
- ):
19
- # Entry point that creates a python job using simpler python types
20
- # and and executes them. The run_id for the job is set to be run_id_rank
21
- # Since the configuration file is passes as environmental variable,
22
- # The job will use the configuration file to get the required information.
23
-
24
- # In pseudocode, the following is done:
25
- # Create the catalog object
26
- # Create the secrets and other objects required for the PythonJob
27
- # Init the process group using:
28
- # https://github.com/pytorch/examples/blob/main/imagenet/main.py#L140
29
- # Execute the job, the job is expected to use the environmental variables
30
- # to identify the rank or can have them as variable in the signature.
31
- # Once the job is executed, we destroy the process group
32
- pass
33
-
34
-
35
- class TorchTaskType(BaseTaskType):
36
- task_type: str = Field(default="torch", serialization_alias="command_type")
37
- command: str
38
- num_gpus: int = Field(default=1, description="Number of GPUs to use")
39
-
40
- @field_validator("num_gpus")
41
- @classmethod
42
- def check_if_cuda_is_available(cls, num_gpus: int) -> int:
43
- # Import torch and check if cuda is available
44
- # validate if the number of gpus is less than or equal to available gpus
45
- return num_gpus
46
-
47
- def execute_command(
48
- self,
49
- map_variable: TypeMapVariable = None,
50
- ) -> StepAttempt:
51
- # We have to spawn here
52
- return StepAttempt(attempt_number=1, status=defaults.SUCCESS)