wandb 0.22.0__py3-none-win32.whl → 0.22.1__py3-none-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +3 -3
  3. wandb/_pydantic/__init__.py +12 -11
  4. wandb/_pydantic/base.py +49 -19
  5. wandb/apis/__init__.py +2 -0
  6. wandb/apis/attrs.py +2 -0
  7. wandb/apis/importers/internals/internal.py +16 -23
  8. wandb/apis/internal.py +2 -0
  9. wandb/apis/normalize.py +2 -0
  10. wandb/apis/public/__init__.py +3 -2
  11. wandb/apis/public/api.py +215 -164
  12. wandb/apis/public/artifacts.py +23 -20
  13. wandb/apis/public/const.py +2 -0
  14. wandb/apis/public/files.py +33 -24
  15. wandb/apis/public/history.py +2 -0
  16. wandb/apis/public/jobs.py +20 -18
  17. wandb/apis/public/projects.py +4 -2
  18. wandb/apis/public/query_generator.py +3 -0
  19. wandb/apis/public/registries/__init__.py +7 -0
  20. wandb/apis/public/registries/_freezable_list.py +9 -12
  21. wandb/apis/public/registries/registries_search.py +8 -6
  22. wandb/apis/public/registries/registry.py +22 -17
  23. wandb/apis/public/reports.py +2 -0
  24. wandb/apis/public/runs.py +261 -57
  25. wandb/apis/public/sweeps.py +10 -9
  26. wandb/apis/public/teams.py +2 -0
  27. wandb/apis/public/users.py +2 -0
  28. wandb/apis/public/utils.py +16 -15
  29. wandb/automations/_generated/__init__.py +54 -127
  30. wandb/automations/_generated/create_generic_webhook_integration.py +1 -7
  31. wandb/automations/_generated/fragments.py +26 -91
  32. wandb/bin/gpu_stats.exe +0 -0
  33. wandb/bin/wandb-core +0 -0
  34. wandb/cli/beta_sync.py +9 -11
  35. wandb/errors/errors.py +3 -3
  36. wandb/proto/v3/wandb_sync_pb2.py +19 -6
  37. wandb/proto/v4/wandb_sync_pb2.py +10 -6
  38. wandb/proto/v5/wandb_sync_pb2.py +10 -6
  39. wandb/proto/v6/wandb_sync_pb2.py +10 -6
  40. wandb/sdk/artifacts/_factories.py +7 -2
  41. wandb/sdk/artifacts/_generated/__init__.py +112 -412
  42. wandb/sdk/artifacts/_generated/fragments.py +65 -0
  43. wandb/sdk/artifacts/_generated/operations.py +52 -22
  44. wandb/sdk/artifacts/_generated/run_input_artifacts.py +3 -23
  45. wandb/sdk/artifacts/_generated/run_output_artifacts.py +3 -23
  46. wandb/sdk/artifacts/_generated/type_info.py +19 -0
  47. wandb/sdk/artifacts/_gqlutils.py +47 -0
  48. wandb/sdk/artifacts/_models/__init__.py +4 -0
  49. wandb/sdk/artifacts/_models/base_model.py +20 -0
  50. wandb/sdk/artifacts/_validators.py +40 -12
  51. wandb/sdk/artifacts/artifact.py +69 -88
  52. wandb/sdk/artifacts/artifact_file_cache.py +6 -1
  53. wandb/sdk/artifacts/artifact_manifest_entry.py +61 -2
  54. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +10 -0
  55. wandb/sdk/data_types/bokeh.py +5 -1
  56. wandb/sdk/data_types/image.py +17 -6
  57. wandb/sdk/interface/interface.py +31 -4
  58. wandb/sdk/interface/interface_queue.py +10 -0
  59. wandb/sdk/interface/interface_shared.py +0 -7
  60. wandb/sdk/interface/interface_sock.py +9 -3
  61. wandb/sdk/internal/_generated/__init__.py +2 -12
  62. wandb/sdk/internal/sender.py +1 -1
  63. wandb/sdk/internal/settings_static.py +2 -82
  64. wandb/sdk/launch/runner/kubernetes_runner.py +25 -20
  65. wandb/sdk/launch/utils.py +82 -1
  66. wandb/sdk/lib/progress.py +7 -4
  67. wandb/sdk/lib/service/service_client.py +5 -9
  68. wandb/sdk/lib/service/service_connection.py +39 -23
  69. wandb/sdk/mailbox/mailbox_handle.py +2 -0
  70. wandb/sdk/projects/_generated/__init__.py +12 -33
  71. wandb/sdk/wandb_init.py +22 -2
  72. wandb/sdk/wandb_login.py +53 -27
  73. wandb/sdk/wandb_run.py +5 -3
  74. wandb/sdk/wandb_settings.py +50 -13
  75. wandb/sync/sync.py +7 -2
  76. wandb/util.py +1 -1
  77. {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/METADATA +1 -1
  78. {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/RECORD +81 -78
  79. wandb/sdk/artifacts/_graphql_fragments.py +0 -19
  80. {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/WHEEL +0 -0
  81. {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/entry_points.txt +0 -0
  82. {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/licenses/LICENSE +0 -0
wandb/apis/public/runs.py CHANGED
@@ -85,6 +85,67 @@ RUN_FRAGMENT = """fragment RunFragment on Run {
85
85
  historyKeys
86
86
  }"""
87
87
 
88
+ # Lightweight fragment for listing operations - excludes heavy fields
89
+ LIGHTWEIGHT_RUN_FRAGMENT = """fragment LightweightRunFragment on Run {
90
+ id
91
+ tags
92
+ name
93
+ displayName
94
+ sweepName
95
+ state
96
+ group
97
+ jobType
98
+ commit
99
+ readOnly
100
+ createdAt
101
+ heartbeatAt
102
+ description
103
+ notes
104
+ historyLineCount
105
+ user {
106
+ name
107
+ username
108
+ }
109
+ }"""
110
+
111
+ # Fragment name constants to avoid string parsing
112
+ RUN_FRAGMENT_NAME = "RunFragment"
113
+ LIGHTWEIGHT_RUN_FRAGMENT_NAME = "LightweightRunFragment"
114
+
115
+
116
+ def _create_runs_query(
117
+ *, lazy: bool, with_internal_id: bool, with_project_id: bool
118
+ ) -> gql:
119
+ """Create GraphQL query for runs with appropriate fragment."""
120
+ fragment = LIGHTWEIGHT_RUN_FRAGMENT if lazy else RUN_FRAGMENT
121
+ fragment_name = LIGHTWEIGHT_RUN_FRAGMENT_NAME if lazy else RUN_FRAGMENT_NAME
122
+
123
+ return gql(
124
+ f"""#graphql
125
+ query Runs($project: String!, $entity: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{
126
+ project(name: $project, entityName: $entity) {{
127
+ {"internalId" if with_internal_id else ""}
128
+ runCount(filters: $filters)
129
+ readOnly
130
+ runs(filters: $filters, after: $cursor, first: $perPage, order: $order) {{
131
+ edges {{
132
+ node {{
133
+ {"projectId" if with_project_id else ""}
134
+ ...{fragment_name}
135
+ }}
136
+ cursor
137
+ }}
138
+ pageInfo {{
139
+ endCursor
140
+ hasNextPage
141
+ }}
142
+ }}
143
+ }}
144
+ }}
145
+ {fragment}
146
+ """
147
+ )
148
+
88
149
 
89
150
  @normalize_exceptions
90
151
  def _server_provides_internal_id_for_project(client) -> bool:
@@ -219,34 +280,15 @@ class Runs(SizedPaginator["Run"]):
219
280
  order: str = "+created_at",
220
281
  per_page: int = 50,
221
282
  include_sweeps: bool = True,
283
+ lazy: bool = True,
222
284
  ):
223
285
  if not order:
224
286
  order = "+created_at"
225
287
 
226
- self.QUERY = gql(
227
- f"""#graphql
228
- query Runs($project: String!, $entity: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{
229
- project(name: $project, entityName: $entity) {{
230
- {"internalId" if _server_provides_internal_id_for_project(client) else ""}
231
- runCount(filters: $filters)
232
- readOnly
233
- runs(filters: $filters, after: $cursor, first: $perPage, order: $order) {{
234
- edges {{
235
- node {{
236
- {"projectId" if _server_provides_project_id_for_run(client) else ""}
237
- ...RunFragment
238
- }}
239
- cursor
240
- }}
241
- pageInfo {{
242
- endCursor
243
- hasNextPage
244
- }}
245
- }}
246
- }}
247
- }}
248
- {RUN_FRAGMENT}
249
- """
288
+ self.QUERY = _create_runs_query(
289
+ lazy=lazy,
290
+ with_internal_id=_server_provides_internal_id_for_project(client),
291
+ with_project_id=_server_provides_project_id_for_run(client),
250
292
  )
251
293
 
252
294
  self.entity = entity
@@ -256,6 +298,7 @@ class Runs(SizedPaginator["Run"]):
256
298
  self.order = order
257
299
  self._sweeps = {}
258
300
  self._include_sweeps = include_sweeps
301
+ self._lazy = lazy
259
302
  variables = {
260
303
  "project": self.project,
261
304
  "entity": self.entity,
@@ -314,6 +357,7 @@ class Runs(SizedPaginator["Run"]):
314
357
  run_response["node"]["name"],
315
358
  run_response["node"],
316
359
  include_sweeps=self._include_sweeps,
360
+ lazy=self._lazy,
317
361
  )
318
362
  objs.append(run)
319
363
 
@@ -440,6 +484,39 @@ class Runs(SizedPaginator["Run"]):
440
484
  def __repr__(self):
441
485
  return f"<Runs {self.entity}/{self.project}>"
442
486
 
487
+ def upgrade_to_full(self):
488
+ """Upgrade this Runs collection from lazy to full mode.
489
+
490
+ This switches to fetching full run data and
491
+ upgrades any already-loaded Run objects to have full data.
492
+ Uses parallel loading for better performance when upgrading multiple runs.
493
+ """
494
+ if not self._lazy:
495
+ return # Already in full mode
496
+
497
+ # Switch to full mode
498
+ self._lazy = False
499
+
500
+ # Regenerate query with full fragment
501
+ self.QUERY = _create_runs_query(
502
+ lazy=False,
503
+ with_internal_id=_server_provides_internal_id_for_project(self.client),
504
+ with_project_id=_server_provides_project_id_for_run(self.client),
505
+ )
506
+
507
+ # Upgrade any existing runs that have been loaded - use parallel loading for performance
508
+ lazy_runs = [run for run in self.objects if run._lazy]
509
+ if lazy_runs:
510
+ from concurrent.futures import ThreadPoolExecutor
511
+
512
+ # Limit workers to avoid overwhelming the server
513
+ max_workers = min(len(lazy_runs), 10)
514
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
515
+ futures = [executor.submit(run.load_full_data) for run in lazy_runs]
516
+ # Wait for all to complete
517
+ for future in futures:
518
+ future.result()
519
+
443
520
 
444
521
  class Run(Attrs):
445
522
  """A single run associated with an entity and project.
@@ -483,6 +560,7 @@ class Run(Attrs):
483
560
  run_id: str,
484
561
  attrs: Mapping | None = None,
485
562
  include_sweeps: bool = True,
563
+ lazy: bool = True,
486
564
  ):
487
565
  """Initialize a Run object.
488
566
 
@@ -499,6 +577,8 @@ class Run(Attrs):
499
577
  self.id = run_id
500
578
  self.sweep = None
501
579
  self._include_sweeps = include_sweeps
580
+ self._lazy = lazy
581
+ self._full_data_loaded = False # Track if we've loaded full data
502
582
  self.dir = os.path.join(self._base_dir, *self.path)
503
583
  try:
504
584
  os.makedirs(self.dir)
@@ -508,6 +588,7 @@ class Run(Attrs):
508
588
  self._metadata: dict[str, Any] | None = None
509
589
  self._state = _attrs.get("state", "not found")
510
590
  self.server_provides_internal_id_field: bool | None = None
591
+ self._server_provides_project_id_field: bool | None = None
511
592
  self._is_loaded: bool = False
512
593
 
513
594
  self.load(force=not _attrs)
@@ -612,23 +693,34 @@ class Run(Attrs):
612
693
  "notes": None,
613
694
  "state": state,
614
695
  },
696
+ lazy=False, # Created runs should have full data available immediately
615
697
  )
616
698
 
617
- def load(self, force=False):
618
- if force or not self._attrs:
619
- self._is_loaded = False
620
- query = gql(f"""#graphql
621
- query Run($project: String!, $entity: String!, $name: String!) {{
622
- project(name: $project, entityName: $entity) {{
623
- run(name: $name) {{
624
- {"projectId" if _server_provides_project_id_for_run(self.client) else ""}
625
- ...RunFragment
626
- }}
699
+ def _load_with_fragment(
700
+ self, fragment: str, fragment_name: str, force: bool = False
701
+ ):
702
+ """Load run data using specified GraphQL fragment."""
703
+ # Cache the server capability check to avoid repeated network calls
704
+ if self._server_provides_project_id_field is None:
705
+ self._server_provides_project_id_field = (
706
+ _server_provides_project_id_for_run(self.client)
707
+ )
708
+
709
+ query = gql(
710
+ f"""
711
+ query Run($project: String!, $entity: String!, $name: String!) {{
712
+ project(name: $project, entityName: $entity) {{
713
+ run(name: $name) {{
714
+ {"projectId" if self._server_provides_project_id_field else ""}
715
+ ...{fragment_name}
627
716
  }}
628
717
  }}
629
- {RUN_FRAGMENT}
630
- """)
718
+ }}
719
+ {fragment}
720
+ """
721
+ )
631
722
 
723
+ if force or not self._attrs:
632
724
  response = self._exec(query)
633
725
  if (
634
726
  response is None
@@ -638,6 +730,10 @@ class Run(Attrs):
638
730
  raise ValueError("Could not find run {}".format(self))
639
731
  self._attrs = response["project"]["run"]
640
732
 
733
+ self._state = self._attrs["state"]
734
+ if self._attrs.get("user"):
735
+ self.user = public.User(self.client, self._attrs["user"])
736
+
641
737
  if self._include_sweeps and self.sweep_name and not self.sweep:
642
738
  # There may be a lot of runs. Don't bother pulling them all
643
739
  # just for the sake of this one.
@@ -650,39 +746,78 @@ class Run(Attrs):
650
746
  )
651
747
 
652
748
  if not self._is_loaded:
653
- self._load_from_attrs()
749
+ # Always set _project_internal_id if projectId is available, regardless of fragment type
750
+ if "projectId" in self._attrs:
751
+ self._project_internal_id = int(self._attrs["projectId"])
752
+ else:
753
+ self._project_internal_id = None
754
+
755
+ # Only call _load_from_attrs when using the full fragment or when the fields are actually present
756
+ if fragment_name == RUN_FRAGMENT_NAME or (
757
+ "config" in self._attrs
758
+ or "summaryMetrics" in self._attrs
759
+ or "systemMetrics" in self._attrs
760
+ ):
761
+ self._load_from_attrs()
654
762
  self._is_loaded = True
655
763
 
656
764
  return self._attrs
657
765
 
658
766
  def _load_from_attrs(self):
659
767
  self._state = self._attrs.get("state", None)
660
- self._attrs["config"] = _convert_to_dict(self._attrs.get("config"))
661
- self._attrs["summaryMetrics"] = _convert_to_dict(
662
- self._attrs.get("summaryMetrics")
663
- )
664
- self._attrs["systemMetrics"] = _convert_to_dict(
665
- self._attrs.get("systemMetrics")
666
- )
667
768
 
668
- if "projectId" in self._attrs:
669
- self._project_internal_id = int(self._attrs["projectId"])
670
- else:
671
- self._project_internal_id = None
769
+ # Only convert fields if they exist in _attrs
770
+ if "config" in self._attrs:
771
+ self._attrs["config"] = _convert_to_dict(self._attrs.get("config"))
772
+ if "summaryMetrics" in self._attrs:
773
+ self._attrs["summaryMetrics"] = _convert_to_dict(
774
+ self._attrs.get("summaryMetrics")
775
+ )
776
+ if "systemMetrics" in self._attrs:
777
+ self._attrs["systemMetrics"] = _convert_to_dict(
778
+ self._attrs.get("systemMetrics")
779
+ )
780
+
781
+ # Only check for sweeps if sweep_name is available (not in lazy mode or if it exists)
782
+ if self._include_sweeps and self._attrs.get("sweepName") and not self.sweep:
783
+ # There may be a lot of runs. Don't bother pulling them all
784
+ self.sweep = public.Sweep(
785
+ self.client,
786
+ self.entity,
787
+ self.project,
788
+ self._attrs["sweepName"],
789
+ withRuns=False,
790
+ )
672
791
 
673
- if self._attrs.get("user"):
674
- self.user = public.User(self.client, self._attrs["user"])
675
792
  config_user, config_raw = {}, {}
676
- for key, value in self._attrs.get("config").items():
677
- config = config_raw if key in WANDB_INTERNAL_KEYS else config_user
678
- if isinstance(value, dict) and "value" in value:
679
- config[key] = value["value"]
680
- else:
681
- config[key] = value
793
+ if self._attrs.get("config"):
794
+ try:
795
+ # config is already converted to dict by _convert_to_dict
796
+ for key, value in self._attrs.get("config", {}).items():
797
+ config = config_raw if key in WANDB_INTERNAL_KEYS else config_user
798
+ if isinstance(value, dict) and "value" in value:
799
+ config[key] = value["value"]
800
+ else:
801
+ config[key] = value
802
+ except (TypeError, AttributeError):
803
+ # Handle case where config is malformed or not a dict
804
+ pass
805
+
682
806
  config_raw.update(config_user)
683
807
  self._attrs["config"] = config_user
684
808
  self._attrs["rawconfig"] = config_raw
685
809
 
810
+ return self._attrs
811
+
812
+ def load(self, force=False):
813
+ """Load run data using appropriate fragment based on lazy mode."""
814
+ if self._lazy:
815
+ return self._load_with_fragment(
816
+ LIGHTWEIGHT_RUN_FRAGMENT, LIGHTWEIGHT_RUN_FRAGMENT_NAME, force
817
+ )
818
+ else:
819
+ return self._load_with_fragment(RUN_FRAGMENT, RUN_FRAGMENT_NAME, force)
820
+
686
821
  @normalize_exceptions
687
822
  def wait_until_finished(self):
688
823
  """Check the state of the run until it is finished."""
@@ -1144,9 +1279,43 @@ class Run(Attrs):
1144
1279
  )
1145
1280
  return artifact
1146
1281
 
1282
+ def load_full_data(self, force: bool = False) -> dict[str, Any]:
1283
+ """Load full run data including heavy fields like config, systemMetrics, summaryMetrics.
1284
+
1285
+ This method is useful when you initially used lazy=True for listing runs,
1286
+ but need access to the full data for specific runs.
1287
+
1288
+ Args:
1289
+ force: Force reload even if data is already loaded
1290
+
1291
+ Returns:
1292
+ The loaded run attributes
1293
+ """
1294
+ if not self._lazy and not force:
1295
+ # Already in full mode, no need to reload
1296
+ return self._attrs
1297
+
1298
+ # Load full data and mark as loaded
1299
+ result = self._load_with_fragment(RUN_FRAGMENT, RUN_FRAGMENT_NAME, force=True)
1300
+ self._full_data_loaded = True
1301
+ return result
1302
+
1303
+ @property
1304
+ def config(self):
1305
+ """Get run config. Auto-loads full data if in lazy mode."""
1306
+ if self._lazy and not self._full_data_loaded and "config" not in self._attrs:
1307
+ self.load_full_data()
1308
+ return self._attrs.get("config", {})
1309
+
1147
1310
  @property
1148
1311
  def summary(self):
1149
- """A mutable dict-like property that holds summary values associated with the run."""
1312
+ """Get run summary metrics. Auto-loads full data if in lazy mode."""
1313
+ if (
1314
+ self._lazy
1315
+ and not self._full_data_loaded
1316
+ and "summaryMetrics" not in self._attrs
1317
+ ):
1318
+ self.load_full_data()
1150
1319
  if self._summary is None:
1151
1320
  from wandb.old.summary import HTTPSummary
1152
1321
 
@@ -1154,6 +1323,41 @@ class Run(Attrs):
1154
1323
  self._summary = HTTPSummary(self, self.client, summary=self.summary_metrics)
1155
1324
  return self._summary
1156
1325
 
1326
+ @property
1327
+ def system_metrics(self):
1328
+ """Get run system metrics. Auto-loads full data if in lazy mode."""
1329
+ if (
1330
+ self._lazy
1331
+ and not self._full_data_loaded
1332
+ and "systemMetrics" not in self._attrs
1333
+ ):
1334
+ self.load_full_data()
1335
+ return self._attrs.get("systemMetrics", {})
1336
+
1337
+ @property
1338
+ def summary_metrics(self):
1339
+ """Get run summary metrics. Auto-loads full data if in lazy mode."""
1340
+ if (
1341
+ self._lazy
1342
+ and not self._full_data_loaded
1343
+ and "summaryMetrics" not in self._attrs
1344
+ ):
1345
+ self.load_full_data()
1346
+ return self._attrs.get("summaryMetrics", {})
1347
+
1348
+ @property
1349
+ def rawconfig(self):
1350
+ """Get raw run config including internal keys. Auto-loads full data if in lazy mode."""
1351
+ if self._lazy and not self._full_data_loaded and "rawconfig" not in self._attrs:
1352
+ self.load_full_data()
1353
+ return self._attrs.get("rawconfig", {})
1354
+
1355
+ @property
1356
+ def sweep_name(self):
1357
+ """Get sweep name. Always available since sweepName is in lightweight fragment."""
1358
+ # sweepName is included in lightweight fragment, so no need to load full data
1359
+ return self._attrs.get("sweepName")
1360
+
1157
1361
  @property
1158
1362
  def path(self):
1159
1363
  """The path of the run. The path is a list containing the entity, project, and run_id."""
@@ -27,8 +27,9 @@ Note:
27
27
  and wandb.agent() functions from the main wandb package.
28
28
  """
29
29
 
30
+ from __future__ import annotations
31
+
30
32
  import urllib
31
- from typing import Optional
32
33
 
33
34
  from wandb_gql import gql
34
35
 
@@ -103,7 +104,7 @@ class Sweeps(SizedPaginator["Sweep"]):
103
104
  entity: str,
104
105
  project: str,
105
106
  per_page: int = 50,
106
- ) -> "Sweeps":
107
+ ) -> Sweeps:
107
108
  """An iterable collection of `Sweep` objects.
108
109
 
109
110
  Args:
@@ -317,7 +318,7 @@ class Sweep(Attrs):
317
318
  return None
318
319
 
319
320
  @property
320
- def expected_run_count(self) -> Optional[int]:
321
+ def expected_run_count(self) -> int | None:
321
322
  """Return the number of expected runs in the sweep or None for infinite runs."""
322
323
  return self._attrs.get("runCountExpected")
323
324
 
@@ -360,12 +361,12 @@ class Sweep(Attrs):
360
361
  @classmethod
361
362
  def get(
362
363
  cls,
363
- client: "RetryingClient",
364
- entity: Optional[str] = None,
365
- project: Optional[str] = None,
366
- sid: Optional[str] = None,
367
- order: Optional[str] = None,
368
- query: Optional[str] = None,
364
+ client: RetryingClient,
365
+ entity: str | None = None,
366
+ project: str | None = None,
367
+ sid: str | None = None,
368
+ order: str | None = None,
369
+ query: str | None = None,
369
370
  **kwargs,
370
371
  ):
371
372
  """Execute a query against the cloud backend.
@@ -8,6 +8,8 @@ Note:
8
8
  permissions.
9
9
  """
10
10
 
11
+ from __future__ import annotations
12
+
11
13
  import requests
12
14
  from wandb_gql import gql
13
15
 
@@ -7,6 +7,8 @@ Note:
7
7
  users and their authentication. Some operations require admin privileges.
8
8
  """
9
9
 
10
+ from __future__ import annotations
11
+
10
12
  import requests
11
13
  from wandb_gql import gql
12
14
 
@@ -1,6 +1,8 @@
1
+ from __future__ import annotations
2
+
1
3
  import re
2
4
  from enum import Enum
3
- from typing import Any, Dict, Iterable, Mapping, Optional, Set
5
+ from typing import Any, Iterable, Mapping
4
6
  from urllib.parse import urlparse
5
7
 
6
8
  from wandb_gql import gql
@@ -75,7 +77,7 @@ def parse_org_from_registry_path(path: str, path_type: PathType) -> str:
75
77
 
76
78
 
77
79
  def fetch_org_from_settings_or_entity(
78
- settings: dict, default_entity: Optional[str] = None
80
+ settings: dict, default_entity: str | None = None
79
81
  ) -> str:
80
82
  """Fetch the org from either the settings or deriving it from the entity.
81
83
 
@@ -110,17 +112,17 @@ def fetch_org_from_settings_or_entity(
110
112
  class _GQLCompatRewriter(visitor.Visitor):
111
113
  """GraphQL AST visitor to rewrite queries/mutations to be compatible with older server versions."""
112
114
 
113
- omit_variables: Set[str]
114
- omit_fragments: Set[str]
115
- omit_fields: Set[str]
116
- rename_fields: Dict[str, str]
115
+ omit_variables: set[str]
116
+ omit_fragments: set[str]
117
+ omit_fields: set[str]
118
+ rename_fields: dict[str, str]
117
119
 
118
120
  def __init__(
119
121
  self,
120
- omit_variables: Optional[Iterable[str]] = None,
121
- omit_fragments: Optional[Iterable[str]] = None,
122
- omit_fields: Optional[Iterable[str]] = None,
123
- rename_fields: Optional[Mapping[str, str]] = None,
122
+ omit_variables: Iterable[str] | None = None,
123
+ omit_fragments: Iterable[str] | None = None,
124
+ omit_fields: Iterable[str] | None = None,
125
+ rename_fields: Mapping[str, str] | None = None,
124
126
  ):
125
127
  self.omit_variables = set(omit_variables or ())
126
128
  self.omit_fragments = set(omit_fragments or ())
@@ -130,7 +132,6 @@ class _GQLCompatRewriter(visitor.Visitor):
130
132
  def enter_VariableDefinition(self, node: ast.VariableDefinition, *_, **__) -> Any: # noqa: N802
131
133
  if node.variable.name.value in self.omit_variables:
132
134
  return visitor.REMOVE
133
- # return node
134
135
 
135
136
  def enter_ObjectField(self, node: ast.ObjectField, *_, **__) -> Any: # noqa: N802
136
137
  # For context, note that e.g.:
@@ -176,10 +177,10 @@ class _GQLCompatRewriter(visitor.Visitor):
176
177
 
177
178
  def gql_compat(
178
179
  request_string: str,
179
- omit_variables: Optional[Iterable[str]] = None,
180
- omit_fragments: Optional[Iterable[str]] = None,
181
- omit_fields: Optional[Iterable[str]] = None,
182
- rename_fields: Optional[Mapping[str, str]] = None,
180
+ omit_variables: Iterable[str] | None = None,
181
+ omit_fragments: Iterable[str] | None = None,
182
+ omit_fields: Iterable[str] | None = None,
183
+ rename_fields: Mapping[str, str] | None = None,
183
184
  ) -> ast.Document:
184
185
  """Rewrite a GraphQL request string to ensure compatibility with older server versions.
185
186