wandb 0.22.0__py3-none-win_amd64.whl → 0.22.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +3 -3
  3. wandb/_pydantic/__init__.py +12 -11
  4. wandb/_pydantic/base.py +49 -19
  5. wandb/apis/__init__.py +2 -0
  6. wandb/apis/attrs.py +2 -0
  7. wandb/apis/importers/internals/internal.py +16 -23
  8. wandb/apis/internal.py +2 -0
  9. wandb/apis/normalize.py +2 -0
  10. wandb/apis/public/__init__.py +3 -2
  11. wandb/apis/public/api.py +215 -164
  12. wandb/apis/public/artifacts.py +23 -20
  13. wandb/apis/public/const.py +2 -0
  14. wandb/apis/public/files.py +33 -24
  15. wandb/apis/public/history.py +2 -0
  16. wandb/apis/public/jobs.py +20 -18
  17. wandb/apis/public/projects.py +4 -2
  18. wandb/apis/public/query_generator.py +3 -0
  19. wandb/apis/public/registries/__init__.py +7 -0
  20. wandb/apis/public/registries/_freezable_list.py +9 -12
  21. wandb/apis/public/registries/registries_search.py +8 -6
  22. wandb/apis/public/registries/registry.py +22 -17
  23. wandb/apis/public/reports.py +2 -0
  24. wandb/apis/public/runs.py +261 -57
  25. wandb/apis/public/sweeps.py +10 -9
  26. wandb/apis/public/teams.py +2 -0
  27. wandb/apis/public/users.py +2 -0
  28. wandb/apis/public/utils.py +16 -15
  29. wandb/automations/_generated/__init__.py +54 -127
  30. wandb/automations/_generated/create_generic_webhook_integration.py +1 -7
  31. wandb/automations/_generated/fragments.py +26 -91
  32. wandb/bin/gpu_stats.exe +0 -0
  33. wandb/bin/wandb-core +0 -0
  34. wandb/cli/beta_sync.py +9 -11
  35. wandb/errors/errors.py +3 -3
  36. wandb/proto/v3/wandb_sync_pb2.py +19 -6
  37. wandb/proto/v4/wandb_sync_pb2.py +10 -6
  38. wandb/proto/v5/wandb_sync_pb2.py +10 -6
  39. wandb/proto/v6/wandb_sync_pb2.py +10 -6
  40. wandb/sdk/artifacts/_factories.py +7 -2
  41. wandb/sdk/artifacts/_generated/__init__.py +112 -412
  42. wandb/sdk/artifacts/_generated/fragments.py +65 -0
  43. wandb/sdk/artifacts/_generated/operations.py +52 -22
  44. wandb/sdk/artifacts/_generated/run_input_artifacts.py +3 -23
  45. wandb/sdk/artifacts/_generated/run_output_artifacts.py +3 -23
  46. wandb/sdk/artifacts/_generated/type_info.py +19 -0
  47. wandb/sdk/artifacts/_gqlutils.py +47 -0
  48. wandb/sdk/artifacts/_models/__init__.py +4 -0
  49. wandb/sdk/artifacts/_models/base_model.py +20 -0
  50. wandb/sdk/artifacts/_validators.py +40 -12
  51. wandb/sdk/artifacts/artifact.py +69 -88
  52. wandb/sdk/artifacts/artifact_file_cache.py +6 -1
  53. wandb/sdk/artifacts/artifact_manifest_entry.py +61 -2
  54. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +10 -0
  55. wandb/sdk/data_types/bokeh.py +5 -1
  56. wandb/sdk/data_types/image.py +17 -6
  57. wandb/sdk/interface/interface.py +31 -4
  58. wandb/sdk/interface/interface_queue.py +10 -0
  59. wandb/sdk/interface/interface_shared.py +0 -7
  60. wandb/sdk/interface/interface_sock.py +9 -3
  61. wandb/sdk/internal/_generated/__init__.py +2 -12
  62. wandb/sdk/internal/sender.py +1 -1
  63. wandb/sdk/internal/settings_static.py +2 -82
  64. wandb/sdk/launch/runner/kubernetes_runner.py +25 -20
  65. wandb/sdk/launch/utils.py +82 -1
  66. wandb/sdk/lib/progress.py +7 -4
  67. wandb/sdk/lib/service/service_client.py +5 -9
  68. wandb/sdk/lib/service/service_connection.py +39 -23
  69. wandb/sdk/mailbox/mailbox_handle.py +2 -0
  70. wandb/sdk/projects/_generated/__init__.py +12 -33
  71. wandb/sdk/wandb_init.py +22 -2
  72. wandb/sdk/wandb_login.py +53 -27
  73. wandb/sdk/wandb_run.py +5 -3
  74. wandb/sdk/wandb_settings.py +50 -13
  75. wandb/sync/sync.py +7 -2
  76. wandb/util.py +1 -1
  77. {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/METADATA +1 -1
  78. {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/RECORD +81 -78
  79. wandb/sdk/artifacts/_graphql_fragments.py +0 -19
  80. {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/WHEEL +0 -0
  81. {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/entry_points.txt +0 -0
  82. {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/licenses/LICENSE +0 -0
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any
6
6
  from typing_extensions import override
7
7
 
8
8
  from wandb.proto import wandb_server_pb2 as spb
9
+ from wandb.sdk.lib import asyncio_manager
9
10
 
10
11
  from .interface_shared import InterfaceShared
11
12
 
@@ -21,10 +22,12 @@ logger = logging.getLogger("wandb")
21
22
  class InterfaceSock(InterfaceShared):
22
23
  def __init__(
23
24
  self,
25
+ asyncer: asyncio_manager.AsyncioManager,
24
26
  client: ServiceClient,
25
27
  stream_id: str,
26
28
  ) -> None:
27
29
  super().__init__()
30
+ self._asyncer = asyncer
28
31
  self._client = client
29
32
  self._stream_id = stream_id
30
33
 
@@ -37,13 +40,16 @@ class InterfaceSock(InterfaceShared):
37
40
  self._assign(record)
38
41
  request = spb.ServerRequest()
39
42
  request.record_publish.CopyFrom(record)
40
- self._client.publish(request)
43
+ self._asyncer.run(lambda: self._client.publish(request))
41
44
 
42
- @override
43
45
  def _deliver(self, record: pb.Record) -> MailboxHandle[pb.Result]:
46
+ return self._asyncer.run(lambda: self.deliver_async(record))
47
+
48
+ @override
49
+ async def deliver_async(self, record: pb.Record) -> MailboxHandle[pb.Result]:
44
50
  self._assign(record)
45
51
  request = spb.ServerRequest()
46
52
  request.record_publish.CopyFrom(record)
47
53
 
48
- handle = self._client.deliver(request)
54
+ handle = await self._client.deliver(request)
49
55
  return handle.map(lambda response: response.result_communicate)
@@ -1,15 +1,5 @@
1
1
  # Generated by ariadne-codegen
2
2
 
3
+ __all__ = ["SERVER_FEATURES_QUERY_GQL", "ServerFeaturesQuery"]
3
4
  from .operations import SERVER_FEATURES_QUERY_GQL
4
- from .server_features_query import (
5
- ServerFeaturesQuery,
6
- ServerFeaturesQueryServerInfo,
7
- ServerFeaturesQueryServerInfoFeatures,
8
- )
9
-
10
- __all__ = [
11
- "SERVER_FEATURES_QUERY_GQL",
12
- "ServerFeaturesQuery",
13
- "ServerFeaturesQueryServerInfo",
14
- "ServerFeaturesQueryServerInfoFeatures",
15
- ]
5
+ from .server_features_query import ServerFeaturesQuery
@@ -343,7 +343,7 @@ class SendManager:
343
343
  publish_interface = InterfaceQueue(record_q=record_q)
344
344
  context_keeper = context.ContextKeeper()
345
345
  return SendManager(
346
- settings=SettingsStatic(settings.to_proto()),
346
+ settings=SettingsStatic(dict(settings)),
347
347
  record_q=record_q,
348
348
  result_q=result_q,
349
349
  interface=publish_interface,
@@ -2,9 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Any, Iterable
4
4
 
5
- from wandb.proto import wandb_settings_pb2
6
- from wandb.sdk.lib import RunMoment
7
- from wandb.sdk.wandb_settings import CLIENT_ONLY_SETTINGS, Settings
5
+ from wandb.sdk.wandb_settings import Settings
8
6
 
9
7
 
10
8
  class SettingsStatic(Settings):
@@ -14,87 +12,9 @@ class SettingsStatic(Settings):
14
12
  attributes or items.
15
13
  """
16
14
 
17
- def __init__(self, proto: wandb_settings_pb2.Settings) -> None:
18
- data = self._proto_to_dict(proto)
15
+ def __init__(self, data: dict[str, Any]) -> None:
19
16
  super().__init__(**data)
20
17
 
21
- def _proto_to_dict(self, proto: wandb_settings_pb2.Settings) -> dict:
22
- data = {}
23
-
24
- exclude_fields = {
25
- "model_config",
26
- "model_fields",
27
- "model_fields_set",
28
- "__fields__",
29
- "__model_fields_set",
30
- "__pydantic_self__",
31
- "__pydantic_initialised__",
32
- }
33
-
34
- fields = (
35
- Settings.model_fields
36
- if hasattr(Settings, "model_fields")
37
- else Settings.__fields__
38
- ) # type: ignore [attr-defined]
39
-
40
- fields = {k: v for k, v in fields.items() if k not in exclude_fields} # type: ignore [union-attr]
41
-
42
- forks_specified: list[str] = []
43
- for key in fields:
44
- if key in CLIENT_ONLY_SETTINGS:
45
- continue
46
-
47
- value: Any = None
48
-
49
- field_info = fields[key]
50
- annotation = str(field_info.annotation)
51
-
52
- if key == "_stats_open_metrics_filters":
53
- # todo: it's an underscored field, refactor into
54
- # something more elegant?
55
- # I'm really about this. It's ugly, but it works.
56
- # Do not try to repeat this at home.
57
- value_type = getattr(proto, key).WhichOneof("value")
58
- if value_type == "sequence":
59
- value = list(getattr(proto, key).sequence.value)
60
- elif value_type == "mapping":
61
- unpacked_mapping = {}
62
- for outer_key, outer_value in getattr(
63
- proto, key
64
- ).mapping.value.items():
65
- unpacked_inner = {}
66
- for inner_key, inner_value in outer_value.value.items():
67
- unpacked_inner[inner_key] = inner_value
68
- unpacked_mapping[outer_key] = unpacked_inner
69
- value = unpacked_mapping
70
- elif key == "fork_from" or key == "resume_from":
71
- value = getattr(proto, key)
72
- if value.run:
73
- value = RunMoment(
74
- run=value.run, value=value.value, metric=value.metric
75
- )
76
- forks_specified.append(key)
77
- else:
78
- value = None
79
- else:
80
- if proto.HasField(key): # type: ignore [arg-type]
81
- value = getattr(proto, key).value
82
- # Convert to list if the field is a sequence
83
- if any(t in annotation for t in ("tuple", "Sequence", "list")):
84
- value = list(value)
85
- else:
86
- value = None
87
-
88
- if value is not None:
89
- data[key] = value
90
-
91
- if len(forks_specified) > 1:
92
- raise ValueError(
93
- "Only one of fork_from or resume_from can be specified, not both"
94
- )
95
-
96
- return data
97
-
98
18
  def __setattr__(self, name: str, value: object) -> None:
99
19
  raise AttributeError("Error: SettingsStatic is a readonly object")
100
20
 
@@ -27,7 +27,11 @@ from wandb.sdk.launch.runner.kubernetes_monitor import (
27
27
  CustomResource,
28
28
  LaunchKubernetesMonitor,
29
29
  )
30
- from wandb.sdk.launch.utils import recursive_macro_sub
30
+ from wandb.sdk.launch.utils import (
31
+ recursive_macro_sub,
32
+ sanitize_identifiers_for_k8s,
33
+ yield_containers,
34
+ )
31
35
  from wandb.sdk.lib.retry import ExponentialBackoff, retry_async
32
36
  from wandb.util import get_module
33
37
 
@@ -39,6 +43,7 @@ from ..utils import (
39
43
  MAX_ENV_LENGTHS,
40
44
  PROJECT_SYNCHRONOUS,
41
45
  get_kube_context_and_api_client,
46
+ make_k8s_label_safe,
42
47
  make_name_dns_safe,
43
48
  )
44
49
  from .abstract import AbstractRun, AbstractRunner
@@ -708,6 +713,7 @@ class KubernetesRunner(AbstractRunner):
708
713
  api_key_secret: Optional["V1Secret"] = None,
709
714
  wait_for_ready: bool = True,
710
715
  wait_timeout: int = 300,
716
+ auxiliary_resource_label_value: Optional[str] = None,
711
717
  ) -> None:
712
718
  """Prepare a service for launch.
713
719
 
@@ -725,6 +731,10 @@ class KubernetesRunner(AbstractRunner):
725
731
  config["metadata"].setdefault("labels", {})
726
732
  config["metadata"]["labels"][WANDB_K8S_RUN_ID] = run_id
727
733
  config["metadata"]["labels"]["wandb.ai/created-by"] = "launch-agent"
734
+ if auxiliary_resource_label_value:
735
+ config["metadata"]["labels"][WANDB_K8S_LABEL_AUXILIARY_RESOURCE] = (
736
+ auxiliary_resource_label_value
737
+ )
728
738
 
729
739
  env_vars = launch_project.get_env_vars_dict(
730
740
  self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
@@ -734,6 +744,13 @@ class KubernetesRunner(AbstractRunner):
734
744
  }
735
745
  add_wandb_env(config, wandb_config_env)
736
746
 
747
+ if auxiliary_resource_label_value:
748
+ add_label_to_pods(
749
+ config,
750
+ WANDB_K8S_LABEL_AUXILIARY_RESOURCE,
751
+ auxiliary_resource_label_value,
752
+ )
753
+
737
754
  if api_key_secret:
738
755
  for cont in yield_containers(config):
739
756
  env = cont.setdefault("env", [])
@@ -751,6 +768,8 @@ class KubernetesRunner(AbstractRunner):
751
768
  cont["env"] = env
752
769
 
753
770
  try:
771
+ sanitize_identifiers_for_k8s(config)
772
+
754
773
  await kubernetes_asyncio.utils.create_from_dict(
755
774
  api_client, config, namespace=namespace
756
775
  )
@@ -924,6 +943,9 @@ class KubernetesRunner(AbstractRunner):
924
943
  additional_services: List[Dict[str, Any]] = recursive_macro_sub(
925
944
  launch_project.launch_spec.get("additional_services", []), update_dict
926
945
  )
946
+ auxiliary_resource_label_value = make_k8s_label_safe(
947
+ f"aux-{launch_project.target_entity}-{launch_project.target_project}-{launch_project.run_id}"
948
+ )
927
949
  if additional_services:
928
950
  wandb.termlog(
929
951
  f"{LOG_PREFIX}Creating additional services: {additional_services}"
@@ -943,6 +965,7 @@ class KubernetesRunner(AbstractRunner):
943
965
  secret,
944
966
  wait_for_ready,
945
967
  wait_timeout,
968
+ auxiliary_resource_label_value,
946
969
  )
947
970
  for resource in additional_services
948
971
  if resource.get("config", {})
@@ -980,7 +1003,7 @@ class KubernetesRunner(AbstractRunner):
980
1003
  job_name,
981
1004
  namespace,
982
1005
  secret,
983
- f"aux-{launch_project.target_entity}-{launch_project.target_project}-{launch_project.run_id}",
1006
+ auxiliary_resource_label_value,
984
1007
  )
985
1008
  if self.backend_config[PROJECT_SYNCHRONOUS]:
986
1009
  await submitted_job.wait()
@@ -1131,24 +1154,6 @@ async def maybe_create_imagepull_secret(
1131
1154
  raise LaunchError(f"Exception when creating Kubernetes secret: {str(e)}\n")
1132
1155
 
1133
1156
 
1134
- def yield_containers(root: Any) -> Iterator[dict]:
1135
- """Yield all container specs in a manifest.
1136
-
1137
- Recursively traverses the manifest and yields all container specs. Container
1138
- specs are identified by the presence of a "containers" key in the value.
1139
- """
1140
- if isinstance(root, dict):
1141
- for k, v in root.items():
1142
- if k == "containers":
1143
- if isinstance(v, list):
1144
- yield from v
1145
- elif isinstance(v, (dict, list)):
1146
- yield from yield_containers(v)
1147
- elif isinstance(root, list):
1148
- for item in root:
1149
- yield from yield_containers(item)
1150
-
1151
-
1152
1157
  def add_wandb_env(root: Union[dict, list], env_vars: Dict[str, str]) -> None:
1153
1158
  """Injects wandb environment variables into specs.
1154
1159
 
wandb/sdk/launch/utils.py CHANGED
@@ -7,7 +7,17 @@ import re
7
7
  import subprocess
8
8
  import sys
9
9
  from collections import defaultdict
10
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
10
+ from typing import (
11
+ TYPE_CHECKING,
12
+ Any,
13
+ Dict,
14
+ Iterator,
15
+ List,
16
+ Optional,
17
+ Tuple,
18
+ Union,
19
+ cast,
20
+ )
11
21
 
12
22
  import click
13
23
 
@@ -599,6 +609,32 @@ def make_name_dns_safe(name: str) -> str:
599
609
  return resp
600
610
 
601
611
 
612
+ def make_k8s_label_safe(value: str) -> str:
613
+ """Return a Kubernetes label/identifier safe string (DNS-1123 label).
614
+
615
+ See:
616
+ https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-label-names
617
+
618
+ Rules:
619
+ - lowercase alphanumeric and '-'
620
+ - must start and end with an alphanumeric
621
+ - max length 63
622
+ """
623
+ # Normalize common separators first
624
+ safe = value.replace("_", "-").lower()
625
+ # Remove any invalid characters
626
+ safe = re.sub(r"[^a-z0-9\-]", "", safe)
627
+ # Collapse consecutive '-'
628
+ safe = re.sub(r"-+", "-", safe)
629
+ # Trim to 63 and strip leading/trailing '-'
630
+ safe = safe[:63].strip("-")
631
+
632
+ if not safe:
633
+ raise LaunchError(f"Invalid value for Kubernetes label: {value}")
634
+
635
+ return safe
636
+
637
+
602
638
  def warn_failed_packages_from_build_logs(
603
639
  log: str, image_uri: str, api: Api, job_tracker: Optional["JobAndRunStatusTracker"]
604
640
  ) -> None:
@@ -744,3 +780,48 @@ def get_current_python_version() -> Tuple[str, str]:
744
780
  major = full_version[0]
745
781
  version = ".".join(full_version[:2]) if len(full_version) >= 2 else major + ".0"
746
782
  return version, major
783
+
784
+
785
+ def yield_containers(root: Union[dict, list]) -> Iterator[dict]:
786
+ """Yield all container specs in a manifest.
787
+
788
+ Recursively traverses the manifest and yields all container specs. Container
789
+ specs are identified by the presence of a "containers" key in the value.
790
+ """
791
+ if isinstance(root, dict):
792
+ for k, v in root.items():
793
+ if k == "containers":
794
+ if isinstance(v, list):
795
+ yield from v
796
+ elif isinstance(v, (dict, list)):
797
+ yield from yield_containers(v)
798
+ elif isinstance(root, list):
799
+ for item in root:
800
+ yield from yield_containers(item)
801
+
802
+
803
+ def sanitize_identifiers_for_k8s(root: Any) -> None:
804
+ if isinstance(root, list):
805
+ for item in root:
806
+ sanitize_identifiers_for_k8s(item)
807
+ return
808
+
809
+ # Only dicts have metadata and nested structures we need to sanitize.
810
+ if not isinstance(root, dict):
811
+ return
812
+
813
+ metadata = root.get("metadata")
814
+ if isinstance(metadata, dict):
815
+ if name := metadata.get("name"):
816
+ metadata["name"] = make_k8s_label_safe(str(name))
817
+
818
+ for container in yield_containers(root):
819
+ if name := container.get("name"):
820
+ container["name"] = make_k8s_label_safe(str(name))
821
+
822
+ # nested names
823
+ for key, value in root.items():
824
+ if isinstance(value, (dict, list)):
825
+ sanitize_identifiers_for_k8s(value)
826
+ elif key == "name" and isinstance(value, str):
827
+ root[key] = make_k8s_label_safe(value)
wandb/sdk/lib/progress.py CHANGED
@@ -45,7 +45,11 @@ async def loop_printing_operation_stats(
45
45
  while True:
46
46
  start_time = time.monotonic()
47
47
 
48
- handle = interface.deliver_operation_stats()
48
+ handle = await interface.deliver_async(
49
+ pb.Record(
50
+ request=pb.Request(operations=pb.OperationStatsRequest()),
51
+ )
52
+ )
49
53
  result = await handle.wait_async(timeout=None)
50
54
  stats = result.response.operations_response.operation_stats
51
55
 
@@ -141,10 +145,9 @@ class ProgressPrinter:
141
145
  if extra_operations > 0:
142
146
  line += f" (+ {extra_operations} more)"
143
147
 
144
- if line != self._last_printed_line:
148
+ if line and line != self._last_printed_line:
145
149
  self._printer.display(line)
146
-
147
- self._last_printed_line = line
150
+ self._last_printed_line = line
148
151
 
149
152
  def _update_single_run(
150
153
  self,
@@ -24,7 +24,6 @@ class ServiceClient:
24
24
  reader: asyncio.StreamReader,
25
25
  writer: asyncio.StreamWriter,
26
26
  ) -> None:
27
- self._asyncer = asyncer
28
27
  self._reader = reader
29
28
  self._writer = writer
30
29
  self._mailbox = Mailbox(asyncer)
@@ -34,11 +33,11 @@ class ServiceClient:
34
33
  name="ServiceClient._forward_responses",
35
34
  )
36
35
 
37
- def publish(self, request: spb.ServerRequest) -> None:
36
+ async def publish(self, request: spb.ServerRequest) -> None:
38
37
  """Send a request without waiting for a response."""
39
- self._asyncer.run_soon(lambda: self._send_server_request(request))
38
+ await self._send_server_request(request)
40
39
 
41
- def deliver(
40
+ async def deliver(
42
41
  self,
43
42
  request: spb.ServerRequest,
44
43
  ) -> MailboxHandle[spb.ServerResponse]:
@@ -52,7 +51,7 @@ class ServiceClient:
52
51
  stopped due to an error.
53
52
  """
54
53
  handle = self._mailbox.require_response(request)
55
- self._asyncer.run_soon(lambda: self._send_server_request(request))
54
+ await self._send_server_request(request)
56
55
  return handle
57
56
 
58
57
  async def _send_server_request(self, request: spb.ServerRequest) -> None:
@@ -64,11 +63,8 @@ class ServiceClient:
64
63
 
65
64
  await self._writer.drain()
66
65
 
67
- def close(self) -> None:
66
+ async def close(self) -> None:
68
67
  """Flush and close the socket."""
69
- self._asyncer.run_soon(self._close)
70
-
71
- async def _close(self) -> None:
72
68
  self._writer.close()
73
69
  await self._writer.wait_closed()
74
70
 
@@ -31,6 +31,7 @@ def connect_to_service(
31
31
 
32
32
  if token:
33
33
  return ServiceConnection(
34
+ asyncer=asyncer,
34
35
  client=token.connect(asyncer=asyncer),
35
36
  proc=None,
36
37
  )
@@ -60,6 +61,7 @@ def _start_and_connect_service(
60
61
  conn.teardown(hooks.exit_code)
61
62
 
62
63
  conn = ServiceConnection(
64
+ asyncer=asyncer,
63
65
  client=client,
64
66
  proc=proc,
65
67
  cleanup=lambda: atexit.unregister(teardown_atexit),
@@ -71,10 +73,14 @@ def _start_and_connect_service(
71
73
 
72
74
 
73
75
  class ServiceConnection:
74
- """A connection to the W&B internal service process."""
76
+ """A connection to the W&B internal service process.
77
+
78
+ None of the synchronous methods may be called in an asyncio context.
79
+ """
75
80
 
76
81
  def __init__(
77
82
  self,
83
+ asyncer: asyncio_manager.AsyncioManager,
78
84
  client: ServiceClient,
79
85
  proc: service_process.ServiceProcess | None,
80
86
  cleanup: Callable[[], None] | None = None,
@@ -82,13 +88,12 @@ class ServiceConnection:
82
88
  """Returns a new ServiceConnection.
83
89
 
84
90
  Args:
85
- mailbox: The mailbox to use for all communication over the socket.
86
- router: A handle to the thread that reads from the socket and
87
- updates the mailbox.
88
- client: A socket that's connected to the service.
91
+ asyncer: An asyncio runner.
92
+ client: A client for communicating with the service over a socket.
89
93
  proc: The service process if we own it, or None otherwise.
90
94
  cleanup: A callback to run on teardown before doing anything.
91
95
  """
96
+ self._asyncer = asyncer
92
97
  self._client = client
93
98
  self._proc = proc
94
99
  self._torn_down = False
@@ -96,9 +101,13 @@ class ServiceConnection:
96
101
 
97
102
  def make_interface(self, stream_id: str) -> InterfaceBase:
98
103
  """Returns an interface for communicating with the service."""
99
- return InterfaceSock(self._client, stream_id=stream_id)
104
+ return InterfaceSock(
105
+ self._asyncer,
106
+ self._client,
107
+ stream_id=stream_id,
108
+ )
100
109
 
101
- def init_sync(
110
+ async def init_sync(
102
111
  self,
103
112
  paths: set[pathlib.Path],
104
113
  settings: wandb_settings.Settings,
@@ -110,10 +119,10 @@ class ServiceConnection:
110
119
  )
111
120
  request = spb.ServerRequest(init_sync=init_sync)
112
121
 
113
- handle = self._client.deliver(request)
122
+ handle = await self._client.deliver(request)
114
123
  return handle.map(lambda r: r.init_sync_response)
115
124
 
116
- def sync(
125
+ async def sync(
117
126
  self,
118
127
  id: str,
119
128
  *,
@@ -123,10 +132,10 @@ class ServiceConnection:
123
132
  sync = wandb_sync_pb2.ServerSyncRequest(id=id, parallelism=parallelism)
124
133
  request = spb.ServerRequest(sync=sync)
125
134
 
126
- handle = self._client.deliver(request)
135
+ handle = await self._client.deliver(request)
127
136
  return handle.map(lambda r: r.sync_response)
128
137
 
129
- def sync_status(
138
+ async def sync_status(
130
139
  self,
131
140
  id: str,
132
141
  ) -> MailboxHandle[wandb_sync_pb2.ServerSyncStatusResponse]:
@@ -134,7 +143,7 @@ class ServiceConnection:
134
143
  sync_status = wandb_sync_pb2.ServerSyncStatusRequest(id=id)
135
144
  request = spb.ServerRequest(sync_status=sync_status)
136
145
 
137
- handle = self._client.deliver(request)
146
+ handle = await self._client.deliver(request)
138
147
  return handle.map(lambda r: r.sync_status_response)
139
148
 
140
149
  def inform_init(
@@ -146,13 +155,17 @@ class ServiceConnection:
146
155
  request = spb.ServerInformInitRequest()
147
156
  request.settings.CopyFrom(settings)
148
157
  request._info.stream_id = run_id
149
- self._client.publish(spb.ServerRequest(inform_init=request))
158
+ self._asyncer.run(
159
+ lambda: self._client.publish(spb.ServerRequest(inform_init=request))
160
+ )
150
161
 
151
162
  def inform_finish(self, run_id: str) -> None:
152
163
  """Send an finish request to the service."""
153
164
  request = spb.ServerInformFinishRequest()
154
165
  request._info.stream_id = run_id
155
- self._client.publish(spb.ServerRequest(inform_finish=request))
166
+ self._asyncer.run(
167
+ lambda: self._client.publish(spb.ServerRequest(inform_finish=request))
168
+ )
156
169
 
157
170
  def inform_attach(
158
171
  self,
@@ -166,7 +179,7 @@ class ServiceConnection:
166
179
  request.inform_attach._info.stream_id = attach_id
167
180
 
168
181
  try:
169
- handle = self._client.deliver(request)
182
+ handle = self._asyncer.run(lambda: self._client.deliver(request))
170
183
  response = handle.wait_or(timeout=10)
171
184
 
172
185
  except (MailboxClosedError, HandleAbandonedError):
@@ -210,13 +223,16 @@ class ServiceConnection:
210
223
  # Clear the service token to prevent new connections to the process.
211
224
  service_token.clear_service_in_env()
212
225
 
213
- self._client.publish(
214
- spb.ServerRequest(
215
- inform_teardown=spb.ServerInformTeardownRequest(
216
- exit_code=exit_code,
217
- )
218
- ),
219
- )
220
- self._client.close()
226
+ async def publish_teardown_and_close() -> None:
227
+ await self._client.publish(
228
+ spb.ServerRequest(
229
+ inform_teardown=spb.ServerInformTeardownRequest(
230
+ exit_code=exit_code,
231
+ )
232
+ ),
233
+ )
234
+ await self._client.close()
235
+
236
+ self._asyncer.run(publish_teardown_and_close)
221
237
 
222
238
  return self._proc.join()
@@ -58,6 +58,8 @@ class MailboxHandle(abc.ABC, Generic[_T]):
58
58
  def cancel(self, iface: interface.InterfaceBase) -> None:
59
59
  """Cancel the handle, requesting any associated work to not complete.
60
60
 
61
+ It is an error to call this from an async function.
62
+
61
63
  This automatically abandons the handle, as a response is no longer
62
64
  guaranteed.
63
65
 
@@ -1,47 +1,26 @@
1
1
  # Generated by ariadne-codegen
2
2
 
3
- from .delete_project import DeleteProject, DeleteProjectDeleteModel
4
- from .fetch_registry import FetchRegistry, FetchRegistryEntity
5
- from .fragments import (
6
- RegistryFragment,
7
- RegistryFragmentArtifactTypes,
8
- RegistryFragmentArtifactTypesEdges,
9
- RegistryFragmentArtifactTypesEdgesNode,
10
- )
11
- from .input_types import ArtifactTypeInput
12
- from .operations import (
13
- DELETE_PROJECT_GQL,
14
- FETCH_REGISTRY_GQL,
15
- RENAME_PROJECT_GQL,
16
- UPSERT_REGISTRY_PROJECT_GQL,
17
- )
18
- from .rename_project import (
19
- RenameProject,
20
- RenameProjectRenameProject,
21
- RenameProjectRenameProjectProject,
22
- )
23
- from .upsert_registry_project import (
24
- UpsertRegistryProject,
25
- UpsertRegistryProjectUpsertModel,
26
- )
27
-
28
3
  __all__ = [
29
4
  "DELETE_PROJECT_GQL",
30
5
  "FETCH_REGISTRY_GQL",
31
6
  "RENAME_PROJECT_GQL",
32
7
  "UPSERT_REGISTRY_PROJECT_GQL",
33
8
  "FetchRegistry",
34
- "FetchRegistryEntity",
35
9
  "RenameProject",
36
- "RenameProjectRenameProject",
37
- "RenameProjectRenameProjectProject",
38
10
  "UpsertRegistryProject",
39
- "UpsertRegistryProjectUpsertModel",
40
11
  "DeleteProject",
41
- "DeleteProjectDeleteModel",
42
12
  "ArtifactTypeInput",
43
13
  "RegistryFragment",
44
- "RegistryFragmentArtifactTypes",
45
- "RegistryFragmentArtifactTypesEdges",
46
- "RegistryFragmentArtifactTypesEdgesNode",
47
14
  ]
15
+ from .delete_project import DeleteProject
16
+ from .fetch_registry import FetchRegistry
17
+ from .fragments import RegistryFragment
18
+ from .input_types import ArtifactTypeInput
19
+ from .operations import (
20
+ DELETE_PROJECT_GQL,
21
+ FETCH_REGISTRY_GQL,
22
+ RENAME_PROJECT_GQL,
23
+ UPSERT_REGISTRY_PROJECT_GQL,
24
+ )
25
+ from .rename_project import RenameProject
26
+ from .upsert_registry_project import UpsertRegistryProject