torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. torchft/__init__.py +34 -0
  2. torchft/_test/diloco_trainer.py +287 -0
  3. torchft/_test/managed_work_test.py +320 -0
  4. torchft/_test_utils.py +111 -0
  5. torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
  6. torchft/_torchft.pyi +116 -0
  7. torchft/checkpointing/__init__.py +20 -0
  8. torchft/checkpointing/_rwlock.py +136 -0
  9. torchft/checkpointing/_serialization.py +39 -0
  10. torchft/checkpointing/http_transport.py +299 -0
  11. torchft/checkpointing/http_transport_bench.py +61 -0
  12. torchft/checkpointing/http_transport_test.py +146 -0
  13. torchft/checkpointing/pg_transport.py +306 -0
  14. torchft/checkpointing/pg_transport_bench.py +99 -0
  15. torchft/checkpointing/pg_transport_test.py +101 -0
  16. torchft/checkpointing/rwlock_test.py +58 -0
  17. torchft/checkpointing/transport.py +68 -0
  18. torchft/checkpointing/transport_test.py +161 -0
  19. torchft/collectives.py +415 -0
  20. torchft/collectives_test.py +212 -0
  21. torchft/coordination.py +39 -0
  22. torchft/coordination_test.py +29 -0
  23. torchft/data.py +77 -0
  24. torchft/data_test.py +39 -0
  25. torchft/ddp.py +105 -0
  26. torchft/ddp_test.py +68 -0
  27. torchft/diloco_regression_test.py +644 -0
  28. torchft/examples/slurm/README.md +34 -0
  29. torchft/examples/slurm/punisher.py +95 -0
  30. torchft/examples/slurm/runner.py +221 -0
  31. torchft/fsdp_test.py +102 -0
  32. torchft/futures.py +353 -0
  33. torchft/futures_test.py +140 -0
  34. torchft/http.py +13 -0
  35. torchft/lighthouse_test.py +163 -0
  36. torchft/local_sgd.py +796 -0
  37. torchft/local_sgd_integ_test.py +600 -0
  38. torchft/local_sgd_test.py +324 -0
  39. torchft/manager.py +1358 -0
  40. torchft/manager_integ_test.py +653 -0
  41. torchft/manager_test.py +911 -0
  42. torchft/multiprocessing.py +38 -0
  43. torchft/multiprocessing_dummy_context.py +135 -0
  44. torchft/multiprocessing_test.py +58 -0
  45. torchft/optim.py +63 -0
  46. torchft/optim_test.py +50 -0
  47. torchft/otel.py +134 -0
  48. torchft/parameter_server.py +195 -0
  49. torchft/parameter_server_test.py +47 -0
  50. torchft/process_group.py +2118 -0
  51. torchft/process_group_test.py +1028 -0
  52. torchft/quantization.py +686 -0
  53. torchft/quantization_test.py +131 -0
  54. torchft/torchx.py +89 -0
  55. torchft/utils.py +67 -0
  56. torchft/work.py +26 -0
  57. torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
  58. torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
  59. torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
  60. torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
  61. torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
@@ -0,0 +1,38 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import queue
8
+ import time
9
+ from datetime import timedelta
10
+ from multiprocessing.connection import Connection
11
+ from typing import Union
12
+
13
+ import torch.multiprocessing as mp
14
+
15
+
16
+ class _MonitoredPipe:
17
+ def __init__(self, pipe: "Connection[object, object]") -> None: # type: ignore
18
+ self._pipe = pipe
19
+
20
+ def send(self, obj: object) -> None:
21
+ self._pipe.send(obj)
22
+
23
+ def recv(self, timeout: Union[float, timedelta]) -> object:
24
+ if isinstance(timeout, timedelta):
25
+ timeout = timeout.total_seconds()
26
+ if self._pipe.poll(timeout):
27
+ out = self._pipe.recv()
28
+ if isinstance(out, Exception):
29
+ raise out
30
+ return out
31
+ else:
32
+ raise TimeoutError(f"pipe.recv() timed out after {timeout} seconds")
33
+
34
+ def close(self) -> None:
35
+ self._pipe.close()
36
+
37
+ def closed(self) -> bool:
38
+ return self._pipe.closed
@@ -0,0 +1,135 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Multiprocessing Dummy Context
9
+ =========================
10
+
11
+ This module provides a context-like interface for multiprocessing.dummy,
12
+ which is a wrapper around the threading module that provides a multiprocessing-like
13
+ interface but uses threads instead of processes.
14
+
15
+ This allows code that uses multiprocessing.get_context() to work with
16
+ multiprocessing.dummy by providing a compatible interface.
17
+ """
18
+
19
+ import multiprocessing.dummy as mp
20
+ import threading
21
+ from typing import Callable, Iterable, Mapping
22
+
23
+
24
+ class DummyContext:
25
+ """
26
+ A context-like class for multiprocessing.dummy that mimics the interface
27
+ of a context returned by multiprocessing.get_context().
28
+ """
29
+
30
+ def __init__(self, method: object = None) -> None:
31
+ """
32
+ Initialize the dummy context.
33
+
34
+ Args:
35
+ method: Ignored, only for compatibility with multiprocessing.get_context()
36
+ """
37
+ pass
38
+
39
+ def Process(
40
+ self,
41
+ group: object = None,
42
+ target: Callable[..., object] | None = None,
43
+ name: str | None = None,
44
+ args: Iterable[object] = (),
45
+ kwargs: Mapping[str, object] = {},
46
+ daemon: bool | None = None,
47
+ ) -> mp.DummyProcess:
48
+ """
49
+ Create a Process using multiprocessing.dummy.Process.
50
+ """
51
+ return mp.Process(
52
+ group=group, target=target, name=name, args=args, kwargs=kwargs
53
+ )
54
+
55
+ def Pipe(
56
+ self, duplex: bool = True
57
+ ) -> tuple[mp.connection.Connection, mp.connection.Connection]:
58
+ """
59
+ Create a Pipe using multiprocessing.dummy.Pipe.
60
+ """
61
+ return mp.Pipe(duplex)
62
+
63
+ def Queue(self, maxsize: int = 0) -> mp.Queue:
64
+ """
65
+ Create a Queue using multiprocessing.dummy.Queue.
66
+ """
67
+ return mp.Queue(maxsize)
68
+
69
+ def Event(self) -> threading.Event:
70
+ """
71
+ Create an Event using multiprocessing.dummy.Event.
72
+ """
73
+ return mp.Event()
74
+
75
+ def Lock(self) -> threading.Lock:
76
+ """
77
+ Create a Lock using multiprocessing.dummy.Lock.
78
+ """
79
+ return mp.Lock()
80
+
81
+ def RLock(self) -> threading.RLock:
82
+ """
83
+ Create an RLock using multiprocessing.dummy.RLock.
84
+ """
85
+ return mp.RLock()
86
+
87
+ def Semaphore(self, value: int = 1) -> threading.Semaphore:
88
+ """
89
+ Create a Semaphore using multiprocessing.dummy.Semaphore.
90
+ """
91
+ return mp.Semaphore(value)
92
+
93
+ def BoundedSemaphore(self, value: int = 1) -> threading.BoundedSemaphore:
94
+ """
95
+ Create a BoundedSemaphore using multiprocessing.dummy.BoundedSemaphore.
96
+ """
97
+ return mp.BoundedSemaphore(value)
98
+
99
+ def Condition(
100
+ self, lock: threading.Lock | threading.RLock | None = None
101
+ ) -> threading.Condition:
102
+ """
103
+ Create a Condition using multiprocessing.dummy.Condition.
104
+ """
105
+ return mp.Condition(lock)
106
+
107
+ def Manager(self) -> object:
108
+ """
109
+ Create a Manager using multiprocessing.dummy.Manager.
110
+ """
111
+ return mp.Manager()
112
+
113
+
114
+ def get_context(method: object = None) -> DummyContext:
115
+ """
116
+ Return a context object for multiprocessing.dummy.
117
+
118
+ This function mimics multiprocessing.get_context() but returns a DummyContext
119
+ that works with multiprocessing.dummy. This can be used to patch
120
+ multiprocessing.dummy like so
121
+
122
+
123
+ ```
124
+ import multiprocessing.dummy as mp
125
+ from torchft.multiprocessing_dummy_context import get_context
126
+ mp.get_context = get_context
127
+ ```
128
+
129
+ Args:
130
+ method: Ignored, only for compatibility with multiprocessing.get_context()
131
+
132
+ Returns:
133
+ A DummyContext instance
134
+ """
135
+ return DummyContext(method)
@@ -0,0 +1,58 @@
1
+ from multiprocessing.connection import Connection
2
+ from unittest import TestCase
3
+
4
+ import torch.multiprocessing as mp
5
+
6
+ from torchft.multiprocessing import _MonitoredPipe
7
+
8
+
9
+ def pipe_get(q: "Connection[object, object]") -> None: # type: ignore
10
+ q.recv()
11
+
12
+
13
+ def pipe_put(q: "Connection[object, object]") -> None: # type: ignore
14
+ q.recv()
15
+ q.send(1)
16
+
17
+
18
+ class MultiprocessingTest(TestCase):
19
+ def test_monitored_queue_put(self) -> None:
20
+ ctx = mp.get_context("fork")
21
+ local, remote = ctx.Pipe()
22
+ p = ctx.Process(target=pipe_get, args=(remote,), daemon=True)
23
+ p.start()
24
+ del remote
25
+
26
+ mq = _MonitoredPipe(local)
27
+ mq.send(1)
28
+ with self.assertRaisesRegex(
29
+ (ConnectionResetError, BrokenPipeError),
30
+ "(Connection reset by peer|Broken pipe)",
31
+ ):
32
+ while True:
33
+ mq.send(1)
34
+
35
+ mq.close()
36
+ assert mq.closed()
37
+
38
+ def test_monitored_queue_get(self) -> None:
39
+ ctx = mp.get_context("fork")
40
+ local, remote = ctx.Pipe()
41
+ p = ctx.Process(target=pipe_put, args=(remote,), daemon=True)
42
+ p.start()
43
+ del remote
44
+
45
+ mq = _MonitoredPipe(local)
46
+
47
+ with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"):
48
+ mq.recv(timeout=0.0)
49
+
50
+ # continue
51
+ mq.send(1)
52
+
53
+ self.assertEqual(mq.recv(timeout=10), 1)
54
+ with self.assertRaises(EOFError):
55
+ mq.recv(timeout=10)
56
+
57
+ mq.close()
58
+ assert mq.closed()
torchft/optim.py ADDED
@@ -0,0 +1,63 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Optimizers
9
+ ============
10
+
11
+ This module implements an optimizer wrapper that works with the Manager to provide fault tolerance.
12
+
13
+ """
14
+
15
+ from typing import Any, Dict, List, Mapping, Optional, TYPE_CHECKING
16
+
17
+ import torch
18
+ from torch.optim import Optimizer
19
+
20
+ if TYPE_CHECKING:
21
+ from torchft.manager import Manager
22
+
23
+
24
+ class OptimizerWrapper(Optimizer):
25
+ """
26
+ This wraps any provided optimizer and in conjunction with the manager will provide fault tolerance.
27
+
28
+ zero_grad() must be called at the start of the forwards pass and step() must
29
+ be called at the end of the backwards pass.
30
+
31
+ Depending on the state of the manager, the optimizer will either commit the
32
+ gradients to the wrapped optimizer or ignore them.
33
+ """
34
+
35
+ def __init__(self, manager: "Manager", optim: Optimizer) -> None:
36
+ self.optim = optim
37
+ self.manager = manager
38
+
39
+ def add_param_group(self, param_group: Dict[str, Any]) -> None:
40
+ self.optim.add_param_group(param_group)
41
+
42
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
43
+ self.optim.load_state_dict(state_dict)
44
+
45
+ def state_dict(self) -> Dict[str, Any]:
46
+ return self.optim.state_dict()
47
+
48
+ def zero_grad(self, set_to_none: bool = True) -> None:
49
+ self.manager.start_quorum()
50
+ self.optim.zero_grad(set_to_none)
51
+
52
+ def step(self, closure: Optional[object] = None) -> None:
53
+ assert closure is None, "optimizers that use closures are not supported"
54
+ if self.manager.should_commit():
55
+ self.optim.step()
56
+
57
+ @property
58
+ def param_groups(self) -> List[Dict[str, Any]]:
59
+ return self.optim.param_groups
60
+
61
+ @property
62
+ def state(self) -> Mapping[torch.Tensor, object]:
63
+ return self.optim.state
torchft/optim_test.py ADDED
@@ -0,0 +1,50 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from unittest import TestCase
8
+ from unittest.mock import create_autospec, MagicMock
9
+
10
+ import torch
11
+ from torch.nn import Linear
12
+ from torch.optim import AdamW
13
+
14
+ from torchft.manager import Manager
15
+ from torchft.optim import OptimizerWrapper
16
+
17
+
18
+ class TestOptim(TestCase):
19
+ def test_optimizer_wrapper(self) -> None:
20
+ manager = create_autospec(Manager)
21
+
22
+ m = Linear(3, 4)
23
+ base_optim = AdamW(m.parameters())
24
+ optim = OptimizerWrapper(manager, base_optim)
25
+ optim.add_param_group(
26
+ {
27
+ "params": [],
28
+ "lr": 1e-4,
29
+ }
30
+ )
31
+
32
+ # test state_dict handling
33
+ optim.load_state_dict(optim.state_dict())
34
+
35
+ optim.zero_grad()
36
+ self.assertEqual(manager.start_quorum.call_count, 1)
37
+
38
+ b = torch.rand(3)
39
+ m(b).sum().backward()
40
+
41
+ manager.should_commit.return_value = True
42
+ optim.step()
43
+ manager.should_commit.return_value = False
44
+ optim.step()
45
+ self.assertEqual(len(optim.param_groups), 2)
46
+ self.assertEqual(optim.param_groups[1]["lr"], 1e-4)
47
+ self.assertEqual(optim.param_groups[1]["params"], [])
48
+ self.assertEqual(len(optim.state), len(list(m.parameters())))
49
+
50
+ self.assertEqual(manager.should_commit.call_count, 2)
torchft/otel.py ADDED
@@ -0,0 +1,134 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import logging
9
+ import os
10
+ import time
11
+ from typing import Any, List, Sequence, TYPE_CHECKING
12
+
13
+ from opentelemetry._logs import set_logger_provider
14
+
15
+ from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter
16
+ from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler
17
+ from opentelemetry.sdk._logs.export import BatchLogRecordProcessor
18
+ from opentelemetry.sdk.resources import Resource
19
+
20
+ # These types are available in opentelemetry-sdk but Pyre's type stubs
21
+ # don't include them. We import them at runtime and provide type aliases for
22
+ # static type checking.
23
+ if TYPE_CHECKING:
24
+ # pyre-fixme[33]: Aliasing to Any is prohibited. opentelemetry-sdk lacks type stubs.
25
+ ReadableLogRecord = Any
26
+ # pyre-fixme[33]: Aliasing to Any is prohibited. opentelemetry-sdk lacks type stubs.
27
+ LogRecordExporter = Any
28
+ # pyre-fixme[33]: Aliasing to Any is prohibited. opentelemetry-sdk lacks type stubs.
29
+ LogRecordExportResult = Any
30
+ # pyre-fixme[33]: Aliasing to Any is prohibited. opentelemetry-sdk lacks type stubs.
31
+ ConsoleLogRecordExporter = Any
32
+ else:
33
+ from opentelemetry.sdk._logs import ReadableLogRecord
34
+ from opentelemetry.sdk._logs.export import (
35
+ ConsoleLogRecordExporter,
36
+ LogRecordExporter,
37
+ LogRecordExportResult,
38
+ )
39
+
40
+ _LOGGER_PROVIDER: dict[str, LoggerProvider] = {}
41
+ # Path to the file containing OTEL resource attributes
42
+ TORCHFT_OTEL_RESOURCE_ATTRIBUTES_JSON = "TORCHFT_OTEL_RESOURCE_ATTRIBUTES_JSON"
43
+
44
+
45
+ class TeeLogExporter(LogRecordExporter):
46
+ """Exporter that writes to multiple exporters."""
47
+
48
+ def __init__(
49
+ self,
50
+ exporters: List[LogRecordExporter],
51
+ ) -> None:
52
+ self._exporters = exporters
53
+
54
+ def export(self, batch: Sequence[ReadableLogRecord]) -> LogRecordExportResult:
55
+ for e in self._exporters:
56
+ e.export(batch)
57
+ return LogRecordExportResult.SUCCESS
58
+
59
+ def shutdown(self) -> None:
60
+ for e in self._exporters:
61
+ e.shutdown()
62
+
63
+
64
+ def setup_logger(name: str) -> None:
65
+ if os.environ.get("TORCHFT_USE_OTEL", "false") == "false":
66
+ return
67
+
68
+ if name in _LOGGER_PROVIDER:
69
+ return
70
+
71
+ torchft_otel_resource_attributes_json = os.environ.get(
72
+ TORCHFT_OTEL_RESOURCE_ATTRIBUTES_JSON
73
+ )
74
+
75
+ if torchft_otel_resource_attributes_json is not None:
76
+ with open(torchft_otel_resource_attributes_json) as f:
77
+ attributes = json.loads(f.read())
78
+ resource = Resource.create(attributes=attributes[name])
79
+ else:
80
+ resource = Resource.create()
81
+
82
+ logger_provider = LoggerProvider(resource=resource)
83
+ set_logger_provider(logger_provider)
84
+
85
+ exporter = TeeLogExporter(
86
+ exporters=[
87
+ ConsoleLogRecordExporter(),
88
+ OTLPLogExporter(
89
+ timeout=5,
90
+ ),
91
+ ],
92
+ )
93
+ logger_provider.add_log_record_processor(BatchLogRecordProcessor(exporter))
94
+ handler = LoggingHandler(level=logging.NOTSET, logger_provider=logger_provider)
95
+
96
+ # Attach OTLP handler to otel logger
97
+ logging.getLogger(name).addHandler(handler)
98
+
99
+ _LOGGER_PROVIDER[name] = logger_provider
100
+
101
+
102
+ def shutdown() -> None:
103
+ for logger_provider in _LOGGER_PROVIDER.values():
104
+ logger_provider.shutdown()
105
+
106
+
107
+ # Example usage of the logger
108
+ def main() -> None:
109
+ logging.getLogger().setLevel(logging.INFO)
110
+ setup_logger("torchft_test")
111
+
112
+ while True:
113
+ time.sleep(1)
114
+ loggers = [
115
+ logging.getLogger("torchft_test"),
116
+ logging.getLogger("myapp.area1"),
117
+ logging.getLogger(),
118
+ ]
119
+
120
+ for i, logger in enumerate(loggers):
121
+ # only this should be picked up by OTEL when using otel logger
122
+ logger.info(
123
+ "Quick zephyrs blow, vexing daft Jim.",
124
+ extra={
125
+ "test_attr": f"value{i}",
126
+ },
127
+ )
128
+ logger.debug("Jackdaws love my big sphinx of quartz.")
129
+
130
+ print("Example done; exiting...")
131
+
132
+
133
+ if __name__ == "__main__":
134
+ main()
@@ -0,0 +1,195 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Parameter Servers
9
+ ==================
10
+
11
+ This module provides a prototype implementation of a fault tolerant parameter server bulit on the reconfigurable ProcessGroups.
12
+ """
13
+
14
+ import json
15
+ import logging
16
+ import socket
17
+ import threading
18
+ import urllib.request
19
+ import uuid
20
+ from abc import ABC, abstractmethod
21
+ from http.server import BaseHTTPRequestHandler
22
+
23
+ from torch.distributed import TCPStore
24
+
25
+ from torchft.http import _IPv6HTTPServer
26
+ from torchft.process_group import ProcessGroup
27
+
28
+ logger: logging.Logger = logging.getLogger(__name__)
29
+
30
+
31
+ class ParameterServer(ABC):
32
+ """
33
+ This implements a threaded parameter server using the torchft reconfigurable
34
+ ProcessGroups.
35
+ """
36
+
37
+ def __init__(self, port: int, store_port: int = 0) -> None:
38
+ """
39
+ Create a new ParameterServer.
40
+
41
+ Args:
42
+ port: the port to bind the HTTP server to.
43
+ store_port: the port to bind the TCPStore server to.
44
+ """
45
+ self.store = TCPStore(
46
+ host_name="0.0.0.0",
47
+ port=store_port,
48
+ is_master=True,
49
+ wait_for_workers=False,
50
+ )
51
+
52
+ ps = self
53
+
54
+ class RequestHandler(BaseHTTPRequestHandler):
55
+ def do_GET(self):
56
+ if self.path != "/new_session":
57
+ self.send_response(400)
58
+ self.send_header("Content-type", "text/plain")
59
+ self.end_headers()
60
+ self.err(f"invalid path, got {self.path}")
61
+ return
62
+
63
+ try:
64
+ self.send_response(200)
65
+ self.send_header(
66
+ "Content-type", "application/json"
67
+ ) # TODO: correct mime type
68
+ self.end_headers()
69
+
70
+ session_id = str(uuid.uuid4())
71
+
72
+ store_addr = (
73
+ f"{socket.gethostname()}:{ps.store.port}/session/{session_id}"
74
+ )
75
+
76
+ logger.info(f"creating new session {session_id}")
77
+
78
+ data = (
79
+ json.dumps(
80
+ {
81
+ "session_id": session_id,
82
+ "store_addr": store_addr,
83
+ }
84
+ )
85
+ + "\n"
86
+ )
87
+ data = data.encode()
88
+
89
+ self.wfile.write(data)
90
+
91
+ # close the connection up front so client will know json is
92
+ # complete
93
+ self.finish()
94
+ self.connection.close()
95
+
96
+ # hijack thread for the session
97
+ ps._handle_session(session_id, store_addr)
98
+ except Exception:
99
+ logger.exception(
100
+ f"got exception in request handler for {self.path}"
101
+ )
102
+ raise
103
+
104
+ server_address = ("", port)
105
+ self._server = _IPv6HTTPServer(server_address, RequestHandler)
106
+ self._server.daemon_threads = True
107
+ logger.info(f"Started ParameterServer on {self.address()}...")
108
+
109
+ self._thread = threading.Thread(
110
+ target=self._serve,
111
+ args=(),
112
+ daemon=True,
113
+ )
114
+ self._thread.start()
115
+
116
+ def address(self) -> str:
117
+ """
118
+ Returns the HTTP address to create a new session on this server.
119
+
120
+ Format: http://host:port/new_session
121
+
122
+ Returns:
123
+ an HTTP address
124
+ """
125
+ port = self._server.socket.getsockname()[1]
126
+ return f"http://{socket.gethostname()}:{port}/new_session"
127
+
128
+ def _serve(self) -> None:
129
+ try:
130
+ self._server.serve_forever()
131
+ except Exception as e:
132
+ logger.exception("got exception in checkpoint server")
133
+
134
+ @classmethod
135
+ @abstractmethod
136
+ def new_process_group(cls) -> ProcessGroup:
137
+ """
138
+ Create a new non-configured ProcessGroup for the ParameterServer to
139
+ configure when setting up server and client connections.
140
+
141
+ Must be implemented by subclasses.
142
+
143
+ Returns:
144
+ a new ProcessGroup
145
+ """
146
+ ...
147
+
148
+ @classmethod
149
+ def new_session(cls, address: str) -> ProcessGroup:
150
+ """
151
+ Creates a new session on the parameter server and returns a ProcessGroup
152
+ configured for that server.
153
+
154
+ Client is rank 1, server is rank 0.
155
+ """
156
+ with urllib.request.urlopen(address) as f:
157
+ data = json.load(f)
158
+
159
+ session_id = data["session_id"]
160
+ store_addr = data["store_addr"]
161
+
162
+ logger.info(f"connecting to session {session_id} at {store_addr}")
163
+
164
+ pg = cls.new_process_group()
165
+ # client is always rank 1
166
+ pg.configure(store_addr, replica_id="0", rank=1, world_size=2)
167
+
168
+ return pg
169
+
170
+ def _handle_session(self, session_id: str, store_addr: str) -> None:
171
+ pg = self.new_process_group()
172
+ # paramter server is always rank 0
173
+ pg.configure(store_addr, replica_id="0", rank=0, world_size=2)
174
+
175
+ self.forward(session_id, pg)
176
+
177
+ @abstractmethod
178
+ def forward(self, session_id: str, pg: ProcessGroup) -> None:
179
+ """
180
+ This method will be called once per session in a dedicated thread. To
181
+ support multiple operations on a single session you should put a
182
+ for-loop in your forward implementation.
183
+
184
+ If an error occurs, the process group will be freed and the client will
185
+ have to create a new session.
186
+
187
+ The server rank is 0 and the client rank is 1.
188
+
189
+ Must be implemented by subclasses.
190
+
191
+ Args:
192
+ session_id: a unique uuid for this session
193
+ pg: the ProcessGroup that's configured for the client.
194
+ """
195
+ ...