wandb 0.20.1__py3-none-win32.whl → 0.20.2rc20250616__py3-none-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +3 -6
- wandb/__init__.pyi +1 -1
- wandb/analytics/sentry.py +2 -2
- wandb/apis/importers/internals/internal.py +0 -3
- wandb/apis/public/api.py +2 -2
- wandb/apis/public/registries/{utils.py → _utils.py} +12 -12
- wandb/apis/public/registries/registries_search.py +2 -2
- wandb/apis/public/registries/registry.py +19 -18
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/beta.py +1 -7
- wandb/cli/cli.py +0 -30
- wandb/env.py +0 -6
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v6/wandb_settings_pb2.py +2 -2
- wandb/proto/v6/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +42 -1
- wandb/sdk/backend/backend.py +1 -1
- wandb/sdk/internal/handler.py +1 -69
- wandb/sdk/lib/printer.py +6 -7
- wandb/sdk/lib/progress.py +1 -3
- wandb/sdk/lib/service/ipc_support.py +13 -0
- wandb/sdk/lib/{service_connection.py → service/service_connection.py} +20 -56
- wandb/sdk/lib/service/service_port_file.py +105 -0
- wandb/sdk/lib/service/service_process.py +111 -0
- wandb/sdk/lib/service/service_token.py +164 -0
- wandb/sdk/lib/sock_client.py +8 -12
- wandb/sdk/wandb_init.py +0 -3
- wandb/sdk/wandb_require.py +9 -20
- wandb/sdk/wandb_run.py +0 -24
- wandb/sdk/wandb_settings.py +0 -9
- wandb/sdk/wandb_setup.py +2 -13
- {wandb-0.20.1.dist-info → wandb-0.20.2rc20250616.dist-info}/METADATA +1 -3
- {wandb-0.20.1.dist-info → wandb-0.20.2rc20250616.dist-info}/RECORD +42 -68
- wandb/sdk/internal/flow_control.py +0 -263
- wandb/sdk/internal/internal.py +0 -401
- wandb/sdk/internal/internal_util.py +0 -97
- wandb/sdk/internal/system/__init__.py +0 -0
- wandb/sdk/internal/system/assets/__init__.py +0 -25
- wandb/sdk/internal/system/assets/aggregators.py +0 -31
- wandb/sdk/internal/system/assets/asset_registry.py +0 -20
- wandb/sdk/internal/system/assets/cpu.py +0 -163
- wandb/sdk/internal/system/assets/disk.py +0 -210
- wandb/sdk/internal/system/assets/gpu.py +0 -416
- wandb/sdk/internal/system/assets/gpu_amd.py +0 -233
- wandb/sdk/internal/system/assets/interfaces.py +0 -205
- wandb/sdk/internal/system/assets/ipu.py +0 -177
- wandb/sdk/internal/system/assets/memory.py +0 -166
- wandb/sdk/internal/system/assets/network.py +0 -125
- wandb/sdk/internal/system/assets/open_metrics.py +0 -293
- wandb/sdk/internal/system/assets/tpu.py +0 -154
- wandb/sdk/internal/system/assets/trainium.py +0 -393
- wandb/sdk/internal/system/env_probe_helpers.py +0 -13
- wandb/sdk/internal/system/system_info.py +0 -248
- wandb/sdk/internal/system/system_monitor.py +0 -224
- wandb/sdk/internal/writer.py +0 -204
- wandb/sdk/lib/service_token.py +0 -93
- wandb/sdk/service/__init__.py +0 -0
- wandb/sdk/service/_startup_debug.py +0 -22
- wandb/sdk/service/port_file.py +0 -53
- wandb/sdk/service/server.py +0 -107
- wandb/sdk/service/server_sock.py +0 -286
- wandb/sdk/service/service.py +0 -252
- wandb/sdk/service/streams.py +0 -425
- {wandb-0.20.1.dist-info → wandb-0.20.2rc20250616.dist-info}/WHEEL +0 -0
- {wandb-0.20.1.dist-info → wandb-0.20.2rc20250616.dist-info}/entry_points.txt +0 -0
- {wandb-0.20.1.dist-info → wandb-0.20.2rc20250616.dist-info}/licenses/LICENSE +0 -0
@@ -1,293 +0,0 @@
|
|
1
|
-
import logging
|
2
|
-
import re
|
3
|
-
import threading
|
4
|
-
from collections import defaultdict, deque
|
5
|
-
from functools import lru_cache
|
6
|
-
from types import ModuleType
|
7
|
-
from typing import TYPE_CHECKING, Dict, Final, List, Mapping, Sequence, Tuple, Union
|
8
|
-
|
9
|
-
import requests
|
10
|
-
import requests.adapters
|
11
|
-
import urllib3
|
12
|
-
|
13
|
-
import wandb
|
14
|
-
from wandb.sdk.lib import hashutil, telemetry
|
15
|
-
|
16
|
-
from .aggregators import aggregate_last, aggregate_mean
|
17
|
-
from .interfaces import Interface, Metric, MetricsMonitor
|
18
|
-
|
19
|
-
if TYPE_CHECKING:
|
20
|
-
from typing import Deque, Optional
|
21
|
-
|
22
|
-
from wandb.sdk.internal.settings_static import SettingsStatic
|
23
|
-
|
24
|
-
|
25
|
-
_PREFIX: Final[str] = "openmetrics"
|
26
|
-
|
27
|
-
_REQUEST_RETRY_STRATEGY = urllib3.util.retry.Retry(
|
28
|
-
backoff_factor=1,
|
29
|
-
total=3,
|
30
|
-
status_forcelist=(408, 409, 429, 500, 502, 503, 504),
|
31
|
-
)
|
32
|
-
_REQUEST_POOL_CONNECTIONS = 4
|
33
|
-
_REQUEST_POOL_MAXSIZE = 4
|
34
|
-
_REQUEST_TIMEOUT = 3
|
35
|
-
|
36
|
-
|
37
|
-
logger = logging.getLogger(__name__)
|
38
|
-
|
39
|
-
|
40
|
-
prometheus_client_parser: "Optional[ModuleType]" = None
|
41
|
-
try:
|
42
|
-
import prometheus_client.parser # type: ignore
|
43
|
-
|
44
|
-
prometheus_client_parser = prometheus_client.parser
|
45
|
-
except ImportError:
|
46
|
-
pass
|
47
|
-
|
48
|
-
|
49
|
-
def _setup_requests_session() -> requests.Session:
|
50
|
-
session = requests.Session()
|
51
|
-
adapter = requests.adapters.HTTPAdapter(
|
52
|
-
max_retries=_REQUEST_RETRY_STRATEGY,
|
53
|
-
pool_connections=_REQUEST_POOL_CONNECTIONS,
|
54
|
-
pool_maxsize=_REQUEST_POOL_MAXSIZE,
|
55
|
-
)
|
56
|
-
session.mount("http://", adapter)
|
57
|
-
session.mount("https://", adapter)
|
58
|
-
return session
|
59
|
-
|
60
|
-
|
61
|
-
def _nested_dict_to_tuple(
|
62
|
-
nested_dict: Mapping[str, Mapping[str, str]],
|
63
|
-
) -> Tuple[Tuple[str, Tuple[str, str]], ...]:
|
64
|
-
return tuple((k, *v.items()) for k, v in nested_dict.items()) # type: ignore
|
65
|
-
|
66
|
-
|
67
|
-
def _tuple_to_nested_dict(
|
68
|
-
nested_tuple: Tuple[Tuple[str, Tuple[str, str]], ...],
|
69
|
-
) -> Dict[str, Dict[str, str]]:
|
70
|
-
return {k: dict(v) for k, *v in nested_tuple}
|
71
|
-
|
72
|
-
|
73
|
-
@lru_cache(maxsize=128)
|
74
|
-
def _should_capture_metric(
|
75
|
-
endpoint_name: str,
|
76
|
-
metric_name: str,
|
77
|
-
metric_labels: Tuple[str, ...],
|
78
|
-
filters: Tuple[Tuple[str, Tuple[str, str]], ...],
|
79
|
-
) -> bool:
|
80
|
-
# we use tuples to make the function arguments hashable => usable with lru_cache
|
81
|
-
should_capture = False
|
82
|
-
|
83
|
-
if not filters:
|
84
|
-
return should_capture
|
85
|
-
|
86
|
-
# self.filters keys are regexes, check the name against them
|
87
|
-
# and for the first match, check the labels against the label filters.
|
88
|
-
# assume that if at least one label filter doesn't match, the metric
|
89
|
-
# should not be captured.
|
90
|
-
# it's up to the user to make sure that the filters are not conflicting etc.
|
91
|
-
metric_labels_dict = {t[0]: t[1] for t in metric_labels}
|
92
|
-
filters_dict = _tuple_to_nested_dict(filters)
|
93
|
-
for metric_name_regex, label_filters in filters_dict.items():
|
94
|
-
if not re.match(metric_name_regex, f"{endpoint_name}.{metric_name}"):
|
95
|
-
continue
|
96
|
-
|
97
|
-
should_capture = True
|
98
|
-
|
99
|
-
for label, label_filter in label_filters.items():
|
100
|
-
if not re.match(label_filter, metric_labels_dict.get(label, "")):
|
101
|
-
should_capture = False
|
102
|
-
break
|
103
|
-
break
|
104
|
-
|
105
|
-
return should_capture
|
106
|
-
|
107
|
-
|
108
|
-
class OpenMetricsMetric:
|
109
|
-
"""Container for all the COUNTER and GAUGE metrics extracted from an OpenMetrics endpoint."""
|
110
|
-
|
111
|
-
def __init__(
|
112
|
-
self,
|
113
|
-
name: str,
|
114
|
-
url: str,
|
115
|
-
filters: Union[Mapping[str, Mapping[str, str]], Sequence[str], None],
|
116
|
-
) -> None:
|
117
|
-
self.name = name # user-defined name for the endpoint
|
118
|
-
self.url = url # full URL
|
119
|
-
|
120
|
-
# - filters can be a dict {"<metric regex>": {"<label>": "<filter regex>"}}
|
121
|
-
# or a sequence of metric regexes. we convert the latter to a dict
|
122
|
-
# to make it easier to work with.
|
123
|
-
# - the metric regexes are matched against the full metric name,
|
124
|
-
# i.e. "<endpoint name>.<metric name>".
|
125
|
-
# - by default, all metrics are captured.
|
126
|
-
self.filters = (
|
127
|
-
filters
|
128
|
-
if isinstance(filters, Mapping)
|
129
|
-
else {k: {} for k in filters or [".*"]}
|
130
|
-
)
|
131
|
-
self.filters_tuple = _nested_dict_to_tuple(self.filters) if self.filters else ()
|
132
|
-
|
133
|
-
self._session: Optional[requests.Session] = None
|
134
|
-
self.samples: Deque[dict] = deque([])
|
135
|
-
# {"<metric name>": {"<labels hash>": <index>}}
|
136
|
-
self.label_map: Dict[str, Dict[str, int]] = defaultdict(dict)
|
137
|
-
# {"<labels hash>": <labels>}
|
138
|
-
self.label_hashes: Dict[str, dict] = {}
|
139
|
-
|
140
|
-
def setup(self) -> None:
|
141
|
-
if self._session is not None:
|
142
|
-
return
|
143
|
-
|
144
|
-
self._session = _setup_requests_session()
|
145
|
-
|
146
|
-
def teardown(self) -> None:
|
147
|
-
if self._session is None:
|
148
|
-
return
|
149
|
-
|
150
|
-
self._session.close()
|
151
|
-
self._session = None
|
152
|
-
|
153
|
-
def parse_open_metrics_endpoint(self) -> Dict[str, Union[str, int, float]]:
|
154
|
-
assert prometheus_client_parser is not None
|
155
|
-
assert self._session is not None
|
156
|
-
|
157
|
-
response = self._session.get(self.url, timeout=_REQUEST_TIMEOUT)
|
158
|
-
response.raise_for_status()
|
159
|
-
|
160
|
-
text = response.text
|
161
|
-
measurement = {}
|
162
|
-
for family in prometheus_client_parser.text_string_to_metric_families(text):
|
163
|
-
if family.type not in ("counter", "gauge"):
|
164
|
-
# todo: add support for other metric types?
|
165
|
-
# todo: log warning about that?
|
166
|
-
continue
|
167
|
-
for sample in family.samples:
|
168
|
-
name, labels, value = sample.name, sample.labels, sample.value
|
169
|
-
|
170
|
-
if not _should_capture_metric(
|
171
|
-
self.name,
|
172
|
-
name,
|
173
|
-
tuple(labels.items()),
|
174
|
-
self.filters_tuple,
|
175
|
-
):
|
176
|
-
continue
|
177
|
-
|
178
|
-
# md5 hash of the labels
|
179
|
-
label_hash = hashutil._md5(str(labels).encode("utf-8")).hexdigest()
|
180
|
-
if label_hash not in self.label_map[name]:
|
181
|
-
# store the index of the label hash in the label map
|
182
|
-
self.label_map[name][label_hash] = len(self.label_map[name])
|
183
|
-
# store the labels themselves
|
184
|
-
self.label_hashes[label_hash] = labels
|
185
|
-
index = self.label_map[name][label_hash]
|
186
|
-
measurement[f"{name}.{index}"] = value
|
187
|
-
|
188
|
-
return measurement
|
189
|
-
|
190
|
-
def sample(self) -> None:
|
191
|
-
s = self.parse_open_metrics_endpoint()
|
192
|
-
self.samples.append(s)
|
193
|
-
|
194
|
-
def clear(self) -> None:
|
195
|
-
self.samples.clear()
|
196
|
-
|
197
|
-
def aggregate(self) -> dict:
|
198
|
-
if not self.samples:
|
199
|
-
return {}
|
200
|
-
|
201
|
-
prefix = f"{_PREFIX}.{self.name}."
|
202
|
-
|
203
|
-
stats = {}
|
204
|
-
for key in self.samples[0].keys():
|
205
|
-
samples = [s[key] for s in self.samples if key in s]
|
206
|
-
if samples and all(isinstance(s, (int, float)) for s in samples):
|
207
|
-
stats[f"{prefix}{key}"] = aggregate_mean(samples)
|
208
|
-
else:
|
209
|
-
stats[f"{prefix}{key}"] = aggregate_last(samples)
|
210
|
-
return stats
|
211
|
-
|
212
|
-
|
213
|
-
class OpenMetrics:
|
214
|
-
# Poll an OpenMetrics endpoint, parse the response and return a dict of metrics
|
215
|
-
# Implements the same Protocol interface as Asset
|
216
|
-
|
217
|
-
def __init__(
|
218
|
-
self,
|
219
|
-
interface: "Interface",
|
220
|
-
settings: "SettingsStatic",
|
221
|
-
shutdown_event: threading.Event,
|
222
|
-
name: str,
|
223
|
-
url: str,
|
224
|
-
) -> None:
|
225
|
-
self.name = name
|
226
|
-
self.url = url
|
227
|
-
self.interface = interface
|
228
|
-
self.settings = settings
|
229
|
-
self.shutdown_event = shutdown_event
|
230
|
-
|
231
|
-
self.metrics: List[Metric] = [
|
232
|
-
OpenMetricsMetric(name, url, settings.x_stats_open_metrics_filters)
|
233
|
-
]
|
234
|
-
|
235
|
-
self.metrics_monitor: MetricsMonitor = MetricsMonitor(
|
236
|
-
asset_name=self.name,
|
237
|
-
metrics=self.metrics,
|
238
|
-
interface=interface,
|
239
|
-
settings=settings,
|
240
|
-
shutdown_event=shutdown_event,
|
241
|
-
)
|
242
|
-
|
243
|
-
telemetry_record = telemetry.TelemetryRecord()
|
244
|
-
telemetry_record.feature.open_metrics = True
|
245
|
-
interface._publish_telemetry(telemetry_record)
|
246
|
-
|
247
|
-
@classmethod
|
248
|
-
def is_available(cls, url: str) -> bool:
|
249
|
-
_is_available: bool = False
|
250
|
-
|
251
|
-
ret = prometheus_client_parser is not None
|
252
|
-
if not ret:
|
253
|
-
wandb.termwarn(
|
254
|
-
"Monitoring OpenMetrics endpoints requires the `prometheus_client` package. "
|
255
|
-
"To install it, run `pip install prometheus_client`.",
|
256
|
-
repeat=False,
|
257
|
-
)
|
258
|
-
return _is_available
|
259
|
-
# check if the endpoint is available and is a valid OpenMetrics endpoint
|
260
|
-
_session: Optional[requests.Session] = None
|
261
|
-
try:
|
262
|
-
assert prometheus_client_parser is not None
|
263
|
-
_session = _setup_requests_session()
|
264
|
-
response = _session.get(url, timeout=_REQUEST_TIMEOUT)
|
265
|
-
response.raise_for_status()
|
266
|
-
|
267
|
-
# check if the response is a valid OpenMetrics response
|
268
|
-
# text_string_to_metric_families returns a generator
|
269
|
-
if list(
|
270
|
-
prometheus_client_parser.text_string_to_metric_families(response.text)
|
271
|
-
):
|
272
|
-
_is_available = True
|
273
|
-
except Exception as e:
|
274
|
-
logger.debug(
|
275
|
-
f"OpenMetrics endpoint {url} is not available: {e}", exc_info=True
|
276
|
-
)
|
277
|
-
|
278
|
-
if _session is not None:
|
279
|
-
try:
|
280
|
-
_session.close()
|
281
|
-
except Exception:
|
282
|
-
pass
|
283
|
-
return _is_available
|
284
|
-
|
285
|
-
def start(self) -> None:
|
286
|
-
self.metrics_monitor.start()
|
287
|
-
|
288
|
-
def finish(self) -> None:
|
289
|
-
self.metrics_monitor.finish()
|
290
|
-
|
291
|
-
def probe(self) -> dict:
|
292
|
-
# todo: also return self.label_hashes
|
293
|
-
return {self.name: self.url}
|
@@ -1,154 +0,0 @@
|
|
1
|
-
import logging
|
2
|
-
import os
|
3
|
-
import threading
|
4
|
-
from collections import deque
|
5
|
-
from typing import TYPE_CHECKING, List, Optional
|
6
|
-
|
7
|
-
from .aggregators import aggregate_mean
|
8
|
-
from .asset_registry import asset_registry
|
9
|
-
from .interfaces import Interface, Metric, MetricsMonitor
|
10
|
-
|
11
|
-
if TYPE_CHECKING:
|
12
|
-
from typing import Deque
|
13
|
-
|
14
|
-
from wandb.sdk.internal.settings_static import SettingsStatic
|
15
|
-
|
16
|
-
logger = logging.getLogger(__name__)
|
17
|
-
|
18
|
-
|
19
|
-
class TPUUtilization:
|
20
|
-
"""Google Cloud TPU utilization in percent."""
|
21
|
-
|
22
|
-
name = "tpu"
|
23
|
-
samples: "Deque[float]"
|
24
|
-
|
25
|
-
def __init__(
|
26
|
-
self,
|
27
|
-
service_addr: str,
|
28
|
-
duration_ms: int = 100,
|
29
|
-
) -> None:
|
30
|
-
self.samples = deque([])
|
31
|
-
|
32
|
-
self.duration_ms = duration_ms
|
33
|
-
self.service_addr = service_addr
|
34
|
-
|
35
|
-
try:
|
36
|
-
from tensorflow.python.profiler import profiler_client # type: ignore
|
37
|
-
|
38
|
-
self._profiler_client = profiler_client
|
39
|
-
except ImportError:
|
40
|
-
logger.warning(
|
41
|
-
"Unable to import `tensorflow.python.profiler.profiler_client`. "
|
42
|
-
"TPU metrics will not be reported."
|
43
|
-
)
|
44
|
-
self._profiler_client = None
|
45
|
-
|
46
|
-
def sample(self) -> None:
|
47
|
-
result = self._profiler_client.monitor(
|
48
|
-
self.service_addr, duration_ms=self.duration_ms, level=2
|
49
|
-
)
|
50
|
-
|
51
|
-
self.samples.append(
|
52
|
-
float(result.split("Utilization ")[1].split(": ")[1].split("%")[0])
|
53
|
-
)
|
54
|
-
|
55
|
-
def clear(self) -> None:
|
56
|
-
self.samples.clear()
|
57
|
-
|
58
|
-
def aggregate(self) -> dict:
|
59
|
-
if not self.samples:
|
60
|
-
return {}
|
61
|
-
aggregate = aggregate_mean(self.samples)
|
62
|
-
return {self.name: aggregate}
|
63
|
-
|
64
|
-
|
65
|
-
@asset_registry.register
|
66
|
-
class TPU:
|
67
|
-
def __init__(
|
68
|
-
self,
|
69
|
-
interface: "Interface",
|
70
|
-
settings: "SettingsStatic",
|
71
|
-
shutdown_event: threading.Event,
|
72
|
-
) -> None:
|
73
|
-
self.name = self.__class__.__name__.lower()
|
74
|
-
self.service_addr = self.get_service_addr()
|
75
|
-
self.metrics: List[Metric] = [TPUUtilization(self.service_addr)]
|
76
|
-
|
77
|
-
self.metrics_monitor = MetricsMonitor(
|
78
|
-
self.name,
|
79
|
-
self.metrics,
|
80
|
-
interface,
|
81
|
-
settings,
|
82
|
-
shutdown_event,
|
83
|
-
)
|
84
|
-
|
85
|
-
@staticmethod
|
86
|
-
def get_service_addr(
|
87
|
-
service_addr: Optional[str] = None,
|
88
|
-
tpu_name: Optional[str] = None,
|
89
|
-
compute_zone: Optional[str] = None,
|
90
|
-
core_project: Optional[str] = None,
|
91
|
-
) -> str:
|
92
|
-
if service_addr is not None:
|
93
|
-
if tpu_name is not None:
|
94
|
-
logger.warning(
|
95
|
-
"Both service_addr and tpu_name arguments provided. "
|
96
|
-
"Ignoring tpu_name and using service_addr."
|
97
|
-
)
|
98
|
-
else:
|
99
|
-
tpu_name = tpu_name or os.environ.get("TPU_NAME")
|
100
|
-
if tpu_name is None:
|
101
|
-
raise Exception("Required environment variable TPU_NAME.")
|
102
|
-
compute_zone = compute_zone or os.environ.get("CLOUDSDK_COMPUTE_ZONE")
|
103
|
-
core_project = core_project or os.environ.get("CLOUDSDK_CORE_PROJECT")
|
104
|
-
try:
|
105
|
-
from tensorflow.python.distribute.cluster_resolver import ( # type: ignore
|
106
|
-
tpu_cluster_resolver,
|
107
|
-
)
|
108
|
-
|
109
|
-
service_addr = tpu_cluster_resolver.TPUClusterResolver(
|
110
|
-
[tpu_name], zone=compute_zone, project=core_project
|
111
|
-
).get_master()
|
112
|
-
except (ValueError, TypeError):
|
113
|
-
raise ValueError(
|
114
|
-
"Failed to find TPU. Try specifying TPU zone "
|
115
|
-
"(via CLOUDSDK_COMPUTE_ZONE environment variable)"
|
116
|
-
" and GCP project (via CLOUDSDK_CORE_PROJECT "
|
117
|
-
"environment variable)."
|
118
|
-
)
|
119
|
-
service_addr = service_addr.replace("grpc://", "").replace(":8470", ":8466")
|
120
|
-
return service_addr
|
121
|
-
|
122
|
-
def start(self) -> None:
|
123
|
-
if self.metrics:
|
124
|
-
self.metrics_monitor.start()
|
125
|
-
|
126
|
-
def finish(self) -> None:
|
127
|
-
self.metrics_monitor.finish()
|
128
|
-
|
129
|
-
@classmethod
|
130
|
-
def is_available(cls) -> bool:
|
131
|
-
if os.environ.get("TPU_NAME", False) is False:
|
132
|
-
return False
|
133
|
-
|
134
|
-
try:
|
135
|
-
from tensorflow.python.distribute.cluster_resolver import ( # noqa: F401
|
136
|
-
tpu_cluster_resolver,
|
137
|
-
)
|
138
|
-
from tensorflow.python.profiler import profiler_client # noqa: F401
|
139
|
-
|
140
|
-
cls.get_service_addr()
|
141
|
-
except (
|
142
|
-
ImportError,
|
143
|
-
TypeError,
|
144
|
-
AttributeError,
|
145
|
-
ValueError,
|
146
|
-
): # Saw type error when iterating paths on colab...
|
147
|
-
# TODO: Saw error in sentry where module 'tensorflow.python.pywrap_tensorflow'
|
148
|
-
# has no attribute 'TFE_DEVICE_PLACEMENT_EXPLICIT'
|
149
|
-
return False
|
150
|
-
|
151
|
-
return True
|
152
|
-
|
153
|
-
def probe(self) -> dict:
|
154
|
-
return {self.name: {"service_address": self.service_addr}}
|