wandb 0.18.0rc1__py3-none-win_amd64.whl → 0.18.2__py3-none-win_amd64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (119) hide show
  1. wandb/__init__.py +4 -4
  2. wandb/__init__.pyi +67 -12
  3. wandb/apis/internal.py +3 -0
  4. wandb/apis/public/api.py +128 -2
  5. wandb/apis/public/artifacts.py +11 -7
  6. wandb/apis/public/jobs.py +8 -0
  7. wandb/apis/public/runs.py +18 -5
  8. wandb/bin/wandb-core +0 -0
  9. wandb/cli/cli.py +0 -5
  10. wandb/data_types.py +9 -2019
  11. wandb/env.py +0 -5
  12. wandb/errors/__init__.py +11 -40
  13. wandb/errors/errors.py +37 -0
  14. wandb/errors/warnings.py +2 -0
  15. wandb/{sklearn → integration/sklearn}/calculate/calibration_curves.py +7 -7
  16. wandb/{sklearn → integration/sklearn}/calculate/class_proportions.py +1 -1
  17. wandb/{sklearn → integration/sklearn}/calculate/confusion_matrix.py +3 -2
  18. wandb/{sklearn → integration/sklearn}/calculate/elbow_curve.py +6 -6
  19. wandb/{sklearn → integration/sklearn}/calculate/learning_curve.py +2 -2
  20. wandb/{sklearn → integration/sklearn}/calculate/outlier_candidates.py +2 -2
  21. wandb/{sklearn → integration/sklearn}/calculate/residuals.py +8 -8
  22. wandb/{sklearn → integration/sklearn}/calculate/silhouette.py +2 -2
  23. wandb/{sklearn → integration/sklearn}/calculate/summary_metrics.py +2 -2
  24. wandb/{sklearn → integration/sklearn}/plot/classifier.py +5 -5
  25. wandb/{sklearn → integration/sklearn}/plot/clusterer.py +10 -6
  26. wandb/{sklearn → integration/sklearn}/plot/regressor.py +5 -5
  27. wandb/{sklearn → integration/sklearn}/plot/shared.py +3 -3
  28. wandb/{sklearn → integration/sklearn}/utils.py +8 -8
  29. wandb/integration/tensorboard/log.py +1 -1
  30. wandb/{wandb_torch.py → integration/torch/wandb_torch.py} +36 -32
  31. wandb/old/core.py +2 -80
  32. wandb/plot/bar.py +7 -4
  33. wandb/plot/confusion_matrix.py +5 -4
  34. wandb/plot/histogram.py +7 -4
  35. wandb/plot/line.py +7 -4
  36. wandb/proto/v3/wandb_base_pb2.py +2 -1
  37. wandb/proto/v3/wandb_internal_pb2.py +2 -1
  38. wandb/proto/v3/wandb_server_pb2.py +2 -1
  39. wandb/proto/v3/wandb_settings_pb2.py +3 -2
  40. wandb/proto/v3/wandb_telemetry_pb2.py +2 -1
  41. wandb/proto/v4/wandb_base_pb2.py +2 -1
  42. wandb/proto/v4/wandb_internal_pb2.py +2 -1
  43. wandb/proto/v4/wandb_server_pb2.py +2 -1
  44. wandb/proto/v4/wandb_settings_pb2.py +3 -2
  45. wandb/proto/v4/wandb_telemetry_pb2.py +2 -1
  46. wandb/proto/v5/wandb_base_pb2.py +3 -2
  47. wandb/proto/v5/wandb_internal_pb2.py +3 -2
  48. wandb/proto/v5/wandb_server_pb2.py +3 -2
  49. wandb/proto/v5/wandb_settings_pb2.py +4 -3
  50. wandb/proto/v5/wandb_telemetry_pb2.py +3 -2
  51. wandb/sdk/artifacts/_validators.py +48 -3
  52. wandb/sdk/artifacts/artifact.py +157 -183
  53. wandb/sdk/artifacts/artifact_file_cache.py +13 -11
  54. wandb/sdk/artifacts/artifact_instance_cache.py +4 -2
  55. wandb/sdk/artifacts/artifact_manifest.py +13 -11
  56. wandb/sdk/artifacts/artifact_manifest_entry.py +24 -22
  57. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +9 -7
  58. wandb/sdk/artifacts/artifact_saver.py +27 -25
  59. wandb/sdk/artifacts/exceptions.py +26 -25
  60. wandb/sdk/artifacts/storage_handler.py +11 -9
  61. wandb/sdk/artifacts/storage_handlers/azure_handler.py +16 -14
  62. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +15 -13
  63. wandb/sdk/artifacts/storage_handlers/http_handler.py +15 -14
  64. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +10 -8
  65. wandb/sdk/artifacts/storage_handlers/multi_handler.py +14 -12
  66. wandb/sdk/artifacts/storage_handlers/s3_handler.py +19 -19
  67. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +10 -8
  68. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +12 -10
  69. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +9 -7
  70. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +31 -29
  71. wandb/sdk/artifacts/storage_policy.py +20 -20
  72. wandb/sdk/backend/backend.py +8 -26
  73. wandb/sdk/data_types/audio.py +165 -0
  74. wandb/sdk/data_types/base_types/wb_value.py +1 -3
  75. wandb/sdk/data_types/bokeh.py +70 -0
  76. wandb/sdk/data_types/graph.py +405 -0
  77. wandb/sdk/data_types/image.py +156 -0
  78. wandb/sdk/data_types/table.py +1204 -0
  79. wandb/sdk/data_types/trace_tree.py +2 -2
  80. wandb/sdk/data_types/utils.py +49 -0
  81. wandb/sdk/data_types/video.py +2 -2
  82. wandb/sdk/interface/interface.py +0 -24
  83. wandb/sdk/interface/interface_shared.py +0 -12
  84. wandb/sdk/internal/handler.py +0 -10
  85. wandb/sdk/internal/internal_api.py +71 -0
  86. wandb/sdk/internal/sender.py +0 -43
  87. wandb/sdk/internal/tb_watcher.py +1 -1
  88. wandb/sdk/lib/_settings_toposort_generated.py +1 -0
  89. wandb/sdk/lib/hashutil.py +34 -12
  90. wandb/sdk/lib/service_connection.py +216 -0
  91. wandb/sdk/lib/service_token.py +94 -0
  92. wandb/sdk/lib/sock_client.py +7 -3
  93. wandb/sdk/service/server.py +2 -5
  94. wandb/sdk/service/service.py +2 -31
  95. wandb/sdk/service/streams.py +0 -7
  96. wandb/sdk/wandb_init.py +42 -25
  97. wandb/sdk/wandb_run.py +18 -159
  98. wandb/sdk/wandb_settings.py +2 -0
  99. wandb/sdk/wandb_setup.py +25 -16
  100. wandb/sdk/wandb_sync.py +9 -3
  101. wandb/sdk/wandb_watch.py +31 -15
  102. wandb/sklearn.py +35 -0
  103. wandb/util.py +14 -3
  104. {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/METADATA +6 -5
  105. {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/RECORD +114 -110
  106. wandb/sdk/internal/update.py +0 -113
  107. wandb/sdk/lib/console.py +0 -39
  108. wandb/sdk/service/service_base.py +0 -50
  109. wandb/sdk/service/service_sock.py +0 -70
  110. wandb/sdk/wandb_manager.py +0 -232
  111. /wandb/{sklearn → integration/sklearn}/__init__.py +0 -0
  112. /wandb/{sklearn → integration/sklearn}/calculate/__init__.py +0 -0
  113. /wandb/{sklearn → integration/sklearn}/calculate/decision_boundaries.py +0 -0
  114. /wandb/{sklearn → integration/sklearn}/calculate/feature_importances.py +0 -0
  115. /wandb/{sklearn → integration/sklearn}/plot/__init__.py +0 -0
  116. /wandb/{sdk/lib → plot}/viz.py +0 -0
  117. {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/WHEEL +0 -0
  118. {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/entry_points.txt +0 -0
  119. {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/licenses/LICENSE +0 -0
@@ -14,9 +14,9 @@ from enum import Enum
14
14
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
15
15
 
16
16
  import wandb
17
- import wandb.data_types
18
17
  from wandb.sdk.data_types import _dtypes
19
18
  from wandb.sdk.data_types.base_types.media import Media
19
+ from wandb.sdk.data_types.utils import _json_helper
20
20
 
21
21
  if TYPE_CHECKING: # pragma: no cover
22
22
  from wandb.sdk.artifacts.artifact import Artifact
@@ -142,7 +142,7 @@ def _fallback_serialize(obj: Any) -> str:
142
142
  def _safe_serialize(obj: dict) -> str:
143
143
  try:
144
144
  return json.dumps(
145
- wandb.data_types._json_helper(obj, None),
145
+ _json_helper(obj, None),
146
146
  skipkeys=True,
147
147
  default=_fallback_serialize,
148
148
  )
@@ -1,6 +1,8 @@
1
+ import datetime
1
2
  import logging
2
3
  import os
3
4
  import re
5
+ from decimal import Decimal
4
6
  from typing import TYPE_CHECKING, Optional, Sequence, Union, cast
5
7
 
6
8
  import wandb
@@ -178,3 +180,50 @@ def _prune_max_seq(seq: Sequence["BatchableMedia"]) -> Sequence["BatchableMedia"
178
180
  )
179
181
  items = seq[: seq[0].MAX_ITEMS]
180
182
  return items
183
+
184
+
185
+ def _json_helper(val, artifact):
186
+ if isinstance(val, WBValue):
187
+ return val.to_json(artifact)
188
+ elif val.__class__ is dict:
189
+ res = {}
190
+ for key in val:
191
+ res[key] = _json_helper(val[key], artifact)
192
+ return res
193
+
194
+ if hasattr(val, "tolist"):
195
+ py_val = val.tolist()
196
+ if val.__class__.__name__ == "datetime64" and isinstance(py_val, int):
197
+ # when numpy datetime64 .tolist() returns an int, it is nanoseconds.
198
+ # need to convert to milliseconds
199
+ return _json_helper(py_val / int(1e6), artifact)
200
+ return _json_helper(py_val, artifact)
201
+ elif hasattr(val, "item"):
202
+ return _json_helper(val.item(), artifact)
203
+
204
+ if isinstance(val, datetime.datetime):
205
+ if val.tzinfo is None:
206
+ val = datetime.datetime(
207
+ val.year,
208
+ val.month,
209
+ val.day,
210
+ val.hour,
211
+ val.minute,
212
+ val.second,
213
+ val.microsecond,
214
+ tzinfo=datetime.timezone.utc,
215
+ )
216
+ return int(val.timestamp() * 1000)
217
+ elif isinstance(val, datetime.date):
218
+ return int(
219
+ datetime.datetime(
220
+ val.year, val.month, val.day, tzinfo=datetime.timezone.utc
221
+ ).timestamp()
222
+ * 1000
223
+ )
224
+ elif isinstance(val, (list, tuple)):
225
+ return [_json_helper(i, artifact) for i in val]
226
+ elif isinstance(val, Decimal):
227
+ return float(val)
228
+ else:
229
+ return util.json_friendly(val)[0]
@@ -34,7 +34,7 @@ def write_gif_with_image_io(
34
34
  ) -> None:
35
35
  imageio = util.get_module(
36
36
  "imageio",
37
- required='wandb.Video requires imageio when passing raw data. Install with "pip install imageio"',
37
+ required='wandb.Video requires imageio when passing raw data. Install with "pip install wandb[media]"',
38
38
  )
39
39
 
40
40
  writer = imageio.save(filename, fps=clip.fps, quantizer=0, palettesize=256, loop=0)
@@ -130,7 +130,7 @@ class Video(BatchableMedia):
130
130
  def encode(self) -> None:
131
131
  mpy = util.get_module(
132
132
  "moviepy.editor",
133
- required='wandb.Video requires moviepy and imageio when passing raw data. Install with "pip install moviepy imageio"',
133
+ required='wandb.Video requires moviepy when passing raw data. Install with "pip install wandb[media]"',
134
134
  )
135
135
  tensor = self._prepare_video(self.data)
136
136
  _, self._height, self._width, self._channels = tensor.shape # type: ignore
@@ -891,20 +891,6 @@ class InterfaceBase:
891
891
  def _deliver_attach(self, status: pb.AttachRequest) -> MailboxHandle:
892
892
  raise NotImplementedError
893
893
 
894
- def deliver_check_version(
895
- self, current_version: Optional[str] = None
896
- ) -> MailboxHandle:
897
- check_version = pb.CheckVersionRequest()
898
- if current_version:
899
- check_version.current_version = current_version
900
- return self._deliver_check_version(check_version)
901
-
902
- @abstractmethod
903
- def _deliver_check_version(
904
- self, check_version: pb.CheckVersionRequest
905
- ) -> MailboxHandle:
906
- raise NotImplementedError
907
-
908
894
  def deliver_stop_status(self) -> MailboxHandle:
909
895
  status = pb.StopStatusRequest()
910
896
  return self._deliver_stop_status(status)
@@ -965,16 +951,6 @@ class InterfaceBase:
965
951
  def _deliver_poll_exit(self, poll_exit: pb.PollExitRequest) -> MailboxHandle:
966
952
  raise NotImplementedError
967
953
 
968
- def deliver_request_server_info(self) -> MailboxHandle:
969
- server_info = pb.ServerInfoRequest()
970
- return self._deliver_request_server_info(server_info)
971
-
972
- @abstractmethod
973
- def _deliver_request_server_info(
974
- self, server_info: pb.ServerInfoRequest
975
- ) -> MailboxHandle:
976
- raise NotImplementedError
977
-
978
954
  def deliver_request_sampled_history(self) -> MailboxHandle:
979
955
  sampled_history = pb.SampledHistoryRequest()
980
956
  return self._deliver_request_sampled_history(sampled_history)
@@ -490,12 +490,6 @@ class InterfaceShared(InterfaceBase):
490
490
  record = self._make_request(attach=attach)
491
491
  return self._deliver_record(record)
492
492
 
493
- def _deliver_check_version(
494
- self, check_version: pb.CheckVersionRequest
495
- ) -> MailboxHandle:
496
- record = self._make_request(check_version=check_version)
497
- return self._deliver_record(record)
498
-
499
493
  def _deliver_network_status(
500
494
  self, network_status: pb.NetworkStatusRequest
501
495
  ) -> MailboxHandle:
@@ -508,12 +502,6 @@ class InterfaceShared(InterfaceBase):
508
502
  record = self._make_request(internal_messages=internal_message)
509
503
  return self._deliver_record(record)
510
504
 
511
- def _deliver_request_server_info(
512
- self, server_info: pb.ServerInfoRequest
513
- ) -> MailboxHandle:
514
- record = self._make_request(server_info=server_info)
515
- return self._deliver_record(record)
516
-
517
505
  def _deliver_request_sampled_history(
518
506
  self, sampled_history: pb.SampledHistoryRequest
519
507
  ) -> MailboxHandle:
@@ -659,13 +659,6 @@ class HandleManager:
659
659
  def handle_footer(self, record: Record) -> None:
660
660
  self._dispatch_record(record)
661
661
 
662
- def handle_request_check_version(self, record: Record) -> None:
663
- if self._settings._offline:
664
- result = proto_util._result_from_record(record)
665
- self._respond_result(result)
666
- else:
667
- self._dispatch_record(record)
668
-
669
662
  def handle_request_attach(self, record: Record) -> None:
670
663
  result = proto_util._result_from_record(record)
671
664
  attach_id = record.request.attach.attach_id
@@ -862,9 +855,6 @@ class HandleManager:
862
855
  result.response.sampled_history_response.item.append(item)
863
856
  self._respond_result(result)
864
857
 
865
- def handle_request_server_info(self, record: Record) -> None:
866
- self._dispatch_record(record, always_send=True)
867
-
868
858
  def handle_request_keepalive(self, record: Record) -> None:
869
859
  """Handle a keepalive request.
870
860
 
@@ -53,6 +53,8 @@ from .progress import Progress
53
53
 
54
54
  logger = logging.getLogger(__name__)
55
55
 
56
+ LAUNCH_DEFAULT_PROJECT = "model-registry"
57
+
56
58
  if TYPE_CHECKING:
57
59
  if sys.version_info >= (3, 8):
58
60
  from typing import Literal, TypedDict
@@ -674,6 +676,11 @@ class Api:
674
676
  self.server_create_run_queue_supports_priority,
675
677
  )
676
678
 
679
+ @normalize_exceptions
680
+ def upsert_run_queue_introspection(self) -> bool:
681
+ _, _, mutations = self.server_info_introspection()
682
+ return "upsertRunQueue" in mutations
683
+
677
684
  @normalize_exceptions
678
685
  def push_to_run_queue_introspection(self) -> Tuple[bool, bool]:
679
686
  query_string = """
@@ -1580,6 +1587,70 @@ class Api:
1580
1587
  ]
1581
1588
  return result
1582
1589
 
1590
+ @normalize_exceptions
1591
+ def upsert_run_queue(
1592
+ self,
1593
+ queue_name: str,
1594
+ entity: str,
1595
+ resource_type: str,
1596
+ resource_config: dict,
1597
+ project: str = LAUNCH_DEFAULT_PROJECT,
1598
+ prioritization_mode: Optional[str] = None,
1599
+ template_variables: Optional[dict] = None,
1600
+ external_links: Optional[dict] = None,
1601
+ ) -> Optional[Dict[str, Any]]:
1602
+ if not self.upsert_run_queue_introspection():
1603
+ raise UnsupportedError(
1604
+ "upserting run queues is not supported by this version of "
1605
+ "wandb server. Consider updating to the latest version."
1606
+ )
1607
+ query = gql(
1608
+ """
1609
+ mutation upsertRunQueue(
1610
+ $entityName: String!
1611
+ $projectName: String!
1612
+ $queueName: String!
1613
+ $resourceType: String!
1614
+ $resourceConfig: JSONString!
1615
+ $templateVariables: JSONString
1616
+ $prioritizationMode: RunQueuePrioritizationMode
1617
+ $externalLinks: JSONString
1618
+ $clientMutationId: String
1619
+ ) {
1620
+ upsertRunQueue(
1621
+ input: {
1622
+ entityName: $entityName
1623
+ projectName: $projectName
1624
+ queueName: $queueName
1625
+ resourceType: $resourceType
1626
+ resourceConfig: $resourceConfig
1627
+ templateVariables: $templateVariables
1628
+ prioritizationMode: $prioritizationMode
1629
+ externalLinks: $externalLinks
1630
+ clientMutationId: $clientMutationId
1631
+ }
1632
+ ) {
1633
+ success
1634
+ configSchemaValidationErrors
1635
+ }
1636
+ }
1637
+ """
1638
+ )
1639
+ variable_values = {
1640
+ "entityName": entity,
1641
+ "projectName": project,
1642
+ "queueName": queue_name,
1643
+ "resourceType": resource_type,
1644
+ "resourceConfig": json.dumps(resource_config),
1645
+ "templateVariables": (
1646
+ json.dumps(template_variables) if template_variables else None
1647
+ ),
1648
+ "prioritizationMode": prioritization_mode,
1649
+ "externalLinks": json.dumps(external_links) if external_links else None,
1650
+ }
1651
+ result: Dict[str, Any] = self.gql(query, variable_values)
1652
+ return result["upsertRunQueue"]
1653
+
1583
1654
  @normalize_exceptions
1584
1655
  def push_to_run_queue_by_name(
1585
1656
  self,
@@ -42,7 +42,6 @@ from wandb.sdk.internal import (
42
42
  file_stream,
43
43
  internal_api,
44
44
  sender_config,
45
- update,
46
45
  )
47
46
  from wandb.sdk.internal.file_pusher import FilePusher
48
47
  from wandb.sdk.internal.job_builder import JobBuilder
@@ -51,7 +50,6 @@ from wandb.sdk.lib import (
51
50
  config_util,
52
51
  filenames,
53
52
  filesystem,
54
- printer,
55
53
  proto_util,
56
54
  redirect,
57
55
  telemetry,
@@ -483,25 +481,6 @@ class SendManager:
483
481
  # make sure that we always update writer for every sended read request
484
482
  self._maybe_report_status(always=True)
485
483
 
486
- def send_request_check_version(self, record: "Record") -> None:
487
- assert record.control.req_resp or record.control.mailbox_slot
488
- result = proto_util._result_from_record(record)
489
- current_version = (
490
- record.request.check_version.current_version or wandb.__version__
491
- )
492
- messages = update.check_available(current_version)
493
- if messages:
494
- upgrade_message = messages.get("upgrade_message")
495
- if upgrade_message:
496
- result.response.check_version_response.upgrade_message = upgrade_message
497
- yank_message = messages.get("yank_message")
498
- if yank_message:
499
- result.response.check_version_response.yank_message = yank_message
500
- delete_message = messages.get("delete_message")
501
- if delete_message:
502
- result.response.check_version_response.delete_message = delete_message
503
- self._respond_result(result)
504
-
505
484
  def send_request_stop_status(self, record: "Record") -> None:
506
485
  result = proto_util._result_from_record(record)
507
486
  status_resp = result.response.stop_status_response
@@ -724,28 +703,6 @@ class SendManager:
724
703
 
725
704
  self._respond_result(result)
726
705
 
727
- def send_request_server_info(self, record: "Record") -> None:
728
- assert record.control.req_resp or record.control.mailbox_slot
729
- result = proto_util._result_from_record(record)
730
-
731
- result.response.server_info_response.local_info.CopyFrom(self.get_local_info())
732
- for message in self._server_messages:
733
- # guard against the case the message level returns malformed from server
734
- message_level = str(message.get("messageLevel"))
735
- message_level_sanitized = int(
736
- printer.INFO if not message_level.isdigit() else message_level
737
- )
738
- result.response.server_info_response.server_messages.item.append(
739
- wandb_internal_pb2.ServerMessage(
740
- utf_text=message.get("utfText", ""),
741
- plain_text=message.get("plainText", ""),
742
- html_text=message.get("htmlText", ""),
743
- type=message.get("messageType", ""),
744
- level=message_level_sanitized,
745
- )
746
- )
747
- self._respond_result(result)
748
-
749
706
  def _setup_resume(
750
707
  self, run: "RunRecord"
751
708
  ) -> Optional["wandb_internal_pb2.ErrorInfo"]:
@@ -12,9 +12,9 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
12
12
 
13
13
  import wandb
14
14
  from wandb import util
15
+ from wandb.plot.viz import CustomChart
15
16
  from wandb.sdk.interface.interface import GlobStr
16
17
  from wandb.sdk.lib import filesystem
17
- from wandb.sdk.lib.viz import CustomChart
18
18
 
19
19
  from . import run as internal_run
20
20
 
@@ -27,6 +27,7 @@ _Setting = Literal[
27
27
  "_executable",
28
28
  "_extra_http_headers",
29
29
  "_file_stream_max_bytes",
30
+ "_file_stream_transmit_interval",
30
31
  "_file_stream_retry_max",
31
32
  "_file_stream_retry_wait_min_seconds",
32
33
  "_file_stream_retry_wait_max_seconds",
wandb/sdk/lib/hashutil.py CHANGED
@@ -1,19 +1,22 @@
1
+ from __future__ import annotations
2
+
1
3
  import base64
2
4
  import hashlib
3
5
  import mmap
4
- import os
5
6
  import sys
6
- from pathlib import Path
7
- from typing import NewType, Union
7
+ from typing import TYPE_CHECKING, NewType
8
8
 
9
9
  from wandb.sdk.lib.paths import StrPath
10
10
 
11
+ if TYPE_CHECKING:
12
+ import _hashlib # type: ignore[import-not-found]
13
+
11
14
  ETag = NewType("ETag", str)
12
15
  HexMD5 = NewType("HexMD5", str)
13
16
  B64MD5 = NewType("B64MD5", str)
14
17
 
15
18
 
16
- def _md5(data: bytes = b"") -> "hashlib._Hash":
19
+ def _md5(data: bytes = b"") -> _hashlib.HASH:
17
20
  """Allow FIPS-compliant md5 hash when supported."""
18
21
  if sys.version_info >= (3, 9):
19
22
  return hashlib.md5(data, usedforsecurity=False)
@@ -25,7 +28,7 @@ def md5_string(string: str) -> B64MD5:
25
28
  return _b64_from_hasher(_md5(string.encode("utf-8")))
26
29
 
27
30
 
28
- def _b64_from_hasher(hasher: "hashlib._Hash") -> B64MD5:
31
+ def _b64_from_hasher(hasher: _hashlib.HASH) -> B64MD5:
29
32
  return B64MD5(base64.b64encode(hasher.digest()).decode("ascii"))
30
33
 
31
34
 
@@ -33,7 +36,7 @@ def b64_to_hex_id(string: B64MD5) -> HexMD5:
33
36
  return HexMD5(base64.standard_b64decode(string).hex())
34
37
 
35
38
 
36
- def hex_to_b64_id(encoded_string: Union[str, bytes]) -> B64MD5:
39
+ def hex_to_b64_id(encoded_string: str | bytes) -> B64MD5:
37
40
  if isinstance(encoded_string, bytes):
38
41
  encoded_string = encoded_string.decode("utf-8")
39
42
  as_str = bytes.fromhex(encoded_string)
@@ -48,15 +51,34 @@ def md5_file_hex(*paths: StrPath) -> HexMD5:
48
51
  return HexMD5(_md5_file_hasher(*paths).hexdigest())
49
52
 
50
53
 
51
- def _md5_file_hasher(*paths: StrPath) -> "hashlib._Hash":
54
+ _KB: int = 1_024
55
+ _CHUNKSIZE: int = 128 * _KB
56
+ """Chunk size (in bytes) for iteratively reading from file, if needed."""
57
+
58
+
59
+ def _md5_file_hasher(*paths: StrPath) -> _hashlib.HASH:
52
60
  md5_hash = _md5()
53
61
 
54
- for path in sorted(Path(p) for p in paths):
55
- with path.open("rb") as f:
56
- if os.stat(f.fileno()).st_size <= 1024 * 1024:
57
- md5_hash.update(f.read())
58
- else:
62
+ # Note: We use str paths (instead of pathlib.Path objs) for minor perf improvements.
63
+ for path in sorted(map(str, paths)):
64
+ with open(path, "rb") as f:
65
+ try:
59
66
  with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mview:
60
67
  md5_hash.update(mview)
68
+ except OSError:
69
+ # This occurs if the mmap-ed file is on a different/mounted filesystem,
70
+ # so we'll fall back on a less performant implementation.
71
+
72
+ # Note: At the time of implementation, the walrus operator `:=`
73
+ # is avoided to maintain support for users on python 3.7.
74
+ # Consider revisiting once 3.7 support is no longer needed.
75
+ chunk = f.read(_CHUNKSIZE)
76
+ while chunk:
77
+ md5_hash.update(chunk)
78
+ chunk = f.read(_CHUNKSIZE)
79
+ except ValueError:
80
+ # This occurs when mmap-ing an empty file, which can be skipped.
81
+ # See: https://github.com/python/cpython/blob/986a4e1b6fcae7fe7a1d0a26aea446107dd58dd2/Modules/mmapmodule.c#L1589
82
+ pass
61
83
 
62
84
  return md5_hash
@@ -0,0 +1,216 @@
1
+ from __future__ import annotations
2
+
3
+ import atexit
4
+ import os
5
+ from typing import Callable
6
+
7
+ from wandb.proto import wandb_internal_pb2 as pb
8
+ from wandb.proto import wandb_server_pb2 as spb
9
+ from wandb.proto import wandb_settings_pb2
10
+ from wandb.sdk import wandb_settings
11
+ from wandb.sdk.interface.interface import InterfaceBase
12
+ from wandb.sdk.interface.interface_sock import InterfaceSock
13
+ from wandb.sdk.lib import service_token
14
+ from wandb.sdk.lib.exit_hooks import ExitHooks
15
+ from wandb.sdk.lib.mailbox import Mailbox
16
+ from wandb.sdk.lib.sock_client import SockClient, SockClientTimeoutError
17
+ from wandb.sdk.service import service
18
+
19
+
20
+ class WandbServiceNotOwnedError(Exception):
21
+ """Raised when the current process does not own the service process."""
22
+
23
+
24
+ class WandbServiceConnectionError(Exception):
25
+ """Raised on failure to connect to the service process."""
26
+
27
+
28
+ class WandbAttachFailedError(Exception):
29
+ """Raised if attaching to a run fails."""
30
+
31
+
32
+ def connect_to_service(
33
+ settings: wandb_settings.Settings,
34
+ ) -> ServiceConnection:
35
+ """Connects to the service process, starting one up if necessary."""
36
+ conn = _try_connect_to_existing_service()
37
+ if conn:
38
+ return conn
39
+
40
+ return _start_and_connect_service(settings)
41
+
42
+
43
+ def _try_connect_to_existing_service() -> ServiceConnection | None:
44
+ """Attemps to connect to an existing service process."""
45
+ token = service_token.get_service_token()
46
+ if not token:
47
+ return None
48
+
49
+ # Only localhost sockets are supported below.
50
+ assert token.host == "localhost"
51
+ client = SockClient()
52
+
53
+ try:
54
+ # TODO: This may block indefinitely if the service is unhealthy.
55
+ client.connect(token.port)
56
+
57
+ except Exception as e:
58
+ raise WandbServiceConnectionError(
59
+ "Failed to connect to internal service."
60
+ ) from e
61
+
62
+ return ServiceConnection(client=client, proc=None)
63
+
64
+
65
+ def _start_and_connect_service(
66
+ settings: wandb_settings.Settings,
67
+ ) -> ServiceConnection:
68
+ """Starts a service process and returns a connection to it.
69
+
70
+ An atexit hook is registered to tear down the service process and wait for
71
+ it to complete. The hook does not run in processes started using the
72
+ multiprocessing module.
73
+ """
74
+ proc = service._Service(settings)
75
+ proc.start()
76
+
77
+ port = proc.sock_port
78
+ assert port
79
+ client = SockClient()
80
+ client.connect(port)
81
+
82
+ service_token.set_service_token(
83
+ parent_pid=os.getpid(),
84
+ transport="tcp",
85
+ host="localhost",
86
+ port=port,
87
+ )
88
+
89
+ hooks = ExitHooks()
90
+ hooks.hook()
91
+
92
+ def teardown_atexit():
93
+ conn.teardown(hooks.exit_code)
94
+
95
+ conn = ServiceConnection(
96
+ client=client,
97
+ proc=proc,
98
+ cleanup=lambda: atexit.unregister(teardown_atexit),
99
+ )
100
+
101
+ atexit.register(teardown_atexit)
102
+
103
+ return conn
104
+
105
+
106
+ class ServiceConnection:
107
+ """A connection to the W&B internal service process."""
108
+
109
+ def __init__(
110
+ self,
111
+ client: SockClient,
112
+ proc: service._Service | None,
113
+ cleanup: Callable[[], None] | None = None,
114
+ ):
115
+ """Returns a new ServiceConnection.
116
+
117
+ Args:
118
+ client: A socket that's connected to the service.
119
+ proc: The service process if we own it, or None otherwise.
120
+ cleanup: A callback to run on teardown before doing anything.
121
+ """
122
+ self._client = client
123
+ self._proc = proc
124
+ self._torn_down = False
125
+ self._cleanup = cleanup
126
+
127
+ def make_interface(self, mailbox: Mailbox) -> InterfaceBase:
128
+ """Returns an interface for communicating with the service."""
129
+ return InterfaceSock(self._client, mailbox)
130
+
131
+ def send_record(self, record: pb.Record) -> None:
132
+ """Sends data to the service."""
133
+ self._client.send_record_publish(record)
134
+
135
+ def inform_init(
136
+ self,
137
+ settings: wandb_settings_pb2.Settings,
138
+ run_id: str,
139
+ ) -> None:
140
+ """Sends an init request to the service."""
141
+ request = spb.ServerInformInitRequest()
142
+ request.settings.CopyFrom(settings)
143
+ request._info.stream_id = run_id
144
+ self._client.send(inform_init=request)
145
+
146
+ def inform_finish(self, run_id: str) -> None:
147
+ """Sends an finish request to the service."""
148
+ request = spb.ServerInformFinishRequest()
149
+ request._info.stream_id = run_id
150
+ self._client.send(inform_finish=request)
151
+
152
+ def inform_attach(
153
+ self,
154
+ attach_id: str,
155
+ ) -> wandb_settings_pb2.Settings:
156
+ """Sends an attach request to the service.
157
+
158
+ Raises a WandbAttachFailedError if attaching is not possible.
159
+ """
160
+ request = spb.ServerInformAttachRequest()
161
+ request._info.stream_id = attach_id
162
+
163
+ try:
164
+ response = self._client.send_and_recv(inform_attach=request)
165
+ return response.inform_attach_response.settings
166
+ except SockClientTimeoutError:
167
+ raise WandbAttachFailedError(
168
+ "Could not attach because the run does not belong to"
169
+ " the current service process, or because the service"
170
+ " process is busy (unlikely)."
171
+ )
172
+
173
+ def inform_start(
174
+ self,
175
+ settings: wandb_settings_pb2.Settings,
176
+ run_id: str,
177
+ ) -> None:
178
+ """Sends a start request to the service."""
179
+ request = spb.ServerInformStartRequest()
180
+ request.settings.CopyFrom(settings)
181
+ request._info.stream_id = run_id
182
+ self._client.send(inform_start=request)
183
+
184
+ def teardown(self, exit_code: int) -> int:
185
+ """Shuts down the service process and returns its exit code.
186
+
187
+ This may only be called once.
188
+
189
+ Returns:
190
+ The exit code of the service process.
191
+
192
+ Raises:
193
+ WandbServiceNotOwnedError: If the current process did not start
194
+ the service process.
195
+ """
196
+ if not self._proc:
197
+ raise WandbServiceNotOwnedError(
198
+ "Cannot tear down service started by different process",
199
+ )
200
+
201
+ assert not self._torn_down
202
+ self._torn_down = True
203
+
204
+ if self._cleanup:
205
+ self._cleanup()
206
+
207
+ # Clear the service token to prevent new connections from being made.
208
+ service_token.clear_service_token()
209
+
210
+ self._client.send(
211
+ inform_teardown=spb.ServerInformTeardownRequest(
212
+ exit_code=exit_code,
213
+ )
214
+ )
215
+
216
+ return self._proc.join()