torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchft/__init__.py +34 -0
- torchft/_test/diloco_trainer.py +287 -0
- torchft/_test/managed_work_test.py +320 -0
- torchft/_test_utils.py +111 -0
- torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
- torchft/_torchft.pyi +116 -0
- torchft/checkpointing/__init__.py +20 -0
- torchft/checkpointing/_rwlock.py +136 -0
- torchft/checkpointing/_serialization.py +39 -0
- torchft/checkpointing/http_transport.py +299 -0
- torchft/checkpointing/http_transport_bench.py +61 -0
- torchft/checkpointing/http_transport_test.py +146 -0
- torchft/checkpointing/pg_transport.py +306 -0
- torchft/checkpointing/pg_transport_bench.py +99 -0
- torchft/checkpointing/pg_transport_test.py +101 -0
- torchft/checkpointing/rwlock_test.py +58 -0
- torchft/checkpointing/transport.py +68 -0
- torchft/checkpointing/transport_test.py +161 -0
- torchft/collectives.py +415 -0
- torchft/collectives_test.py +212 -0
- torchft/coordination.py +39 -0
- torchft/coordination_test.py +29 -0
- torchft/data.py +77 -0
- torchft/data_test.py +39 -0
- torchft/ddp.py +105 -0
- torchft/ddp_test.py +68 -0
- torchft/diloco_regression_test.py +644 -0
- torchft/examples/slurm/README.md +34 -0
- torchft/examples/slurm/punisher.py +95 -0
- torchft/examples/slurm/runner.py +221 -0
- torchft/fsdp_test.py +102 -0
- torchft/futures.py +353 -0
- torchft/futures_test.py +140 -0
- torchft/http.py +13 -0
- torchft/lighthouse_test.py +163 -0
- torchft/local_sgd.py +796 -0
- torchft/local_sgd_integ_test.py +600 -0
- torchft/local_sgd_test.py +324 -0
- torchft/manager.py +1358 -0
- torchft/manager_integ_test.py +653 -0
- torchft/manager_test.py +911 -0
- torchft/multiprocessing.py +38 -0
- torchft/multiprocessing_dummy_context.py +135 -0
- torchft/multiprocessing_test.py +58 -0
- torchft/optim.py +63 -0
- torchft/optim_test.py +50 -0
- torchft/otel.py +134 -0
- torchft/parameter_server.py +195 -0
- torchft/parameter_server_test.py +47 -0
- torchft/process_group.py +2118 -0
- torchft/process_group_test.py +1028 -0
- torchft/quantization.py +686 -0
- torchft/quantization_test.py +131 -0
- torchft/torchx.py +89 -0
- torchft/utils.py +67 -0
- torchft/work.py +26 -0
- torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
- torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
- torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
- torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
- torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import argparse
|
|
8
|
+
import logging
|
|
9
|
+
import random
|
|
10
|
+
import time
|
|
11
|
+
|
|
12
|
+
from torchx import specs
|
|
13
|
+
from torchx.runner import get_runner, Runner
|
|
14
|
+
|
|
15
|
+
logging.basicConfig(level=logging.INFO)
|
|
16
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
_SCHEDULER = "slurm"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def kill_all(runner: Runner) -> None:
|
|
22
|
+
jobs = runner.list(_SCHEDULER)
|
|
23
|
+
jobs = [job for job in jobs if job.state == specs.AppState.RUNNING]
|
|
24
|
+
for job in jobs:
|
|
25
|
+
if "ft_" not in job.name:
|
|
26
|
+
continue
|
|
27
|
+
print(f"killing {job.app_handle}")
|
|
28
|
+
runner.cancel(job.app_handle)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def kill_one(runner: Runner) -> None:
|
|
32
|
+
jobs = runner.list(_SCHEDULER)
|
|
33
|
+
jobs = [job for job in jobs if job.state == specs.AppState.RUNNING]
|
|
34
|
+
candidates = []
|
|
35
|
+
for job in jobs:
|
|
36
|
+
if "ft_" not in job.name:
|
|
37
|
+
continue
|
|
38
|
+
if "ft_0" in job.name:
|
|
39
|
+
continue
|
|
40
|
+
candidates.append(job.app_handle)
|
|
41
|
+
choice = random.choice(candidates)
|
|
42
|
+
print(f"killing {choice=} {candidates=}")
|
|
43
|
+
runner.cancel(choice)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def kill_loop(runner: Runner, args: argparse.Namespace) -> None:
|
|
47
|
+
for _ in range(args.num_failures):
|
|
48
|
+
kill_one(runner)
|
|
49
|
+
dur = random.random() * (2 * args.mtbf_secs)
|
|
50
|
+
print(f"sleeping for {dur=} {args.mtbf_secs=}")
|
|
51
|
+
time.sleep(args.mtbf_secs)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def main() -> None:
|
|
55
|
+
parser = argparse.ArgumentParser(description="CLI tool to inject failures on slurm")
|
|
56
|
+
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
|
57
|
+
|
|
58
|
+
# kill_loop subcommand
|
|
59
|
+
kill_loop_parser = subparsers.add_parser("kill_loop", help="Kill jobs in a loop")
|
|
60
|
+
kill_loop_parser.add_argument(
|
|
61
|
+
"--mtbf-secs",
|
|
62
|
+
type=float,
|
|
63
|
+
default=5,
|
|
64
|
+
help="Mean time between failures",
|
|
65
|
+
)
|
|
66
|
+
kill_loop_parser.add_argument(
|
|
67
|
+
"--num-failures",
|
|
68
|
+
type=int,
|
|
69
|
+
default=1,
|
|
70
|
+
help="Number of failures to inject",
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# kill_one subcommand
|
|
74
|
+
subparsers.add_parser("kill_one", help="Kill a single job")
|
|
75
|
+
|
|
76
|
+
# kill_all subcommand
|
|
77
|
+
subparsers.add_parser("kill_all", help="Kill all jobs")
|
|
78
|
+
|
|
79
|
+
args = parser.parse_args()
|
|
80
|
+
|
|
81
|
+
if args.command is None:
|
|
82
|
+
parser.print_help()
|
|
83
|
+
return
|
|
84
|
+
|
|
85
|
+
with get_runner() as runner:
|
|
86
|
+
if args.command == "kill_loop":
|
|
87
|
+
kill_loop(runner, args)
|
|
88
|
+
elif args.command == "kill_one":
|
|
89
|
+
kill_one(runner)
|
|
90
|
+
elif args.command == "kill_all":
|
|
91
|
+
kill_all(runner)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
if __name__ == "__main__":
|
|
95
|
+
main()
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import argparse
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import time
|
|
11
|
+
|
|
12
|
+
from torchx import specs
|
|
13
|
+
from torchx.components.dist import ddp
|
|
14
|
+
from torchx.runner import get_runner, Runner
|
|
15
|
+
|
|
16
|
+
logging.basicConfig(level=logging.INFO)
|
|
17
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
_SCHEDULER = "slurm"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _make_app(replica_id: int, cli_args: argparse.Namespace) -> specs.AppDef:
|
|
23
|
+
args = [
|
|
24
|
+
"--comm.trace_buf_size=0",
|
|
25
|
+
"--comm.train_timeout_seconds=60",
|
|
26
|
+
"--metrics.log_freq=1",
|
|
27
|
+
"--profiling.enable_profiling",
|
|
28
|
+
"--experimental.custom_args_module=torchtitan.components.ft.config",
|
|
29
|
+
"--job.config_file=./torchtitan/models/llama3/train_configs/llama3_8b.toml",
|
|
30
|
+
"--model.name=llama3_ft",
|
|
31
|
+
"--training.dataset=c4",
|
|
32
|
+
"--training.steps=10000",
|
|
33
|
+
"--training.local_batch_size=2",
|
|
34
|
+
f"--parallelism.data_parallel_shard_degree={cli_args.nodes * cli_args.nproc_per_node}",
|
|
35
|
+
"--fault_tolerance.enable",
|
|
36
|
+
f"--fault_tolerance.replica_id={replica_id}",
|
|
37
|
+
f"--fault_tolerance.group_size={cli_args.replica_count}",
|
|
38
|
+
f"--fault_tolerance.process_group={cli_args.process_group}",
|
|
39
|
+
f"--fault_tolerance.process_group_timeout_ms={600 * 1000}",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
if cli_args.enable_semi_sync:
|
|
43
|
+
args += [
|
|
44
|
+
f"--fault_tolerance.semi_sync_method={cli_args.semi_sync_method}",
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
if cli_args.semi_sync_method == "diloco":
|
|
48
|
+
args += [
|
|
49
|
+
"--fault_tolerance.sync_steps=20",
|
|
50
|
+
"--fault_tolerance.fragment_sync_delay=1",
|
|
51
|
+
f"--fault_tolerance.num_fragments={cli_args.num_fragments}",
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
if replica_id == 0:
|
|
55
|
+
args += [
|
|
56
|
+
"--metrics.enable-wandb",
|
|
57
|
+
"--checkpoint.interval=100",
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
env = {}
|
|
61
|
+
|
|
62
|
+
# use agent store in torchelastic to avoid TCPStore init race condition
|
|
63
|
+
env["TORCH_SHARE_RDZV_TCP_STORE"] = "1"
|
|
64
|
+
env["TORCH_CPP_LOG_LEVEL"] = "INFO"
|
|
65
|
+
|
|
66
|
+
env["TORCH_CUDA_SANITIZER=1"] = "1"
|
|
67
|
+
|
|
68
|
+
# NCCL envs for debugging
|
|
69
|
+
env["NCCL_DEBUG"] = "INFO"
|
|
70
|
+
env["NCCL_DEBUG_SUBSYS"] = "ALL"
|
|
71
|
+
env["NCCL_PROTO"] = "Simple"
|
|
72
|
+
|
|
73
|
+
# gloo
|
|
74
|
+
if os.environ.get("GLOO_SOCKET_IFNAME") is not None:
|
|
75
|
+
env["GLOO_SOCKET_IFNAME"] = os.environ.get("GLOO_SOCKET_IFNAME")
|
|
76
|
+
|
|
77
|
+
# application log levels
|
|
78
|
+
env["LOGLEVEL"] = "INFO"
|
|
79
|
+
env["RUST_LOGS"] = "INFO"
|
|
80
|
+
env["TORCH_CPP_LOG_LEVEL"] = "INFO"
|
|
81
|
+
|
|
82
|
+
# application timeouts
|
|
83
|
+
env["TORCHFT_QUORUM_TIMEOUT_SEC"] = "900"
|
|
84
|
+
env["TORCHFT_TIMEOUT_SEC"] = "600"
|
|
85
|
+
env["TORCHFT_QUORUM_RETRIES"] = "0"
|
|
86
|
+
|
|
87
|
+
env["TORCHFT_LIGHTHOUSE"] = os.environ.get(
|
|
88
|
+
"TORCHFT_LIGHTHOUSE", "http://slurm-head-node-0:29510"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
env["WANDB_PROJECT"] = "torchft"
|
|
92
|
+
|
|
93
|
+
app = ddp(
|
|
94
|
+
*args,
|
|
95
|
+
name=f"ft_{replica_id}",
|
|
96
|
+
env=env,
|
|
97
|
+
script="./torchtitan/train.py",
|
|
98
|
+
gpu=cli_args.nproc_per_node,
|
|
99
|
+
j=f"{cli_args.nodes}x{cli_args.nproc_per_node}",
|
|
100
|
+
)
|
|
101
|
+
app.roles[0].name = app.name
|
|
102
|
+
return app
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def start_replica(
|
|
106
|
+
runner: Runner, replica_id: int, args: argparse.Namespace
|
|
107
|
+
) -> specs.AppHandle:
|
|
108
|
+
app = _make_app(replica_id, args)
|
|
109
|
+
|
|
110
|
+
app_handle = runner.run(
|
|
111
|
+
app,
|
|
112
|
+
scheduler=_SCHEDULER,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
return app_handle
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def monitor(runner: Runner, args: argparse.Namespace) -> None:
|
|
119
|
+
jobs = runner.list(_SCHEDULER)
|
|
120
|
+
jobs = [job for job in jobs if job.state == specs.AppState.RUNNING]
|
|
121
|
+
|
|
122
|
+
active_replicas = {}
|
|
123
|
+
|
|
124
|
+
for job in jobs:
|
|
125
|
+
if "ft_" not in job.name:
|
|
126
|
+
continue
|
|
127
|
+
name, _, _ = job.name.partition("-")
|
|
128
|
+
_, _, replica_id_str = name.partition("_")
|
|
129
|
+
replica_id = int(replica_id_str)
|
|
130
|
+
active_replicas[replica_id] = job
|
|
131
|
+
|
|
132
|
+
to_launch = set()
|
|
133
|
+
for replica_id in range(args.replica_count):
|
|
134
|
+
alive = replica_id in active_replicas
|
|
135
|
+
|
|
136
|
+
if alive:
|
|
137
|
+
job = active_replicas[replica_id]
|
|
138
|
+
print(f" - {replica_id=:2d}: ALIVE {job.app_handle}")
|
|
139
|
+
else:
|
|
140
|
+
print(f" - {replica_id=:2d}: DEAD")
|
|
141
|
+
to_launch.add(replica_id)
|
|
142
|
+
|
|
143
|
+
for replica_id in to_launch:
|
|
144
|
+
app_handle = start_replica(
|
|
145
|
+
runner,
|
|
146
|
+
replica_id,
|
|
147
|
+
args,
|
|
148
|
+
)
|
|
149
|
+
print(f"launched {replica_id=}: {app_handle=}")
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def main() -> None:
|
|
153
|
+
parser = argparse.ArgumentParser(
|
|
154
|
+
description="CLI tool lauch data parallel replicas on slurm"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
parser.add_argument(
|
|
158
|
+
"--workspace-dir", type=str, help="Location of torchtitan folder"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
parser.add_argument(
|
|
162
|
+
"--nodes",
|
|
163
|
+
type=int,
|
|
164
|
+
default=10,
|
|
165
|
+
help="Number of nodes per replica",
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
parser.add_argument(
|
|
169
|
+
"--nproc-per-node",
|
|
170
|
+
type=int,
|
|
171
|
+
default=10,
|
|
172
|
+
help="Number of ranks per node",
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
parser.add_argument(
|
|
176
|
+
"--replica-count",
|
|
177
|
+
type=int,
|
|
178
|
+
default=10,
|
|
179
|
+
help="Number of data parallel replicas",
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
parser.add_argument(
|
|
183
|
+
"--process-group",
|
|
184
|
+
type=str,
|
|
185
|
+
default="gloo",
|
|
186
|
+
help="The process group to use for data parallel",
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
parser.add_argument(
|
|
190
|
+
"--enable-semi-sync",
|
|
191
|
+
type=bool,
|
|
192
|
+
default=True,
|
|
193
|
+
help="Whether to enable semi-sync method for data parallel",
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
parser.add_argument(
|
|
197
|
+
"--semi-sync-method",
|
|
198
|
+
type=str,
|
|
199
|
+
default="diloco",
|
|
200
|
+
help="The semi-sync method to use for data parallel. Options: diloco, local_sgd",
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
parser.add_argument(
|
|
204
|
+
"--num-fragments",
|
|
205
|
+
type=int,
|
|
206
|
+
default=2,
|
|
207
|
+
help="The number of fragments to use for data parallel. Only used for diloco semi-sync method",
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
args = parser.parse_args()
|
|
211
|
+
|
|
212
|
+
os.chdir(args.workspace_dir)
|
|
213
|
+
|
|
214
|
+
with get_runner() as runner:
|
|
215
|
+
while True:
|
|
216
|
+
monitor(runner, args)
|
|
217
|
+
time.sleep(10)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
if __name__ == "__main__":
|
|
221
|
+
main()
|
torchft/fsdp_test.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import multiprocessing
|
|
8
|
+
import os
|
|
9
|
+
import unittest
|
|
10
|
+
from concurrent.futures import ProcessPoolExecutor
|
|
11
|
+
from unittest.mock import Mock
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
import torch.distributed as dist
|
|
15
|
+
from torch import nn
|
|
16
|
+
from torch._C._distributed_c10d import ReduceOp
|
|
17
|
+
from torch.distributed._composable.fsdp import FSDPModule, fully_shard
|
|
18
|
+
from torch.distributed.tensor import init_device_mesh
|
|
19
|
+
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
|
|
20
|
+
|
|
21
|
+
from torchft.manager import Manager
|
|
22
|
+
from torchft.process_group import ProcessGroupGloo
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class FSDPTest(unittest.TestCase):
|
|
26
|
+
@staticmethod
|
|
27
|
+
def _test_fsdp(
|
|
28
|
+
world_size: int,
|
|
29
|
+
rank: int,
|
|
30
|
+
dp_replicate: int = 2,
|
|
31
|
+
dp_shard: int = 2,
|
|
32
|
+
tp: int = 1,
|
|
33
|
+
) -> None:
|
|
34
|
+
torch.cuda.set_device(rank)
|
|
35
|
+
|
|
36
|
+
group_size = world_size // dp_replicate
|
|
37
|
+
group = rank // group_size
|
|
38
|
+
group_rank = rank % group_size
|
|
39
|
+
|
|
40
|
+
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
|
41
|
+
os.environ["MASTER_PORT"] = str(12346 + group)
|
|
42
|
+
os.environ["RANK"] = str(group_rank)
|
|
43
|
+
os.environ["WORLD_SIZE"] = str(group_size)
|
|
44
|
+
|
|
45
|
+
manager = Mock(spec=Manager)
|
|
46
|
+
pg: ProcessGroupGloo = Mock(spec=ProcessGroupGloo)
|
|
47
|
+
device_mesh = init_device_mesh(
|
|
48
|
+
device_type="cuda",
|
|
49
|
+
mesh_shape=(dp_shard, tp),
|
|
50
|
+
mesh_dim_names=("dp_shard", "tp"),
|
|
51
|
+
)
|
|
52
|
+
manager.num_participants.return_value = 1
|
|
53
|
+
model = nn.Linear(128, 128).cuda()
|
|
54
|
+
batch = torch.randn(4, 128).cuda()
|
|
55
|
+
|
|
56
|
+
fsdp_mesh = device_mesh["dp_shard"]
|
|
57
|
+
|
|
58
|
+
def all_reduce_hook(output: torch.Tensor) -> None:
|
|
59
|
+
dist.all_reduce(output, group=pg, op=ReduceOp.AVG)
|
|
60
|
+
|
|
61
|
+
def apply_set_all_reduce_hook(m: nn.Module) -> None:
|
|
62
|
+
assert isinstance(m, FSDPModule)
|
|
63
|
+
m.set_all_reduce_hook(all_reduce_hook)
|
|
64
|
+
|
|
65
|
+
if tp > 1:
|
|
66
|
+
tp_mesh = device_mesh["tp"]
|
|
67
|
+
model = parallelize_module(
|
|
68
|
+
model,
|
|
69
|
+
tp_mesh,
|
|
70
|
+
ColwiseParallel(),
|
|
71
|
+
)
|
|
72
|
+
shard_model = fully_shard(model, mesh=fsdp_mesh)
|
|
73
|
+
shard_model.apply(apply_set_all_reduce_hook)
|
|
74
|
+
shard_model(batch).mean().backward()
|
|
75
|
+
|
|
76
|
+
# pyre-ignore[56]: Pyre was not able to infer the type of argument
|
|
77
|
+
@unittest.skipIf(torch.cuda.device_count() < 4, "Not enough GPUs")
|
|
78
|
+
def test_fsdp(self) -> None:
|
|
79
|
+
context = multiprocessing.get_context("spawn")
|
|
80
|
+
with ProcessPoolExecutor(max_workers=4, mp_context=context) as executor:
|
|
81
|
+
futures = []
|
|
82
|
+
for i in range(4):
|
|
83
|
+
future = executor.submit(self._test_fsdp, 4, i)
|
|
84
|
+
futures.append(future)
|
|
85
|
+
|
|
86
|
+
for fut in futures:
|
|
87
|
+
fut.result()
|
|
88
|
+
|
|
89
|
+
# pyre-ignore[56]: Pyre was not able to infer the type of argument
|
|
90
|
+
@unittest.skipIf(torch.cuda.device_count() < 4, "Not enough GPUs")
|
|
91
|
+
def test_fsdp_tp(self) -> None:
|
|
92
|
+
context = multiprocessing.get_context("spawn")
|
|
93
|
+
with ProcessPoolExecutor(max_workers=4, mp_context=context) as executor:
|
|
94
|
+
futures = []
|
|
95
|
+
for i in range(4):
|
|
96
|
+
future = executor.submit(
|
|
97
|
+
self._test_fsdp, 4, i, dp_replicate=1, dp_shard=2, tp=2
|
|
98
|
+
)
|
|
99
|
+
futures.append(future)
|
|
100
|
+
|
|
101
|
+
for fut in futures:
|
|
102
|
+
fut.result()
|