wandb 0.21.4__py3-none-musllinux_1_2_aarch64.whl → 0.22.1__py3-none-musllinux_1_2_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +3 -3
  3. wandb/_pydantic/__init__.py +12 -11
  4. wandb/_pydantic/base.py +49 -19
  5. wandb/apis/__init__.py +2 -0
  6. wandb/apis/attrs.py +2 -0
  7. wandb/apis/importers/internals/internal.py +16 -23
  8. wandb/apis/internal.py +2 -0
  9. wandb/apis/normalize.py +2 -0
  10. wandb/apis/public/__init__.py +44 -1
  11. wandb/apis/public/api.py +215 -164
  12. wandb/apis/public/artifacts.py +23 -20
  13. wandb/apis/public/const.py +2 -0
  14. wandb/apis/public/files.py +33 -24
  15. wandb/apis/public/history.py +2 -0
  16. wandb/apis/public/jobs.py +20 -18
  17. wandb/apis/public/projects.py +4 -2
  18. wandb/apis/public/query_generator.py +3 -0
  19. wandb/apis/public/registries/__init__.py +7 -0
  20. wandb/apis/public/registries/_freezable_list.py +9 -12
  21. wandb/apis/public/registries/registries_search.py +8 -6
  22. wandb/apis/public/registries/registry.py +22 -17
  23. wandb/apis/public/reports.py +2 -0
  24. wandb/apis/public/runs.py +282 -60
  25. wandb/apis/public/sweeps.py +10 -9
  26. wandb/apis/public/teams.py +2 -0
  27. wandb/apis/public/users.py +2 -0
  28. wandb/apis/public/utils.py +16 -15
  29. wandb/automations/_generated/__init__.py +54 -127
  30. wandb/automations/_generated/create_generic_webhook_integration.py +1 -7
  31. wandb/automations/_generated/fragments.py +26 -91
  32. wandb/bin/gpu_stats +0 -0
  33. wandb/bin/wandb-core +0 -0
  34. wandb/cli/beta_sync.py +9 -11
  35. wandb/errors/errors.py +3 -3
  36. wandb/proto/v3/wandb_internal_pb2.py +234 -224
  37. wandb/proto/v3/wandb_sync_pb2.py +19 -6
  38. wandb/proto/v4/wandb_internal_pb2.py +226 -224
  39. wandb/proto/v4/wandb_sync_pb2.py +10 -6
  40. wandb/proto/v5/wandb_internal_pb2.py +226 -224
  41. wandb/proto/v5/wandb_sync_pb2.py +10 -6
  42. wandb/proto/v6/wandb_base_pb2.py +3 -3
  43. wandb/proto/v6/wandb_internal_pb2.py +229 -227
  44. wandb/proto/v6/wandb_server_pb2.py +3 -3
  45. wandb/proto/v6/wandb_settings_pb2.py +3 -3
  46. wandb/proto/v6/wandb_sync_pb2.py +13 -9
  47. wandb/proto/v6/wandb_telemetry_pb2.py +3 -3
  48. wandb/sdk/artifacts/_factories.py +7 -2
  49. wandb/sdk/artifacts/_generated/__init__.py +112 -412
  50. wandb/sdk/artifacts/_generated/fragments.py +65 -0
  51. wandb/sdk/artifacts/_generated/operations.py +52 -22
  52. wandb/sdk/artifacts/_generated/run_input_artifacts.py +3 -23
  53. wandb/sdk/artifacts/_generated/run_output_artifacts.py +3 -23
  54. wandb/sdk/artifacts/_generated/type_info.py +19 -0
  55. wandb/sdk/artifacts/_gqlutils.py +47 -0
  56. wandb/sdk/artifacts/_models/__init__.py +4 -0
  57. wandb/sdk/artifacts/_models/base_model.py +20 -0
  58. wandb/sdk/artifacts/_validators.py +40 -12
  59. wandb/sdk/artifacts/artifact.py +69 -88
  60. wandb/sdk/artifacts/artifact_file_cache.py +6 -1
  61. wandb/sdk/artifacts/artifact_manifest_entry.py +61 -2
  62. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +1 -1
  63. wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -3
  64. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -1
  65. wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
  66. wandb/sdk/artifacts/storage_policies/_factories.py +63 -0
  67. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +69 -124
  68. wandb/sdk/data_types/bokeh.py +5 -1
  69. wandb/sdk/data_types/image.py +17 -6
  70. wandb/sdk/interface/interface.py +41 -4
  71. wandb/sdk/interface/interface_queue.py +10 -0
  72. wandb/sdk/interface/interface_shared.py +9 -7
  73. wandb/sdk/interface/interface_sock.py +9 -3
  74. wandb/sdk/internal/_generated/__init__.py +2 -12
  75. wandb/sdk/internal/sender.py +1 -1
  76. wandb/sdk/internal/settings_static.py +2 -82
  77. wandb/sdk/launch/runner/kubernetes_runner.py +25 -20
  78. wandb/sdk/launch/utils.py +82 -1
  79. wandb/sdk/lib/progress.py +7 -4
  80. wandb/sdk/lib/service/service_client.py +5 -9
  81. wandb/sdk/lib/service/service_connection.py +39 -23
  82. wandb/sdk/mailbox/mailbox_handle.py +2 -0
  83. wandb/sdk/projects/_generated/__init__.py +12 -33
  84. wandb/sdk/wandb_init.py +31 -3
  85. wandb/sdk/wandb_login.py +53 -27
  86. wandb/sdk/wandb_run.py +5 -3
  87. wandb/sdk/wandb_settings.py +50 -13
  88. wandb/sync/sync.py +7 -2
  89. wandb/util.py +1 -1
  90. wandb/wandb_agent.py +35 -4
  91. {wandb-0.21.4.dist-info → wandb-0.22.1.dist-info}/METADATA +1 -1
  92. {wandb-0.21.4.dist-info → wandb-0.22.1.dist-info}/RECORD +818 -814
  93. wandb/sdk/artifacts/_graphql_fragments.py +0 -19
  94. {wandb-0.21.4.dist-info → wandb-0.22.1.dist-info}/WHEEL +0 -0
  95. {wandb-0.21.4.dist-info → wandb-0.22.1.dist-info}/entry_points.txt +0 -0
  96. {wandb-0.21.4.dist-info → wandb-0.22.1.dist-info}/licenses/LICENSE +0 -0
wandb/apis/public/runs.py CHANGED
@@ -85,13 +85,92 @@ RUN_FRAGMENT = """fragment RunFragment on Run {
85
85
  historyKeys
86
86
  }"""
87
87
 
88
+ # Lightweight fragment for listing operations - excludes heavy fields
89
+ LIGHTWEIGHT_RUN_FRAGMENT = """fragment LightweightRunFragment on Run {
90
+ id
91
+ tags
92
+ name
93
+ displayName
94
+ sweepName
95
+ state
96
+ group
97
+ jobType
98
+ commit
99
+ readOnly
100
+ createdAt
101
+ heartbeatAt
102
+ description
103
+ notes
104
+ historyLineCount
105
+ user {
106
+ name
107
+ username
108
+ }
109
+ }"""
110
+
111
+ # Fragment name constants to avoid string parsing
112
+ RUN_FRAGMENT_NAME = "RunFragment"
113
+ LIGHTWEIGHT_RUN_FRAGMENT_NAME = "LightweightRunFragment"
114
+
115
+
116
+ def _create_runs_query(
117
+ *, lazy: bool, with_internal_id: bool, with_project_id: bool
118
+ ) -> gql:
119
+ """Create GraphQL query for runs with appropriate fragment."""
120
+ fragment = LIGHTWEIGHT_RUN_FRAGMENT if lazy else RUN_FRAGMENT
121
+ fragment_name = LIGHTWEIGHT_RUN_FRAGMENT_NAME if lazy else RUN_FRAGMENT_NAME
122
+
123
+ return gql(
124
+ f"""#graphql
125
+ query Runs($project: String!, $entity: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{
126
+ project(name: $project, entityName: $entity) {{
127
+ {"internalId" if with_internal_id else ""}
128
+ runCount(filters: $filters)
129
+ readOnly
130
+ runs(filters: $filters, after: $cursor, first: $perPage, order: $order) {{
131
+ edges {{
132
+ node {{
133
+ {"projectId" if with_project_id else ""}
134
+ ...{fragment_name}
135
+ }}
136
+ cursor
137
+ }}
138
+ pageInfo {{
139
+ endCursor
140
+ hasNextPage
141
+ }}
142
+ }}
143
+ }}
144
+ }}
145
+ {fragment}
146
+ """
147
+ )
148
+
88
149
 
89
150
  @normalize_exceptions
90
151
  def _server_provides_internal_id_for_project(client) -> bool:
91
- """Returns True if the server allows us to query the internalId field for a project.
92
-
93
- This check is done by utilizing GraphQL introspection in the available fields on the Project type.
152
+ """Returns True if the server allows us to query the internalId field for a project."""
153
+ query_string = """
154
+ query ProbeProjectInput {
155
+ ProjectType: __type(name:"Project") {
156
+ fields {
157
+ name
158
+ }
159
+ }
160
+ }
94
161
  """
162
+
163
+ # Only perform the query once to avoid extra network calls
164
+ query = gql(query_string)
165
+ res = client.execute(query)
166
+ return "internalId" in [
167
+ x["name"] for x in (res.get("ProjectType", {}).get("fields", [{}]))
168
+ ]
169
+
170
+
171
+ @normalize_exceptions
172
+ def _server_provides_project_id_for_run(client) -> bool:
173
+ """Returns True if the server allows us to query the projectId field for a run."""
95
174
  query_string = """
96
175
  query ProbeRunInput {
97
176
  RunType: __type(name:"Run") {
@@ -201,34 +280,15 @@ class Runs(SizedPaginator["Run"]):
201
280
  order: str = "+created_at",
202
281
  per_page: int = 50,
203
282
  include_sweeps: bool = True,
283
+ lazy: bool = True,
204
284
  ):
205
285
  if not order:
206
286
  order = "+created_at"
207
287
 
208
- self.QUERY = gql(
209
- f"""#graphql
210
- query Runs($project: String!, $entity: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{
211
- project(name: $project, entityName: $entity) {{
212
- internalId
213
- runCount(filters: $filters)
214
- readOnly
215
- runs(filters: $filters, after: $cursor, first: $perPage, order: $order) {{
216
- edges {{
217
- node {{
218
- {"projectId" if _server_provides_internal_id_for_project(client) else ""}
219
- ...RunFragment
220
- }}
221
- cursor
222
- }}
223
- pageInfo {{
224
- endCursor
225
- hasNextPage
226
- }}
227
- }}
228
- }}
229
- }}
230
- {RUN_FRAGMENT}
231
- """
288
+ self.QUERY = _create_runs_query(
289
+ lazy=lazy,
290
+ with_internal_id=_server_provides_internal_id_for_project(client),
291
+ with_project_id=_server_provides_project_id_for_run(client),
232
292
  )
233
293
 
234
294
  self.entity = entity
@@ -238,6 +298,7 @@ class Runs(SizedPaginator["Run"]):
238
298
  self.order = order
239
299
  self._sweeps = {}
240
300
  self._include_sweeps = include_sweeps
301
+ self._lazy = lazy
241
302
  variables = {
242
303
  "project": self.project,
243
304
  "entity": self.entity,
@@ -296,6 +357,7 @@ class Runs(SizedPaginator["Run"]):
296
357
  run_response["node"]["name"],
297
358
  run_response["node"],
298
359
  include_sweeps=self._include_sweeps,
360
+ lazy=self._lazy,
299
361
  )
300
362
  objs.append(run)
301
363
 
@@ -422,6 +484,39 @@ class Runs(SizedPaginator["Run"]):
422
484
  def __repr__(self):
423
485
  return f"<Runs {self.entity}/{self.project}>"
424
486
 
487
+ def upgrade_to_full(self):
488
+ """Upgrade this Runs collection from lazy to full mode.
489
+
490
+ This switches to fetching full run data and
491
+ upgrades any already-loaded Run objects to have full data.
492
+ Uses parallel loading for better performance when upgrading multiple runs.
493
+ """
494
+ if not self._lazy:
495
+ return # Already in full mode
496
+
497
+ # Switch to full mode
498
+ self._lazy = False
499
+
500
+ # Regenerate query with full fragment
501
+ self.QUERY = _create_runs_query(
502
+ lazy=False,
503
+ with_internal_id=_server_provides_internal_id_for_project(self.client),
504
+ with_project_id=_server_provides_project_id_for_run(self.client),
505
+ )
506
+
507
+ # Upgrade any existing runs that have been loaded - use parallel loading for performance
508
+ lazy_runs = [run for run in self.objects if run._lazy]
509
+ if lazy_runs:
510
+ from concurrent.futures import ThreadPoolExecutor
511
+
512
+ # Limit workers to avoid overwhelming the server
513
+ max_workers = min(len(lazy_runs), 10)
514
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
515
+ futures = [executor.submit(run.load_full_data) for run in lazy_runs]
516
+ # Wait for all to complete
517
+ for future in futures:
518
+ future.result()
519
+
425
520
 
426
521
  class Run(Attrs):
427
522
  """A single run associated with an entity and project.
@@ -465,6 +560,7 @@ class Run(Attrs):
465
560
  run_id: str,
466
561
  attrs: Mapping | None = None,
467
562
  include_sweeps: bool = True,
563
+ lazy: bool = True,
468
564
  ):
469
565
  """Initialize a Run object.
470
566
 
@@ -481,6 +577,8 @@ class Run(Attrs):
481
577
  self.id = run_id
482
578
  self.sweep = None
483
579
  self._include_sweeps = include_sweeps
580
+ self._lazy = lazy
581
+ self._full_data_loaded = False # Track if we've loaded full data
484
582
  self.dir = os.path.join(self._base_dir, *self.path)
485
583
  try:
486
584
  os.makedirs(self.dir)
@@ -490,6 +588,7 @@ class Run(Attrs):
490
588
  self._metadata: dict[str, Any] | None = None
491
589
  self._state = _attrs.get("state", "not found")
492
590
  self.server_provides_internal_id_field: bool | None = None
591
+ self._server_provides_project_id_field: bool | None = None
493
592
  self._is_loaded: bool = False
494
593
 
495
594
  self.load(force=not _attrs)
@@ -594,23 +693,34 @@ class Run(Attrs):
594
693
  "notes": None,
595
694
  "state": state,
596
695
  },
696
+ lazy=False, # Created runs should have full data available immediately
597
697
  )
598
698
 
599
- def load(self, force=False):
600
- if force or not self._attrs:
601
- self._is_loaded = False
602
- query = gql(f"""#graphql
603
- query Run($project: String!, $entity: String!, $name: String!) {{
604
- project(name: $project, entityName: $entity) {{
605
- run(name: $name) {{
606
- {"projectId" if _server_provides_internal_id_for_project(self.client) else ""}
607
- ...RunFragment
608
- }}
699
+ def _load_with_fragment(
700
+ self, fragment: str, fragment_name: str, force: bool = False
701
+ ):
702
+ """Load run data using specified GraphQL fragment."""
703
+ # Cache the server capability check to avoid repeated network calls
704
+ if self._server_provides_project_id_field is None:
705
+ self._server_provides_project_id_field = (
706
+ _server_provides_project_id_for_run(self.client)
707
+ )
708
+
709
+ query = gql(
710
+ f"""
711
+ query Run($project: String!, $entity: String!, $name: String!) {{
712
+ project(name: $project, entityName: $entity) {{
713
+ run(name: $name) {{
714
+ {"projectId" if self._server_provides_project_id_field else ""}
715
+ ...{fragment_name}
609
716
  }}
610
717
  }}
611
- {RUN_FRAGMENT}
612
- """)
718
+ }}
719
+ {fragment}
720
+ """
721
+ )
613
722
 
723
+ if force or not self._attrs:
614
724
  response = self._exec(query)
615
725
  if (
616
726
  response is None
@@ -620,6 +730,10 @@ class Run(Attrs):
620
730
  raise ValueError("Could not find run {}".format(self))
621
731
  self._attrs = response["project"]["run"]
622
732
 
733
+ self._state = self._attrs["state"]
734
+ if self._attrs.get("user"):
735
+ self.user = public.User(self.client, self._attrs["user"])
736
+
623
737
  if self._include_sweeps and self.sweep_name and not self.sweep:
624
738
  # There may be a lot of runs. Don't bother pulling them all
625
739
  # just for the sake of this one.
@@ -632,39 +746,78 @@ class Run(Attrs):
632
746
  )
633
747
 
634
748
  if not self._is_loaded:
635
- self._load_from_attrs()
749
+ # Always set _project_internal_id if projectId is available, regardless of fragment type
750
+ if "projectId" in self._attrs:
751
+ self._project_internal_id = int(self._attrs["projectId"])
752
+ else:
753
+ self._project_internal_id = None
754
+
755
+ # Only call _load_from_attrs when using the full fragment or when the fields are actually present
756
+ if fragment_name == RUN_FRAGMENT_NAME or (
757
+ "config" in self._attrs
758
+ or "summaryMetrics" in self._attrs
759
+ or "systemMetrics" in self._attrs
760
+ ):
761
+ self._load_from_attrs()
636
762
  self._is_loaded = True
637
763
 
638
764
  return self._attrs
639
765
 
640
766
  def _load_from_attrs(self):
641
767
  self._state = self._attrs.get("state", None)
642
- self._attrs["config"] = _convert_to_dict(self._attrs.get("config"))
643
- self._attrs["summaryMetrics"] = _convert_to_dict(
644
- self._attrs.get("summaryMetrics")
645
- )
646
- self._attrs["systemMetrics"] = _convert_to_dict(
647
- self._attrs.get("systemMetrics")
648
- )
649
768
 
650
- if "projectId" in self._attrs:
651
- self._project_internal_id = int(self._attrs["projectId"])
652
- else:
653
- self._project_internal_id = None
769
+ # Only convert fields if they exist in _attrs
770
+ if "config" in self._attrs:
771
+ self._attrs["config"] = _convert_to_dict(self._attrs.get("config"))
772
+ if "summaryMetrics" in self._attrs:
773
+ self._attrs["summaryMetrics"] = _convert_to_dict(
774
+ self._attrs.get("summaryMetrics")
775
+ )
776
+ if "systemMetrics" in self._attrs:
777
+ self._attrs["systemMetrics"] = _convert_to_dict(
778
+ self._attrs.get("systemMetrics")
779
+ )
780
+
781
+ # Only check for sweeps if sweep_name is available (not in lazy mode or if it exists)
782
+ if self._include_sweeps and self._attrs.get("sweepName") and not self.sweep:
783
+ # There may be a lot of runs. Don't bother pulling them all
784
+ self.sweep = public.Sweep(
785
+ self.client,
786
+ self.entity,
787
+ self.project,
788
+ self._attrs["sweepName"],
789
+ withRuns=False,
790
+ )
654
791
 
655
- if self._attrs.get("user"):
656
- self.user = public.User(self.client, self._attrs["user"])
657
792
  config_user, config_raw = {}, {}
658
- for key, value in self._attrs.get("config").items():
659
- config = config_raw if key in WANDB_INTERNAL_KEYS else config_user
660
- if isinstance(value, dict) and "value" in value:
661
- config[key] = value["value"]
662
- else:
663
- config[key] = value
793
+ if self._attrs.get("config"):
794
+ try:
795
+ # config is already converted to dict by _convert_to_dict
796
+ for key, value in self._attrs.get("config", {}).items():
797
+ config = config_raw if key in WANDB_INTERNAL_KEYS else config_user
798
+ if isinstance(value, dict) and "value" in value:
799
+ config[key] = value["value"]
800
+ else:
801
+ config[key] = value
802
+ except (TypeError, AttributeError):
803
+ # Handle case where config is malformed or not a dict
804
+ pass
805
+
664
806
  config_raw.update(config_user)
665
807
  self._attrs["config"] = config_user
666
808
  self._attrs["rawconfig"] = config_raw
667
809
 
810
+ return self._attrs
811
+
812
+ def load(self, force=False):
813
+ """Load run data using appropriate fragment based on lazy mode."""
814
+ if self._lazy:
815
+ return self._load_with_fragment(
816
+ LIGHTWEIGHT_RUN_FRAGMENT, LIGHTWEIGHT_RUN_FRAGMENT_NAME, force
817
+ )
818
+ else:
819
+ return self._load_with_fragment(RUN_FRAGMENT, RUN_FRAGMENT_NAME, force)
820
+
668
821
  @normalize_exceptions
669
822
  def wait_until_finished(self):
670
823
  """Check the state of the run until it is finished."""
@@ -1126,9 +1279,43 @@ class Run(Attrs):
1126
1279
  )
1127
1280
  return artifact
1128
1281
 
1282
+ def load_full_data(self, force: bool = False) -> dict[str, Any]:
1283
+ """Load full run data including heavy fields like config, systemMetrics, summaryMetrics.
1284
+
1285
+ This method is useful when you initially used lazy=True for listing runs,
1286
+ but need access to the full data for specific runs.
1287
+
1288
+ Args:
1289
+ force: Force reload even if data is already loaded
1290
+
1291
+ Returns:
1292
+ The loaded run attributes
1293
+ """
1294
+ if not self._lazy and not force:
1295
+ # Already in full mode, no need to reload
1296
+ return self._attrs
1297
+
1298
+ # Load full data and mark as loaded
1299
+ result = self._load_with_fragment(RUN_FRAGMENT, RUN_FRAGMENT_NAME, force=True)
1300
+ self._full_data_loaded = True
1301
+ return result
1302
+
1303
+ @property
1304
+ def config(self):
1305
+ """Get run config. Auto-loads full data if in lazy mode."""
1306
+ if self._lazy and not self._full_data_loaded and "config" not in self._attrs:
1307
+ self.load_full_data()
1308
+ return self._attrs.get("config", {})
1309
+
1129
1310
  @property
1130
1311
  def summary(self):
1131
- """A mutable dict-like property that holds summary values associated with the run."""
1312
+ """Get run summary metrics. Auto-loads full data if in lazy mode."""
1313
+ if (
1314
+ self._lazy
1315
+ and not self._full_data_loaded
1316
+ and "summaryMetrics" not in self._attrs
1317
+ ):
1318
+ self.load_full_data()
1132
1319
  if self._summary is None:
1133
1320
  from wandb.old.summary import HTTPSummary
1134
1321
 
@@ -1136,6 +1323,41 @@ class Run(Attrs):
1136
1323
  self._summary = HTTPSummary(self, self.client, summary=self.summary_metrics)
1137
1324
  return self._summary
1138
1325
 
1326
+ @property
1327
+ def system_metrics(self):
1328
+ """Get run system metrics. Auto-loads full data if in lazy mode."""
1329
+ if (
1330
+ self._lazy
1331
+ and not self._full_data_loaded
1332
+ and "systemMetrics" not in self._attrs
1333
+ ):
1334
+ self.load_full_data()
1335
+ return self._attrs.get("systemMetrics", {})
1336
+
1337
+ @property
1338
+ def summary_metrics(self):
1339
+ """Get run summary metrics. Auto-loads full data if in lazy mode."""
1340
+ if (
1341
+ self._lazy
1342
+ and not self._full_data_loaded
1343
+ and "summaryMetrics" not in self._attrs
1344
+ ):
1345
+ self.load_full_data()
1346
+ return self._attrs.get("summaryMetrics", {})
1347
+
1348
+ @property
1349
+ def rawconfig(self):
1350
+ """Get raw run config including internal keys. Auto-loads full data if in lazy mode."""
1351
+ if self._lazy and not self._full_data_loaded and "rawconfig" not in self._attrs:
1352
+ self.load_full_data()
1353
+ return self._attrs.get("rawconfig", {})
1354
+
1355
+ @property
1356
+ def sweep_name(self):
1357
+ """Get sweep name. Always available since sweepName is in lightweight fragment."""
1358
+ # sweepName is included in lightweight fragment, so no need to load full data
1359
+ return self._attrs.get("sweepName")
1360
+
1139
1361
  @property
1140
1362
  def path(self):
1141
1363
  """The path of the run. The path is a list containing the entity, project, and run_id."""
@@ -27,8 +27,9 @@ Note:
27
27
  and wandb.agent() functions from the main wandb package.
28
28
  """
29
29
 
30
+ from __future__ import annotations
31
+
30
32
  import urllib
31
- from typing import Optional
32
33
 
33
34
  from wandb_gql import gql
34
35
 
@@ -103,7 +104,7 @@ class Sweeps(SizedPaginator["Sweep"]):
103
104
  entity: str,
104
105
  project: str,
105
106
  per_page: int = 50,
106
- ) -> "Sweeps":
107
+ ) -> Sweeps:
107
108
  """An iterable collection of `Sweep` objects.
108
109
 
109
110
  Args:
@@ -317,7 +318,7 @@ class Sweep(Attrs):
317
318
  return None
318
319
 
319
320
  @property
320
- def expected_run_count(self) -> Optional[int]:
321
+ def expected_run_count(self) -> int | None:
321
322
  """Return the number of expected runs in the sweep or None for infinite runs."""
322
323
  return self._attrs.get("runCountExpected")
323
324
 
@@ -360,12 +361,12 @@ class Sweep(Attrs):
360
361
  @classmethod
361
362
  def get(
362
363
  cls,
363
- client: "RetryingClient",
364
- entity: Optional[str] = None,
365
- project: Optional[str] = None,
366
- sid: Optional[str] = None,
367
- order: Optional[str] = None,
368
- query: Optional[str] = None,
364
+ client: RetryingClient,
365
+ entity: str | None = None,
366
+ project: str | None = None,
367
+ sid: str | None = None,
368
+ order: str | None = None,
369
+ query: str | None = None,
369
370
  **kwargs,
370
371
  ):
371
372
  """Execute a query against the cloud backend.
@@ -8,6 +8,8 @@ Note:
8
8
  permissions.
9
9
  """
10
10
 
11
+ from __future__ import annotations
12
+
11
13
  import requests
12
14
  from wandb_gql import gql
13
15
 
@@ -7,6 +7,8 @@ Note:
7
7
  users and their authentication. Some operations require admin privileges.
8
8
  """
9
9
 
10
+ from __future__ import annotations
11
+
10
12
  import requests
11
13
  from wandb_gql import gql
12
14
 
@@ -1,6 +1,8 @@
1
+ from __future__ import annotations
2
+
1
3
  import re
2
4
  from enum import Enum
3
- from typing import Any, Dict, Iterable, Mapping, Optional, Set
5
+ from typing import Any, Iterable, Mapping
4
6
  from urllib.parse import urlparse
5
7
 
6
8
  from wandb_gql import gql
@@ -75,7 +77,7 @@ def parse_org_from_registry_path(path: str, path_type: PathType) -> str:
75
77
 
76
78
 
77
79
  def fetch_org_from_settings_or_entity(
78
- settings: dict, default_entity: Optional[str] = None
80
+ settings: dict, default_entity: str | None = None
79
81
  ) -> str:
80
82
  """Fetch the org from either the settings or deriving it from the entity.
81
83
 
@@ -110,17 +112,17 @@ def fetch_org_from_settings_or_entity(
110
112
  class _GQLCompatRewriter(visitor.Visitor):
111
113
  """GraphQL AST visitor to rewrite queries/mutations to be compatible with older server versions."""
112
114
 
113
- omit_variables: Set[str]
114
- omit_fragments: Set[str]
115
- omit_fields: Set[str]
116
- rename_fields: Dict[str, str]
115
+ omit_variables: set[str]
116
+ omit_fragments: set[str]
117
+ omit_fields: set[str]
118
+ rename_fields: dict[str, str]
117
119
 
118
120
  def __init__(
119
121
  self,
120
- omit_variables: Optional[Iterable[str]] = None,
121
- omit_fragments: Optional[Iterable[str]] = None,
122
- omit_fields: Optional[Iterable[str]] = None,
123
- rename_fields: Optional[Mapping[str, str]] = None,
122
+ omit_variables: Iterable[str] | None = None,
123
+ omit_fragments: Iterable[str] | None = None,
124
+ omit_fields: Iterable[str] | None = None,
125
+ rename_fields: Mapping[str, str] | None = None,
124
126
  ):
125
127
  self.omit_variables = set(omit_variables or ())
126
128
  self.omit_fragments = set(omit_fragments or ())
@@ -130,7 +132,6 @@ class _GQLCompatRewriter(visitor.Visitor):
130
132
  def enter_VariableDefinition(self, node: ast.VariableDefinition, *_, **__) -> Any: # noqa: N802
131
133
  if node.variable.name.value in self.omit_variables:
132
134
  return visitor.REMOVE
133
- # return node
134
135
 
135
136
  def enter_ObjectField(self, node: ast.ObjectField, *_, **__) -> Any: # noqa: N802
136
137
  # For context, note that e.g.:
@@ -176,10 +177,10 @@ class _GQLCompatRewriter(visitor.Visitor):
176
177
 
177
178
  def gql_compat(
178
179
  request_string: str,
179
- omit_variables: Optional[Iterable[str]] = None,
180
- omit_fragments: Optional[Iterable[str]] = None,
181
- omit_fields: Optional[Iterable[str]] = None,
182
- rename_fields: Optional[Mapping[str, str]] = None,
180
+ omit_variables: Iterable[str] | None = None,
181
+ omit_fragments: Iterable[str] | None = None,
182
+ omit_fields: Iterable[str] | None = None,
183
+ rename_fields: Mapping[str, str] | None = None,
183
184
  ) -> ast.Document:
184
185
  """Rewrite a GraphQL request string to ensure compatibility with older server versions.
185
186