torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchft/__init__.py +34 -0
- torchft/_test/diloco_trainer.py +287 -0
- torchft/_test/managed_work_test.py +320 -0
- torchft/_test_utils.py +111 -0
- torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
- torchft/_torchft.pyi +116 -0
- torchft/checkpointing/__init__.py +20 -0
- torchft/checkpointing/_rwlock.py +136 -0
- torchft/checkpointing/_serialization.py +39 -0
- torchft/checkpointing/http_transport.py +299 -0
- torchft/checkpointing/http_transport_bench.py +61 -0
- torchft/checkpointing/http_transport_test.py +146 -0
- torchft/checkpointing/pg_transport.py +306 -0
- torchft/checkpointing/pg_transport_bench.py +99 -0
- torchft/checkpointing/pg_transport_test.py +101 -0
- torchft/checkpointing/rwlock_test.py +58 -0
- torchft/checkpointing/transport.py +68 -0
- torchft/checkpointing/transport_test.py +161 -0
- torchft/collectives.py +415 -0
- torchft/collectives_test.py +212 -0
- torchft/coordination.py +39 -0
- torchft/coordination_test.py +29 -0
- torchft/data.py +77 -0
- torchft/data_test.py +39 -0
- torchft/ddp.py +105 -0
- torchft/ddp_test.py +68 -0
- torchft/diloco_regression_test.py +644 -0
- torchft/examples/slurm/README.md +34 -0
- torchft/examples/slurm/punisher.py +95 -0
- torchft/examples/slurm/runner.py +221 -0
- torchft/fsdp_test.py +102 -0
- torchft/futures.py +353 -0
- torchft/futures_test.py +140 -0
- torchft/http.py +13 -0
- torchft/lighthouse_test.py +163 -0
- torchft/local_sgd.py +796 -0
- torchft/local_sgd_integ_test.py +600 -0
- torchft/local_sgd_test.py +324 -0
- torchft/manager.py +1358 -0
- torchft/manager_integ_test.py +653 -0
- torchft/manager_test.py +911 -0
- torchft/multiprocessing.py +38 -0
- torchft/multiprocessing_dummy_context.py +135 -0
- torchft/multiprocessing_test.py +58 -0
- torchft/optim.py +63 -0
- torchft/optim_test.py +50 -0
- torchft/otel.py +134 -0
- torchft/parameter_server.py +195 -0
- torchft/parameter_server_test.py +47 -0
- torchft/process_group.py +2118 -0
- torchft/process_group_test.py +1028 -0
- torchft/quantization.py +686 -0
- torchft/quantization_test.py +131 -0
- torchft/torchx.py +89 -0
- torchft/utils.py +67 -0
- torchft/work.py +26 -0
- torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
- torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
- torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
- torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
- torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
torchft/__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from torchft.data import DistributedSampler
|
|
8
|
+
from torchft.ddp import DistributedDataParallel
|
|
9
|
+
from torchft.manager import Manager
|
|
10
|
+
from torchft.optim import OptimizerWrapper as Optimizer
|
|
11
|
+
from torchft.otel import setup_logger
|
|
12
|
+
from torchft.process_group import (
|
|
13
|
+
ProcessGroupBabyNCCL,
|
|
14
|
+
ProcessGroupBabyXCCL,
|
|
15
|
+
ProcessGroupGloo,
|
|
16
|
+
ProcessGroupNCCL,
|
|
17
|
+
ProcessGroupXCCL,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
setup_logger("torchft_quorums")
|
|
21
|
+
setup_logger("torchft_commits")
|
|
22
|
+
setup_logger("torchft_errors")
|
|
23
|
+
|
|
24
|
+
__all__ = (
|
|
25
|
+
"DistributedDataParallel",
|
|
26
|
+
"DistributedSampler",
|
|
27
|
+
"Manager",
|
|
28
|
+
"Optimizer",
|
|
29
|
+
"ProcessGroupNCCL",
|
|
30
|
+
"ProcessGroupXCCL",
|
|
31
|
+
"ProcessGroupBabyNCCL",
|
|
32
|
+
"ProcessGroupBabyXCCL",
|
|
33
|
+
"ProcessGroupGloo",
|
|
34
|
+
)
|
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import copy
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
from datetime import timedelta
|
|
11
|
+
from typing import Any, Dict
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from torch import nn
|
|
15
|
+
from torch.distributed.tensor import DeviceMesh, DTensor
|
|
16
|
+
|
|
17
|
+
from torchft.local_sgd import DiLoCo
|
|
18
|
+
from torchft.manager import Manager
|
|
19
|
+
from torchft.manager_integ_test import MyModel, Runner
|
|
20
|
+
from torchft.process_group import (
|
|
21
|
+
FakeProcessGroupWrapper,
|
|
22
|
+
ProcessGroupBabyNCCL,
|
|
23
|
+
ProcessGroupGloo,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MultiModel(torch.nn.Module):
|
|
30
|
+
def __init__(self, in_dim: int = 3, out_dim: int = 4, n_layers: int = 1) -> None:
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.layers = torch.nn.ModuleList()
|
|
33
|
+
|
|
34
|
+
def get_rand_inputs(
|
|
35
|
+
self, batch_size: int, device: torch.device = torch.device("cpu")
|
|
36
|
+
) -> torch.Tensor:
|
|
37
|
+
raise
|
|
38
|
+
|
|
39
|
+
def get_rand_labels(
|
|
40
|
+
self, batch_size: int, device: torch.device = torch.device("cpu")
|
|
41
|
+
) -> torch.Tensor:
|
|
42
|
+
raise
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class MultiMyModel(MultiModel):
|
|
46
|
+
def __init__(self, in_dim: int = 3, out_dim: int = 4, n_layers: int = 1) -> None:
|
|
47
|
+
super().__init__()
|
|
48
|
+
self.in_dim = in_dim
|
|
49
|
+
|
|
50
|
+
for _ in range(n_layers):
|
|
51
|
+
self.layers.append(MyModel(in_dim, out_dim))
|
|
52
|
+
in_dim, out_dim = out_dim, in_dim
|
|
53
|
+
|
|
54
|
+
self.out_dim = in_dim
|
|
55
|
+
|
|
56
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
57
|
+
for layer in self.layers:
|
|
58
|
+
x = layer(x)
|
|
59
|
+
return x
|
|
60
|
+
|
|
61
|
+
def get_rand_inputs(
|
|
62
|
+
self, batch_size: int, device: torch.device = torch.device("cpu")
|
|
63
|
+
) -> torch.Tensor:
|
|
64
|
+
return torch.rand(batch_size, self.in_dim, device=device)
|
|
65
|
+
|
|
66
|
+
def get_rand_labels(
|
|
67
|
+
self, batch_size: int, device: torch.device = torch.device("cpu")
|
|
68
|
+
) -> torch.Tensor:
|
|
69
|
+
return torch.randint(self.out_dim, (batch_size,), device=device)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class DiLoCoTrainer:
|
|
73
|
+
"""
|
|
74
|
+
A class that encapsulates the DiLoCo training process.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def __init__(
|
|
78
|
+
self,
|
|
79
|
+
rank: int,
|
|
80
|
+
store_port: int,
|
|
81
|
+
device: torch.device,
|
|
82
|
+
runner: Runner,
|
|
83
|
+
model_state_dict: dict[str, Any],
|
|
84
|
+
n_fragments: int,
|
|
85
|
+
diloco_args: dict[str, Any],
|
|
86
|
+
) -> None:
|
|
87
|
+
"""
|
|
88
|
+
Initialize the DiLoCoTrainer.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
rank: The rank of the current process.
|
|
92
|
+
store_port: The port for the store.
|
|
93
|
+
device: The device to use for training.
|
|
94
|
+
runner: The runner instance.
|
|
95
|
+
train_loop_args: Additional arguments for the training loop.
|
|
96
|
+
"""
|
|
97
|
+
self.rank: int = rank
|
|
98
|
+
self.store_port: int = store_port
|
|
99
|
+
self.device: torch.device = device
|
|
100
|
+
self.runner: Runner = runner
|
|
101
|
+
|
|
102
|
+
# Extract arguments from train_loop_args
|
|
103
|
+
self.model_state_dict: Dict[str, Any] = model_state_dict
|
|
104
|
+
self.n_fragments: int = n_fragments
|
|
105
|
+
self.diloco_args: dict[str, Any] = diloco_args
|
|
106
|
+
|
|
107
|
+
# Initialize components
|
|
108
|
+
self.model: MultiModel = self.setup_model()
|
|
109
|
+
self.inner_optimizer: torch.optim.Optimizer = self.setup_inner_optimizer()
|
|
110
|
+
self.outer_optimizers: list[torch.optim.Optimizer] = (
|
|
111
|
+
self.setup_outer_optimizers()
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
self.pg: FakeProcessGroupWrapper = self.setup_pg()
|
|
115
|
+
# Set up the process group for the event injector
|
|
116
|
+
self.runner.event_injector.set_pg(self.pg)
|
|
117
|
+
|
|
118
|
+
self.manager: Manager = self.setup_manager()
|
|
119
|
+
|
|
120
|
+
self.device_mesh: None | DeviceMesh = None
|
|
121
|
+
self.setup_distributed()
|
|
122
|
+
|
|
123
|
+
self.criterion: nn.CrossEntropyLoss = nn.CrossEntropyLoss()
|
|
124
|
+
|
|
125
|
+
self.diloco: DiLoCo | None = None
|
|
126
|
+
|
|
127
|
+
def setup_model(self) -> MultiModel:
|
|
128
|
+
"""Set up the model and move it to the device."""
|
|
129
|
+
model = MultiMyModel(2, 3, self.n_fragments)
|
|
130
|
+
model.load_state_dict(self.model_state_dict)
|
|
131
|
+
model.to(self.device)
|
|
132
|
+
return model
|
|
133
|
+
|
|
134
|
+
def setup_inner_optimizer(self) -> torch.optim.Optimizer:
|
|
135
|
+
"""Set up the inner optimizer."""
|
|
136
|
+
return torch.optim.AdamW(
|
|
137
|
+
self.model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def setup_outer_optimizers(self) -> list[torch.optim.Optimizer]:
|
|
141
|
+
"""Set up outer optimizers."""
|
|
142
|
+
# Setup inner optimizer
|
|
143
|
+
# Create one outer optimizer per fragment
|
|
144
|
+
outer_optimizers = []
|
|
145
|
+
for _, layers in enumerate(self.model.layers):
|
|
146
|
+
outer_optimizers.append(
|
|
147
|
+
torch.optim.SGD(
|
|
148
|
+
layers.parameters(), lr=0.7, momentum=0.9, nesterov=True
|
|
149
|
+
)
|
|
150
|
+
)
|
|
151
|
+
return outer_optimizers
|
|
152
|
+
|
|
153
|
+
def setup_pg(self) -> FakeProcessGroupWrapper:
|
|
154
|
+
if self.device.type == "cuda":
|
|
155
|
+
return FakeProcessGroupWrapper(ProcessGroupBabyNCCL())
|
|
156
|
+
else:
|
|
157
|
+
return FakeProcessGroupWrapper(
|
|
158
|
+
ProcessGroupGloo(timeout=timedelta(seconds=10))
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def setup_manager(self) -> Manager:
|
|
162
|
+
"""Set up the process group and manager."""
|
|
163
|
+
print(
|
|
164
|
+
f"worker {self.runner.replica_id=} {self.rank=} {self.runner.world_size=} starting"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Create manager with all arguments passed directly
|
|
168
|
+
return Manager(
|
|
169
|
+
pg=self.pg,
|
|
170
|
+
min_replica_size=2,
|
|
171
|
+
use_async_quorum=False,
|
|
172
|
+
load_state_dict=self.load_state_dict,
|
|
173
|
+
state_dict=self.state_dict,
|
|
174
|
+
replica_id=str(self.runner.replica_id),
|
|
175
|
+
store_addr="localhost",
|
|
176
|
+
store_port=self.store_port,
|
|
177
|
+
rank=self.rank,
|
|
178
|
+
world_size=self.runner.world_size,
|
|
179
|
+
lighthouse_addr=self.runner.lighthouse_address,
|
|
180
|
+
port=19530 + self.runner.replica_id,
|
|
181
|
+
connect_timeout=timedelta(seconds=10),
|
|
182
|
+
quorum_timeout=timedelta(seconds=10),
|
|
183
|
+
timeout=timedelta(seconds=10),
|
|
184
|
+
**self.runner.manager_args, # type: ignore
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def setup_distributed(self) -> None:
|
|
188
|
+
"""Set up distributed training."""
|
|
189
|
+
# Initialize default group for device mesh to work
|
|
190
|
+
if not torch.distributed.is_initialized():
|
|
191
|
+
# TODO: remove this try-except once pytorch is updated to 2.8.0 and can use localhost:0
|
|
192
|
+
try:
|
|
193
|
+
torch.distributed.init_process_group(
|
|
194
|
+
init_method="tcp://localhost:0",
|
|
195
|
+
rank=self.rank,
|
|
196
|
+
world_size=self.runner.world_size,
|
|
197
|
+
)
|
|
198
|
+
except ValueError:
|
|
199
|
+
os.environ["MASTER_ADDR"] = "localhost"
|
|
200
|
+
os.environ["MASTER_PORT"] = "0"
|
|
201
|
+
os.environ["WORLD_SIZE"] = str(self.runner.world_size)
|
|
202
|
+
os.environ["RANK"] = str(self.rank)
|
|
203
|
+
|
|
204
|
+
self.device_mesh = DeviceMesh(
|
|
205
|
+
self.device.type,
|
|
206
|
+
torch.arange(self.runner.world_size),
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# Convert model parameters to DTensor
|
|
210
|
+
for layer in self.model.layers:
|
|
211
|
+
if isinstance(layer, nn.Linear):
|
|
212
|
+
for param in layer.parameters():
|
|
213
|
+
param = DTensor.from_local(
|
|
214
|
+
param,
|
|
215
|
+
device_mesh=self.device_mesh,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
def load_state_dict(self, state_dict: Dict[str, Dict[str, object]]) -> None:
|
|
219
|
+
"""
|
|
220
|
+
Load the state dictionary.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
state_dict: The state dictionary to load.
|
|
224
|
+
"""
|
|
225
|
+
assert self.diloco is not None
|
|
226
|
+
|
|
227
|
+
self.model.load_state_dict(state_dict["model"])
|
|
228
|
+
self.model.to(self.device)
|
|
229
|
+
|
|
230
|
+
self.inner_optimizer.load_state_dict(state_dict["inner_optim"])
|
|
231
|
+
|
|
232
|
+
def state_dict(self) -> Dict[str, Dict[str, object]]:
|
|
233
|
+
"""
|
|
234
|
+
Get the state dictionary.
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
The state dictionary.
|
|
238
|
+
"""
|
|
239
|
+
assert self.diloco is not None
|
|
240
|
+
|
|
241
|
+
return {
|
|
242
|
+
"model": self.model.state_dict(),
|
|
243
|
+
"inner_optim": self.inner_optimizer.state_dict(),
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
def train_loop(self) -> dict[str, Any]:
|
|
247
|
+
"""Run the training loop."""
|
|
248
|
+
# Ensure sync_every is set in diloco_args
|
|
249
|
+
all_state_dicts = {}
|
|
250
|
+
|
|
251
|
+
if "sync_every" not in self.diloco_args:
|
|
252
|
+
self.diloco_args["sync_every"] = 2
|
|
253
|
+
|
|
254
|
+
with DiLoCo(
|
|
255
|
+
self.manager,
|
|
256
|
+
[layer for layer in self.model.layers],
|
|
257
|
+
self.inner_optimizer,
|
|
258
|
+
self.outer_optimizers,
|
|
259
|
+
backup_device=self.device,
|
|
260
|
+
**self.diloco_args,
|
|
261
|
+
) as self.diloco:
|
|
262
|
+
while True:
|
|
263
|
+
self.runner.event_injector.check(self.rank, self.manager.current_step())
|
|
264
|
+
|
|
265
|
+
manager_curr_step = self.manager.current_step()
|
|
266
|
+
if manager_curr_step not in all_state_dicts:
|
|
267
|
+
# Store the manager state dict, converting to the right type
|
|
268
|
+
all_state_dicts[manager_curr_step] = copy.deepcopy(
|
|
269
|
+
self.manager._manager_state_dict()
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
batch_size = 1
|
|
273
|
+
inputs = self.model.get_rand_inputs(batch_size, device=self.device)
|
|
274
|
+
labels = self.model.get_rand_labels(batch_size, device=self.device)
|
|
275
|
+
|
|
276
|
+
out = self.model(inputs)
|
|
277
|
+
loss = self.criterion(out, labels)
|
|
278
|
+
|
|
279
|
+
self.inner_optimizer.zero_grad()
|
|
280
|
+
loss.backward()
|
|
281
|
+
self.inner_optimizer.step()
|
|
282
|
+
|
|
283
|
+
# after 4 model updates then break
|
|
284
|
+
if self.manager.current_step() >= 4:
|
|
285
|
+
break
|
|
286
|
+
|
|
287
|
+
return all_state_dicts
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import types
|
|
8
|
+
import unittest
|
|
9
|
+
from datetime import timedelta
|
|
10
|
+
from typing import Callable, cast, Dict, List, Optional, Tuple, TypeVar
|
|
11
|
+
|
|
12
|
+
# Define a type variable for the Future's value type
|
|
13
|
+
T = TypeVar("T")
|
|
14
|
+
|
|
15
|
+
import parameterized
|
|
16
|
+
import torch
|
|
17
|
+
from torch.distributed.distributed_c10d import Work
|
|
18
|
+
from torch.futures import Future
|
|
19
|
+
|
|
20
|
+
from torchft.manager import _ManagedWork, Manager
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SimpleWork(Work):
|
|
24
|
+
"""A simple implementation of torch.distributed.Work for testing."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, tensors: List[torch.Tensor]) -> None:
|
|
27
|
+
super().__init__()
|
|
28
|
+
self._tensors = tensors
|
|
29
|
+
self._future: Future[List[torch.Tensor]] = torch.futures.Future()
|
|
30
|
+
self._is_completed: bool = False
|
|
31
|
+
|
|
32
|
+
def wait(self, timeout: Optional[timedelta] = None) -> bool:
|
|
33
|
+
self._is_completed = True
|
|
34
|
+
self._future.set_result(self._tensors)
|
|
35
|
+
return True
|
|
36
|
+
|
|
37
|
+
def get_future(self) -> Future[List[torch.Tensor]]:
|
|
38
|
+
return self._future
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class TestManagedWork(unittest.TestCase):
|
|
42
|
+
@parameterized.parameterized.expand(
|
|
43
|
+
[
|
|
44
|
+
("cpu", torch.device("cpu")),
|
|
45
|
+
("cuda", torch.device("cuda:0")),
|
|
46
|
+
]
|
|
47
|
+
)
|
|
48
|
+
def test_callbacks_execute_after_wait(
|
|
49
|
+
self, name: str, device: torch.device
|
|
50
|
+
) -> None:
|
|
51
|
+
"""Test that callbacks are only executed after wait() is called."""
|
|
52
|
+
# Skip if CUDA is requested but not available
|
|
53
|
+
if device.type == "cuda" and not torch.cuda.is_available():
|
|
54
|
+
self.skipTest("CUDA not available")
|
|
55
|
+
|
|
56
|
+
# Create a tensor to work with
|
|
57
|
+
tensor: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device)
|
|
58
|
+
|
|
59
|
+
# Create a simple work object
|
|
60
|
+
work = SimpleWork([tensor])
|
|
61
|
+
|
|
62
|
+
# Create a minimal manager object with just the wrap_future method
|
|
63
|
+
manager = Manager.__new__(Manager) # Create instance without calling __init__
|
|
64
|
+
# We're using types.MethodType to attach a method to the manager instance
|
|
65
|
+
# This is just for testing purposes
|
|
66
|
+
manager.wrap_future = types.MethodType( # type: ignore
|
|
67
|
+
lambda self, fut, default, timeout=None: fut, manager
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Create the managed work
|
|
71
|
+
managed_work = _ManagedWork(manager, work, [tensor])
|
|
72
|
+
|
|
73
|
+
# Track callback execution
|
|
74
|
+
callback_executed: bool = False
|
|
75
|
+
|
|
76
|
+
def callback(fut: Future[object]) -> List[torch.Tensor]:
|
|
77
|
+
# Cast to the expected type
|
|
78
|
+
nonlocal callback_executed, tensor
|
|
79
|
+
callback_executed = True
|
|
80
|
+
# Multiply tensor by 2 to verify the callback ran
|
|
81
|
+
tensor.mul_(2)
|
|
82
|
+
return [tensor]
|
|
83
|
+
|
|
84
|
+
# Add the callback
|
|
85
|
+
fut = managed_work.get_future()
|
|
86
|
+
fut = fut.then(callback)
|
|
87
|
+
|
|
88
|
+
# Verify callback hasn't executed yet
|
|
89
|
+
self.assertFalse(callback_executed)
|
|
90
|
+
self.assertEqual(tensor.item(), 1.0)
|
|
91
|
+
|
|
92
|
+
# Call wait() which should trigger the callback
|
|
93
|
+
managed_work.wait()
|
|
94
|
+
|
|
95
|
+
# Verify callback has executed
|
|
96
|
+
self.assertTrue(callback_executed)
|
|
97
|
+
self.assertEqual(tensor.item(), 2.0)
|
|
98
|
+
|
|
99
|
+
@parameterized.parameterized.expand(
|
|
100
|
+
[
|
|
101
|
+
("cpu", torch.device("cpu")),
|
|
102
|
+
("cuda", torch.device("cuda:0")),
|
|
103
|
+
]
|
|
104
|
+
)
|
|
105
|
+
def test_multiple_callbacks_execute_in_order(
|
|
106
|
+
self, name: str, device: torch.device
|
|
107
|
+
) -> None:
|
|
108
|
+
"""Test that multiple callbacks are executed in the order they were added."""
|
|
109
|
+
# Skip if CUDA is requested but not available
|
|
110
|
+
if device.type == "cuda" and not torch.cuda.is_available():
|
|
111
|
+
self.skipTest("CUDA not available")
|
|
112
|
+
|
|
113
|
+
# Create a tensor to work with
|
|
114
|
+
tensor: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device)
|
|
115
|
+
|
|
116
|
+
# Create a simple work object
|
|
117
|
+
work = SimpleWork([tensor])
|
|
118
|
+
|
|
119
|
+
# Create a minimal manager object with just the wrap_future method
|
|
120
|
+
manager = Manager.__new__(Manager) # Create instance without calling __init__
|
|
121
|
+
manager.wrap_future = types.MethodType( # type: ignore
|
|
122
|
+
lambda self, fut, default, timeout=None: fut, manager
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Create the managed work
|
|
126
|
+
managed_work = _ManagedWork(manager, work, [tensor])
|
|
127
|
+
|
|
128
|
+
# Track execution order
|
|
129
|
+
execution_order: List[int] = []
|
|
130
|
+
|
|
131
|
+
def callback1(fut: Future[list[torch.Tensor]]) -> List[torch.Tensor]:
|
|
132
|
+
nonlocal tensor
|
|
133
|
+
execution_order.append(1)
|
|
134
|
+
tensor.add_(1)
|
|
135
|
+
return [tensor]
|
|
136
|
+
|
|
137
|
+
def callback2(fut: Future[list[torch.Tensor]]) -> List[torch.Tensor]:
|
|
138
|
+
nonlocal tensor
|
|
139
|
+
execution_order.append(2)
|
|
140
|
+
tensor.add_(2)
|
|
141
|
+
return [tensor]
|
|
142
|
+
|
|
143
|
+
def callback3(fut: Future[list[torch.Tensor]]) -> List[torch.Tensor]:
|
|
144
|
+
nonlocal tensor
|
|
145
|
+
execution_order.append(3)
|
|
146
|
+
tensor.add_(3)
|
|
147
|
+
return [tensor]
|
|
148
|
+
|
|
149
|
+
# Add callbacks
|
|
150
|
+
fut = managed_work.get_future()
|
|
151
|
+
fut = cast(Future[list[torch.Tensor]], fut)
|
|
152
|
+
fut = fut.then(callback1)
|
|
153
|
+
fut = fut.then(callback2)
|
|
154
|
+
fut = fut.then(callback3)
|
|
155
|
+
|
|
156
|
+
# Verify no callbacks have executed yet
|
|
157
|
+
self.assertEqual(len(execution_order), 0)
|
|
158
|
+
self.assertEqual(tensor.item(), 1.0)
|
|
159
|
+
|
|
160
|
+
# Call wait() which should trigger the callbacks
|
|
161
|
+
managed_work.wait()
|
|
162
|
+
|
|
163
|
+
# Verify callbacks executed in order
|
|
164
|
+
self.assertEqual(execution_order, [1, 2, 3])
|
|
165
|
+
|
|
166
|
+
# Each callback adds to the tensor, so final value should be 1 + 1 + 2 + 3 = 7
|
|
167
|
+
self.assertEqual(tensor.item(), 7.0)
|
|
168
|
+
|
|
169
|
+
@parameterized.parameterized.expand(
|
|
170
|
+
[
|
|
171
|
+
("cpu", torch.device("cpu")),
|
|
172
|
+
("cuda", torch.device("cuda:0")),
|
|
173
|
+
]
|
|
174
|
+
)
|
|
175
|
+
def test_future_then_api(self, name: str, device: torch.device) -> None:
|
|
176
|
+
"""Test that the future's then API works correctly with ManagedWork."""
|
|
177
|
+
# Skip if CUDA is requested but not available
|
|
178
|
+
if device.type == "cuda" and not torch.cuda.is_available():
|
|
179
|
+
self.skipTest("CUDA not available")
|
|
180
|
+
|
|
181
|
+
# Create a tensor to work with
|
|
182
|
+
tensor: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device)
|
|
183
|
+
|
|
184
|
+
# Create a simple work object
|
|
185
|
+
work = SimpleWork([tensor])
|
|
186
|
+
|
|
187
|
+
# Create a minimal manager object with just the wrap_future method
|
|
188
|
+
manager = Manager.__new__(Manager) # Create instance without calling __init__
|
|
189
|
+
manager.wrap_future = types.MethodType( # type: ignore
|
|
190
|
+
lambda self, fut, default, timeout=None: fut, manager
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Create the managed work
|
|
194
|
+
managed_work = _ManagedWork(manager, work, [tensor])
|
|
195
|
+
|
|
196
|
+
# Get the future
|
|
197
|
+
future = managed_work.get_future()
|
|
198
|
+
|
|
199
|
+
# Track callback execution
|
|
200
|
+
callback_executed: bool = False
|
|
201
|
+
|
|
202
|
+
def callback(fut: Future[object]) -> List[torch.Tensor]:
|
|
203
|
+
# Cast to the expected type
|
|
204
|
+
nonlocal callback_executed, tensor
|
|
205
|
+
callback_executed = True
|
|
206
|
+
# Multiply tensor by 3 to verify the callback ran
|
|
207
|
+
tensor.mul_(3)
|
|
208
|
+
return [tensor]
|
|
209
|
+
|
|
210
|
+
# Use the then API
|
|
211
|
+
future = future.then(callback)
|
|
212
|
+
|
|
213
|
+
# Verify callback hasn't executed yet
|
|
214
|
+
self.assertFalse(callback_executed)
|
|
215
|
+
self.assertEqual(tensor.item(), 1.0)
|
|
216
|
+
|
|
217
|
+
# Call wait() on the managed_work first to set up the future properly
|
|
218
|
+
managed_work.wait()
|
|
219
|
+
|
|
220
|
+
# Verify callback has executed
|
|
221
|
+
self.assertTrue(callback_executed)
|
|
222
|
+
self.assertEqual(tensor.item(), 3.0)
|
|
223
|
+
|
|
224
|
+
@parameterized.parameterized.expand(
|
|
225
|
+
[
|
|
226
|
+
("cpu", torch.device("cpu")),
|
|
227
|
+
("cuda", torch.device("cuda:0")),
|
|
228
|
+
]
|
|
229
|
+
)
|
|
230
|
+
def test_callbacks_changing_return_types(
|
|
231
|
+
self, name: str, device: torch.device
|
|
232
|
+
) -> None:
|
|
233
|
+
"""
|
|
234
|
+
Test that callbacks can change return types and that tensors are modified in-place.
|
|
235
|
+
This test demonstrates:
|
|
236
|
+
1. Callbacks changing return types (List[Tensor] -> Dict -> Tuple)
|
|
237
|
+
2. Using Future.value() instead of nonlocal
|
|
238
|
+
3. Verifying tensors are modified in-place for both approaches
|
|
239
|
+
"""
|
|
240
|
+
# Skip if CUDA is requested but not available
|
|
241
|
+
if device.type == "cuda" and not torch.cuda.is_available():
|
|
242
|
+
self.skipTest("CUDA not available")
|
|
243
|
+
|
|
244
|
+
# Create tensors to work with
|
|
245
|
+
tensor1: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device)
|
|
246
|
+
tensor2: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device) * 2
|
|
247
|
+
|
|
248
|
+
# Store original tensor memory addresses to verify in-place modification
|
|
249
|
+
tensor1_address = tensor1.data_ptr()
|
|
250
|
+
tensor2_address = tensor2.data_ptr()
|
|
251
|
+
|
|
252
|
+
# Create a simple work object
|
|
253
|
+
work = SimpleWork([tensor1, tensor2])
|
|
254
|
+
|
|
255
|
+
# Create a minimal manager object with just the wrap_future method
|
|
256
|
+
manager = Manager.__new__(Manager) # Create instance without calling __init__
|
|
257
|
+
manager.wrap_future = types.MethodType( # type: ignore
|
|
258
|
+
lambda self, fut, default, timeout=None: fut, manager
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
# Create the managed work
|
|
262
|
+
managed_work = _ManagedWork(manager, work, [tensor1, tensor2])
|
|
263
|
+
|
|
264
|
+
# Get the future
|
|
265
|
+
future = managed_work.get_future()
|
|
266
|
+
future = cast(Future[List[torch.Tensor]], future)
|
|
267
|
+
|
|
268
|
+
# First callback: Takes List[Tensor] and returns Dict[str, Tensor]
|
|
269
|
+
# Uses nonlocal to modify tensor1
|
|
270
|
+
def callback1(fut: Future[List[torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
|
271
|
+
tensors = fut.value()
|
|
272
|
+
nonlocal tensor1
|
|
273
|
+
# Modify tensor1 in-place using nonlocal
|
|
274
|
+
tensor1.mul_(3)
|
|
275
|
+
# Return a dictionary instead of a list
|
|
276
|
+
return {"first": tensors[0], "second": tensors[1]}
|
|
277
|
+
|
|
278
|
+
# Second callback: Takes Dict[str, Tensor] and returns Tuple[Tensor, float]
|
|
279
|
+
# Uses Future.value() to modify tensor2
|
|
280
|
+
def callback2(
|
|
281
|
+
fut: Future[Dict[str, torch.Tensor]],
|
|
282
|
+
) -> Tuple[torch.Tensor, float]:
|
|
283
|
+
data = fut.value()
|
|
284
|
+
# Modify tensor2 in-place using the value from the future
|
|
285
|
+
data["second"].add_(5) # Should modify tensor2 in-place
|
|
286
|
+
# Return a tuple instead of a dict
|
|
287
|
+
return (data["second"], data["first"].item())
|
|
288
|
+
|
|
289
|
+
# Third callback: Takes Tuple[Tensor, float] and returns a single Tensor
|
|
290
|
+
def callback3(fut: Future[Tuple[torch.Tensor, float]]) -> torch.Tensor:
|
|
291
|
+
tensor, value = fut.value()
|
|
292
|
+
# Create a new tensor based on the tuple values
|
|
293
|
+
result = tensor * value
|
|
294
|
+
return result
|
|
295
|
+
|
|
296
|
+
# Chain the callbacks
|
|
297
|
+
future = future.then(callback1)
|
|
298
|
+
future = future.then(callback2)
|
|
299
|
+
future = future.then(callback3)
|
|
300
|
+
|
|
301
|
+
# Call wait() to trigger the callbacks
|
|
302
|
+
managed_work.wait()
|
|
303
|
+
|
|
304
|
+
# Verify tensor1 was modified in-place (using nonlocal)
|
|
305
|
+
self.assertEqual(tensor1.item(), 3.0) # 1 * 3 = 3
|
|
306
|
+
self.assertEqual(tensor1.data_ptr(), tensor1_address) # Same memory address
|
|
307
|
+
|
|
308
|
+
# Verify tensor2 was modified in-place (using Future.value())
|
|
309
|
+
self.assertEqual(tensor2.item(), 7.0) # 2 + 5 = 7
|
|
310
|
+
self.assertEqual(tensor2.data_ptr(), tensor2_address) # Same memory address
|
|
311
|
+
|
|
312
|
+
# Get the final result from the future
|
|
313
|
+
final_result = future.wait()
|
|
314
|
+
|
|
315
|
+
# The final result should be tensor2 * tensor1.item() = 7 * 3 = 21
|
|
316
|
+
self.assertEqual(final_result.item(), 21.0)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
if __name__ == "__main__":
|
|
320
|
+
unittest.main()
|