torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. torchft/__init__.py +34 -0
  2. torchft/_test/diloco_trainer.py +287 -0
  3. torchft/_test/managed_work_test.py +320 -0
  4. torchft/_test_utils.py +111 -0
  5. torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
  6. torchft/_torchft.pyi +116 -0
  7. torchft/checkpointing/__init__.py +20 -0
  8. torchft/checkpointing/_rwlock.py +136 -0
  9. torchft/checkpointing/_serialization.py +39 -0
  10. torchft/checkpointing/http_transport.py +299 -0
  11. torchft/checkpointing/http_transport_bench.py +61 -0
  12. torchft/checkpointing/http_transport_test.py +146 -0
  13. torchft/checkpointing/pg_transport.py +306 -0
  14. torchft/checkpointing/pg_transport_bench.py +99 -0
  15. torchft/checkpointing/pg_transport_test.py +101 -0
  16. torchft/checkpointing/rwlock_test.py +58 -0
  17. torchft/checkpointing/transport.py +68 -0
  18. torchft/checkpointing/transport_test.py +161 -0
  19. torchft/collectives.py +415 -0
  20. torchft/collectives_test.py +212 -0
  21. torchft/coordination.py +39 -0
  22. torchft/coordination_test.py +29 -0
  23. torchft/data.py +77 -0
  24. torchft/data_test.py +39 -0
  25. torchft/ddp.py +105 -0
  26. torchft/ddp_test.py +68 -0
  27. torchft/diloco_regression_test.py +644 -0
  28. torchft/examples/slurm/README.md +34 -0
  29. torchft/examples/slurm/punisher.py +95 -0
  30. torchft/examples/slurm/runner.py +221 -0
  31. torchft/fsdp_test.py +102 -0
  32. torchft/futures.py +353 -0
  33. torchft/futures_test.py +140 -0
  34. torchft/http.py +13 -0
  35. torchft/lighthouse_test.py +163 -0
  36. torchft/local_sgd.py +796 -0
  37. torchft/local_sgd_integ_test.py +600 -0
  38. torchft/local_sgd_test.py +324 -0
  39. torchft/manager.py +1358 -0
  40. torchft/manager_integ_test.py +653 -0
  41. torchft/manager_test.py +911 -0
  42. torchft/multiprocessing.py +38 -0
  43. torchft/multiprocessing_dummy_context.py +135 -0
  44. torchft/multiprocessing_test.py +58 -0
  45. torchft/optim.py +63 -0
  46. torchft/optim_test.py +50 -0
  47. torchft/otel.py +134 -0
  48. torchft/parameter_server.py +195 -0
  49. torchft/parameter_server_test.py +47 -0
  50. torchft/process_group.py +2118 -0
  51. torchft/process_group_test.py +1028 -0
  52. torchft/quantization.py +686 -0
  53. torchft/quantization_test.py +131 -0
  54. torchft/torchx.py +89 -0
  55. torchft/utils.py +67 -0
  56. torchft/work.py +26 -0
  57. torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
  58. torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
  59. torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
  60. torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
  61. torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
torchft/__init__.py ADDED
@@ -0,0 +1,34 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torchft.data import DistributedSampler
8
+ from torchft.ddp import DistributedDataParallel
9
+ from torchft.manager import Manager
10
+ from torchft.optim import OptimizerWrapper as Optimizer
11
+ from torchft.otel import setup_logger
12
+ from torchft.process_group import (
13
+ ProcessGroupBabyNCCL,
14
+ ProcessGroupBabyXCCL,
15
+ ProcessGroupGloo,
16
+ ProcessGroupNCCL,
17
+ ProcessGroupXCCL,
18
+ )
19
+
20
+ setup_logger("torchft_quorums")
21
+ setup_logger("torchft_commits")
22
+ setup_logger("torchft_errors")
23
+
24
+ __all__ = (
25
+ "DistributedDataParallel",
26
+ "DistributedSampler",
27
+ "Manager",
28
+ "Optimizer",
29
+ "ProcessGroupNCCL",
30
+ "ProcessGroupXCCL",
31
+ "ProcessGroupBabyNCCL",
32
+ "ProcessGroupBabyXCCL",
33
+ "ProcessGroupGloo",
34
+ )
@@ -0,0 +1,287 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import copy
8
+ import logging
9
+ import os
10
+ from datetime import timedelta
11
+ from typing import Any, Dict
12
+
13
+ import torch
14
+ from torch import nn
15
+ from torch.distributed.tensor import DeviceMesh, DTensor
16
+
17
+ from torchft.local_sgd import DiLoCo
18
+ from torchft.manager import Manager
19
+ from torchft.manager_integ_test import MyModel, Runner
20
+ from torchft.process_group import (
21
+ FakeProcessGroupWrapper,
22
+ ProcessGroupBabyNCCL,
23
+ ProcessGroupGloo,
24
+ )
25
+
26
+ logger: logging.Logger = logging.getLogger(__name__)
27
+
28
+
29
+ class MultiModel(torch.nn.Module):
30
+ def __init__(self, in_dim: int = 3, out_dim: int = 4, n_layers: int = 1) -> None:
31
+ super().__init__()
32
+ self.layers = torch.nn.ModuleList()
33
+
34
+ def get_rand_inputs(
35
+ self, batch_size: int, device: torch.device = torch.device("cpu")
36
+ ) -> torch.Tensor:
37
+ raise
38
+
39
+ def get_rand_labels(
40
+ self, batch_size: int, device: torch.device = torch.device("cpu")
41
+ ) -> torch.Tensor:
42
+ raise
43
+
44
+
45
+ class MultiMyModel(MultiModel):
46
+ def __init__(self, in_dim: int = 3, out_dim: int = 4, n_layers: int = 1) -> None:
47
+ super().__init__()
48
+ self.in_dim = in_dim
49
+
50
+ for _ in range(n_layers):
51
+ self.layers.append(MyModel(in_dim, out_dim))
52
+ in_dim, out_dim = out_dim, in_dim
53
+
54
+ self.out_dim = in_dim
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ for layer in self.layers:
58
+ x = layer(x)
59
+ return x
60
+
61
+ def get_rand_inputs(
62
+ self, batch_size: int, device: torch.device = torch.device("cpu")
63
+ ) -> torch.Tensor:
64
+ return torch.rand(batch_size, self.in_dim, device=device)
65
+
66
+ def get_rand_labels(
67
+ self, batch_size: int, device: torch.device = torch.device("cpu")
68
+ ) -> torch.Tensor:
69
+ return torch.randint(self.out_dim, (batch_size,), device=device)
70
+
71
+
72
+ class DiLoCoTrainer:
73
+ """
74
+ A class that encapsulates the DiLoCo training process.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ rank: int,
80
+ store_port: int,
81
+ device: torch.device,
82
+ runner: Runner,
83
+ model_state_dict: dict[str, Any],
84
+ n_fragments: int,
85
+ diloco_args: dict[str, Any],
86
+ ) -> None:
87
+ """
88
+ Initialize the DiLoCoTrainer.
89
+
90
+ Args:
91
+ rank: The rank of the current process.
92
+ store_port: The port for the store.
93
+ device: The device to use for training.
94
+ runner: The runner instance.
95
+ train_loop_args: Additional arguments for the training loop.
96
+ """
97
+ self.rank: int = rank
98
+ self.store_port: int = store_port
99
+ self.device: torch.device = device
100
+ self.runner: Runner = runner
101
+
102
+ # Extract arguments from train_loop_args
103
+ self.model_state_dict: Dict[str, Any] = model_state_dict
104
+ self.n_fragments: int = n_fragments
105
+ self.diloco_args: dict[str, Any] = diloco_args
106
+
107
+ # Initialize components
108
+ self.model: MultiModel = self.setup_model()
109
+ self.inner_optimizer: torch.optim.Optimizer = self.setup_inner_optimizer()
110
+ self.outer_optimizers: list[torch.optim.Optimizer] = (
111
+ self.setup_outer_optimizers()
112
+ )
113
+
114
+ self.pg: FakeProcessGroupWrapper = self.setup_pg()
115
+ # Set up the process group for the event injector
116
+ self.runner.event_injector.set_pg(self.pg)
117
+
118
+ self.manager: Manager = self.setup_manager()
119
+
120
+ self.device_mesh: None | DeviceMesh = None
121
+ self.setup_distributed()
122
+
123
+ self.criterion: nn.CrossEntropyLoss = nn.CrossEntropyLoss()
124
+
125
+ self.diloco: DiLoCo | None = None
126
+
127
+ def setup_model(self) -> MultiModel:
128
+ """Set up the model and move it to the device."""
129
+ model = MultiMyModel(2, 3, self.n_fragments)
130
+ model.load_state_dict(self.model_state_dict)
131
+ model.to(self.device)
132
+ return model
133
+
134
+ def setup_inner_optimizer(self) -> torch.optim.Optimizer:
135
+ """Set up the inner optimizer."""
136
+ return torch.optim.AdamW(
137
+ self.model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
138
+ )
139
+
140
+ def setup_outer_optimizers(self) -> list[torch.optim.Optimizer]:
141
+ """Set up outer optimizers."""
142
+ # Setup inner optimizer
143
+ # Create one outer optimizer per fragment
144
+ outer_optimizers = []
145
+ for _, layers in enumerate(self.model.layers):
146
+ outer_optimizers.append(
147
+ torch.optim.SGD(
148
+ layers.parameters(), lr=0.7, momentum=0.9, nesterov=True
149
+ )
150
+ )
151
+ return outer_optimizers
152
+
153
+ def setup_pg(self) -> FakeProcessGroupWrapper:
154
+ if self.device.type == "cuda":
155
+ return FakeProcessGroupWrapper(ProcessGroupBabyNCCL())
156
+ else:
157
+ return FakeProcessGroupWrapper(
158
+ ProcessGroupGloo(timeout=timedelta(seconds=10))
159
+ )
160
+
161
+ def setup_manager(self) -> Manager:
162
+ """Set up the process group and manager."""
163
+ print(
164
+ f"worker {self.runner.replica_id=} {self.rank=} {self.runner.world_size=} starting"
165
+ )
166
+
167
+ # Create manager with all arguments passed directly
168
+ return Manager(
169
+ pg=self.pg,
170
+ min_replica_size=2,
171
+ use_async_quorum=False,
172
+ load_state_dict=self.load_state_dict,
173
+ state_dict=self.state_dict,
174
+ replica_id=str(self.runner.replica_id),
175
+ store_addr="localhost",
176
+ store_port=self.store_port,
177
+ rank=self.rank,
178
+ world_size=self.runner.world_size,
179
+ lighthouse_addr=self.runner.lighthouse_address,
180
+ port=19530 + self.runner.replica_id,
181
+ connect_timeout=timedelta(seconds=10),
182
+ quorum_timeout=timedelta(seconds=10),
183
+ timeout=timedelta(seconds=10),
184
+ **self.runner.manager_args, # type: ignore
185
+ )
186
+
187
+ def setup_distributed(self) -> None:
188
+ """Set up distributed training."""
189
+ # Initialize default group for device mesh to work
190
+ if not torch.distributed.is_initialized():
191
+ # TODO: remove this try-except once pytorch is updated to 2.8.0 and can use localhost:0
192
+ try:
193
+ torch.distributed.init_process_group(
194
+ init_method="tcp://localhost:0",
195
+ rank=self.rank,
196
+ world_size=self.runner.world_size,
197
+ )
198
+ except ValueError:
199
+ os.environ["MASTER_ADDR"] = "localhost"
200
+ os.environ["MASTER_PORT"] = "0"
201
+ os.environ["WORLD_SIZE"] = str(self.runner.world_size)
202
+ os.environ["RANK"] = str(self.rank)
203
+
204
+ self.device_mesh = DeviceMesh(
205
+ self.device.type,
206
+ torch.arange(self.runner.world_size),
207
+ )
208
+
209
+ # Convert model parameters to DTensor
210
+ for layer in self.model.layers:
211
+ if isinstance(layer, nn.Linear):
212
+ for param in layer.parameters():
213
+ param = DTensor.from_local(
214
+ param,
215
+ device_mesh=self.device_mesh,
216
+ )
217
+
218
+ def load_state_dict(self, state_dict: Dict[str, Dict[str, object]]) -> None:
219
+ """
220
+ Load the state dictionary.
221
+
222
+ Args:
223
+ state_dict: The state dictionary to load.
224
+ """
225
+ assert self.diloco is not None
226
+
227
+ self.model.load_state_dict(state_dict["model"])
228
+ self.model.to(self.device)
229
+
230
+ self.inner_optimizer.load_state_dict(state_dict["inner_optim"])
231
+
232
+ def state_dict(self) -> Dict[str, Dict[str, object]]:
233
+ """
234
+ Get the state dictionary.
235
+
236
+ Returns:
237
+ The state dictionary.
238
+ """
239
+ assert self.diloco is not None
240
+
241
+ return {
242
+ "model": self.model.state_dict(),
243
+ "inner_optim": self.inner_optimizer.state_dict(),
244
+ }
245
+
246
+ def train_loop(self) -> dict[str, Any]:
247
+ """Run the training loop."""
248
+ # Ensure sync_every is set in diloco_args
249
+ all_state_dicts = {}
250
+
251
+ if "sync_every" not in self.diloco_args:
252
+ self.diloco_args["sync_every"] = 2
253
+
254
+ with DiLoCo(
255
+ self.manager,
256
+ [layer for layer in self.model.layers],
257
+ self.inner_optimizer,
258
+ self.outer_optimizers,
259
+ backup_device=self.device,
260
+ **self.diloco_args,
261
+ ) as self.diloco:
262
+ while True:
263
+ self.runner.event_injector.check(self.rank, self.manager.current_step())
264
+
265
+ manager_curr_step = self.manager.current_step()
266
+ if manager_curr_step not in all_state_dicts:
267
+ # Store the manager state dict, converting to the right type
268
+ all_state_dicts[manager_curr_step] = copy.deepcopy(
269
+ self.manager._manager_state_dict()
270
+ )
271
+
272
+ batch_size = 1
273
+ inputs = self.model.get_rand_inputs(batch_size, device=self.device)
274
+ labels = self.model.get_rand_labels(batch_size, device=self.device)
275
+
276
+ out = self.model(inputs)
277
+ loss = self.criterion(out, labels)
278
+
279
+ self.inner_optimizer.zero_grad()
280
+ loss.backward()
281
+ self.inner_optimizer.step()
282
+
283
+ # after 4 model updates then break
284
+ if self.manager.current_step() >= 4:
285
+ break
286
+
287
+ return all_state_dicts
@@ -0,0 +1,320 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import types
8
+ import unittest
9
+ from datetime import timedelta
10
+ from typing import Callable, cast, Dict, List, Optional, Tuple, TypeVar
11
+
12
+ # Define a type variable for the Future's value type
13
+ T = TypeVar("T")
14
+
15
+ import parameterized
16
+ import torch
17
+ from torch.distributed.distributed_c10d import Work
18
+ from torch.futures import Future
19
+
20
+ from torchft.manager import _ManagedWork, Manager
21
+
22
+
23
+ class SimpleWork(Work):
24
+ """A simple implementation of torch.distributed.Work for testing."""
25
+
26
+ def __init__(self, tensors: List[torch.Tensor]) -> None:
27
+ super().__init__()
28
+ self._tensors = tensors
29
+ self._future: Future[List[torch.Tensor]] = torch.futures.Future()
30
+ self._is_completed: bool = False
31
+
32
+ def wait(self, timeout: Optional[timedelta] = None) -> bool:
33
+ self._is_completed = True
34
+ self._future.set_result(self._tensors)
35
+ return True
36
+
37
+ def get_future(self) -> Future[List[torch.Tensor]]:
38
+ return self._future
39
+
40
+
41
+ class TestManagedWork(unittest.TestCase):
42
+ @parameterized.parameterized.expand(
43
+ [
44
+ ("cpu", torch.device("cpu")),
45
+ ("cuda", torch.device("cuda:0")),
46
+ ]
47
+ )
48
+ def test_callbacks_execute_after_wait(
49
+ self, name: str, device: torch.device
50
+ ) -> None:
51
+ """Test that callbacks are only executed after wait() is called."""
52
+ # Skip if CUDA is requested but not available
53
+ if device.type == "cuda" and not torch.cuda.is_available():
54
+ self.skipTest("CUDA not available")
55
+
56
+ # Create a tensor to work with
57
+ tensor: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device)
58
+
59
+ # Create a simple work object
60
+ work = SimpleWork([tensor])
61
+
62
+ # Create a minimal manager object with just the wrap_future method
63
+ manager = Manager.__new__(Manager) # Create instance without calling __init__
64
+ # We're using types.MethodType to attach a method to the manager instance
65
+ # This is just for testing purposes
66
+ manager.wrap_future = types.MethodType( # type: ignore
67
+ lambda self, fut, default, timeout=None: fut, manager
68
+ )
69
+
70
+ # Create the managed work
71
+ managed_work = _ManagedWork(manager, work, [tensor])
72
+
73
+ # Track callback execution
74
+ callback_executed: bool = False
75
+
76
+ def callback(fut: Future[object]) -> List[torch.Tensor]:
77
+ # Cast to the expected type
78
+ nonlocal callback_executed, tensor
79
+ callback_executed = True
80
+ # Multiply tensor by 2 to verify the callback ran
81
+ tensor.mul_(2)
82
+ return [tensor]
83
+
84
+ # Add the callback
85
+ fut = managed_work.get_future()
86
+ fut = fut.then(callback)
87
+
88
+ # Verify callback hasn't executed yet
89
+ self.assertFalse(callback_executed)
90
+ self.assertEqual(tensor.item(), 1.0)
91
+
92
+ # Call wait() which should trigger the callback
93
+ managed_work.wait()
94
+
95
+ # Verify callback has executed
96
+ self.assertTrue(callback_executed)
97
+ self.assertEqual(tensor.item(), 2.0)
98
+
99
+ @parameterized.parameterized.expand(
100
+ [
101
+ ("cpu", torch.device("cpu")),
102
+ ("cuda", torch.device("cuda:0")),
103
+ ]
104
+ )
105
+ def test_multiple_callbacks_execute_in_order(
106
+ self, name: str, device: torch.device
107
+ ) -> None:
108
+ """Test that multiple callbacks are executed in the order they were added."""
109
+ # Skip if CUDA is requested but not available
110
+ if device.type == "cuda" and not torch.cuda.is_available():
111
+ self.skipTest("CUDA not available")
112
+
113
+ # Create a tensor to work with
114
+ tensor: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device)
115
+
116
+ # Create a simple work object
117
+ work = SimpleWork([tensor])
118
+
119
+ # Create a minimal manager object with just the wrap_future method
120
+ manager = Manager.__new__(Manager) # Create instance without calling __init__
121
+ manager.wrap_future = types.MethodType( # type: ignore
122
+ lambda self, fut, default, timeout=None: fut, manager
123
+ )
124
+
125
+ # Create the managed work
126
+ managed_work = _ManagedWork(manager, work, [tensor])
127
+
128
+ # Track execution order
129
+ execution_order: List[int] = []
130
+
131
+ def callback1(fut: Future[list[torch.Tensor]]) -> List[torch.Tensor]:
132
+ nonlocal tensor
133
+ execution_order.append(1)
134
+ tensor.add_(1)
135
+ return [tensor]
136
+
137
+ def callback2(fut: Future[list[torch.Tensor]]) -> List[torch.Tensor]:
138
+ nonlocal tensor
139
+ execution_order.append(2)
140
+ tensor.add_(2)
141
+ return [tensor]
142
+
143
+ def callback3(fut: Future[list[torch.Tensor]]) -> List[torch.Tensor]:
144
+ nonlocal tensor
145
+ execution_order.append(3)
146
+ tensor.add_(3)
147
+ return [tensor]
148
+
149
+ # Add callbacks
150
+ fut = managed_work.get_future()
151
+ fut = cast(Future[list[torch.Tensor]], fut)
152
+ fut = fut.then(callback1)
153
+ fut = fut.then(callback2)
154
+ fut = fut.then(callback3)
155
+
156
+ # Verify no callbacks have executed yet
157
+ self.assertEqual(len(execution_order), 0)
158
+ self.assertEqual(tensor.item(), 1.0)
159
+
160
+ # Call wait() which should trigger the callbacks
161
+ managed_work.wait()
162
+
163
+ # Verify callbacks executed in order
164
+ self.assertEqual(execution_order, [1, 2, 3])
165
+
166
+ # Each callback adds to the tensor, so final value should be 1 + 1 + 2 + 3 = 7
167
+ self.assertEqual(tensor.item(), 7.0)
168
+
169
+ @parameterized.parameterized.expand(
170
+ [
171
+ ("cpu", torch.device("cpu")),
172
+ ("cuda", torch.device("cuda:0")),
173
+ ]
174
+ )
175
+ def test_future_then_api(self, name: str, device: torch.device) -> None:
176
+ """Test that the future's then API works correctly with ManagedWork."""
177
+ # Skip if CUDA is requested but not available
178
+ if device.type == "cuda" and not torch.cuda.is_available():
179
+ self.skipTest("CUDA not available")
180
+
181
+ # Create a tensor to work with
182
+ tensor: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device)
183
+
184
+ # Create a simple work object
185
+ work = SimpleWork([tensor])
186
+
187
+ # Create a minimal manager object with just the wrap_future method
188
+ manager = Manager.__new__(Manager) # Create instance without calling __init__
189
+ manager.wrap_future = types.MethodType( # type: ignore
190
+ lambda self, fut, default, timeout=None: fut, manager
191
+ )
192
+
193
+ # Create the managed work
194
+ managed_work = _ManagedWork(manager, work, [tensor])
195
+
196
+ # Get the future
197
+ future = managed_work.get_future()
198
+
199
+ # Track callback execution
200
+ callback_executed: bool = False
201
+
202
+ def callback(fut: Future[object]) -> List[torch.Tensor]:
203
+ # Cast to the expected type
204
+ nonlocal callback_executed, tensor
205
+ callback_executed = True
206
+ # Multiply tensor by 3 to verify the callback ran
207
+ tensor.mul_(3)
208
+ return [tensor]
209
+
210
+ # Use the then API
211
+ future = future.then(callback)
212
+
213
+ # Verify callback hasn't executed yet
214
+ self.assertFalse(callback_executed)
215
+ self.assertEqual(tensor.item(), 1.0)
216
+
217
+ # Call wait() on the managed_work first to set up the future properly
218
+ managed_work.wait()
219
+
220
+ # Verify callback has executed
221
+ self.assertTrue(callback_executed)
222
+ self.assertEqual(tensor.item(), 3.0)
223
+
224
+ @parameterized.parameterized.expand(
225
+ [
226
+ ("cpu", torch.device("cpu")),
227
+ ("cuda", torch.device("cuda:0")),
228
+ ]
229
+ )
230
+ def test_callbacks_changing_return_types(
231
+ self, name: str, device: torch.device
232
+ ) -> None:
233
+ """
234
+ Test that callbacks can change return types and that tensors are modified in-place.
235
+ This test demonstrates:
236
+ 1. Callbacks changing return types (List[Tensor] -> Dict -> Tuple)
237
+ 2. Using Future.value() instead of nonlocal
238
+ 3. Verifying tensors are modified in-place for both approaches
239
+ """
240
+ # Skip if CUDA is requested but not available
241
+ if device.type == "cuda" and not torch.cuda.is_available():
242
+ self.skipTest("CUDA not available")
243
+
244
+ # Create tensors to work with
245
+ tensor1: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device)
246
+ tensor2: torch.Tensor = torch.ones(1, dtype=torch.float32, device=device) * 2
247
+
248
+ # Store original tensor memory addresses to verify in-place modification
249
+ tensor1_address = tensor1.data_ptr()
250
+ tensor2_address = tensor2.data_ptr()
251
+
252
+ # Create a simple work object
253
+ work = SimpleWork([tensor1, tensor2])
254
+
255
+ # Create a minimal manager object with just the wrap_future method
256
+ manager = Manager.__new__(Manager) # Create instance without calling __init__
257
+ manager.wrap_future = types.MethodType( # type: ignore
258
+ lambda self, fut, default, timeout=None: fut, manager
259
+ )
260
+
261
+ # Create the managed work
262
+ managed_work = _ManagedWork(manager, work, [tensor1, tensor2])
263
+
264
+ # Get the future
265
+ future = managed_work.get_future()
266
+ future = cast(Future[List[torch.Tensor]], future)
267
+
268
+ # First callback: Takes List[Tensor] and returns Dict[str, Tensor]
269
+ # Uses nonlocal to modify tensor1
270
+ def callback1(fut: Future[List[torch.Tensor]]) -> Dict[str, torch.Tensor]:
271
+ tensors = fut.value()
272
+ nonlocal tensor1
273
+ # Modify tensor1 in-place using nonlocal
274
+ tensor1.mul_(3)
275
+ # Return a dictionary instead of a list
276
+ return {"first": tensors[0], "second": tensors[1]}
277
+
278
+ # Second callback: Takes Dict[str, Tensor] and returns Tuple[Tensor, float]
279
+ # Uses Future.value() to modify tensor2
280
+ def callback2(
281
+ fut: Future[Dict[str, torch.Tensor]],
282
+ ) -> Tuple[torch.Tensor, float]:
283
+ data = fut.value()
284
+ # Modify tensor2 in-place using the value from the future
285
+ data["second"].add_(5) # Should modify tensor2 in-place
286
+ # Return a tuple instead of a dict
287
+ return (data["second"], data["first"].item())
288
+
289
+ # Third callback: Takes Tuple[Tensor, float] and returns a single Tensor
290
+ def callback3(fut: Future[Tuple[torch.Tensor, float]]) -> torch.Tensor:
291
+ tensor, value = fut.value()
292
+ # Create a new tensor based on the tuple values
293
+ result = tensor * value
294
+ return result
295
+
296
+ # Chain the callbacks
297
+ future = future.then(callback1)
298
+ future = future.then(callback2)
299
+ future = future.then(callback3)
300
+
301
+ # Call wait() to trigger the callbacks
302
+ managed_work.wait()
303
+
304
+ # Verify tensor1 was modified in-place (using nonlocal)
305
+ self.assertEqual(tensor1.item(), 3.0) # 1 * 3 = 3
306
+ self.assertEqual(tensor1.data_ptr(), tensor1_address) # Same memory address
307
+
308
+ # Verify tensor2 was modified in-place (using Future.value())
309
+ self.assertEqual(tensor2.item(), 7.0) # 2 + 5 = 7
310
+ self.assertEqual(tensor2.data_ptr(), tensor2_address) # Same memory address
311
+
312
+ # Get the final result from the future
313
+ final_result = future.wait()
314
+
315
+ # The final result should be tensor2 * tensor1.item() = 7 * 3 = 21
316
+ self.assertEqual(final_result.item(), 21.0)
317
+
318
+
319
+ if __name__ == "__main__":
320
+ unittest.main()