torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchft/__init__.py +34 -0
- torchft/_test/diloco_trainer.py +287 -0
- torchft/_test/managed_work_test.py +320 -0
- torchft/_test_utils.py +111 -0
- torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
- torchft/_torchft.pyi +116 -0
- torchft/checkpointing/__init__.py +20 -0
- torchft/checkpointing/_rwlock.py +136 -0
- torchft/checkpointing/_serialization.py +39 -0
- torchft/checkpointing/http_transport.py +299 -0
- torchft/checkpointing/http_transport_bench.py +61 -0
- torchft/checkpointing/http_transport_test.py +146 -0
- torchft/checkpointing/pg_transport.py +306 -0
- torchft/checkpointing/pg_transport_bench.py +99 -0
- torchft/checkpointing/pg_transport_test.py +101 -0
- torchft/checkpointing/rwlock_test.py +58 -0
- torchft/checkpointing/transport.py +68 -0
- torchft/checkpointing/transport_test.py +161 -0
- torchft/collectives.py +415 -0
- torchft/collectives_test.py +212 -0
- torchft/coordination.py +39 -0
- torchft/coordination_test.py +29 -0
- torchft/data.py +77 -0
- torchft/data_test.py +39 -0
- torchft/ddp.py +105 -0
- torchft/ddp_test.py +68 -0
- torchft/diloco_regression_test.py +644 -0
- torchft/examples/slurm/README.md +34 -0
- torchft/examples/slurm/punisher.py +95 -0
- torchft/examples/slurm/runner.py +221 -0
- torchft/fsdp_test.py +102 -0
- torchft/futures.py +353 -0
- torchft/futures_test.py +140 -0
- torchft/http.py +13 -0
- torchft/lighthouse_test.py +163 -0
- torchft/local_sgd.py +796 -0
- torchft/local_sgd_integ_test.py +600 -0
- torchft/local_sgd_test.py +324 -0
- torchft/manager.py +1358 -0
- torchft/manager_integ_test.py +653 -0
- torchft/manager_test.py +911 -0
- torchft/multiprocessing.py +38 -0
- torchft/multiprocessing_dummy_context.py +135 -0
- torchft/multiprocessing_test.py +58 -0
- torchft/optim.py +63 -0
- torchft/optim_test.py +50 -0
- torchft/otel.py +134 -0
- torchft/parameter_server.py +195 -0
- torchft/parameter_server_test.py +47 -0
- torchft/process_group.py +2118 -0
- torchft/process_group_test.py +1028 -0
- torchft/quantization.py +686 -0
- torchft/quantization_test.py +131 -0
- torchft/torchx.py +89 -0
- torchft/utils.py +67 -0
- torchft/work.py +26 -0
- torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
- torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
- torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
- torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
- torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import queue
|
|
8
|
+
import time
|
|
9
|
+
from datetime import timedelta
|
|
10
|
+
from multiprocessing.connection import Connection
|
|
11
|
+
from typing import Union
|
|
12
|
+
|
|
13
|
+
import torch.multiprocessing as mp
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class _MonitoredPipe:
|
|
17
|
+
def __init__(self, pipe: "Connection[object, object]") -> None: # type: ignore
|
|
18
|
+
self._pipe = pipe
|
|
19
|
+
|
|
20
|
+
def send(self, obj: object) -> None:
|
|
21
|
+
self._pipe.send(obj)
|
|
22
|
+
|
|
23
|
+
def recv(self, timeout: Union[float, timedelta]) -> object:
|
|
24
|
+
if isinstance(timeout, timedelta):
|
|
25
|
+
timeout = timeout.total_seconds()
|
|
26
|
+
if self._pipe.poll(timeout):
|
|
27
|
+
out = self._pipe.recv()
|
|
28
|
+
if isinstance(out, Exception):
|
|
29
|
+
raise out
|
|
30
|
+
return out
|
|
31
|
+
else:
|
|
32
|
+
raise TimeoutError(f"pipe.recv() timed out after {timeout} seconds")
|
|
33
|
+
|
|
34
|
+
def close(self) -> None:
|
|
35
|
+
self._pipe.close()
|
|
36
|
+
|
|
37
|
+
def closed(self) -> bool:
|
|
38
|
+
return self._pipe.closed
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
Multiprocessing Dummy Context
|
|
9
|
+
=========================
|
|
10
|
+
|
|
11
|
+
This module provides a context-like interface for multiprocessing.dummy,
|
|
12
|
+
which is a wrapper around the threading module that provides a multiprocessing-like
|
|
13
|
+
interface but uses threads instead of processes.
|
|
14
|
+
|
|
15
|
+
This allows code that uses multiprocessing.get_context() to work with
|
|
16
|
+
multiprocessing.dummy by providing a compatible interface.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import multiprocessing.dummy as mp
|
|
20
|
+
import threading
|
|
21
|
+
from typing import Callable, Iterable, Mapping
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class DummyContext:
|
|
25
|
+
"""
|
|
26
|
+
A context-like class for multiprocessing.dummy that mimics the interface
|
|
27
|
+
of a context returned by multiprocessing.get_context().
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, method: object = None) -> None:
|
|
31
|
+
"""
|
|
32
|
+
Initialize the dummy context.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
method: Ignored, only for compatibility with multiprocessing.get_context()
|
|
36
|
+
"""
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
def Process(
|
|
40
|
+
self,
|
|
41
|
+
group: object = None,
|
|
42
|
+
target: Callable[..., object] | None = None,
|
|
43
|
+
name: str | None = None,
|
|
44
|
+
args: Iterable[object] = (),
|
|
45
|
+
kwargs: Mapping[str, object] = {},
|
|
46
|
+
daemon: bool | None = None,
|
|
47
|
+
) -> mp.DummyProcess:
|
|
48
|
+
"""
|
|
49
|
+
Create a Process using multiprocessing.dummy.Process.
|
|
50
|
+
"""
|
|
51
|
+
return mp.Process(
|
|
52
|
+
group=group, target=target, name=name, args=args, kwargs=kwargs
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
def Pipe(
|
|
56
|
+
self, duplex: bool = True
|
|
57
|
+
) -> tuple[mp.connection.Connection, mp.connection.Connection]:
|
|
58
|
+
"""
|
|
59
|
+
Create a Pipe using multiprocessing.dummy.Pipe.
|
|
60
|
+
"""
|
|
61
|
+
return mp.Pipe(duplex)
|
|
62
|
+
|
|
63
|
+
def Queue(self, maxsize: int = 0) -> mp.Queue:
|
|
64
|
+
"""
|
|
65
|
+
Create a Queue using multiprocessing.dummy.Queue.
|
|
66
|
+
"""
|
|
67
|
+
return mp.Queue(maxsize)
|
|
68
|
+
|
|
69
|
+
def Event(self) -> threading.Event:
|
|
70
|
+
"""
|
|
71
|
+
Create an Event using multiprocessing.dummy.Event.
|
|
72
|
+
"""
|
|
73
|
+
return mp.Event()
|
|
74
|
+
|
|
75
|
+
def Lock(self) -> threading.Lock:
|
|
76
|
+
"""
|
|
77
|
+
Create a Lock using multiprocessing.dummy.Lock.
|
|
78
|
+
"""
|
|
79
|
+
return mp.Lock()
|
|
80
|
+
|
|
81
|
+
def RLock(self) -> threading.RLock:
|
|
82
|
+
"""
|
|
83
|
+
Create an RLock using multiprocessing.dummy.RLock.
|
|
84
|
+
"""
|
|
85
|
+
return mp.RLock()
|
|
86
|
+
|
|
87
|
+
def Semaphore(self, value: int = 1) -> threading.Semaphore:
|
|
88
|
+
"""
|
|
89
|
+
Create a Semaphore using multiprocessing.dummy.Semaphore.
|
|
90
|
+
"""
|
|
91
|
+
return mp.Semaphore(value)
|
|
92
|
+
|
|
93
|
+
def BoundedSemaphore(self, value: int = 1) -> threading.BoundedSemaphore:
|
|
94
|
+
"""
|
|
95
|
+
Create a BoundedSemaphore using multiprocessing.dummy.BoundedSemaphore.
|
|
96
|
+
"""
|
|
97
|
+
return mp.BoundedSemaphore(value)
|
|
98
|
+
|
|
99
|
+
def Condition(
|
|
100
|
+
self, lock: threading.Lock | threading.RLock | None = None
|
|
101
|
+
) -> threading.Condition:
|
|
102
|
+
"""
|
|
103
|
+
Create a Condition using multiprocessing.dummy.Condition.
|
|
104
|
+
"""
|
|
105
|
+
return mp.Condition(lock)
|
|
106
|
+
|
|
107
|
+
def Manager(self) -> object:
|
|
108
|
+
"""
|
|
109
|
+
Create a Manager using multiprocessing.dummy.Manager.
|
|
110
|
+
"""
|
|
111
|
+
return mp.Manager()
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def get_context(method: object = None) -> DummyContext:
|
|
115
|
+
"""
|
|
116
|
+
Return a context object for multiprocessing.dummy.
|
|
117
|
+
|
|
118
|
+
This function mimics multiprocessing.get_context() but returns a DummyContext
|
|
119
|
+
that works with multiprocessing.dummy. This can be used to patch
|
|
120
|
+
multiprocessing.dummy like so
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
```
|
|
124
|
+
import multiprocessing.dummy as mp
|
|
125
|
+
from torchft.multiprocessing_dummy_context import get_context
|
|
126
|
+
mp.get_context = get_context
|
|
127
|
+
```
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
method: Ignored, only for compatibility with multiprocessing.get_context()
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
A DummyContext instance
|
|
134
|
+
"""
|
|
135
|
+
return DummyContext(method)
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from multiprocessing.connection import Connection
|
|
2
|
+
from unittest import TestCase
|
|
3
|
+
|
|
4
|
+
import torch.multiprocessing as mp
|
|
5
|
+
|
|
6
|
+
from torchft.multiprocessing import _MonitoredPipe
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def pipe_get(q: "Connection[object, object]") -> None: # type: ignore
|
|
10
|
+
q.recv()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def pipe_put(q: "Connection[object, object]") -> None: # type: ignore
|
|
14
|
+
q.recv()
|
|
15
|
+
q.send(1)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MultiprocessingTest(TestCase):
|
|
19
|
+
def test_monitored_queue_put(self) -> None:
|
|
20
|
+
ctx = mp.get_context("fork")
|
|
21
|
+
local, remote = ctx.Pipe()
|
|
22
|
+
p = ctx.Process(target=pipe_get, args=(remote,), daemon=True)
|
|
23
|
+
p.start()
|
|
24
|
+
del remote
|
|
25
|
+
|
|
26
|
+
mq = _MonitoredPipe(local)
|
|
27
|
+
mq.send(1)
|
|
28
|
+
with self.assertRaisesRegex(
|
|
29
|
+
(ConnectionResetError, BrokenPipeError),
|
|
30
|
+
"(Connection reset by peer|Broken pipe)",
|
|
31
|
+
):
|
|
32
|
+
while True:
|
|
33
|
+
mq.send(1)
|
|
34
|
+
|
|
35
|
+
mq.close()
|
|
36
|
+
assert mq.closed()
|
|
37
|
+
|
|
38
|
+
def test_monitored_queue_get(self) -> None:
|
|
39
|
+
ctx = mp.get_context("fork")
|
|
40
|
+
local, remote = ctx.Pipe()
|
|
41
|
+
p = ctx.Process(target=pipe_put, args=(remote,), daemon=True)
|
|
42
|
+
p.start()
|
|
43
|
+
del remote
|
|
44
|
+
|
|
45
|
+
mq = _MonitoredPipe(local)
|
|
46
|
+
|
|
47
|
+
with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"):
|
|
48
|
+
mq.recv(timeout=0.0)
|
|
49
|
+
|
|
50
|
+
# continue
|
|
51
|
+
mq.send(1)
|
|
52
|
+
|
|
53
|
+
self.assertEqual(mq.recv(timeout=10), 1)
|
|
54
|
+
with self.assertRaises(EOFError):
|
|
55
|
+
mq.recv(timeout=10)
|
|
56
|
+
|
|
57
|
+
mq.close()
|
|
58
|
+
assert mq.closed()
|
torchft/optim.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
Optimizers
|
|
9
|
+
============
|
|
10
|
+
|
|
11
|
+
This module implements an optimizer wrapper that works with the Manager to provide fault tolerance.
|
|
12
|
+
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from typing import Any, Dict, List, Mapping, Optional, TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch.optim import Optimizer
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from torchft.manager import Manager
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class OptimizerWrapper(Optimizer):
|
|
25
|
+
"""
|
|
26
|
+
This wraps any provided optimizer and in conjunction with the manager will provide fault tolerance.
|
|
27
|
+
|
|
28
|
+
zero_grad() must be called at the start of the forwards pass and step() must
|
|
29
|
+
be called at the end of the backwards pass.
|
|
30
|
+
|
|
31
|
+
Depending on the state of the manager, the optimizer will either commit the
|
|
32
|
+
gradients to the wrapped optimizer or ignore them.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, manager: "Manager", optim: Optimizer) -> None:
|
|
36
|
+
self.optim = optim
|
|
37
|
+
self.manager = manager
|
|
38
|
+
|
|
39
|
+
def add_param_group(self, param_group: Dict[str, Any]) -> None:
|
|
40
|
+
self.optim.add_param_group(param_group)
|
|
41
|
+
|
|
42
|
+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
|
43
|
+
self.optim.load_state_dict(state_dict)
|
|
44
|
+
|
|
45
|
+
def state_dict(self) -> Dict[str, Any]:
|
|
46
|
+
return self.optim.state_dict()
|
|
47
|
+
|
|
48
|
+
def zero_grad(self, set_to_none: bool = True) -> None:
|
|
49
|
+
self.manager.start_quorum()
|
|
50
|
+
self.optim.zero_grad(set_to_none)
|
|
51
|
+
|
|
52
|
+
def step(self, closure: Optional[object] = None) -> None:
|
|
53
|
+
assert closure is None, "optimizers that use closures are not supported"
|
|
54
|
+
if self.manager.should_commit():
|
|
55
|
+
self.optim.step()
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def param_groups(self) -> List[Dict[str, Any]]:
|
|
59
|
+
return self.optim.param_groups
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def state(self) -> Mapping[torch.Tensor, object]:
|
|
63
|
+
return self.optim.state
|
torchft/optim_test.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from unittest import TestCase
|
|
8
|
+
from unittest.mock import create_autospec, MagicMock
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from torch.nn import Linear
|
|
12
|
+
from torch.optim import AdamW
|
|
13
|
+
|
|
14
|
+
from torchft.manager import Manager
|
|
15
|
+
from torchft.optim import OptimizerWrapper
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TestOptim(TestCase):
|
|
19
|
+
def test_optimizer_wrapper(self) -> None:
|
|
20
|
+
manager = create_autospec(Manager)
|
|
21
|
+
|
|
22
|
+
m = Linear(3, 4)
|
|
23
|
+
base_optim = AdamW(m.parameters())
|
|
24
|
+
optim = OptimizerWrapper(manager, base_optim)
|
|
25
|
+
optim.add_param_group(
|
|
26
|
+
{
|
|
27
|
+
"params": [],
|
|
28
|
+
"lr": 1e-4,
|
|
29
|
+
}
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
# test state_dict handling
|
|
33
|
+
optim.load_state_dict(optim.state_dict())
|
|
34
|
+
|
|
35
|
+
optim.zero_grad()
|
|
36
|
+
self.assertEqual(manager.start_quorum.call_count, 1)
|
|
37
|
+
|
|
38
|
+
b = torch.rand(3)
|
|
39
|
+
m(b).sum().backward()
|
|
40
|
+
|
|
41
|
+
manager.should_commit.return_value = True
|
|
42
|
+
optim.step()
|
|
43
|
+
manager.should_commit.return_value = False
|
|
44
|
+
optim.step()
|
|
45
|
+
self.assertEqual(len(optim.param_groups), 2)
|
|
46
|
+
self.assertEqual(optim.param_groups[1]["lr"], 1e-4)
|
|
47
|
+
self.assertEqual(optim.param_groups[1]["params"], [])
|
|
48
|
+
self.assertEqual(len(optim.state), len(list(m.parameters())))
|
|
49
|
+
|
|
50
|
+
self.assertEqual(manager.should_commit.call_count, 2)
|
torchft/otel.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import time
|
|
11
|
+
from typing import Any, List, Sequence, TYPE_CHECKING
|
|
12
|
+
|
|
13
|
+
from opentelemetry._logs import set_logger_provider
|
|
14
|
+
|
|
15
|
+
from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter
|
|
16
|
+
from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler
|
|
17
|
+
from opentelemetry.sdk._logs.export import BatchLogRecordProcessor
|
|
18
|
+
from opentelemetry.sdk.resources import Resource
|
|
19
|
+
|
|
20
|
+
# These types are available in opentelemetry-sdk but Pyre's type stubs
|
|
21
|
+
# don't include them. We import them at runtime and provide type aliases for
|
|
22
|
+
# static type checking.
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
# pyre-fixme[33]: Aliasing to Any is prohibited. opentelemetry-sdk lacks type stubs.
|
|
25
|
+
ReadableLogRecord = Any
|
|
26
|
+
# pyre-fixme[33]: Aliasing to Any is prohibited. opentelemetry-sdk lacks type stubs.
|
|
27
|
+
LogRecordExporter = Any
|
|
28
|
+
# pyre-fixme[33]: Aliasing to Any is prohibited. opentelemetry-sdk lacks type stubs.
|
|
29
|
+
LogRecordExportResult = Any
|
|
30
|
+
# pyre-fixme[33]: Aliasing to Any is prohibited. opentelemetry-sdk lacks type stubs.
|
|
31
|
+
ConsoleLogRecordExporter = Any
|
|
32
|
+
else:
|
|
33
|
+
from opentelemetry.sdk._logs import ReadableLogRecord
|
|
34
|
+
from opentelemetry.sdk._logs.export import (
|
|
35
|
+
ConsoleLogRecordExporter,
|
|
36
|
+
LogRecordExporter,
|
|
37
|
+
LogRecordExportResult,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
_LOGGER_PROVIDER: dict[str, LoggerProvider] = {}
|
|
41
|
+
# Path to the file containing OTEL resource attributes
|
|
42
|
+
TORCHFT_OTEL_RESOURCE_ATTRIBUTES_JSON = "TORCHFT_OTEL_RESOURCE_ATTRIBUTES_JSON"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class TeeLogExporter(LogRecordExporter):
|
|
46
|
+
"""Exporter that writes to multiple exporters."""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
exporters: List[LogRecordExporter],
|
|
51
|
+
) -> None:
|
|
52
|
+
self._exporters = exporters
|
|
53
|
+
|
|
54
|
+
def export(self, batch: Sequence[ReadableLogRecord]) -> LogRecordExportResult:
|
|
55
|
+
for e in self._exporters:
|
|
56
|
+
e.export(batch)
|
|
57
|
+
return LogRecordExportResult.SUCCESS
|
|
58
|
+
|
|
59
|
+
def shutdown(self) -> None:
|
|
60
|
+
for e in self._exporters:
|
|
61
|
+
e.shutdown()
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def setup_logger(name: str) -> None:
|
|
65
|
+
if os.environ.get("TORCHFT_USE_OTEL", "false") == "false":
|
|
66
|
+
return
|
|
67
|
+
|
|
68
|
+
if name in _LOGGER_PROVIDER:
|
|
69
|
+
return
|
|
70
|
+
|
|
71
|
+
torchft_otel_resource_attributes_json = os.environ.get(
|
|
72
|
+
TORCHFT_OTEL_RESOURCE_ATTRIBUTES_JSON
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
if torchft_otel_resource_attributes_json is not None:
|
|
76
|
+
with open(torchft_otel_resource_attributes_json) as f:
|
|
77
|
+
attributes = json.loads(f.read())
|
|
78
|
+
resource = Resource.create(attributes=attributes[name])
|
|
79
|
+
else:
|
|
80
|
+
resource = Resource.create()
|
|
81
|
+
|
|
82
|
+
logger_provider = LoggerProvider(resource=resource)
|
|
83
|
+
set_logger_provider(logger_provider)
|
|
84
|
+
|
|
85
|
+
exporter = TeeLogExporter(
|
|
86
|
+
exporters=[
|
|
87
|
+
ConsoleLogRecordExporter(),
|
|
88
|
+
OTLPLogExporter(
|
|
89
|
+
timeout=5,
|
|
90
|
+
),
|
|
91
|
+
],
|
|
92
|
+
)
|
|
93
|
+
logger_provider.add_log_record_processor(BatchLogRecordProcessor(exporter))
|
|
94
|
+
handler = LoggingHandler(level=logging.NOTSET, logger_provider=logger_provider)
|
|
95
|
+
|
|
96
|
+
# Attach OTLP handler to otel logger
|
|
97
|
+
logging.getLogger(name).addHandler(handler)
|
|
98
|
+
|
|
99
|
+
_LOGGER_PROVIDER[name] = logger_provider
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def shutdown() -> None:
|
|
103
|
+
for logger_provider in _LOGGER_PROVIDER.values():
|
|
104
|
+
logger_provider.shutdown()
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
# Example usage of the logger
|
|
108
|
+
def main() -> None:
|
|
109
|
+
logging.getLogger().setLevel(logging.INFO)
|
|
110
|
+
setup_logger("torchft_test")
|
|
111
|
+
|
|
112
|
+
while True:
|
|
113
|
+
time.sleep(1)
|
|
114
|
+
loggers = [
|
|
115
|
+
logging.getLogger("torchft_test"),
|
|
116
|
+
logging.getLogger("myapp.area1"),
|
|
117
|
+
logging.getLogger(),
|
|
118
|
+
]
|
|
119
|
+
|
|
120
|
+
for i, logger in enumerate(loggers):
|
|
121
|
+
# only this should be picked up by OTEL when using otel logger
|
|
122
|
+
logger.info(
|
|
123
|
+
"Quick zephyrs blow, vexing daft Jim.",
|
|
124
|
+
extra={
|
|
125
|
+
"test_attr": f"value{i}",
|
|
126
|
+
},
|
|
127
|
+
)
|
|
128
|
+
logger.debug("Jackdaws love my big sphinx of quartz.")
|
|
129
|
+
|
|
130
|
+
print("Example done; exiting...")
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
if __name__ == "__main__":
|
|
134
|
+
main()
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
Parameter Servers
|
|
9
|
+
==================
|
|
10
|
+
|
|
11
|
+
This module provides a prototype implementation of a fault tolerant parameter server bulit on the reconfigurable ProcessGroups.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import json
|
|
15
|
+
import logging
|
|
16
|
+
import socket
|
|
17
|
+
import threading
|
|
18
|
+
import urllib.request
|
|
19
|
+
import uuid
|
|
20
|
+
from abc import ABC, abstractmethod
|
|
21
|
+
from http.server import BaseHTTPRequestHandler
|
|
22
|
+
|
|
23
|
+
from torch.distributed import TCPStore
|
|
24
|
+
|
|
25
|
+
from torchft.http import _IPv6HTTPServer
|
|
26
|
+
from torchft.process_group import ProcessGroup
|
|
27
|
+
|
|
28
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ParameterServer(ABC):
|
|
32
|
+
"""
|
|
33
|
+
This implements a threaded parameter server using the torchft reconfigurable
|
|
34
|
+
ProcessGroups.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, port: int, store_port: int = 0) -> None:
|
|
38
|
+
"""
|
|
39
|
+
Create a new ParameterServer.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
port: the port to bind the HTTP server to.
|
|
43
|
+
store_port: the port to bind the TCPStore server to.
|
|
44
|
+
"""
|
|
45
|
+
self.store = TCPStore(
|
|
46
|
+
host_name="0.0.0.0",
|
|
47
|
+
port=store_port,
|
|
48
|
+
is_master=True,
|
|
49
|
+
wait_for_workers=False,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
ps = self
|
|
53
|
+
|
|
54
|
+
class RequestHandler(BaseHTTPRequestHandler):
|
|
55
|
+
def do_GET(self):
|
|
56
|
+
if self.path != "/new_session":
|
|
57
|
+
self.send_response(400)
|
|
58
|
+
self.send_header("Content-type", "text/plain")
|
|
59
|
+
self.end_headers()
|
|
60
|
+
self.err(f"invalid path, got {self.path}")
|
|
61
|
+
return
|
|
62
|
+
|
|
63
|
+
try:
|
|
64
|
+
self.send_response(200)
|
|
65
|
+
self.send_header(
|
|
66
|
+
"Content-type", "application/json"
|
|
67
|
+
) # TODO: correct mime type
|
|
68
|
+
self.end_headers()
|
|
69
|
+
|
|
70
|
+
session_id = str(uuid.uuid4())
|
|
71
|
+
|
|
72
|
+
store_addr = (
|
|
73
|
+
f"{socket.gethostname()}:{ps.store.port}/session/{session_id}"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
logger.info(f"creating new session {session_id}")
|
|
77
|
+
|
|
78
|
+
data = (
|
|
79
|
+
json.dumps(
|
|
80
|
+
{
|
|
81
|
+
"session_id": session_id,
|
|
82
|
+
"store_addr": store_addr,
|
|
83
|
+
}
|
|
84
|
+
)
|
|
85
|
+
+ "\n"
|
|
86
|
+
)
|
|
87
|
+
data = data.encode()
|
|
88
|
+
|
|
89
|
+
self.wfile.write(data)
|
|
90
|
+
|
|
91
|
+
# close the connection up front so client will know json is
|
|
92
|
+
# complete
|
|
93
|
+
self.finish()
|
|
94
|
+
self.connection.close()
|
|
95
|
+
|
|
96
|
+
# hijack thread for the session
|
|
97
|
+
ps._handle_session(session_id, store_addr)
|
|
98
|
+
except Exception:
|
|
99
|
+
logger.exception(
|
|
100
|
+
f"got exception in request handler for {self.path}"
|
|
101
|
+
)
|
|
102
|
+
raise
|
|
103
|
+
|
|
104
|
+
server_address = ("", port)
|
|
105
|
+
self._server = _IPv6HTTPServer(server_address, RequestHandler)
|
|
106
|
+
self._server.daemon_threads = True
|
|
107
|
+
logger.info(f"Started ParameterServer on {self.address()}...")
|
|
108
|
+
|
|
109
|
+
self._thread = threading.Thread(
|
|
110
|
+
target=self._serve,
|
|
111
|
+
args=(),
|
|
112
|
+
daemon=True,
|
|
113
|
+
)
|
|
114
|
+
self._thread.start()
|
|
115
|
+
|
|
116
|
+
def address(self) -> str:
|
|
117
|
+
"""
|
|
118
|
+
Returns the HTTP address to create a new session on this server.
|
|
119
|
+
|
|
120
|
+
Format: http://host:port/new_session
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
an HTTP address
|
|
124
|
+
"""
|
|
125
|
+
port = self._server.socket.getsockname()[1]
|
|
126
|
+
return f"http://{socket.gethostname()}:{port}/new_session"
|
|
127
|
+
|
|
128
|
+
def _serve(self) -> None:
|
|
129
|
+
try:
|
|
130
|
+
self._server.serve_forever()
|
|
131
|
+
except Exception as e:
|
|
132
|
+
logger.exception("got exception in checkpoint server")
|
|
133
|
+
|
|
134
|
+
@classmethod
|
|
135
|
+
@abstractmethod
|
|
136
|
+
def new_process_group(cls) -> ProcessGroup:
|
|
137
|
+
"""
|
|
138
|
+
Create a new non-configured ProcessGroup for the ParameterServer to
|
|
139
|
+
configure when setting up server and client connections.
|
|
140
|
+
|
|
141
|
+
Must be implemented by subclasses.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
a new ProcessGroup
|
|
145
|
+
"""
|
|
146
|
+
...
|
|
147
|
+
|
|
148
|
+
@classmethod
|
|
149
|
+
def new_session(cls, address: str) -> ProcessGroup:
|
|
150
|
+
"""
|
|
151
|
+
Creates a new session on the parameter server and returns a ProcessGroup
|
|
152
|
+
configured for that server.
|
|
153
|
+
|
|
154
|
+
Client is rank 1, server is rank 0.
|
|
155
|
+
"""
|
|
156
|
+
with urllib.request.urlopen(address) as f:
|
|
157
|
+
data = json.load(f)
|
|
158
|
+
|
|
159
|
+
session_id = data["session_id"]
|
|
160
|
+
store_addr = data["store_addr"]
|
|
161
|
+
|
|
162
|
+
logger.info(f"connecting to session {session_id} at {store_addr}")
|
|
163
|
+
|
|
164
|
+
pg = cls.new_process_group()
|
|
165
|
+
# client is always rank 1
|
|
166
|
+
pg.configure(store_addr, replica_id="0", rank=1, world_size=2)
|
|
167
|
+
|
|
168
|
+
return pg
|
|
169
|
+
|
|
170
|
+
def _handle_session(self, session_id: str, store_addr: str) -> None:
|
|
171
|
+
pg = self.new_process_group()
|
|
172
|
+
# paramter server is always rank 0
|
|
173
|
+
pg.configure(store_addr, replica_id="0", rank=0, world_size=2)
|
|
174
|
+
|
|
175
|
+
self.forward(session_id, pg)
|
|
176
|
+
|
|
177
|
+
@abstractmethod
|
|
178
|
+
def forward(self, session_id: str, pg: ProcessGroup) -> None:
|
|
179
|
+
"""
|
|
180
|
+
This method will be called once per session in a dedicated thread. To
|
|
181
|
+
support multiple operations on a single session you should put a
|
|
182
|
+
for-loop in your forward implementation.
|
|
183
|
+
|
|
184
|
+
If an error occurs, the process group will be freed and the client will
|
|
185
|
+
have to create a new session.
|
|
186
|
+
|
|
187
|
+
The server rank is 0 and the client rank is 1.
|
|
188
|
+
|
|
189
|
+
Must be implemented by subclasses.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
session_id: a unique uuid for this session
|
|
193
|
+
pg: the ProcessGroup that's configured for the client.
|
|
194
|
+
"""
|
|
195
|
+
...
|