wandb 0.20.1__py3-none-win32.whl → 0.20.2rc20250616__py3-none-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. wandb/__init__.py +3 -6
  2. wandb/__init__.pyi +1 -1
  3. wandb/analytics/sentry.py +2 -2
  4. wandb/apis/importers/internals/internal.py +0 -3
  5. wandb/apis/public/api.py +2 -2
  6. wandb/apis/public/registries/{utils.py → _utils.py} +12 -12
  7. wandb/apis/public/registries/registries_search.py +2 -2
  8. wandb/apis/public/registries/registry.py +19 -18
  9. wandb/bin/gpu_stats.exe +0 -0
  10. wandb/bin/wandb-core +0 -0
  11. wandb/cli/beta.py +1 -7
  12. wandb/cli/cli.py +0 -30
  13. wandb/env.py +0 -6
  14. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  15. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  16. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  17. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  18. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  19. wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
  20. wandb/proto/v6/wandb_settings_pb2.py +2 -2
  21. wandb/proto/v6/wandb_telemetry_pb2.py +10 -10
  22. wandb/sdk/artifacts/storage_handlers/s3_handler.py +42 -1
  23. wandb/sdk/backend/backend.py +1 -1
  24. wandb/sdk/internal/handler.py +1 -69
  25. wandb/sdk/lib/printer.py +6 -7
  26. wandb/sdk/lib/progress.py +1 -3
  27. wandb/sdk/lib/service/ipc_support.py +13 -0
  28. wandb/sdk/lib/{service_connection.py → service/service_connection.py} +20 -56
  29. wandb/sdk/lib/service/service_port_file.py +105 -0
  30. wandb/sdk/lib/service/service_process.py +111 -0
  31. wandb/sdk/lib/service/service_token.py +164 -0
  32. wandb/sdk/lib/sock_client.py +8 -12
  33. wandb/sdk/wandb_init.py +0 -3
  34. wandb/sdk/wandb_require.py +9 -20
  35. wandb/sdk/wandb_run.py +0 -24
  36. wandb/sdk/wandb_settings.py +0 -9
  37. wandb/sdk/wandb_setup.py +2 -13
  38. {wandb-0.20.1.dist-info → wandb-0.20.2rc20250616.dist-info}/METADATA +1 -3
  39. {wandb-0.20.1.dist-info → wandb-0.20.2rc20250616.dist-info}/RECORD +42 -68
  40. wandb/sdk/internal/flow_control.py +0 -263
  41. wandb/sdk/internal/internal.py +0 -401
  42. wandb/sdk/internal/internal_util.py +0 -97
  43. wandb/sdk/internal/system/__init__.py +0 -0
  44. wandb/sdk/internal/system/assets/__init__.py +0 -25
  45. wandb/sdk/internal/system/assets/aggregators.py +0 -31
  46. wandb/sdk/internal/system/assets/asset_registry.py +0 -20
  47. wandb/sdk/internal/system/assets/cpu.py +0 -163
  48. wandb/sdk/internal/system/assets/disk.py +0 -210
  49. wandb/sdk/internal/system/assets/gpu.py +0 -416
  50. wandb/sdk/internal/system/assets/gpu_amd.py +0 -233
  51. wandb/sdk/internal/system/assets/interfaces.py +0 -205
  52. wandb/sdk/internal/system/assets/ipu.py +0 -177
  53. wandb/sdk/internal/system/assets/memory.py +0 -166
  54. wandb/sdk/internal/system/assets/network.py +0 -125
  55. wandb/sdk/internal/system/assets/open_metrics.py +0 -293
  56. wandb/sdk/internal/system/assets/tpu.py +0 -154
  57. wandb/sdk/internal/system/assets/trainium.py +0 -393
  58. wandb/sdk/internal/system/env_probe_helpers.py +0 -13
  59. wandb/sdk/internal/system/system_info.py +0 -248
  60. wandb/sdk/internal/system/system_monitor.py +0 -224
  61. wandb/sdk/internal/writer.py +0 -204
  62. wandb/sdk/lib/service_token.py +0 -93
  63. wandb/sdk/service/__init__.py +0 -0
  64. wandb/sdk/service/_startup_debug.py +0 -22
  65. wandb/sdk/service/port_file.py +0 -53
  66. wandb/sdk/service/server.py +0 -107
  67. wandb/sdk/service/server_sock.py +0 -286
  68. wandb/sdk/service/service.py +0 -252
  69. wandb/sdk/service/streams.py +0 -425
  70. {wandb-0.20.1.dist-info → wandb-0.20.2rc20250616.dist-info}/WHEEL +0 -0
  71. {wandb-0.20.1.dist-info → wandb-0.20.2rc20250616.dist-info}/entry_points.txt +0 -0
  72. {wandb-0.20.1.dist-info → wandb-0.20.2rc20250616.dist-info}/licenses/LICENSE +0 -0
@@ -1,224 +0,0 @@
1
- import datetime
2
- import logging
3
- import queue
4
- import threading
5
- from collections import defaultdict, deque
6
- from typing import TYPE_CHECKING, Deque, Dict, List, Optional, Tuple
7
-
8
- from .assets.asset_registry import asset_registry
9
- from .assets.interfaces import Asset, Interface
10
- from .assets.open_metrics import OpenMetrics
11
- from .system_info import SystemInfo
12
-
13
- if TYPE_CHECKING:
14
- from wandb.proto.wandb_telemetry_pb2 import TelemetryRecord
15
- from wandb.sdk.interface.interface import FilesDict
16
- from wandb.sdk.internal.settings_static import SettingsStatic
17
-
18
-
19
- logger = logging.getLogger(__name__)
20
-
21
-
22
- class AssetInterface:
23
- def __init__(self) -> None:
24
- self.metrics_queue: queue.Queue[dict] = queue.Queue()
25
- self.telemetry_queue: queue.Queue[TelemetryRecord] = queue.Queue()
26
-
27
- def publish_stats(self, stats: dict) -> None:
28
- self.metrics_queue.put(stats)
29
-
30
- def _publish_telemetry(self, telemetry: "TelemetryRecord") -> None:
31
- self.telemetry_queue.put(telemetry)
32
-
33
- def publish_files(self, files_dict: "FilesDict") -> None:
34
- pass
35
-
36
-
37
- class SystemMonitor:
38
- # SystemMonitor is responsible for managing system metrics data.
39
-
40
- # if joining assets, wait for publishing_interval times this many seconds
41
- PUBLISHING_INTERVAL_DELAY_FACTOR = 2
42
-
43
- def __init__(
44
- self,
45
- settings: "SettingsStatic",
46
- interface: "Interface",
47
- ) -> None:
48
- self._shutdown_event: threading.Event = threading.Event()
49
- self._process: Optional[threading.Thread] = None
50
-
51
- self.settings = settings
52
-
53
- # settings._stats_join_assets controls whether we should join stats from different assets
54
- # before publishing them to the backend. If set to False, we will publish stats from each
55
- # asset separately, using the backend interface. If set to True, we will aggregate stats from
56
- # all assets before publishing them, using an internal queue interface, and then publish
57
- # them using the interface to the backend.
58
- # This is done to improve compatibility with older versions of the backend as it used to
59
- # collect the names of the metrics to be displayed in the UI from the first stats message.
60
-
61
- # compute the global publishing interval if _stats_join_assets is requested
62
- sampling_interval: float = float(
63
- max(0.1, self.settings.x_stats_sampling_interval)
64
- ) # seconds
65
- samples_to_aggregate: int = 1
66
- self.publishing_interval: float = sampling_interval * samples_to_aggregate
67
- self.join_assets: bool = False
68
-
69
- self.backend_interface = interface
70
- self.asset_interface: Optional[AssetInterface] = (
71
- AssetInterface() if self.join_assets else None
72
- )
73
-
74
- # hardware assets
75
- self.assets: List[Asset] = self._get_assets()
76
-
77
- # OpenMetrics/Prometheus-compatible endpoints
78
- self.assets.extend(self._get_open_metrics_assets())
79
-
80
- # static system info, both hardware and software
81
- self.system_info: SystemInfo = SystemInfo(
82
- settings=self.settings, interface=interface
83
- )
84
-
85
- self.buffer: Dict[str, Deque[Tuple[float, float]]] = defaultdict(
86
- lambda: deque([], maxlen=self.settings.x_stats_buffer_size)
87
- )
88
-
89
- def _get_assets(self) -> List["Asset"]:
90
- return [
91
- asset_class(
92
- interface=self.asset_interface or self.backend_interface,
93
- settings=self.settings,
94
- shutdown_event=self._shutdown_event,
95
- )
96
- for asset_class in asset_registry
97
- ]
98
-
99
- def _get_open_metrics_assets(self) -> List["Asset"]:
100
- open_metrics_endpoints = self.settings.x_stats_open_metrics_endpoints
101
- if not open_metrics_endpoints:
102
- return []
103
-
104
- assets: List[Asset] = []
105
- for name, endpoint in open_metrics_endpoints.items():
106
- if not OpenMetrics.is_available(url=endpoint):
107
- continue
108
- logger.debug(f"Monitoring OpenMetrics endpoint: {endpoint}")
109
- open_metrics = OpenMetrics(
110
- interface=self.asset_interface or self.backend_interface,
111
- settings=self.settings,
112
- shutdown_event=self._shutdown_event,
113
- name=name,
114
- url=endpoint,
115
- )
116
- assets.append(open_metrics) # type: ignore
117
-
118
- return assets
119
-
120
- def aggregate_and_publish_asset_metrics(self) -> None:
121
- if self.asset_interface is None:
122
- return None
123
- # only extract as many items as are available in the queue at the moment
124
- size = self.asset_interface.metrics_queue.qsize()
125
-
126
- aggregated_metrics = {}
127
- for _ in range(size):
128
- item = self.asset_interface.metrics_queue.get()
129
- aggregated_metrics.update(item)
130
-
131
- if aggregated_metrics:
132
- # update buffer:
133
- # todo: get it from publish_stats instead?
134
- # either is not too accurate, just use wandb-core!
135
- t = datetime.datetime.now().timestamp()
136
- for k, v in aggregated_metrics.items():
137
- self.buffer[k].append((t, v))
138
- # publish aggregated metrics
139
- self.backend_interface.publish_stats(aggregated_metrics)
140
-
141
- def publish_telemetry(self) -> None:
142
- if self.asset_interface is None:
143
- return None
144
- # get everything from the self.asset_interface.telemetry_queue,
145
- # merge into a single dictionary and publish on the backend_interface
146
- while not self.asset_interface.telemetry_queue.empty():
147
- telemetry_record = self.asset_interface.telemetry_queue.get()
148
- self.backend_interface._publish_telemetry(telemetry_record)
149
-
150
- def _start(self) -> None:
151
- logger.info("Starting system asset monitoring threads")
152
- for asset in self.assets:
153
- asset.start()
154
-
155
- # compatibility mode: join stats from different assets before publishing
156
- if not (self.join_assets and self.asset_interface is not None):
157
- return None
158
-
159
- # give the assets a chance to accumulate and publish their first stats
160
- # this will provide a constant offset for the following accumulation events below
161
- self._shutdown_event.wait(
162
- self.publishing_interval * self.PUBLISHING_INTERVAL_DELAY_FACTOR
163
- )
164
-
165
- logger.debug("Starting system metrics aggregation loop")
166
-
167
- while not self._shutdown_event.is_set():
168
- self.publish_telemetry()
169
- self.aggregate_and_publish_asset_metrics()
170
- self._shutdown_event.wait(self.publishing_interval)
171
-
172
- logger.debug("Finished system metrics aggregation loop")
173
-
174
- # try to publish the last batch of metrics + telemetry
175
- try:
176
- logger.debug("Publishing last batch of metrics")
177
- # publish telemetry
178
- self.publish_telemetry()
179
- self.aggregate_and_publish_asset_metrics()
180
- except Exception:
181
- logger.exception("Error publishing last batch of metrics.")
182
-
183
- def start(self) -> None:
184
- self._shutdown_event.clear()
185
- if self._process is not None:
186
- return None
187
- logger.info("Starting system monitor")
188
- self._process = threading.Thread(
189
- target=self._start, daemon=True, name="SystemMonitor"
190
- )
191
- self._process.start()
192
-
193
- def finish(self) -> None:
194
- if self._process is None:
195
- return None
196
- logger.info("Stopping system monitor")
197
- self._shutdown_event.set()
198
- for asset in self.assets:
199
- asset.finish()
200
- try:
201
- self._process.join()
202
- except Exception:
203
- logger.exception("Error joining system monitor process.")
204
- self._process = None
205
-
206
- def probe(self, publish: bool = True) -> dict:
207
- logger.info("Collecting system info")
208
- # collect static info about the hardware from registered assets
209
- hardware_info: dict = {
210
- k: v for d in [asset.probe() for asset in self.assets] for k, v in d.items()
211
- }
212
- # collect static info about the software environment
213
- software_info: dict = self.system_info.probe()
214
- # merge the two dictionaries
215
- system_info = {**software_info, **hardware_info}
216
- logger.debug(system_info)
217
- logger.info("Finished collecting system info")
218
-
219
- if publish:
220
- logger.info("Publishing system info")
221
- self.system_info.publish(system_info)
222
- logger.info("Finished publishing system info")
223
-
224
- return system_info
@@ -1,204 +0,0 @@
1
- """Writer thread."""
2
-
3
- import logging
4
- from typing import TYPE_CHECKING, Callable, Optional
5
-
6
- from wandb.proto import wandb_internal_pb2 as pb
7
- from wandb.proto import wandb_telemetry_pb2 as tpb
8
-
9
- from ..interface.interface_queue import InterfaceQueue
10
- from ..lib import proto_util, telemetry
11
- from . import context, datastore, flow_control
12
- from .settings_static import SettingsStatic
13
-
14
- if TYPE_CHECKING:
15
- from queue import Queue
16
-
17
-
18
- logger = logging.getLogger(__name__)
19
-
20
-
21
- class WriteManager:
22
- _settings: SettingsStatic
23
- _record_q: "Queue[pb.Record]"
24
- _result_q: "Queue[pb.Result]"
25
- _sender_q: "Queue[pb.Record]"
26
- _interface: InterfaceQueue
27
- _context_keeper: context.ContextKeeper
28
-
29
- _ds: Optional[datastore.DataStore]
30
- _flow_control: Optional[flow_control.FlowControl]
31
- _status_report: Optional["pb.StatusReportRequest"]
32
- _record_num: int
33
- _telemetry_obj: tpb.TelemetryRecord
34
- _telemetry_overflow: bool
35
- _use_flow_control: bool
36
-
37
- # TODO(cancel_paused): implement me
38
- # _sender_cancel_set: Set[str]
39
-
40
- def __init__(
41
- self,
42
- settings: SettingsStatic,
43
- record_q: "Queue[pb.Record]",
44
- result_q: "Queue[pb.Result]",
45
- sender_q: "Queue[pb.Record]",
46
- interface: InterfaceQueue,
47
- context_keeper: context.ContextKeeper,
48
- ):
49
- self._settings = settings
50
- self._record_q = record_q
51
- self._result_q = result_q
52
- self._sender_q = sender_q
53
- self._interface = interface
54
- self._context_keeper = context_keeper
55
-
56
- # TODO(cancel_paused): implement me
57
- # self._sender_cancel_set = set()
58
-
59
- self._ds = None
60
- self._flow_control = None
61
- self._status_report = None
62
- self._record_num = 0
63
- self._telemetry_obj = tpb.TelemetryRecord()
64
- self._telemetry_overflow = False
65
- self._use_flow_control = not (
66
- self._settings.x_flow_control_disabled or self._settings._offline
67
- )
68
-
69
- def open(self) -> None:
70
- self._ds = datastore.DataStore()
71
- self._ds.open_for_write(self._settings.sync_file)
72
- self._flow_control = flow_control.FlowControl(
73
- settings=self._settings,
74
- write_record=self._write_record,
75
- forward_record=self._forward_record,
76
- pause_marker=self._pause_marker,
77
- recover_records=self._recover_records,
78
- )
79
-
80
- def _forward_record(self, record: "pb.Record") -> None:
81
- self._context_keeper.add_from_record(record)
82
- self._sender_q.put(record)
83
-
84
- def _send_mark(self) -> None:
85
- sender_mark = pb.SenderMarkRequest()
86
- record = self._interface._make_request(sender_mark=sender_mark)
87
- self._forward_record(record)
88
-
89
- def _maybe_send_telemetry(self) -> None:
90
- if self._telemetry_overflow:
91
- return
92
- self._telemetry_overflow = True
93
- with telemetry.context(obj=self._telemetry_obj) as tel:
94
- tel.feature.flow_control_overflow = True
95
- telemetry_record = pb.TelemetryRecordRequest(telemetry=self._telemetry_obj)
96
- record = self._interface._make_request(telemetry_record=telemetry_record)
97
- self._forward_record(record)
98
-
99
- def _pause_marker(self) -> None:
100
- self._maybe_send_telemetry()
101
- self._send_mark()
102
-
103
- def _write_record(self, record: "pb.Record") -> int:
104
- assert self._ds
105
-
106
- self._record_num += 1
107
- proto_util._assign_record_num(record, self._record_num)
108
- ret = self._ds.write(record)
109
- assert ret is not None
110
-
111
- _start_offset, end_offset, _flush_offset = ret
112
- proto_util._assign_end_offset(record, end_offset)
113
- return end_offset
114
-
115
- def _ensure_flushed(self, offset: int) -> None:
116
- if self._ds:
117
- self._ds.ensure_flushed(offset)
118
-
119
- def _recover_records(self, start: int, end: int) -> None:
120
- sender_read = pb.SenderReadRequest(start_offset=start, final_offset=end)
121
- # TODO(cancel_paused): implement me
122
- # for cancel_id in self._sender_cancel_set:
123
- # sender_read.cancel_list.append(cancel_id)
124
- record = self._interface._make_request(sender_read=sender_read)
125
- self._ensure_flushed(end)
126
- self._forward_record(record)
127
-
128
- def _write(self, record: "pb.Record") -> None:
129
- if not self._ds:
130
- self.open()
131
- assert self._flow_control
132
-
133
- if not record.control.local:
134
- self._write_record(record)
135
-
136
- if self._use_flow_control:
137
- self._flow_control.flow(record)
138
- elif not self._settings._offline or record.control.always_send:
139
- # when flow_control is disabled we pass through all records to
140
- # the sender as long as we are online. The exception is there
141
- # are special records that we always pass to the sender
142
- # (namely the exit record so we can trigger the defer shutdown
143
- # state machine)
144
- self._forward_record(record)
145
-
146
- def write(self, record: "pb.Record") -> None:
147
- record_type = record.WhichOneof("record_type")
148
- assert record_type
149
- writer_str = "write_" + record_type
150
- write_handler: Callable[[pb.Record], None] = getattr(
151
- self, writer_str, self._write
152
- )
153
- write_handler(record)
154
-
155
- def write_request(self, record: "pb.Record") -> None:
156
- request_type = record.request.WhichOneof("request_type")
157
- assert request_type
158
- write_request_str = "write_request_" + request_type
159
- write_request_handler: Optional[Callable[[pb.Record], None]] = getattr(
160
- self, write_request_str, None
161
- )
162
- if write_request_handler:
163
- return write_request_handler(record)
164
- self._write(record)
165
-
166
- def write_request_run_status(self, record: "pb.Record") -> None:
167
- result = proto_util._result_from_record(record)
168
- if self._status_report:
169
- result.response.run_status_response.sync_time.CopyFrom(
170
- self._status_report.sync_time
171
- )
172
- send_record_num = self._status_report.record_num
173
- result.response.run_status_response.sync_items_total = self._record_num
174
- result.response.run_status_response.sync_items_pending = (
175
- self._record_num - send_record_num
176
- )
177
- self._respond_result(result)
178
-
179
- def write_request_status_report(self, record: "pb.Record") -> None:
180
- self._status_report = record.request.status_report
181
- self._write(record)
182
-
183
- def write_request_cancel(self, record: "pb.Record") -> None:
184
- cancel_id = record.request.cancel.cancel_slot
185
- self._context_keeper.cancel(cancel_id)
186
-
187
- # TODO(cancel_paused): implement me
188
- # cancelled = self._context_keeper.cancel(cancel_id)
189
- # if not cancelled:
190
- # self._sender_cancel_set.add(cancel_id)
191
-
192
- def _respond_result(self, result: "pb.Result") -> None:
193
- self._result_q.put(result)
194
-
195
- def finish(self) -> None:
196
- if self._flow_control:
197
- self._flow_control.flush()
198
- if self._ds:
199
- self._ds.close()
200
- # TODO(debug_context) see context.py
201
- # self._context_keeper._debug_print_orphans(print_to_stdout=self._settings._debug)
202
-
203
- def debounce(self) -> None:
204
- pass
@@ -1,93 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import dataclasses
4
- import os
5
-
6
- from wandb import env
7
-
8
- _CURRENT_VERSION = "2"
9
- _SUPPORTED_TRANSPORTS = "tcp"
10
-
11
-
12
- def get_service_token() -> ServiceToken | None:
13
- """Reads the token from environment variables.
14
-
15
- Returns:
16
- The token if the correct environment variable is set, or None.
17
-
18
- Raises:
19
- ValueError: If the environment variable is set but cannot be
20
- parsed.
21
- """
22
- token = os.environ.get(env.SERVICE)
23
- if not token:
24
- return None
25
-
26
- parts = token.split("-")
27
- if len(parts) != 5:
28
- raise ValueError(f"Invalid token: {token}")
29
-
30
- version, pid_str, transport, host, port_str = parts
31
-
32
- if version != _CURRENT_VERSION:
33
- raise ValueError(
34
- f"Expected version {_CURRENT_VERSION}, but got {version} (token={token})"
35
- )
36
- if transport not in _SUPPORTED_TRANSPORTS:
37
- raise ValueError(
38
- f"Unsupported transport: {transport} (token={token})",
39
- )
40
-
41
- try:
42
- return ServiceToken(
43
- version=version,
44
- pid=int(pid_str),
45
- transport=transport,
46
- host=host,
47
- port=int(port_str),
48
- )
49
- except ValueError as e:
50
- raise ValueError(f"Invalid token: {token}") from e
51
-
52
-
53
- def set_service_token(parent_pid: int, transport: str, host: str, port: int) -> None:
54
- """Stores a service token in an environment variable.
55
-
56
- Args:
57
- parent_pid: The process ID of the process that started the service.
58
- transport: The transport used to communicate with the service.
59
- host: The host part of the internet address on which the service
60
- is listening (e.g. localhost).
61
- port: The port the service is listening on.
62
-
63
- Raises:
64
- ValueError: If given an unsupported transport.
65
- """
66
- if transport not in _SUPPORTED_TRANSPORTS:
67
- raise ValueError(f"Unsupported transport: {transport}")
68
-
69
- os.environ[env.SERVICE] = "-".join(
70
- (
71
- _CURRENT_VERSION,
72
- str(parent_pid),
73
- transport,
74
- host,
75
- str(port),
76
- )
77
- )
78
-
79
-
80
- def clear_service_token() -> None:
81
- """Clears the environment variable storing the service token."""
82
- os.environ.pop(env.SERVICE, None)
83
-
84
-
85
- @dataclasses.dataclass(frozen=True)
86
- class ServiceToken:
87
- """An identifier for a running service process."""
88
-
89
- version: str
90
- pid: int
91
- transport: str
92
- host: str
93
- port: int
File without changes
@@ -1,22 +0,0 @@
1
- """_startup_debug.
2
-
3
- Temporary helper to debug issues with wandb service startup
4
- """
5
-
6
- import os
7
- import time
8
-
9
-
10
- def is_enabled() -> bool:
11
- # This is very temporary to help diagnose problems seen by some
12
- # customers which we are having trouble reproducing. It should be
13
- # replaced by something more permanent in the future when we have
14
- # proper logging for wandb-service
15
- if os.environ.get("_WANDB_STARTUP_DEBUG"):
16
- return True
17
- return False
18
-
19
-
20
- def print_message(message: str) -> None:
21
- time_now = time.time()
22
- print("WANDB_STARTUP_DEBUG", time_now, message) # noqa: T201
@@ -1,53 +0,0 @@
1
- """port_file: write/read file containing port info."""
2
-
3
- import os
4
- import tempfile
5
- from typing import Optional
6
-
7
-
8
- class PortFile:
9
- _sock_port: Optional[int]
10
- _valid: bool
11
-
12
- SOCK_TOKEN = "sock="
13
- EOF_TOKEN = "EOF"
14
-
15
- def __init__(self, sock_port: Optional[int] = None) -> None:
16
- self._sock_port = sock_port
17
- self._valid = False
18
-
19
- def write(self, fname: str) -> None:
20
- dname, bname = os.path.split(fname)
21
- f = tempfile.NamedTemporaryFile(prefix=bname, dir=dname, mode="w", delete=False)
22
- try:
23
- tmp_filename = f.name
24
- with f:
25
- data = []
26
- if self._sock_port:
27
- data.append(f"{self.SOCK_TOKEN}{self._sock_port}")
28
- data.append(self.EOF_TOKEN)
29
- port_str = "\n".join(data)
30
- written = f.write(port_str)
31
- assert written == len(port_str)
32
- os.rename(tmp_filename, fname)
33
- except Exception:
34
- os.unlink(tmp_filename)
35
- raise
36
-
37
- def read(self, fname: str) -> None:
38
- with open(fname) as f:
39
- lines = f.readlines()
40
- if lines[-1] != self.EOF_TOKEN:
41
- return
42
- for ln in lines:
43
- if ln.startswith(self.SOCK_TOKEN):
44
- self._sock_port = int(ln[len(self.SOCK_TOKEN) :])
45
- self._valid = True
46
-
47
- @property
48
- def sock_port(self) -> Optional[int]:
49
- return self._sock_port
50
-
51
- @property
52
- def is_valid(self) -> bool:
53
- return self._valid