wandb 0.18.0rc1__py3-none-any.whl → 0.18.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +4 -4
- wandb/__init__.pyi +67 -12
- wandb/apis/internal.py +3 -0
- wandb/apis/public/api.py +128 -2
- wandb/apis/public/artifacts.py +11 -7
- wandb/apis/public/jobs.py +8 -0
- wandb/apis/public/runs.py +18 -5
- wandb/bin/nvidia_gpu_stats +0 -0
- wandb/cli/cli.py +0 -5
- wandb/data_types.py +9 -2019
- wandb/env.py +0 -5
- wandb/errors/__init__.py +11 -40
- wandb/errors/errors.py +37 -0
- wandb/errors/warnings.py +2 -0
- wandb/{sklearn → integration/sklearn}/calculate/calibration_curves.py +7 -7
- wandb/{sklearn → integration/sklearn}/calculate/class_proportions.py +1 -1
- wandb/{sklearn → integration/sklearn}/calculate/confusion_matrix.py +3 -2
- wandb/{sklearn → integration/sklearn}/calculate/elbow_curve.py +6 -6
- wandb/{sklearn → integration/sklearn}/calculate/learning_curve.py +2 -2
- wandb/{sklearn → integration/sklearn}/calculate/outlier_candidates.py +2 -2
- wandb/{sklearn → integration/sklearn}/calculate/residuals.py +8 -8
- wandb/{sklearn → integration/sklearn}/calculate/silhouette.py +2 -2
- wandb/{sklearn → integration/sklearn}/calculate/summary_metrics.py +2 -2
- wandb/{sklearn → integration/sklearn}/plot/classifier.py +5 -5
- wandb/{sklearn → integration/sklearn}/plot/clusterer.py +10 -6
- wandb/{sklearn → integration/sklearn}/plot/regressor.py +5 -5
- wandb/{sklearn → integration/sklearn}/plot/shared.py +3 -3
- wandb/{sklearn → integration/sklearn}/utils.py +8 -8
- wandb/integration/tensorboard/log.py +1 -1
- wandb/{wandb_torch.py → integration/torch/wandb_torch.py} +36 -32
- wandb/old/core.py +2 -80
- wandb/plot/bar.py +7 -4
- wandb/plot/confusion_matrix.py +5 -4
- wandb/plot/histogram.py +7 -4
- wandb/plot/line.py +7 -4
- wandb/proto/v3/wandb_base_pb2.py +2 -1
- wandb/proto/v3/wandb_internal_pb2.py +2 -1
- wandb/proto/v3/wandb_server_pb2.py +2 -1
- wandb/proto/v3/wandb_settings_pb2.py +3 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +2 -1
- wandb/proto/v4/wandb_base_pb2.py +2 -1
- wandb/proto/v4/wandb_internal_pb2.py +2 -1
- wandb/proto/v4/wandb_server_pb2.py +2 -1
- wandb/proto/v4/wandb_settings_pb2.py +3 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +2 -1
- wandb/proto/v5/wandb_base_pb2.py +3 -2
- wandb/proto/v5/wandb_internal_pb2.py +3 -2
- wandb/proto/v5/wandb_server_pb2.py +3 -2
- wandb/proto/v5/wandb_settings_pb2.py +4 -3
- wandb/proto/v5/wandb_telemetry_pb2.py +3 -2
- wandb/sdk/artifacts/_validators.py +48 -3
- wandb/sdk/artifacts/artifact.py +157 -183
- wandb/sdk/artifacts/artifact_file_cache.py +13 -11
- wandb/sdk/artifacts/artifact_instance_cache.py +4 -2
- wandb/sdk/artifacts/artifact_manifest.py +13 -11
- wandb/sdk/artifacts/artifact_manifest_entry.py +24 -22
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +9 -7
- wandb/sdk/artifacts/artifact_saver.py +27 -25
- wandb/sdk/artifacts/exceptions.py +26 -25
- wandb/sdk/artifacts/storage_handler.py +11 -9
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +16 -14
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +15 -13
- wandb/sdk/artifacts/storage_handlers/http_handler.py +15 -14
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +10 -8
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +14 -12
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +19 -19
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +10 -8
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +12 -10
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +9 -7
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +31 -29
- wandb/sdk/artifacts/storage_policy.py +20 -20
- wandb/sdk/backend/backend.py +8 -26
- wandb/sdk/data_types/audio.py +165 -0
- wandb/sdk/data_types/base_types/wb_value.py +1 -3
- wandb/sdk/data_types/bokeh.py +70 -0
- wandb/sdk/data_types/graph.py +405 -0
- wandb/sdk/data_types/image.py +156 -0
- wandb/sdk/data_types/table.py +1204 -0
- wandb/sdk/data_types/trace_tree.py +2 -2
- wandb/sdk/data_types/utils.py +49 -0
- wandb/sdk/data_types/video.py +2 -2
- wandb/sdk/interface/interface.py +0 -24
- wandb/sdk/interface/interface_shared.py +0 -12
- wandb/sdk/internal/handler.py +0 -10
- wandb/sdk/internal/internal_api.py +71 -0
- wandb/sdk/internal/sender.py +0 -43
- wandb/sdk/internal/tb_watcher.py +1 -1
- wandb/sdk/lib/_settings_toposort_generated.py +1 -0
- wandb/sdk/lib/hashutil.py +34 -12
- wandb/sdk/lib/service_connection.py +216 -0
- wandb/sdk/lib/service_token.py +94 -0
- wandb/sdk/lib/sock_client.py +7 -3
- wandb/sdk/service/server.py +2 -5
- wandb/sdk/service/service.py +2 -31
- wandb/sdk/service/streams.py +0 -7
- wandb/sdk/wandb_init.py +42 -25
- wandb/sdk/wandb_run.py +18 -159
- wandb/sdk/wandb_settings.py +2 -0
- wandb/sdk/wandb_setup.py +25 -16
- wandb/sdk/wandb_sync.py +9 -3
- wandb/sdk/wandb_watch.py +31 -15
- wandb/sklearn.py +35 -0
- wandb/util.py +14 -3
- {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/METADATA +6 -5
- {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/RECORD +114 -110
- wandb/sdk/internal/update.py +0 -113
- wandb/sdk/lib/console.py +0 -39
- wandb/sdk/service/service_base.py +0 -50
- wandb/sdk/service/service_sock.py +0 -70
- wandb/sdk/wandb_manager.py +0 -232
- /wandb/{sklearn → integration/sklearn}/__init__.py +0 -0
- /wandb/{sklearn → integration/sklearn}/calculate/__init__.py +0 -0
- /wandb/{sklearn → integration/sklearn}/calculate/decision_boundaries.py +0 -0
- /wandb/{sklearn → integration/sklearn}/calculate/feature_importances.py +0 -0
- /wandb/{sklearn → integration/sklearn}/plot/__init__.py +0 -0
- /wandb/{sdk/lib → plot}/viz.py +0 -0
- {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/WHEEL +0 -0
- {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/entry_points.txt +0 -0
- {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/licenses/LICENSE +0 -0
@@ -14,9 +14,9 @@ from enum import Enum
|
|
14
14
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
15
15
|
|
16
16
|
import wandb
|
17
|
-
import wandb.data_types
|
18
17
|
from wandb.sdk.data_types import _dtypes
|
19
18
|
from wandb.sdk.data_types.base_types.media import Media
|
19
|
+
from wandb.sdk.data_types.utils import _json_helper
|
20
20
|
|
21
21
|
if TYPE_CHECKING: # pragma: no cover
|
22
22
|
from wandb.sdk.artifacts.artifact import Artifact
|
@@ -142,7 +142,7 @@ def _fallback_serialize(obj: Any) -> str:
|
|
142
142
|
def _safe_serialize(obj: dict) -> str:
|
143
143
|
try:
|
144
144
|
return json.dumps(
|
145
|
-
|
145
|
+
_json_helper(obj, None),
|
146
146
|
skipkeys=True,
|
147
147
|
default=_fallback_serialize,
|
148
148
|
)
|
wandb/sdk/data_types/utils.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
|
+
import datetime
|
1
2
|
import logging
|
2
3
|
import os
|
3
4
|
import re
|
5
|
+
from decimal import Decimal
|
4
6
|
from typing import TYPE_CHECKING, Optional, Sequence, Union, cast
|
5
7
|
|
6
8
|
import wandb
|
@@ -178,3 +180,50 @@ def _prune_max_seq(seq: Sequence["BatchableMedia"]) -> Sequence["BatchableMedia"
|
|
178
180
|
)
|
179
181
|
items = seq[: seq[0].MAX_ITEMS]
|
180
182
|
return items
|
183
|
+
|
184
|
+
|
185
|
+
def _json_helper(val, artifact):
|
186
|
+
if isinstance(val, WBValue):
|
187
|
+
return val.to_json(artifact)
|
188
|
+
elif val.__class__ is dict:
|
189
|
+
res = {}
|
190
|
+
for key in val:
|
191
|
+
res[key] = _json_helper(val[key], artifact)
|
192
|
+
return res
|
193
|
+
|
194
|
+
if hasattr(val, "tolist"):
|
195
|
+
py_val = val.tolist()
|
196
|
+
if val.__class__.__name__ == "datetime64" and isinstance(py_val, int):
|
197
|
+
# when numpy datetime64 .tolist() returns an int, it is nanoseconds.
|
198
|
+
# need to convert to milliseconds
|
199
|
+
return _json_helper(py_val / int(1e6), artifact)
|
200
|
+
return _json_helper(py_val, artifact)
|
201
|
+
elif hasattr(val, "item"):
|
202
|
+
return _json_helper(val.item(), artifact)
|
203
|
+
|
204
|
+
if isinstance(val, datetime.datetime):
|
205
|
+
if val.tzinfo is None:
|
206
|
+
val = datetime.datetime(
|
207
|
+
val.year,
|
208
|
+
val.month,
|
209
|
+
val.day,
|
210
|
+
val.hour,
|
211
|
+
val.minute,
|
212
|
+
val.second,
|
213
|
+
val.microsecond,
|
214
|
+
tzinfo=datetime.timezone.utc,
|
215
|
+
)
|
216
|
+
return int(val.timestamp() * 1000)
|
217
|
+
elif isinstance(val, datetime.date):
|
218
|
+
return int(
|
219
|
+
datetime.datetime(
|
220
|
+
val.year, val.month, val.day, tzinfo=datetime.timezone.utc
|
221
|
+
).timestamp()
|
222
|
+
* 1000
|
223
|
+
)
|
224
|
+
elif isinstance(val, (list, tuple)):
|
225
|
+
return [_json_helper(i, artifact) for i in val]
|
226
|
+
elif isinstance(val, Decimal):
|
227
|
+
return float(val)
|
228
|
+
else:
|
229
|
+
return util.json_friendly(val)[0]
|
wandb/sdk/data_types/video.py
CHANGED
@@ -34,7 +34,7 @@ def write_gif_with_image_io(
|
|
34
34
|
) -> None:
|
35
35
|
imageio = util.get_module(
|
36
36
|
"imageio",
|
37
|
-
required='wandb.Video requires imageio when passing raw data. Install with "pip install
|
37
|
+
required='wandb.Video requires imageio when passing raw data. Install with "pip install wandb[media]"',
|
38
38
|
)
|
39
39
|
|
40
40
|
writer = imageio.save(filename, fps=clip.fps, quantizer=0, palettesize=256, loop=0)
|
@@ -130,7 +130,7 @@ class Video(BatchableMedia):
|
|
130
130
|
def encode(self) -> None:
|
131
131
|
mpy = util.get_module(
|
132
132
|
"moviepy.editor",
|
133
|
-
required='wandb.Video requires moviepy
|
133
|
+
required='wandb.Video requires moviepy when passing raw data. Install with "pip install wandb[media]"',
|
134
134
|
)
|
135
135
|
tensor = self._prepare_video(self.data)
|
136
136
|
_, self._height, self._width, self._channels = tensor.shape # type: ignore
|
wandb/sdk/interface/interface.py
CHANGED
@@ -891,20 +891,6 @@ class InterfaceBase:
|
|
891
891
|
def _deliver_attach(self, status: pb.AttachRequest) -> MailboxHandle:
|
892
892
|
raise NotImplementedError
|
893
893
|
|
894
|
-
def deliver_check_version(
|
895
|
-
self, current_version: Optional[str] = None
|
896
|
-
) -> MailboxHandle:
|
897
|
-
check_version = pb.CheckVersionRequest()
|
898
|
-
if current_version:
|
899
|
-
check_version.current_version = current_version
|
900
|
-
return self._deliver_check_version(check_version)
|
901
|
-
|
902
|
-
@abstractmethod
|
903
|
-
def _deliver_check_version(
|
904
|
-
self, check_version: pb.CheckVersionRequest
|
905
|
-
) -> MailboxHandle:
|
906
|
-
raise NotImplementedError
|
907
|
-
|
908
894
|
def deliver_stop_status(self) -> MailboxHandle:
|
909
895
|
status = pb.StopStatusRequest()
|
910
896
|
return self._deliver_stop_status(status)
|
@@ -965,16 +951,6 @@ class InterfaceBase:
|
|
965
951
|
def _deliver_poll_exit(self, poll_exit: pb.PollExitRequest) -> MailboxHandle:
|
966
952
|
raise NotImplementedError
|
967
953
|
|
968
|
-
def deliver_request_server_info(self) -> MailboxHandle:
|
969
|
-
server_info = pb.ServerInfoRequest()
|
970
|
-
return self._deliver_request_server_info(server_info)
|
971
|
-
|
972
|
-
@abstractmethod
|
973
|
-
def _deliver_request_server_info(
|
974
|
-
self, server_info: pb.ServerInfoRequest
|
975
|
-
) -> MailboxHandle:
|
976
|
-
raise NotImplementedError
|
977
|
-
|
978
954
|
def deliver_request_sampled_history(self) -> MailboxHandle:
|
979
955
|
sampled_history = pb.SampledHistoryRequest()
|
980
956
|
return self._deliver_request_sampled_history(sampled_history)
|
@@ -490,12 +490,6 @@ class InterfaceShared(InterfaceBase):
|
|
490
490
|
record = self._make_request(attach=attach)
|
491
491
|
return self._deliver_record(record)
|
492
492
|
|
493
|
-
def _deliver_check_version(
|
494
|
-
self, check_version: pb.CheckVersionRequest
|
495
|
-
) -> MailboxHandle:
|
496
|
-
record = self._make_request(check_version=check_version)
|
497
|
-
return self._deliver_record(record)
|
498
|
-
|
499
493
|
def _deliver_network_status(
|
500
494
|
self, network_status: pb.NetworkStatusRequest
|
501
495
|
) -> MailboxHandle:
|
@@ -508,12 +502,6 @@ class InterfaceShared(InterfaceBase):
|
|
508
502
|
record = self._make_request(internal_messages=internal_message)
|
509
503
|
return self._deliver_record(record)
|
510
504
|
|
511
|
-
def _deliver_request_server_info(
|
512
|
-
self, server_info: pb.ServerInfoRequest
|
513
|
-
) -> MailboxHandle:
|
514
|
-
record = self._make_request(server_info=server_info)
|
515
|
-
return self._deliver_record(record)
|
516
|
-
|
517
505
|
def _deliver_request_sampled_history(
|
518
506
|
self, sampled_history: pb.SampledHistoryRequest
|
519
507
|
) -> MailboxHandle:
|
wandb/sdk/internal/handler.py
CHANGED
@@ -659,13 +659,6 @@ class HandleManager:
|
|
659
659
|
def handle_footer(self, record: Record) -> None:
|
660
660
|
self._dispatch_record(record)
|
661
661
|
|
662
|
-
def handle_request_check_version(self, record: Record) -> None:
|
663
|
-
if self._settings._offline:
|
664
|
-
result = proto_util._result_from_record(record)
|
665
|
-
self._respond_result(result)
|
666
|
-
else:
|
667
|
-
self._dispatch_record(record)
|
668
|
-
|
669
662
|
def handle_request_attach(self, record: Record) -> None:
|
670
663
|
result = proto_util._result_from_record(record)
|
671
664
|
attach_id = record.request.attach.attach_id
|
@@ -862,9 +855,6 @@ class HandleManager:
|
|
862
855
|
result.response.sampled_history_response.item.append(item)
|
863
856
|
self._respond_result(result)
|
864
857
|
|
865
|
-
def handle_request_server_info(self, record: Record) -> None:
|
866
|
-
self._dispatch_record(record, always_send=True)
|
867
|
-
|
868
858
|
def handle_request_keepalive(self, record: Record) -> None:
|
869
859
|
"""Handle a keepalive request.
|
870
860
|
|
@@ -53,6 +53,8 @@ from .progress import Progress
|
|
53
53
|
|
54
54
|
logger = logging.getLogger(__name__)
|
55
55
|
|
56
|
+
LAUNCH_DEFAULT_PROJECT = "model-registry"
|
57
|
+
|
56
58
|
if TYPE_CHECKING:
|
57
59
|
if sys.version_info >= (3, 8):
|
58
60
|
from typing import Literal, TypedDict
|
@@ -674,6 +676,11 @@ class Api:
|
|
674
676
|
self.server_create_run_queue_supports_priority,
|
675
677
|
)
|
676
678
|
|
679
|
+
@normalize_exceptions
|
680
|
+
def upsert_run_queue_introspection(self) -> bool:
|
681
|
+
_, _, mutations = self.server_info_introspection()
|
682
|
+
return "upsertRunQueue" in mutations
|
683
|
+
|
677
684
|
@normalize_exceptions
|
678
685
|
def push_to_run_queue_introspection(self) -> Tuple[bool, bool]:
|
679
686
|
query_string = """
|
@@ -1580,6 +1587,70 @@ class Api:
|
|
1580
1587
|
]
|
1581
1588
|
return result
|
1582
1589
|
|
1590
|
+
@normalize_exceptions
|
1591
|
+
def upsert_run_queue(
|
1592
|
+
self,
|
1593
|
+
queue_name: str,
|
1594
|
+
entity: str,
|
1595
|
+
resource_type: str,
|
1596
|
+
resource_config: dict,
|
1597
|
+
project: str = LAUNCH_DEFAULT_PROJECT,
|
1598
|
+
prioritization_mode: Optional[str] = None,
|
1599
|
+
template_variables: Optional[dict] = None,
|
1600
|
+
external_links: Optional[dict] = None,
|
1601
|
+
) -> Optional[Dict[str, Any]]:
|
1602
|
+
if not self.upsert_run_queue_introspection():
|
1603
|
+
raise UnsupportedError(
|
1604
|
+
"upserting run queues is not supported by this version of "
|
1605
|
+
"wandb server. Consider updating to the latest version."
|
1606
|
+
)
|
1607
|
+
query = gql(
|
1608
|
+
"""
|
1609
|
+
mutation upsertRunQueue(
|
1610
|
+
$entityName: String!
|
1611
|
+
$projectName: String!
|
1612
|
+
$queueName: String!
|
1613
|
+
$resourceType: String!
|
1614
|
+
$resourceConfig: JSONString!
|
1615
|
+
$templateVariables: JSONString
|
1616
|
+
$prioritizationMode: RunQueuePrioritizationMode
|
1617
|
+
$externalLinks: JSONString
|
1618
|
+
$clientMutationId: String
|
1619
|
+
) {
|
1620
|
+
upsertRunQueue(
|
1621
|
+
input: {
|
1622
|
+
entityName: $entityName
|
1623
|
+
projectName: $projectName
|
1624
|
+
queueName: $queueName
|
1625
|
+
resourceType: $resourceType
|
1626
|
+
resourceConfig: $resourceConfig
|
1627
|
+
templateVariables: $templateVariables
|
1628
|
+
prioritizationMode: $prioritizationMode
|
1629
|
+
externalLinks: $externalLinks
|
1630
|
+
clientMutationId: $clientMutationId
|
1631
|
+
}
|
1632
|
+
) {
|
1633
|
+
success
|
1634
|
+
configSchemaValidationErrors
|
1635
|
+
}
|
1636
|
+
}
|
1637
|
+
"""
|
1638
|
+
)
|
1639
|
+
variable_values = {
|
1640
|
+
"entityName": entity,
|
1641
|
+
"projectName": project,
|
1642
|
+
"queueName": queue_name,
|
1643
|
+
"resourceType": resource_type,
|
1644
|
+
"resourceConfig": json.dumps(resource_config),
|
1645
|
+
"templateVariables": (
|
1646
|
+
json.dumps(template_variables) if template_variables else None
|
1647
|
+
),
|
1648
|
+
"prioritizationMode": prioritization_mode,
|
1649
|
+
"externalLinks": json.dumps(external_links) if external_links else None,
|
1650
|
+
}
|
1651
|
+
result: Dict[str, Any] = self.gql(query, variable_values)
|
1652
|
+
return result["upsertRunQueue"]
|
1653
|
+
|
1583
1654
|
@normalize_exceptions
|
1584
1655
|
def push_to_run_queue_by_name(
|
1585
1656
|
self,
|
wandb/sdk/internal/sender.py
CHANGED
@@ -42,7 +42,6 @@ from wandb.sdk.internal import (
|
|
42
42
|
file_stream,
|
43
43
|
internal_api,
|
44
44
|
sender_config,
|
45
|
-
update,
|
46
45
|
)
|
47
46
|
from wandb.sdk.internal.file_pusher import FilePusher
|
48
47
|
from wandb.sdk.internal.job_builder import JobBuilder
|
@@ -51,7 +50,6 @@ from wandb.sdk.lib import (
|
|
51
50
|
config_util,
|
52
51
|
filenames,
|
53
52
|
filesystem,
|
54
|
-
printer,
|
55
53
|
proto_util,
|
56
54
|
redirect,
|
57
55
|
telemetry,
|
@@ -483,25 +481,6 @@ class SendManager:
|
|
483
481
|
# make sure that we always update writer for every sended read request
|
484
482
|
self._maybe_report_status(always=True)
|
485
483
|
|
486
|
-
def send_request_check_version(self, record: "Record") -> None:
|
487
|
-
assert record.control.req_resp or record.control.mailbox_slot
|
488
|
-
result = proto_util._result_from_record(record)
|
489
|
-
current_version = (
|
490
|
-
record.request.check_version.current_version or wandb.__version__
|
491
|
-
)
|
492
|
-
messages = update.check_available(current_version)
|
493
|
-
if messages:
|
494
|
-
upgrade_message = messages.get("upgrade_message")
|
495
|
-
if upgrade_message:
|
496
|
-
result.response.check_version_response.upgrade_message = upgrade_message
|
497
|
-
yank_message = messages.get("yank_message")
|
498
|
-
if yank_message:
|
499
|
-
result.response.check_version_response.yank_message = yank_message
|
500
|
-
delete_message = messages.get("delete_message")
|
501
|
-
if delete_message:
|
502
|
-
result.response.check_version_response.delete_message = delete_message
|
503
|
-
self._respond_result(result)
|
504
|
-
|
505
484
|
def send_request_stop_status(self, record: "Record") -> None:
|
506
485
|
result = proto_util._result_from_record(record)
|
507
486
|
status_resp = result.response.stop_status_response
|
@@ -724,28 +703,6 @@ class SendManager:
|
|
724
703
|
|
725
704
|
self._respond_result(result)
|
726
705
|
|
727
|
-
def send_request_server_info(self, record: "Record") -> None:
|
728
|
-
assert record.control.req_resp or record.control.mailbox_slot
|
729
|
-
result = proto_util._result_from_record(record)
|
730
|
-
|
731
|
-
result.response.server_info_response.local_info.CopyFrom(self.get_local_info())
|
732
|
-
for message in self._server_messages:
|
733
|
-
# guard against the case the message level returns malformed from server
|
734
|
-
message_level = str(message.get("messageLevel"))
|
735
|
-
message_level_sanitized = int(
|
736
|
-
printer.INFO if not message_level.isdigit() else message_level
|
737
|
-
)
|
738
|
-
result.response.server_info_response.server_messages.item.append(
|
739
|
-
wandb_internal_pb2.ServerMessage(
|
740
|
-
utf_text=message.get("utfText", ""),
|
741
|
-
plain_text=message.get("plainText", ""),
|
742
|
-
html_text=message.get("htmlText", ""),
|
743
|
-
type=message.get("messageType", ""),
|
744
|
-
level=message_level_sanitized,
|
745
|
-
)
|
746
|
-
)
|
747
|
-
self._respond_result(result)
|
748
|
-
|
749
706
|
def _setup_resume(
|
750
707
|
self, run: "RunRecord"
|
751
708
|
) -> Optional["wandb_internal_pb2.ErrorInfo"]:
|
wandb/sdk/internal/tb_watcher.py
CHANGED
@@ -12,9 +12,9 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
12
12
|
|
13
13
|
import wandb
|
14
14
|
from wandb import util
|
15
|
+
from wandb.plot.viz import CustomChart
|
15
16
|
from wandb.sdk.interface.interface import GlobStr
|
16
17
|
from wandb.sdk.lib import filesystem
|
17
|
-
from wandb.sdk.lib.viz import CustomChart
|
18
18
|
|
19
19
|
from . import run as internal_run
|
20
20
|
|
wandb/sdk/lib/hashutil.py
CHANGED
@@ -1,19 +1,22 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import base64
|
2
4
|
import hashlib
|
3
5
|
import mmap
|
4
|
-
import os
|
5
6
|
import sys
|
6
|
-
from
|
7
|
-
from typing import NewType, Union
|
7
|
+
from typing import TYPE_CHECKING, NewType
|
8
8
|
|
9
9
|
from wandb.sdk.lib.paths import StrPath
|
10
10
|
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
import _hashlib # type: ignore[import-not-found]
|
13
|
+
|
11
14
|
ETag = NewType("ETag", str)
|
12
15
|
HexMD5 = NewType("HexMD5", str)
|
13
16
|
B64MD5 = NewType("B64MD5", str)
|
14
17
|
|
15
18
|
|
16
|
-
def _md5(data: bytes = b"") ->
|
19
|
+
def _md5(data: bytes = b"") -> _hashlib.HASH:
|
17
20
|
"""Allow FIPS-compliant md5 hash when supported."""
|
18
21
|
if sys.version_info >= (3, 9):
|
19
22
|
return hashlib.md5(data, usedforsecurity=False)
|
@@ -25,7 +28,7 @@ def md5_string(string: str) -> B64MD5:
|
|
25
28
|
return _b64_from_hasher(_md5(string.encode("utf-8")))
|
26
29
|
|
27
30
|
|
28
|
-
def _b64_from_hasher(hasher:
|
31
|
+
def _b64_from_hasher(hasher: _hashlib.HASH) -> B64MD5:
|
29
32
|
return B64MD5(base64.b64encode(hasher.digest()).decode("ascii"))
|
30
33
|
|
31
34
|
|
@@ -33,7 +36,7 @@ def b64_to_hex_id(string: B64MD5) -> HexMD5:
|
|
33
36
|
return HexMD5(base64.standard_b64decode(string).hex())
|
34
37
|
|
35
38
|
|
36
|
-
def hex_to_b64_id(encoded_string:
|
39
|
+
def hex_to_b64_id(encoded_string: str | bytes) -> B64MD5:
|
37
40
|
if isinstance(encoded_string, bytes):
|
38
41
|
encoded_string = encoded_string.decode("utf-8")
|
39
42
|
as_str = bytes.fromhex(encoded_string)
|
@@ -48,15 +51,34 @@ def md5_file_hex(*paths: StrPath) -> HexMD5:
|
|
48
51
|
return HexMD5(_md5_file_hasher(*paths).hexdigest())
|
49
52
|
|
50
53
|
|
51
|
-
|
54
|
+
_KB: int = 1_024
|
55
|
+
_CHUNKSIZE: int = 128 * _KB
|
56
|
+
"""Chunk size (in bytes) for iteratively reading from file, if needed."""
|
57
|
+
|
58
|
+
|
59
|
+
def _md5_file_hasher(*paths: StrPath) -> _hashlib.HASH:
|
52
60
|
md5_hash = _md5()
|
53
61
|
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
else:
|
62
|
+
# Note: We use str paths (instead of pathlib.Path objs) for minor perf improvements.
|
63
|
+
for path in sorted(map(str, paths)):
|
64
|
+
with open(path, "rb") as f:
|
65
|
+
try:
|
59
66
|
with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mview:
|
60
67
|
md5_hash.update(mview)
|
68
|
+
except OSError:
|
69
|
+
# This occurs if the mmap-ed file is on a different/mounted filesystem,
|
70
|
+
# so we'll fall back on a less performant implementation.
|
71
|
+
|
72
|
+
# Note: At the time of implementation, the walrus operator `:=`
|
73
|
+
# is avoided to maintain support for users on python 3.7.
|
74
|
+
# Consider revisiting once 3.7 support is no longer needed.
|
75
|
+
chunk = f.read(_CHUNKSIZE)
|
76
|
+
while chunk:
|
77
|
+
md5_hash.update(chunk)
|
78
|
+
chunk = f.read(_CHUNKSIZE)
|
79
|
+
except ValueError:
|
80
|
+
# This occurs when mmap-ing an empty file, which can be skipped.
|
81
|
+
# See: https://github.com/python/cpython/blob/986a4e1b6fcae7fe7a1d0a26aea446107dd58dd2/Modules/mmapmodule.c#L1589
|
82
|
+
pass
|
61
83
|
|
62
84
|
return md5_hash
|
@@ -0,0 +1,216 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import atexit
|
4
|
+
import os
|
5
|
+
from typing import Callable
|
6
|
+
|
7
|
+
from wandb.proto import wandb_internal_pb2 as pb
|
8
|
+
from wandb.proto import wandb_server_pb2 as spb
|
9
|
+
from wandb.proto import wandb_settings_pb2
|
10
|
+
from wandb.sdk import wandb_settings
|
11
|
+
from wandb.sdk.interface.interface import InterfaceBase
|
12
|
+
from wandb.sdk.interface.interface_sock import InterfaceSock
|
13
|
+
from wandb.sdk.lib import service_token
|
14
|
+
from wandb.sdk.lib.exit_hooks import ExitHooks
|
15
|
+
from wandb.sdk.lib.mailbox import Mailbox
|
16
|
+
from wandb.sdk.lib.sock_client import SockClient, SockClientTimeoutError
|
17
|
+
from wandb.sdk.service import service
|
18
|
+
|
19
|
+
|
20
|
+
class WandbServiceNotOwnedError(Exception):
|
21
|
+
"""Raised when the current process does not own the service process."""
|
22
|
+
|
23
|
+
|
24
|
+
class WandbServiceConnectionError(Exception):
|
25
|
+
"""Raised on failure to connect to the service process."""
|
26
|
+
|
27
|
+
|
28
|
+
class WandbAttachFailedError(Exception):
|
29
|
+
"""Raised if attaching to a run fails."""
|
30
|
+
|
31
|
+
|
32
|
+
def connect_to_service(
|
33
|
+
settings: wandb_settings.Settings,
|
34
|
+
) -> ServiceConnection:
|
35
|
+
"""Connects to the service process, starting one up if necessary."""
|
36
|
+
conn = _try_connect_to_existing_service()
|
37
|
+
if conn:
|
38
|
+
return conn
|
39
|
+
|
40
|
+
return _start_and_connect_service(settings)
|
41
|
+
|
42
|
+
|
43
|
+
def _try_connect_to_existing_service() -> ServiceConnection | None:
|
44
|
+
"""Attemps to connect to an existing service process."""
|
45
|
+
token = service_token.get_service_token()
|
46
|
+
if not token:
|
47
|
+
return None
|
48
|
+
|
49
|
+
# Only localhost sockets are supported below.
|
50
|
+
assert token.host == "localhost"
|
51
|
+
client = SockClient()
|
52
|
+
|
53
|
+
try:
|
54
|
+
# TODO: This may block indefinitely if the service is unhealthy.
|
55
|
+
client.connect(token.port)
|
56
|
+
|
57
|
+
except Exception as e:
|
58
|
+
raise WandbServiceConnectionError(
|
59
|
+
"Failed to connect to internal service."
|
60
|
+
) from e
|
61
|
+
|
62
|
+
return ServiceConnection(client=client, proc=None)
|
63
|
+
|
64
|
+
|
65
|
+
def _start_and_connect_service(
|
66
|
+
settings: wandb_settings.Settings,
|
67
|
+
) -> ServiceConnection:
|
68
|
+
"""Starts a service process and returns a connection to it.
|
69
|
+
|
70
|
+
An atexit hook is registered to tear down the service process and wait for
|
71
|
+
it to complete. The hook does not run in processes started using the
|
72
|
+
multiprocessing module.
|
73
|
+
"""
|
74
|
+
proc = service._Service(settings)
|
75
|
+
proc.start()
|
76
|
+
|
77
|
+
port = proc.sock_port
|
78
|
+
assert port
|
79
|
+
client = SockClient()
|
80
|
+
client.connect(port)
|
81
|
+
|
82
|
+
service_token.set_service_token(
|
83
|
+
parent_pid=os.getpid(),
|
84
|
+
transport="tcp",
|
85
|
+
host="localhost",
|
86
|
+
port=port,
|
87
|
+
)
|
88
|
+
|
89
|
+
hooks = ExitHooks()
|
90
|
+
hooks.hook()
|
91
|
+
|
92
|
+
def teardown_atexit():
|
93
|
+
conn.teardown(hooks.exit_code)
|
94
|
+
|
95
|
+
conn = ServiceConnection(
|
96
|
+
client=client,
|
97
|
+
proc=proc,
|
98
|
+
cleanup=lambda: atexit.unregister(teardown_atexit),
|
99
|
+
)
|
100
|
+
|
101
|
+
atexit.register(teardown_atexit)
|
102
|
+
|
103
|
+
return conn
|
104
|
+
|
105
|
+
|
106
|
+
class ServiceConnection:
|
107
|
+
"""A connection to the W&B internal service process."""
|
108
|
+
|
109
|
+
def __init__(
|
110
|
+
self,
|
111
|
+
client: SockClient,
|
112
|
+
proc: service._Service | None,
|
113
|
+
cleanup: Callable[[], None] | None = None,
|
114
|
+
):
|
115
|
+
"""Returns a new ServiceConnection.
|
116
|
+
|
117
|
+
Args:
|
118
|
+
client: A socket that's connected to the service.
|
119
|
+
proc: The service process if we own it, or None otherwise.
|
120
|
+
cleanup: A callback to run on teardown before doing anything.
|
121
|
+
"""
|
122
|
+
self._client = client
|
123
|
+
self._proc = proc
|
124
|
+
self._torn_down = False
|
125
|
+
self._cleanup = cleanup
|
126
|
+
|
127
|
+
def make_interface(self, mailbox: Mailbox) -> InterfaceBase:
|
128
|
+
"""Returns an interface for communicating with the service."""
|
129
|
+
return InterfaceSock(self._client, mailbox)
|
130
|
+
|
131
|
+
def send_record(self, record: pb.Record) -> None:
|
132
|
+
"""Sends data to the service."""
|
133
|
+
self._client.send_record_publish(record)
|
134
|
+
|
135
|
+
def inform_init(
|
136
|
+
self,
|
137
|
+
settings: wandb_settings_pb2.Settings,
|
138
|
+
run_id: str,
|
139
|
+
) -> None:
|
140
|
+
"""Sends an init request to the service."""
|
141
|
+
request = spb.ServerInformInitRequest()
|
142
|
+
request.settings.CopyFrom(settings)
|
143
|
+
request._info.stream_id = run_id
|
144
|
+
self._client.send(inform_init=request)
|
145
|
+
|
146
|
+
def inform_finish(self, run_id: str) -> None:
|
147
|
+
"""Sends an finish request to the service."""
|
148
|
+
request = spb.ServerInformFinishRequest()
|
149
|
+
request._info.stream_id = run_id
|
150
|
+
self._client.send(inform_finish=request)
|
151
|
+
|
152
|
+
def inform_attach(
|
153
|
+
self,
|
154
|
+
attach_id: str,
|
155
|
+
) -> wandb_settings_pb2.Settings:
|
156
|
+
"""Sends an attach request to the service.
|
157
|
+
|
158
|
+
Raises a WandbAttachFailedError if attaching is not possible.
|
159
|
+
"""
|
160
|
+
request = spb.ServerInformAttachRequest()
|
161
|
+
request._info.stream_id = attach_id
|
162
|
+
|
163
|
+
try:
|
164
|
+
response = self._client.send_and_recv(inform_attach=request)
|
165
|
+
return response.inform_attach_response.settings
|
166
|
+
except SockClientTimeoutError:
|
167
|
+
raise WandbAttachFailedError(
|
168
|
+
"Could not attach because the run does not belong to"
|
169
|
+
" the current service process, or because the service"
|
170
|
+
" process is busy (unlikely)."
|
171
|
+
)
|
172
|
+
|
173
|
+
def inform_start(
|
174
|
+
self,
|
175
|
+
settings: wandb_settings_pb2.Settings,
|
176
|
+
run_id: str,
|
177
|
+
) -> None:
|
178
|
+
"""Sends a start request to the service."""
|
179
|
+
request = spb.ServerInformStartRequest()
|
180
|
+
request.settings.CopyFrom(settings)
|
181
|
+
request._info.stream_id = run_id
|
182
|
+
self._client.send(inform_start=request)
|
183
|
+
|
184
|
+
def teardown(self, exit_code: int) -> int:
|
185
|
+
"""Shuts down the service process and returns its exit code.
|
186
|
+
|
187
|
+
This may only be called once.
|
188
|
+
|
189
|
+
Returns:
|
190
|
+
The exit code of the service process.
|
191
|
+
|
192
|
+
Raises:
|
193
|
+
WandbServiceNotOwnedError: If the current process did not start
|
194
|
+
the service process.
|
195
|
+
"""
|
196
|
+
if not self._proc:
|
197
|
+
raise WandbServiceNotOwnedError(
|
198
|
+
"Cannot tear down service started by different process",
|
199
|
+
)
|
200
|
+
|
201
|
+
assert not self._torn_down
|
202
|
+
self._torn_down = True
|
203
|
+
|
204
|
+
if self._cleanup:
|
205
|
+
self._cleanup()
|
206
|
+
|
207
|
+
# Clear the service token to prevent new connections from being made.
|
208
|
+
service_token.clear_service_token()
|
209
|
+
|
210
|
+
self._client.send(
|
211
|
+
inform_teardown=spb.ServerInformTeardownRequest(
|
212
|
+
exit_code=exit_code,
|
213
|
+
)
|
214
|
+
)
|
215
|
+
|
216
|
+
return self._proc.join()
|