qadence 1.10.3__py3-none-any.whl → 1.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,209 @@
1
+ from __future__ import annotations
2
+
3
+
4
+ from logging import getLogger
5
+
6
+ import os
7
+ import torch
8
+ import torch.distributed as dist
9
+ from qadence.ml_tools.train_utils.execution import BaseExecution, detect_execution
10
+ from qadence.types import ExecutionType
11
+
12
+ # Initialize the logger for this module
13
+ logger = getLogger("ml_tools")
14
+
15
+
16
+ class Distributor:
17
+ """
18
+ Class to set up and manage distributed training.
19
+
20
+ This class uses the detect_execution() method to get the correct current launch execution
21
+ (e.g., torchrun, default). It provides methods to setup processes, start, and clean up the
22
+ PyTorch distributed process group.
23
+
24
+ The execution configures environment variables required for distributed training (such as rank, world size,
25
+ master address, and master port), and sets the appropriate computation device (CPU or GPU).
26
+
27
+ Attributes:
28
+ nprocs (int): Number of processes to launch for distributed training.
29
+ execution (BaseExecution): Detected execution instance for process launch (e.g., "torchrun","default").
30
+ execution_type (ExecutionType): Type of exeuction used.
31
+ rank (int): Global rank of the process (to be set during environment setup).
32
+ world_size (int): Total number of processes (to be set during environment setup).
33
+ local_rank (int | None): Local rank on the node (to be set during environment setup).
34
+ master_addr (str): Master node address (to be set during environment setup).
35
+ master_port (str): Master node port (to be set during environment setup).
36
+ node_rank (int): Rank of the node on the cluster setup.
37
+ """
38
+
39
+ # -----------------------------------------------------------------------------
40
+ # HEAD level methods
41
+ # -----------------------------------------------------------------------------
42
+ def __init__(
43
+ self,
44
+ nprocs: int,
45
+ compute_setup: str,
46
+ log_setup: str,
47
+ backend: str,
48
+ dtype: torch.dtype | None = None,
49
+ ) -> None:
50
+ """
51
+ Initialize the Distributor.
52
+
53
+ Args:
54
+ nprocs (int): Number of processes to launch.
55
+ compute_setup (str): Compute device setup; options are "auto" (default), "gpu", or "cpu".
56
+ - "auto": Uses GPU if available, otherwise CPU.
57
+ - "gpu": Forces GPU usage, raising an error if no CUDA device is available.
58
+ - "cpu": Forces CPU usage.
59
+ log_setup (str): Logging device setup; options are "auto", "cpu" (default).
60
+ - "auto": Uses same device to log as used for computation.
61
+ - "cpu": Forces CPU logging.
62
+ backend (str): Backend to use for distributed communication (default: "nccl").
63
+ dtype (torch.dtype | None): Data type for controlling numerical precision. Default is None.
64
+ """
65
+ self._nprocs: int
66
+ self.rank: int
67
+ self.world_size: int
68
+ self.local_rank: int | None
69
+ self.master_addr: str
70
+ self.master_port: str
71
+ self.execution: BaseExecution
72
+
73
+ self.execution, self.execution_type = detect_execution(
74
+ compute_setup, log_setup, backend, dtype
75
+ )
76
+
77
+ self._config_nprocs = nprocs
78
+ if self.execution_type == ExecutionType.TORCHRUN:
79
+ # torchrun already spawns multiple process with required env variables
80
+ self.nprocs = 1
81
+ else:
82
+ self.nprocs = nprocs
83
+
84
+ # -----------------------------------------------------------------------------
85
+ # PROCESS level methods
86
+ # -----------------------------------------------------------------------------
87
+ def setup_process(self, process_rank: int) -> None:
88
+ """
89
+ Sets up the distributed training environment for a given process.
90
+
91
+ Each process sets up a rank, local_rank, and world size. If there are multiple processes
92
+ (based on the world size) a master_add and master port are also assigned.
93
+ Setting up process also sets up the device for the process. These are selected based on 'compute_setup'
94
+ argument in TrainConfig. For compute_setup = "auto" - gpus are selected if available.
95
+ The selected devices could be
96
+ - "cpu": in case of cpu based computation
97
+ - "cuda:n": GPU based on the distributed setup. Note that n is the local_rank of the gpu.
98
+ This also sets up the logging device for each process. In case the log_setup is "auto",
99
+ log_device is the same as device - otherwise its "cpu".
100
+
101
+ This method initializes the distributed process group and logs relevant details.
102
+
103
+ Args:
104
+ process_rank (int): The rank of the process in the distributed setting.
105
+ """
106
+ self.setup_process_rank_environment(process_rank)
107
+
108
+ logger.info("Initializing Accelerator")
109
+ logger.info("=============================")
110
+ logger.info(
111
+ " Node, Device : %s, %s",
112
+ str(self.execution.node_name),
113
+ self.execution.device,
114
+ )
115
+ logger.info(
116
+ " Rank, Local Rank, World Size: %s, %s, %s",
117
+ str(self.rank),
118
+ str(self.local_rank),
119
+ str(self.world_size),
120
+ )
121
+ logger.info(" Master Address, Master Port : %s, %s", self.master_addr, self.master_port)
122
+
123
+ self.start_process_group()
124
+ if self.rank == 0:
125
+ self._log_warnings() # log the warnings only from the main process
126
+
127
+ def setup_process_rank_environment(self, process_rank: int) -> dict[str, int | None]:
128
+ """
129
+ Set up the process for distributed training, especially useful when processes are spawned.
130
+
131
+ Set up environment variables and the computation device for distributed processing.
132
+
133
+ This method optionally sets environment variables for a spawned process if a process rank is provided.
134
+ This method retrieves the global rank, world size, and local rank using helper methods,
135
+ sets the corresponding environment variables, and if running in a multi-process setting,
136
+ sets up the master address and port for distributed communication. Finally, it configures
137
+ the computation device based on the specified compute setup.
138
+ This method sets:
139
+ rank (int): Global rank of the process (to be set during environment setup).
140
+ world_size (int): Total number of processes (to be set during environment setup).
141
+ local_rank (int | None): Local rank on the node (to be set during environment setup).
142
+ master_addr (str): Master node address (to be set during environment setup).
143
+ master_port (str): Master node port (to be set during environment setup).
144
+ node_rank (int): Rank of the node on the cluster setup.
145
+ node_name (str): Name of the node on the cluster setup.
146
+
147
+ Args:
148
+ process_rank (int | None): The rank to assign to the process (used in spawn scenarios).
149
+
150
+ Returns:
151
+ dict[str, int | None]: A dictionary containing the global rank, world size, and local rank.
152
+ """
153
+ # set the process based variables
154
+ self.local_rank = self.execution.get_local_rank(process_rank)
155
+ self.world_size = self.execution.get_world_size(process_rank, self.nprocs)
156
+ self.rank = self.execution.get_rank(process_rank)
157
+ self.execution.set_device(self.local_rank)
158
+ # Set environment variables for distributed training
159
+ os.environ["RANK"] = str(self.rank)
160
+ os.environ["WORLD_SIZE"] = str(self.world_size)
161
+ os.environ["LOCAL_RANK"] = str(self.local_rank)
162
+ os.environ["MASTER_ADDR"] = self.master_addr = self.execution.get_master_addr()
163
+ os.environ["MASTER_PORT"] = self.master_port = self.execution.get_master_port()
164
+
165
+ return {"RANK": self.rank, "WORLD_SIZE": self.world_size, "LOCAL_RANK": self.local_rank}
166
+
167
+ def start_process_group(self) -> None:
168
+ """
169
+ Initialize the PyTorch distributed process group for multi-process training.
170
+
171
+ If the world size is greater than 1, this method initializes the process group with the specified
172
+ backend, rank, and world size. For the master process (rank 0), it logs configuration details such
173
+ as the total number of nodes, processes, master address, and master port. Finally, it synchronizes
174
+ all processes with a barrier.
175
+ """
176
+ if self.world_size and self.world_size > 1:
177
+ dist.init_process_group(
178
+ backend=self.execution.backend, rank=self.rank, world_size=self.world_size
179
+ )
180
+ if self.rank == 0:
181
+ logger.info("Starting Distributed Process Group")
182
+ logger.info("=============================")
183
+ logger.info(" Total Nodes : %d", int(os.environ.get("SLURM_NNODES", 1)))
184
+ logger.info(" Total Processes : %d", self.world_size)
185
+ logger.info(" Master Address : %s", self.master_addr)
186
+ logger.info(" Master Port : %s", self.master_port)
187
+ dist.barrier()
188
+
189
+ def finalize(self) -> None:
190
+ """
191
+ Clean up the PyTorch distributed process group after training is complete.
192
+
193
+ If the distributed process group has been initialized, it is destroyed. Additionally, the master
194
+ process (rank 0) logs that the process group is being killed.
195
+ """
196
+ if dist.is_initialized():
197
+ dist.destroy_process_group()
198
+ if self.rank == 0:
199
+ logger.info("Killing Distributed Process Group")
200
+
201
+ def _log_warnings(self) -> None:
202
+
203
+ if self.execution_type == ExecutionType.TORCHRUN:
204
+ logger.info(
205
+ f"Process was launched using `torchrun`, "
206
+ "processes spawned will be set based on `torchrun` setup."
207
+ )
208
+ logger.info(f"User sepcifed `nprocs`={self._config_nprocs}")
209
+ logger.info(f"Total processes spawned={self.world_size}")
@@ -0,0 +1,421 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import random
5
+ import socket
6
+ import subprocess
7
+ from abc import ABC, abstractmethod
8
+ from logging import getLogger
9
+
10
+ import torch
11
+
12
+ from qadence.types import ExecutionType
13
+
14
+ logger = getLogger("ml_tools")
15
+
16
+
17
+ class BaseExecution(ABC):
18
+ """
19
+ Class to set up and manage execution of the processes in different environments.
20
+
21
+ This is a abstract base class, and inherited classes should implement methods to get rank,
22
+ local rank, world size, master addr and master port.
23
+
24
+ It configures environment variables required for distributed training (such as rank, world size,
25
+ master address, and master port), and sets the appropriate computation device (CPU or GPU).
26
+
27
+ Attributes:
28
+ backend (str): The backend used for distributed communication (e.g., "nccl", "gloo").
29
+ It should be one of the backends supported by torch.distributed
30
+ compute_setup (str): Desired computation device setup.
31
+ log_setup (str): Desired logging device setup.
32
+ device (str | None): Computation device, e.g., "cpu" or "cuda:<local_rank>".
33
+ log_device (str | None): Logging device, e.g., "cpu" or "cuda:<local_rank>".
34
+ dtype (torch.dtype | None): Data type for controlling numerical precision (e.g., torch.float32).
35
+ data_dtype (torch.dtype | None): Data type for controlling datasets precision (e.g., torch.float16).
36
+ node_rank (int): Rank of the node on the cluster setup.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ compute_setup: str,
42
+ log_setup: str,
43
+ backend: str,
44
+ dtype: torch.dtype | None = None,
45
+ ) -> None:
46
+ """
47
+ Initialize the BaseExecution.
48
+
49
+ Args:
50
+ compute_setup (str): Compute device setup; options are "auto" (default), "gpu", or "cpu".
51
+ - "auto": Uses GPU if available, otherwise CPU.
52
+ - "gpu": Forces GPU usage, raising an error if no CUDA device is available.
53
+ - "cpu": Forces CPU usage.
54
+ log_setup (str): Logging device setup; options are "auto", "cpu" (default).
55
+ - "auto": Uses same device to log as used for computation.
56
+ - "cpu": Forces CPU logging.
57
+ backend (str): Backend to use for distributed communication (default: "nccl").
58
+ dtype (torch.dtype | None): Data type for controlling numerical precision. Default is None.
59
+ """
60
+ self.compute_setup = compute_setup
61
+ self.log_setup = log_setup
62
+ self.backend = backend
63
+ self.compute: str
64
+ self.device: str
65
+ self.log_device: str
66
+
67
+ self.dtype: torch.dtype | None = dtype
68
+ self.data_dtype: torch.dtype | None = None
69
+ if self.dtype:
70
+ if self.dtype == torch.complex128:
71
+ self.data_dtype = torch.float64
72
+ elif self.dtype == torch.complex64:
73
+ self.data_dtype = torch.float32
74
+ elif self.dtype == torch.complex32:
75
+ self.data_dtype = torch.float16
76
+ else:
77
+ self.data_dtype = self.dtype
78
+
79
+ self._set_cluster_variables()
80
+ self._set_compute()
81
+ # We assign an available host/port in the __init__ so that spawnned subprocesses can use it.
82
+ self._available_host: str = "localhost"
83
+ self._available_port: str = self._find_available_port()
84
+ self.device = "cpu" # set the initial device to cpu, it will change to correct device when process runs.
85
+ self.log_device = "cpu"
86
+
87
+ @abstractmethod
88
+ def get_rank(self, process_rank: int) -> int:
89
+ """Retrieve the global rank of the current process.
90
+
91
+ Implemented in the inherited class.
92
+
93
+ Args:
94
+ process_rank (int): The rank to assign to the process.
95
+
96
+ Returns:
97
+ int: The global rank of the process.
98
+ """
99
+ pass
100
+
101
+ @abstractmethod
102
+ def get_local_rank(self, process_rank: int) -> int | None:
103
+ """
104
+ Retrieve the local rank of the current process.
105
+
106
+ Args:
107
+ process_rank (int): The rank to assign to the process.
108
+
109
+ Returns:
110
+ int | None: The local rank. Is None for cpu setups.
111
+ """
112
+ pass
113
+
114
+ @abstractmethod
115
+ def get_world_size(self, process_rank: int, nprocs: int) -> int:
116
+ """Retrieve the total number of processes in the distributed training job.
117
+
118
+ Implemented in the inherited class.
119
+
120
+ Args:
121
+ process_rank (int): The rank to assign to the process.
122
+ nprocs (int): Number of processes to launch.
123
+
124
+ Returns:
125
+ int: The total number of processes (world size).
126
+ """
127
+ pass
128
+
129
+ @abstractmethod
130
+ def get_master_addr(self) -> str:
131
+ """Return the master node address.
132
+
133
+ Implemented in the inherited class.
134
+
135
+ Returns:
136
+ str: The master address.
137
+ """
138
+ pass
139
+
140
+ @abstractmethod
141
+ def get_master_port(self) -> str:
142
+ """Return the master node port.
143
+
144
+ Implemented in the inherited class.
145
+
146
+ Returns:
147
+ str: The master port.
148
+ """
149
+ pass
150
+
151
+ def _set_cluster_variables(self) -> None:
152
+ """
153
+ Sets the initial default variables for the cluster.
154
+
155
+ For now it only supports SLURM Cluster, and should be extended to others
156
+ when needed.
157
+ """
158
+ self.job_id = str(os.environ.get("SLURM_JOB_ID", "Unknown"))
159
+ self.num_nodes = int(os.environ.get("SLURM_JOB_NUM_NODES", 1))
160
+ self.node_list = os.environ.get("SLURM_JOB_NODELIST", "Unknown")
161
+ self.node_rank = int(os.environ.get("SLURM_NODEID", 0))
162
+ self.node_name = os.environ.get("SLURMD_NODENAME", "Unknown")
163
+ # currently we only do this for GPUs
164
+ # TODO: extend support to TPUs, CPUs, etc.
165
+ self.cores_per_node: int = (
166
+ int(torch.cuda.device_count()) if torch.cuda.is_available() else 1
167
+ )
168
+
169
+ def _set_compute(self) -> None:
170
+ """
171
+ Set the compute (cpu or gpu) for the current process based on the compute setup.
172
+
173
+ The method checks for CUDA availability and selects the appropriate device.
174
+ If compute_setup is set to "gpu" but CUDA is unavailable, a RuntimeError is raised.
175
+
176
+ Raises:
177
+ RuntimeError: If compute_setup is "gpu" but no CUDA devices are available.
178
+ """
179
+ if self.compute_setup == "gpu":
180
+ if not torch.cuda.is_available():
181
+ raise RuntimeError("Compute setup set to 'gpu' but no CUDA devices are available.")
182
+ self.compute = "gpu"
183
+ elif self.compute_setup == "auto":
184
+ self.compute = "gpu" if torch.cuda.is_available() else "cpu"
185
+ else:
186
+ self.compute = "cpu"
187
+
188
+ def set_device(self, local_rank: int | None) -> None:
189
+ """Set the computation device (cpu or cuda:<n>) for the current process based on the compute setup."""
190
+ if self.compute == "gpu":
191
+ self.device = f"cuda:{local_rank}"
192
+ torch.cuda.set_device(local_rank)
193
+ else:
194
+ self.device = "cpu"
195
+
196
+ if self.log_setup == "auto":
197
+ self.log_device = self.device
198
+ elif self.log_setup == "cpu":
199
+ self.log_device = "cpu"
200
+ else:
201
+ raise ValueError(f"log_setup {self.log_setup} not supported. Choose 'auto' or 'cpu'.")
202
+
203
+ def _find_available_port(
204
+ self, start: int = 1024, end: int = 65535, max_attempts: int = 100
205
+ ) -> str:
206
+ """
207
+ Find an available port by trying random ports in the specified range.
208
+
209
+ Args:
210
+ param start: Start of port range (default: 1024)
211
+ param end: End of port range (default: 65535)
212
+ param max_attempts: Maximum attempts before giving up (default: 100)
213
+ Return:
214
+ int : Available port number. if no port is found, raises runtime error.
215
+ """
216
+ attempts = 0
217
+ while attempts < max_attempts:
218
+ port = random.randint(start, end)
219
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
220
+ try:
221
+ s.bind(("0.0.0.0", port))
222
+ return str(port)
223
+ except OSError:
224
+ attempts += 1
225
+ raise RuntimeError("Available port not found")
226
+
227
+
228
+ class DefaultExecution(BaseExecution):
229
+ """
230
+ Default execution for SLURM-like environments.
231
+
232
+ Uses SLURM-specific environment variables when available.
233
+ """
234
+
235
+ def get_rank(self, process_rank: int) -> int:
236
+ """
237
+ Retrieve the global rank of the current process.
238
+
239
+ Args:
240
+ process_rank (int): The rank to assign to the process.
241
+
242
+ Returns:
243
+ int: The global rank of the process.
244
+ Priority is given to the "RANK" environment variable; if not found, in a SLURM environment,
245
+ the "SLURM_PROCID" is used. Defaults to 0.
246
+ """
247
+ if self.compute == "cpu":
248
+ return int(process_rank)
249
+ local_rank = self.get_local_rank(process_rank)
250
+ multi_node_index = self.node_rank * self.cores_per_node
251
+ return int(multi_node_index + local_rank) if local_rank else int(multi_node_index)
252
+
253
+ def get_local_rank(self, process_rank: int) -> int | None:
254
+ """
255
+ Retrieve the local rank of the current process.
256
+
257
+ Args:
258
+ process_rank (int): The rank to assign to the process.
259
+
260
+ Returns:
261
+ int | None: The local rank. Uses the "LOCAL_RANK" environment variable if set.
262
+ """
263
+ if self.compute == "cpu":
264
+ return None
265
+ return int(os.environ.get("LOCAL_RANK", process_rank))
266
+
267
+ def get_world_size(self, process_rank: int, nprocs: int) -> int:
268
+ """
269
+ Retrieve the total number of processes in the distributed training job.
270
+
271
+ Args:
272
+ process_rank (int | None): The rank to assign to the process.
273
+ nprocs (int): Number of processes to launch.
274
+
275
+ Returns:
276
+ int: The total number of processes (world size).
277
+ Uses the "WORLD_SIZE" environment variable if set.
278
+ """
279
+ return int(os.environ.get("WORLD_SIZE", nprocs))
280
+
281
+ def get_master_addr(self) -> str:
282
+ """
283
+ Determine the master node's address for distributed training.
284
+
285
+ Returns:
286
+ str: The master address. If the environment variable "MASTER_ADDR" is set, that value is used.
287
+ In a SLURM environment, the first hostname from the SLURM node list is used.
288
+ Defaults to "localhost" if none is found.
289
+ """
290
+ if "MASTER_ADDR" in os.environ:
291
+ return os.environ["MASTER_ADDR"]
292
+ try:
293
+ output = subprocess.check_output(
294
+ ["scontrol", "show", "hostnames", os.environ["SLURM_NODELIST"]]
295
+ )
296
+ return output.splitlines()[0].decode("utf-8").strip()
297
+ except Exception:
298
+ return self._available_host
299
+
300
+ def get_master_port(self) -> str:
301
+ """
302
+ Determine the master node's port for distributed training.
303
+
304
+ Returns:
305
+ str: The master port. Uses the environment variable "MASTER_PORT" if set.
306
+ In a SLURM environment, computes a port based on the SLURM_JOB_ID.
307
+ Defaults to a specific port if not set or on error.
308
+ """
309
+ if "MASTER_PORT" in os.environ:
310
+ return os.environ["MASTER_PORT"]
311
+ if self.job_id == "Unknown":
312
+ return str(self._available_port)
313
+ else:
314
+ # This is needed for Multi-node Slurm clusters
315
+ return str(int(12000 + int(self.job_id) % 5000))
316
+
317
+
318
+ class TorchRunexecution(BaseExecution):
319
+ """
320
+ Execution for torchrun or when using TORCHELASTIC.
321
+
322
+ Expects that environment variables like RANK, LOCAL_RANK, WORLD_SIZE,
323
+ MASTER_ADDR, and MASTER_PORT are already set.
324
+ """
325
+
326
+ def get_rank(self, process_rank: int) -> int:
327
+ """
328
+ Retrieve the global rank of the current process set by torchrun.
329
+
330
+ Args:
331
+ process_rank (int): The rank to assign to the process.
332
+
333
+ Returns:
334
+ int: The global rank of the process.
335
+ """
336
+ if self.compute == "cpu":
337
+ return int(process_rank)
338
+ return int(os.environ.get("RANK", process_rank))
339
+
340
+ def get_local_rank(self, process_rank: int) -> int | None:
341
+ """
342
+ Retrieve the local rank of the current process (its index on the local node).
343
+
344
+ Args:
345
+ process_rank (int): The rank to assign to the process.
346
+
347
+ Returns:
348
+ int | None: The local rank. Uses the "LOCAL_RANK" environment variable if set.
349
+ """
350
+ if self.compute == "cpu":
351
+ return None
352
+ return int(os.environ.get("LOCAL_RANK", process_rank))
353
+
354
+ def get_world_size(self, process_rank: int, nprocs: int) -> int:
355
+ """
356
+ Retrieve the total number of processes in the distributed training job.
357
+
358
+ Args:
359
+ process_rank (int): The rank to assign to the process.
360
+ nprocs (int): Number of processes to launch.
361
+
362
+ Returns:
363
+ int: The total number of processes (world size).
364
+ Uses the "WORLD_SIZE" environment variable if set.
365
+ """
366
+ return int(os.environ.get("WORLD_SIZE", nprocs))
367
+
368
+ def get_master_addr(self) -> str:
369
+ """
370
+ Determine the master node's address for distributed training set by torchrun.
371
+
372
+ Returns:
373
+ str: The master address.
374
+ """
375
+ return os.environ.get("MASTER_ADDR", "localhost")
376
+
377
+ def get_master_port(self) -> str:
378
+ """
379
+ Determine the master node's port for distributed training set by torchrun.
380
+
381
+ Returns:
382
+ str: The master port.
383
+ """
384
+ return os.environ.get("MASTER_PORT", "12364")
385
+
386
+
387
+ def detect_execution(
388
+ compute_setup: str,
389
+ log_setup: str,
390
+ backend: str,
391
+ dtype: torch.dtype | None = None,
392
+ ) -> tuple[BaseExecution, ExecutionType]:
393
+ """
394
+ Detect and return the appropriate execution instance.
395
+
396
+ If no explicit execution is provided, auto-detect using environment variables.
397
+
398
+ Args:
399
+ compute_setup (str): Compute device setup; options are "auto" (default), "gpu", or "cpu".
400
+ - "auto": Uses GPU if available, otherwise CPU.
401
+ - "gpu": Forces GPU usage, raising an error if no CUDA device is available.
402
+ - "cpu": Forces CPU usage.
403
+ log_setup (str): Logging device setup; options are "auto", "cpu" (default).
404
+ - "auto": Uses same device to log as used for computation.
405
+ - "cpu": Forces CPU logging.
406
+ backend (str): Backend to use for distributed communication (default: "nccl").
407
+ dtype (torch.dtype | None): Data type for controlling numerical precision. Default is None.
408
+
409
+ Returns:
410
+ tuple[BaseExecution, ExecutionType]: tuple of
411
+ - Instance of the appropriate execution used for launching the code.
412
+ - Appropriate ExecutionType
413
+ """
414
+ execution = (
415
+ ExecutionType.TORCHRUN if "TORCHELASTIC_RUN_ID" in os.environ else ExecutionType.DEFAULT
416
+ )
417
+
418
+ if execution == ExecutionType.TORCHRUN:
419
+ return TorchRunexecution(compute_setup, log_setup, backend, dtype), execution
420
+ else:
421
+ return DefaultExecution(compute_setup, log_setup, backend, dtype), execution