sagemaker-train 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. sagemaker/train/__init__.py +34 -0
  2. sagemaker/train/configs.py +31 -0
  3. sagemaker/train/constants.py +41 -0
  4. sagemaker/train/container_drivers/__init__.py +14 -0
  5. sagemaker/train/container_drivers/common/__init__.py +14 -0
  6. sagemaker/train/container_drivers/common/utils.py +205 -0
  7. sagemaker/train/container_drivers/distributed_drivers/__init__.py +14 -0
  8. sagemaker/train/container_drivers/distributed_drivers/basic_script_driver.py +81 -0
  9. sagemaker/train/container_drivers/distributed_drivers/mpi_driver.py +105 -0
  10. sagemaker/train/container_drivers/distributed_drivers/mpi_utils.py +302 -0
  11. sagemaker/train/container_drivers/distributed_drivers/torchrun_driver.py +129 -0
  12. sagemaker/train/container_drivers/scripts/__init__.py +14 -0
  13. sagemaker/train/container_drivers/scripts/environment.py +305 -0
  14. sagemaker/train/defaults.py +630 -0
  15. sagemaker/train/distributed.py +181 -0
  16. sagemaker/train/local/__init__.py +0 -0
  17. sagemaker/train/local/data.py +405 -0
  18. sagemaker/train/local/entities.py +97 -0
  19. sagemaker/train/local/local_container.py +613 -0
  20. sagemaker/train/model_trainer.py +1471 -0
  21. sagemaker/train/modules/model_trainer.py +1030 -0
  22. sagemaker/train/remote_function/__init__.py +34 -0
  23. sagemaker/train/remote_function/checkpoint_location.py +47 -0
  24. sagemaker/train/remote_function/client.py +30 -0
  25. sagemaker/train/remote_function/core/__init__.py +27 -0
  26. sagemaker/train/remote_function/core/_custom_dispatch_table.py +56 -0
  27. sagemaker/train/remote_function/core/pipeline_variables.py +30 -0
  28. sagemaker/train/remote_function/core/serialization.py +30 -0
  29. sagemaker/train/remote_function/core/stored_function.py +30 -0
  30. sagemaker/train/remote_function/custom_file_filter.py +128 -0
  31. sagemaker/train/remote_function/errors.py +30 -0
  32. sagemaker/train/remote_function/invoke_function.py +172 -0
  33. sagemaker/train/remote_function/job.py +30 -0
  34. sagemaker/train/remote_function/logging_config.py +38 -0
  35. sagemaker/train/remote_function/runtime_environment/__init__.py +14 -0
  36. sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py +602 -0
  37. sagemaker/train/remote_function/runtime_environment/mpi_utils_remote.py +252 -0
  38. sagemaker/train/remote_function/runtime_environment/runtime_environment_manager.py +467 -0
  39. sagemaker/train/remote_function/runtime_environment/spark_app.py +18 -0
  40. sagemaker/train/remote_function/spark_config.py +30 -0
  41. sagemaker/train/sm_recipes/__init__.py +0 -0
  42. sagemaker/train/sm_recipes/training_recipes.json +17 -0
  43. sagemaker/train/sm_recipes/utils.py +332 -0
  44. sagemaker/train/templates.py +102 -0
  45. sagemaker/train/tuner.py +1418 -0
  46. sagemaker/train/types.py +19 -0
  47. sagemaker/train/utils.py +236 -0
  48. sagemaker_train-1.0.dist-info/METADATA +193 -0
  49. sagemaker_train-1.0.dist-info/RECORD +52 -0
  50. sagemaker_train-1.0.dist-info/WHEEL +5 -0
  51. sagemaker_train-1.0.dist-info/licenses/LICENSE +201 -0
  52. sagemaker_train-1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,34 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """SageMaker Python SDK Train Module."""
14
+ from __future__ import absolute_import
15
+
16
+ # Lazy imports to avoid circular dependencies
17
+ # Session and get_execution_role are available from sagemaker.core.helper.session_helper
18
+ # Import them directly from there if needed, or use lazy import pattern
19
+
20
+ def __getattr__(name):
21
+ """Lazy import to avoid circular dependencies."""
22
+ if name == "Session":
23
+ from sagemaker.core.helper.session_helper import Session
24
+ return Session
25
+ elif name == "get_execution_role":
26
+ from sagemaker.core.helper.session_helper import get_execution_role
27
+ return get_execution_role
28
+ elif name == "ModelTrainer":
29
+ from sagemaker.train.model_trainer import ModelTrainer
30
+ return ModelTrainer
31
+ elif name == "logger":
32
+ from sagemaker.core.utils.utils import logger
33
+ return logger
34
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
@@ -0,0 +1,31 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """
14
+ DEPRECATED: This module has been moved to sagemaker.core.training.configs
15
+
16
+ This is a backward compatibility shim. Please update your imports to:
17
+ from sagemaker.core.training.configs import ...
18
+ """
19
+ from __future__ import absolute_import
20
+
21
+ import warnings
22
+
23
+ # Backward compatibility: re-export from core
24
+ from sagemaker.core.training.configs import * # noqa: F401, F403
25
+
26
+ warnings.warn(
27
+ "sagemaker.train.configs has been moved to sagemaker.core.training.configs. "
28
+ "Please update your imports. This shim will be removed in a future version.",
29
+ DeprecationWarning,
30
+ stacklevel=2
31
+ )
@@ -0,0 +1,41 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """
14
+ DEPRECATED: This module has been moved to sagemaker.core.training.constants
15
+
16
+ This is a backward compatibility shim. Please update your imports to:
17
+ from sagemaker.core.training.constants import ...
18
+ """
19
+ from __future__ import absolute_import
20
+
21
+ import os
22
+
23
+ SM_CODE = "code"
24
+ SM_CODE_CONTAINER_PATH = "/opt/ml/input/data/code"
25
+
26
+ SM_DRIVERS = "sm_drivers"
27
+ SM_DRIVERS_CONTAINER_PATH = "/opt/ml/input/data/sm_drivers"
28
+ SM_DRIVERS_LOCAL_PATH = os.path.join(
29
+ os.path.dirname(os.path.abspath(__file__)), "container_drivers"
30
+ )
31
+
32
+ SOURCE_CODE_JSON = "sourcecode.json"
33
+ DISTRIBUTED_JSON = "distributed.json"
34
+ TRAIN_SCRIPT = "sm_train.sh"
35
+
36
+ DEFAULT_CONTAINER_ENTRYPOINT = ["/bin/bash"]
37
+ DEFAULT_CONTAINER_ARGUMENTS = [
38
+ "-c",
39
+ f"chmod +x {SM_DRIVERS_CONTAINER_PATH}/{TRAIN_SCRIPT} "
40
+ + f"&& {SM_DRIVERS_CONTAINER_PATH}/{TRAIN_SCRIPT}",
41
+ ]
@@ -0,0 +1,14 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Sagemaker modules container drivers directory."""
14
+ from __future__ import absolute_import
@@ -0,0 +1,14 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Sagemaker modules container drivers - common directory."""
14
+ from __future__ import absolute_import
@@ -0,0 +1,205 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module provides utility functions for the container drivers."""
14
+ from __future__ import absolute_import
15
+
16
+ import os
17
+ import logging
18
+ import sys
19
+ import subprocess
20
+ import traceback
21
+ import json
22
+
23
+ from typing import List, Dict, Any, Tuple, IO, Optional
24
+
25
+ # Initialize logger
26
+ SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20)
27
+ logger = logging.getLogger(__name__)
28
+ console_handler = logging.StreamHandler(sys.stdout)
29
+ logger.addHandler(console_handler)
30
+ logger.setLevel(int(SM_LOG_LEVEL))
31
+
32
+ FAILURE_FILE = "/opt/ml/output/failure"
33
+ DEFAULT_FAILURE_MESSAGE = """
34
+ Training Execution failed.
35
+ For more details, see CloudWatch logs at 'aws/sagemaker/TrainingJobs'.
36
+ TrainingJob - {training_job_name}
37
+ """
38
+
39
+ USER_CODE_PATH = "/opt/ml/input/data/code"
40
+ SOURCE_CODE_JSON = "/opt/ml/input/data/sm_drivers/sourcecode.json"
41
+ DISTRIBUTED_JSON = "/opt/ml/input/data/sm_drivers/distributed.json"
42
+
43
+ HYPERPARAMETERS_JSON = "/opt/ml/input/config/hyperparameters.json"
44
+
45
+ SM_EFA_NCCL_INSTANCES = [
46
+ "ml.g4dn.8xlarge",
47
+ "ml.g4dn.12xlarge",
48
+ "ml.g5.48xlarge",
49
+ "ml.p3dn.24xlarge",
50
+ "ml.p4d.24xlarge",
51
+ "ml.p4de.24xlarge",
52
+ "ml.p5.48xlarge",
53
+ "ml.trn1.32xlarge",
54
+ ]
55
+
56
+ SM_EFA_RDMA_INSTANCES = [
57
+ "ml.p4d.24xlarge",
58
+ "ml.p4de.24xlarge",
59
+ "ml.trn1.32xlarge",
60
+ ]
61
+
62
+
63
+ def write_failure_file(message: Optional[str] = None):
64
+ """Write a failure file with the message."""
65
+ if message is None:
66
+ message = DEFAULT_FAILURE_MESSAGE.format(training_job_name=os.environ["TRAINING_JOB_NAME"])
67
+ if not os.path.exists(FAILURE_FILE):
68
+ with open(FAILURE_FILE, "w") as f:
69
+ f.write(message)
70
+
71
+
72
+ def read_source_code_json(source_code_json: Dict[str, Any] = SOURCE_CODE_JSON):
73
+ """Read the source code config json file."""
74
+ try:
75
+ with open(source_code_json, "r") as f:
76
+ source_code_dict = json.load(f) or {}
77
+ except FileNotFoundError:
78
+ source_code_dict = {}
79
+ return source_code_dict
80
+
81
+
82
+ def read_distributed_json(distributed_json: Dict[str, Any] = DISTRIBUTED_JSON):
83
+ """Read the distribution config json file."""
84
+ try:
85
+ with open(distributed_json, "r") as f:
86
+ distributed_dict = json.load(f) or {}
87
+ except FileNotFoundError:
88
+ distributed_dict = {}
89
+ return distributed_dict
90
+
91
+
92
+ def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAMETERS_JSON):
93
+ """Read the hyperparameters config json file."""
94
+ try:
95
+ with open(hyperparameters_json, "r") as f:
96
+ hyperparameters_dict = json.load(f) or {}
97
+ except FileNotFoundError:
98
+ hyperparameters_dict = {}
99
+ return hyperparameters_dict
100
+
101
+
102
+ def get_process_count(process_count: Optional[int] = None) -> int:
103
+ """Get the number of processes to run on each node in the training job."""
104
+ return (
105
+ process_count
106
+ or int(os.environ.get("SM_NUM_GPUS", 0))
107
+ or int(os.environ.get("SM_NUM_NEURONS", 0))
108
+ or 1
109
+ )
110
+
111
+
112
+ def hyperparameters_to_cli_args(hyperparameters: Dict[str, Any]) -> List[str]:
113
+ """Convert the hyperparameters to CLI arguments."""
114
+ cli_args = []
115
+ for key, value in hyperparameters.items():
116
+ value = safe_deserialize(value)
117
+ cli_args.extend([f"--{key}", safe_serialize(value)])
118
+
119
+ return cli_args
120
+
121
+
122
+ def safe_deserialize(data: Any) -> Any:
123
+ """Safely deserialize data from a JSON string.
124
+
125
+ This function handles the following cases:
126
+ 1. If `data` is not a string, it returns the input as-is.
127
+ 2. If `data` is a JSON-encoded string, it attempts to deserialize it using `json.loads()`.
128
+ 3. If `data` is a string but cannot be decoded as JSON, it returns the original string.
129
+
130
+ Returns:
131
+ Any: The deserialized data, or the original input if it cannot be JSON-decoded.
132
+ """
133
+ if not isinstance(data, str):
134
+ return data
135
+
136
+ try:
137
+ return json.loads(data)
138
+ except json.JSONDecodeError:
139
+ return data
140
+
141
+
142
+ def safe_serialize(data):
143
+ """Serialize the data without wrapping strings in quotes.
144
+
145
+ This function handles the following cases:
146
+ 1. If `data` is a string, it returns the string as-is without wrapping in quotes.
147
+ 2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns
148
+ the JSON-encoded string using `json.dumps()`.
149
+ 3. If `data` cannot be serialized (e.g., a custom object), it returns the string
150
+ representation of the data using `str(data)`.
151
+
152
+ Args:
153
+ data (Any): The data to serialize.
154
+
155
+ Returns:
156
+ str: The serialized JSON-compatible string or the string representation of the input.
157
+ """
158
+ if isinstance(data, str):
159
+ return data
160
+ try:
161
+ return json.dumps(data)
162
+ except TypeError:
163
+ return str(data)
164
+
165
+
166
+ def get_python_executable() -> str:
167
+ """Get the python executable path."""
168
+ return sys.executable
169
+
170
+
171
+ def log_subprocess_output(pipe: IO[bytes]):
172
+ """Log the output from the subprocess."""
173
+ for line in iter(pipe.readline, b""):
174
+ logger.info(line.decode("utf-8").strip())
175
+
176
+
177
+ def execute_commands(commands: List[str]) -> Tuple[int, str]:
178
+ """Execute the provided commands and return exit code with failure traceback if any."""
179
+ try:
180
+ process = subprocess.Popen(
181
+ commands,
182
+ stdout=subprocess.PIPE,
183
+ stderr=subprocess.STDOUT,
184
+ )
185
+ with process.stdout:
186
+ log_subprocess_output(process.stdout)
187
+ exitcode = process.wait()
188
+ if exitcode != 0:
189
+ raise subprocess.CalledProcessError(exitcode, commands)
190
+ return exitcode, ""
191
+ except subprocess.CalledProcessError as e:
192
+ # Capture the traceback in case of failure
193
+ error_traceback = traceback.format_exc()
194
+ print(f"Command failed with exit code {e.returncode}. Traceback: {error_traceback}")
195
+ return e.returncode, error_traceback
196
+
197
+
198
+ def is_worker_node() -> bool:
199
+ """Check if the current node is a worker node."""
200
+ return os.environ.get("SM_CURRENT_HOST") != os.environ.get("SM_MASTER_ADDR")
201
+
202
+
203
+ def is_master_node() -> bool:
204
+ """Check if the current node is the master node."""
205
+ return os.environ.get("SM_CURRENT_HOST") == os.environ.get("SM_MASTER_ADDR")
@@ -0,0 +1,14 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """Sagemaker modules container drivers - drivers directory."""
14
+ from __future__ import absolute_import
@@ -0,0 +1,81 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module is the entry point for the Basic Script Driver."""
14
+ from __future__ import absolute_import
15
+
16
+ import os
17
+ import sys
18
+ import json
19
+ import shlex
20
+
21
+ from pathlib import Path
22
+ from typing import List
23
+
24
+ sys.path.insert(0, str(Path(__file__).parent.parent))
25
+
26
+ from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611
27
+ logger,
28
+ get_python_executable,
29
+ write_failure_file,
30
+ hyperparameters_to_cli_args,
31
+ execute_commands,
32
+ )
33
+
34
+
35
+ def create_commands() -> List[str]:
36
+ """Create the commands to execute."""
37
+ entry_script = os.environ["SM_ENTRY_SCRIPT"]
38
+ hyperparameters = json.loads(os.environ["SM_HPS"])
39
+ python_executable = get_python_executable()
40
+
41
+ args = hyperparameters_to_cli_args(hyperparameters)
42
+ if entry_script.endswith(".py"):
43
+ commands = [python_executable, entry_script]
44
+ commands += args
45
+ elif entry_script.endswith(".sh"):
46
+ args_str = " ".join(shlex.quote(arg) for arg in args)
47
+ commands = [
48
+ "/bin/sh",
49
+ "-c",
50
+ f"chmod +x {entry_script} && ./{entry_script} {args_str}",
51
+ ]
52
+ else:
53
+ raise ValueError(
54
+ f"Unsupported entry script type: {entry_script}. Only .py and .sh are supported."
55
+ )
56
+ return commands
57
+
58
+
59
+ def main():
60
+ """Main function for the Basic Script Driver.
61
+
62
+ This function is the entry point for the Basic Script Driver.
63
+
64
+ Execution Lifecycle:
65
+ 1. Read the source code and hyperparameters JSON files.
66
+ 2. Set hyperparameters as command line arguments.
67
+ 3. Create the commands to execute.
68
+ 4. Execute the commands.
69
+ """
70
+
71
+ cmd = create_commands()
72
+
73
+ logger.info(f"Executing command: {' '.join(cmd)}")
74
+ exit_code, traceback = execute_commands(cmd)
75
+ if exit_code != 0:
76
+ write_failure_file(traceback)
77
+ sys.exit(exit_code)
78
+
79
+
80
+ if __name__ == "__main__":
81
+ main()
@@ -0,0 +1,105 @@
1
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ """This module is the entry point for the MPI driver script."""
14
+ from __future__ import absolute_import
15
+
16
+ import os
17
+ import sys
18
+ import json
19
+
20
+ from sagemaker.train.container_drivers.distributed_drivers.mpi_utils import (
21
+ start_sshd_daemon,
22
+ bootstrap_master_node,
23
+ bootstrap_worker_node,
24
+ get_mpirun_command,
25
+ write_status_file_to_workers,
26
+ write_env_vars_to_file,
27
+ )
28
+
29
+
30
+ from sagemaker.train.container_drivers.common.utils import (
31
+ logger,
32
+ hyperparameters_to_cli_args,
33
+ get_process_count,
34
+ execute_commands,
35
+ write_failure_file,
36
+ )
37
+
38
+
39
+ def main():
40
+ """Main function for the MPI driver script.
41
+
42
+ The MPI Dirver is responsible for setting up the MPI environment,
43
+ generating the correct mpi commands, and launching the MPI job.
44
+
45
+ Execution Lifecycle:
46
+ 1. Setup General Environment Variables at /etc/environment
47
+ 2. Start SSHD Daemon
48
+ 3. Bootstrap Worker Nodes
49
+ a. Wait to establish connection with Master Node
50
+ b. Wait for Master Node to write status file
51
+ 4. Bootstrap Master Node
52
+ a. Wait to establish connection with Worker Nodes
53
+ b. Generate MPI Command
54
+ c. Execute MPI Command with user script provided in `entry_script`
55
+ d. Write status file to Worker Nodes
56
+ 5. Exit
57
+
58
+ """
59
+ entry_script = os.environ["SM_ENTRY_SCRIPT"]
60
+ distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"])
61
+ hyperparameters = json.loads(os.environ["SM_HPS"])
62
+
63
+ sm_current_host = os.environ["SM_CURRENT_HOST"]
64
+ sm_hosts = json.loads(os.environ["SM_HOSTS"])
65
+ sm_master_addr = os.environ["SM_MASTER_ADDR"]
66
+
67
+ write_env_vars_to_file()
68
+ start_sshd_daemon()
69
+
70
+ if sm_current_host != sm_master_addr:
71
+ bootstrap_worker_node(sm_master_addr)
72
+ else:
73
+ worker_hosts = [host for host in sm_hosts if host != sm_master_addr]
74
+ bootstrap_master_node(worker_hosts)
75
+
76
+ host_list = json.loads(os.environ["SM_HOSTS"])
77
+ host_count = int(os.environ["SM_HOST_COUNT"])
78
+ process_count = int(distributed_config["process_count_per_node"] or 0)
79
+ process_count = get_process_count(process_count)
80
+
81
+ if process_count > 1:
82
+ host_list = ["{}:{}".format(host, process_count) for host in host_list]
83
+
84
+ mpi_command = get_mpirun_command(
85
+ host_count=host_count,
86
+ host_list=host_list,
87
+ num_processes=process_count,
88
+ additional_options=distributed_config["mpi_additional_options"] or [],
89
+ entry_script_path=entry_script,
90
+ )
91
+
92
+ args = hyperparameters_to_cli_args(hyperparameters)
93
+ mpi_command += args
94
+
95
+ logger.info(f"Executing command: {' '.join(mpi_command)}")
96
+ exit_code, error_traceback = execute_commands(mpi_command)
97
+ write_status_file_to_workers(worker_hosts)
98
+
99
+ if exit_code != 0:
100
+ write_failure_file(error_traceback)
101
+ sys.exit(exit_code)
102
+
103
+
104
+ if __name__ == "__main__":
105
+ main()