wandb 0.16.3__py3-none-any.whl → 0.16.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/agents/pyagent.py +1 -1
  3. wandb/apis/importers/__init__.py +1 -4
  4. wandb/apis/importers/internals/internal.py +386 -0
  5. wandb/apis/importers/internals/protocols.py +125 -0
  6. wandb/apis/importers/internals/util.py +78 -0
  7. wandb/apis/importers/mlflow.py +125 -88
  8. wandb/apis/importers/validation.py +108 -0
  9. wandb/apis/importers/wandb.py +1604 -0
  10. wandb/apis/public/api.py +7 -10
  11. wandb/apis/public/artifacts.py +38 -0
  12. wandb/apis/public/files.py +11 -2
  13. wandb/apis/reports/v2/__init__.py +0 -19
  14. wandb/apis/reports/v2/expr_parsing.py +0 -1
  15. wandb/apis/reports/v2/interface.py +15 -18
  16. wandb/apis/reports/v2/internal.py +12 -45
  17. wandb/cli/cli.py +52 -55
  18. wandb/integration/gym/__init__.py +2 -1
  19. wandb/integration/keras/callbacks/model_checkpoint.py +1 -1
  20. wandb/integration/keras/keras.py +6 -4
  21. wandb/integration/kfp/kfp_patch.py +2 -2
  22. wandb/integration/openai/fine_tuning.py +1 -2
  23. wandb/integration/ultralytics/callback.py +0 -1
  24. wandb/proto/v3/wandb_internal_pb2.py +332 -312
  25. wandb/proto/v3/wandb_settings_pb2.py +13 -3
  26. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  27. wandb/proto/v4/wandb_internal_pb2.py +316 -312
  28. wandb/proto/v4/wandb_settings_pb2.py +5 -3
  29. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  30. wandb/sdk/artifacts/artifact.py +75 -31
  31. wandb/sdk/artifacts/artifact_manifest.py +5 -2
  32. wandb/sdk/artifacts/artifact_manifest_entry.py +6 -1
  33. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +8 -2
  34. wandb/sdk/artifacts/artifact_saver.py +19 -47
  35. wandb/sdk/artifacts/storage_handler.py +2 -1
  36. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +22 -9
  37. wandb/sdk/artifacts/storage_policy.py +4 -1
  38. wandb/sdk/data_types/base_types/wb_value.py +1 -1
  39. wandb/sdk/data_types/image.py +2 -2
  40. wandb/sdk/interface/interface.py +49 -13
  41. wandb/sdk/interface/interface_shared.py +17 -11
  42. wandb/sdk/internal/file_stream.py +20 -1
  43. wandb/sdk/internal/handler.py +1 -4
  44. wandb/sdk/internal/internal_api.py +3 -1
  45. wandb/sdk/internal/job_builder.py +49 -19
  46. wandb/sdk/internal/profiler.py +1 -1
  47. wandb/sdk/internal/sender.py +96 -124
  48. wandb/sdk/internal/sender_config.py +197 -0
  49. wandb/sdk/internal/settings_static.py +9 -0
  50. wandb/sdk/internal/system/system_info.py +5 -3
  51. wandb/sdk/internal/update.py +1 -1
  52. wandb/sdk/launch/_launch.py +3 -3
  53. wandb/sdk/launch/_launch_add.py +28 -29
  54. wandb/sdk/launch/_project_spec.py +148 -136
  55. wandb/sdk/launch/agent/agent.py +3 -7
  56. wandb/sdk/launch/agent/config.py +0 -27
  57. wandb/sdk/launch/builder/build.py +54 -28
  58. wandb/sdk/launch/builder/docker_builder.py +4 -15
  59. wandb/sdk/launch/builder/kaniko_builder.py +72 -45
  60. wandb/sdk/launch/create_job.py +6 -40
  61. wandb/sdk/launch/loader.py +10 -0
  62. wandb/sdk/launch/registry/anon.py +29 -0
  63. wandb/sdk/launch/registry/local_registry.py +4 -1
  64. wandb/sdk/launch/runner/kubernetes_runner.py +20 -2
  65. wandb/sdk/launch/runner/local_container.py +15 -10
  66. wandb/sdk/launch/runner/sagemaker_runner.py +1 -1
  67. wandb/sdk/launch/sweeps/scheduler.py +11 -3
  68. wandb/sdk/launch/utils.py +14 -0
  69. wandb/sdk/lib/__init__.py +2 -5
  70. wandb/sdk/lib/_settings_toposort_generated.py +4 -1
  71. wandb/sdk/lib/apikey.py +0 -5
  72. wandb/sdk/lib/config_util.py +0 -31
  73. wandb/sdk/lib/filesystem.py +11 -1
  74. wandb/sdk/lib/run_moment.py +72 -0
  75. wandb/sdk/service/service.py +7 -2
  76. wandb/sdk/service/streams.py +1 -6
  77. wandb/sdk/verify/verify.py +2 -1
  78. wandb/sdk/wandb_init.py +12 -1
  79. wandb/sdk/wandb_login.py +43 -26
  80. wandb/sdk/wandb_run.py +164 -110
  81. wandb/sdk/wandb_settings.py +58 -16
  82. wandb/testing/relay.py +5 -6
  83. wandb/util.py +50 -7
  84. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/METADATA +8 -1
  85. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/RECORD +89 -82
  86. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/WHEEL +1 -1
  87. wandb/apis/importers/base.py +0 -400
  88. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/LICENSE +0 -0
  89. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/entry_points.txt +0 -0
  90. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/top_level.txt +0 -0
@@ -1,33 +1,27 @@
1
+ import itertools
2
+ import logging
1
3
  import re
2
4
  from collections import defaultdict
3
- from concurrent.futures import ThreadPoolExecutor, as_completed
4
5
  from typing import Any, Dict, Iterable, List, Optional, Tuple
5
- from unittest.mock import patch
6
6
 
7
+ import mlflow
7
8
  from packaging.version import Version # type: ignore
8
- from tqdm.auto import tqdm
9
9
 
10
10
  import wandb
11
11
  from wandb import Artifact
12
- from wandb.util import coalesce, get_module
13
12
 
14
- from .base import ImporterRun, send_run_with_send_manager
15
-
16
- with patch("click.echo"):
17
- from wandb.apis.reports import Report
18
-
19
- mlflow = get_module(
20
- "mlflow",
21
- required="To use the MlflowImporter, please install mlflow: `pip install mlflow`",
22
- )
13
+ from .internals import internal
14
+ from .internals.util import Namespace, for_each
23
15
 
24
16
  mlflow_version = Version(mlflow.__version__)
25
17
 
18
+ logger = logging.getLogger("import_logger")
19
+
26
20
 
27
21
  class MlflowRun:
28
22
  def __init__(self, run, mlflow_client):
29
23
  self.run = run
30
- self.mlflow_client = mlflow_client
24
+ self.mlflow_client: mlflow.MlflowClient = mlflow_client
31
25
 
32
26
  def run_id(self) -> str:
33
27
  return self.run.info.run_id
@@ -39,7 +33,13 @@ class MlflowRun:
39
33
  return "imported-from-mlflow"
40
34
 
41
35
  def config(self) -> Dict[str, Any]:
42
- return self.run.data.params
36
+ conf = self.run.data.params
37
+
38
+ # Add tags here since mlflow supports very long tag names but we only support up to 64 chars
39
+ tags = {
40
+ k: v for k, v in self.run.data.tags.items() if not k.startswith("mlflow.")
41
+ }
42
+ return {**conf, "imported_mlflow_tags": tags}
43
43
 
44
44
  def summary(self) -> Dict[str, float]:
45
45
  return self.run.data.metrics
@@ -71,19 +71,22 @@ class MlflowRun:
71
71
  return self.run.data.tags.get("mlflow.note.content")
72
72
 
73
73
  def tags(self) -> Optional[List[str]]:
74
- mlflow_tags = {
75
- k: v for k, v in self.run.data.tags.items() if not k.startswith("mlflow.")
76
- }
77
- return [f"{k}={v}" for k, v in mlflow_tags.items()]
74
+ ...
75
+
76
+ # W&B tags are different than mlflow tags.
77
+ # The full mlflow tags are added to config under key `imported_mlflow_tags` instead
78
78
 
79
79
  def artifacts(self) -> Optional[Iterable[Artifact]]: # type: ignore
80
80
  if mlflow_version < Version("2.0.0"):
81
81
  dir_path = self.mlflow_client.download_artifacts(
82
- run_id=self.run.info.run_id, path=""
82
+ run_id=self.run.info.run_id,
83
+ path="",
83
84
  )
84
85
  else:
85
86
  dir_path = mlflow.artifacts.download_artifacts(run_id=self.run.info.run_id)
86
87
 
88
+ # Since mlflow doesn't have extra metadata about the artifacts,
89
+ # we just lump them all together into a single wandb.Artifact
87
90
  artifact_name = self._handle_incompatible_strings(self.display_name())
88
91
  art = wandb.Artifact(artifact_name, "imported-artifacts")
89
92
  art.add_dir(dir_path)
@@ -91,37 +94,37 @@ class MlflowRun:
91
94
  return [art]
92
95
 
93
96
  def used_artifacts(self) -> Optional[Iterable[Artifact]]: # type: ignore
94
- ...
97
+ ... # pragma: no cover
95
98
 
96
99
  def os_version(self) -> Optional[str]:
97
- ...
100
+ ... # pragma: no cover
98
101
 
99
102
  def python_version(self) -> Optional[str]:
100
- ...
103
+ ... # pragma: no cover
101
104
 
102
105
  def cuda_version(self) -> Optional[str]:
103
- ...
106
+ ... # pragma: no cover
104
107
 
105
108
  def program(self) -> Optional[str]:
106
- ...
109
+ ... # pragma: no cover
107
110
 
108
111
  def host(self) -> Optional[str]:
109
- ...
112
+ ... # pragma: no cover
110
113
 
111
114
  def username(self) -> Optional[str]:
112
- ...
115
+ ... # pragma: no cover
113
116
 
114
117
  def executable(self) -> Optional[str]:
115
- ...
118
+ ... # pragma: no cover
116
119
 
117
120
  def gpus_used(self) -> Optional[str]:
118
- ...
121
+ ... # pragma: no cover
119
122
 
120
123
  def cpus_used(self) -> Optional[int]: # can we get the model?
121
- ...
124
+ ... # pragma: no cover
122
125
 
123
126
  def memory_used(self) -> Optional[int]:
124
- ...
127
+ ... # pragma: no cover
125
128
 
126
129
  def runtime(self) -> Optional[int]:
127
130
  end_time = (
@@ -135,16 +138,16 @@ class MlflowRun:
135
138
  return self.run.info.start_time // 1000
136
139
 
137
140
  def code_path(self) -> Optional[str]:
138
- ...
141
+ ... # pragma: no cover
139
142
 
140
143
  def cli_version(self) -> Optional[str]:
141
- ...
144
+ ... # pragma: no cover
142
145
 
143
146
  def files(self) -> Optional[Iterable[Tuple[str, str]]]:
144
- ...
147
+ ... # pragma: no cover
145
148
 
146
149
  def logs(self) -> Optional[Iterable[str]]:
147
- ...
150
+ ... # pragma: no cover
148
151
 
149
152
  @staticmethod
150
153
  def _handle_incompatible_strings(s: str) -> str:
@@ -155,76 +158,110 @@ class MlflowRun:
155
158
 
156
159
 
157
160
  class MlflowImporter:
158
- def __init__(self, mlflow_tracking_uri, mlflow_registry_uri=None) -> None:
159
- self.mlflow_tracking_uri = mlflow_tracking_uri
161
+ def __init__(
162
+ self,
163
+ dst_base_url: str,
164
+ dst_api_key: str,
165
+ mlflow_tracking_uri: str,
166
+ mlflow_registry_uri: Optional[str] = None,
167
+ *,
168
+ custom_api_kwargs: Optional[Dict[str, Any]] = None,
169
+ ) -> None:
170
+ self.dst_base_url = dst_base_url
171
+ self.dst_api_key = dst_api_key
172
+
173
+ if custom_api_kwargs is None:
174
+ custom_api_kwargs = {"timeout": 600}
160
175
 
176
+ self.dst_api = wandb.Api(
177
+ api_key=dst_api_key,
178
+ overrides={"base_url": dst_base_url},
179
+ **custom_api_kwargs,
180
+ )
181
+ self.mlflow_tracking_uri = mlflow_tracking_uri
161
182
  mlflow.set_tracking_uri(self.mlflow_tracking_uri)
183
+
162
184
  if mlflow_registry_uri:
163
185
  mlflow.set_registry_uri(mlflow_registry_uri)
186
+
164
187
  self.mlflow_client = mlflow.tracking.MlflowClient(mlflow_tracking_uri)
165
188
 
166
- def collect_runs(self, limit: Optional[int] = None) -> Iterable[MlflowRun]:
189
+ def __repr__(self):
190
+ return f"<MlflowImporter src={self.mlflow_tracking_uri}>"
191
+
192
+ def collect_runs(self, *, limit: Optional[int] = None) -> Iterable[MlflowRun]:
167
193
  if mlflow_version < Version("1.28.0"):
168
194
  experiments = self.mlflow_client.list_experiments()
169
195
  else:
170
196
  experiments = self.mlflow_client.search_experiments()
171
197
 
172
- runs = (
173
- run
174
- for exp in experiments
175
- for run in self.mlflow_client.search_runs(exp.experiment_id)
176
- )
177
- for i, run in enumerate(runs):
178
- if limit and i >= limit:
179
- break
180
- yield MlflowRun(run, self.mlflow_client)
198
+ def _runs():
199
+ for exp in experiments:
200
+ for run in self.mlflow_client.search_runs(exp.experiment_id):
201
+ yield MlflowRun(run, self.mlflow_client)
202
+
203
+ runs = itertools.islice(_runs(), limit)
204
+ yield from runs
181
205
 
182
- def import_run(
206
+ def _import_run(
183
207
  self,
184
- run: ImporterRun,
185
- overrides: Optional[Dict[str, Any]] = None,
208
+ run: MlflowRun,
209
+ *,
210
+ artifacts: bool = True,
211
+ namespace: Optional[Namespace] = None,
212
+ config: Optional[internal.SendManagerConfig] = None,
186
213
  ) -> None:
214
+ if namespace is None:
215
+ namespace = Namespace(run.entity(), run.project())
216
+
217
+ if config is None:
218
+ config = internal.SendManagerConfig(
219
+ metadata=True,
220
+ files=True,
221
+ media=True,
222
+ code=True,
223
+ history=True,
224
+ summary=True,
225
+ terminal_output=True,
226
+ )
227
+
228
+ settings_override = {
229
+ "api_key": self.dst_api_key,
230
+ "base_url": self.dst_base_url,
231
+ "resume": "true",
232
+ "resumed": True,
233
+ }
234
+
187
235
  mlflow.set_tracking_uri(self.mlflow_tracking_uri)
188
- send_run_with_send_manager(run, overrides)
236
+ internal.send_run(
237
+ run,
238
+ overrides=namespace.send_manager_overrides,
239
+ settings_override=settings_override,
240
+ config=config,
241
+ )
242
+
243
+ # in mlflow, the artifacts come with the runs, so import them together
244
+ if artifacts:
245
+ arts = list(run.artifacts())
246
+ logger.debug(f"Importing history artifacts, {run=}")
247
+ internal.send_run(
248
+ run,
249
+ extra_arts=arts,
250
+ overrides=namespace.send_manager_overrides,
251
+ settings_override=settings_override,
252
+ config=internal.SendManagerConfig(log_artifacts=True),
253
+ )
189
254
 
190
255
  def import_runs(
191
256
  self,
192
- runs: Iterable[ImporterRun],
193
- overrides: Optional[Dict[str, Any]] = None,
194
- pool_kwargs: Optional[Dict[str, Any]] = None,
195
- ) -> None:
196
- _overrides = coalesce(overrides, {})
197
- _pool_kwargs = coalesce(pool_kwargs, {})
198
- runs = list(self.collect_runs())
199
-
200
- with ThreadPoolExecutor(**_pool_kwargs) as exc:
201
- futures = {
202
- exc.submit(self.import_run, run, overrides=_overrides): run
203
- for run in runs
204
- }
205
- with tqdm(desc="Importing runs", total=len(futures), unit="run") as pbar:
206
- for future in as_completed(futures):
207
- run = futures[future]
208
- try:
209
- future.result()
210
- except Exception as e:
211
- wandb.termerror(f"Failed to import {run.display_name()}: {e}")
212
- raise e
213
- else:
214
- pbar.set_description(
215
- f"Imported Run: {run.run_group()} {run.display_name()}"
216
- )
217
- finally:
218
- pbar.update(1)
219
-
220
- def import_all_runs(
221
- self,
222
- limit: Optional[int] = None,
223
- overrides: Optional[Dict[str, Any]] = None,
224
- pool_kwargs: Optional[Dict[str, Any]] = None,
257
+ runs: Iterable[MlflowRun],
258
+ *,
259
+ artifacts: bool = True,
260
+ namespace: Optional[Namespace] = None,
261
+ parallel: bool = True,
262
+ max_workers: Optional[int] = None,
225
263
  ) -> None:
226
- runs = self.collect_runs(limit)
227
- self.import_runs(runs, overrides, pool_kwargs)
264
+ def _import_run_wrapped(run):
265
+ self._import_run(run, namespace=namespace, artifacts=artifacts)
228
266
 
229
- def import_report(self, report: Report):
230
- raise NotImplementedError("MLFlow does not have a reports concept")
267
+ for_each(_import_run_wrapped, runs, parallel=parallel, max_workers=max_workers)
@@ -0,0 +1,108 @@
1
+ import filecmp
2
+ import logging
3
+ import os
4
+
5
+ import requests
6
+
7
+ import wandb
8
+
9
+ logger = logging.getLogger(__name__)
10
+ logger.setLevel(logging.INFO)
11
+
12
+
13
+ def _compare_artifact_manifests(
14
+ src_art: wandb.Artifact, dst_art: wandb.Artifact
15
+ ) -> list:
16
+ problems = []
17
+ if isinstance(dst_art, wandb.CommError):
18
+ return ["commError"]
19
+
20
+ if src_art.digest != dst_art.digest:
21
+ problems.append(f"digest mismatch {src_art.digest=}, {dst_art.digest=}")
22
+
23
+ for name, src_entry in src_art.manifest.entries.items():
24
+ dst_entry = dst_art.manifest.entries.get(name)
25
+ if dst_entry is None:
26
+ problems.append(f"missing manifest entry {name=}, {src_entry=}")
27
+ continue
28
+
29
+ for attr in ["path", "digest", "size"]:
30
+ if getattr(src_entry, attr) != getattr(dst_entry, attr):
31
+ problems.append(
32
+ f"manifest entry mismatch {attr=}, {getattr(src_entry, attr)=}, {getattr(dst_entry, attr)=}"
33
+ )
34
+
35
+ return problems
36
+
37
+
38
+ def _compare_artifact_dirs(src_dir, dst_dir) -> list:
39
+ def compare(src_dir, dst_dir):
40
+ comparison = filecmp.dircmp(src_dir, dst_dir)
41
+ differences = {
42
+ "left_only": comparison.left_only,
43
+ "right_only": comparison.right_only,
44
+ "diff_files": comparison.diff_files,
45
+ "subdir_differences": {},
46
+ }
47
+
48
+ # Recursively find differences in subdirectories
49
+ for subdir in comparison.subdirs:
50
+ subdir_src = os.path.join(src_dir, subdir)
51
+ subdir_dst = os.path.join(dst_dir, subdir)
52
+ subdir_differences = compare(subdir_src, subdir_dst)
53
+ # If there are differences, add them to the result
54
+ if subdir_differences and any(subdir_differences.values()):
55
+ differences["subdir_differences"][subdir] = subdir_differences
56
+
57
+ if all(not diff for diff in differences.values()):
58
+ return None
59
+
60
+ return differences
61
+
62
+ return compare(src_dir, dst_dir)
63
+
64
+
65
+ def _check_entries_are_downloadable(art):
66
+ entries = _collect_entries(art)
67
+ for entry in entries:
68
+ if not _check_entry_is_downloable(entry):
69
+ return False
70
+ return True
71
+
72
+
73
+ def _collect_entries(art):
74
+ has_next_page = True
75
+ cursor = None
76
+ entries = []
77
+ while has_next_page:
78
+ attrs = art._fetch_file_urls(cursor)
79
+ has_next_page = attrs["pageInfo"]["hasNextPage"]
80
+ cursor = attrs["pageInfo"]["endCursor"]
81
+ for edge in attrs["edges"]:
82
+ name = edge["node"]["name"]
83
+ entry = art.get_entry(name)
84
+ entry._download_url = edge["node"]["directUrl"]
85
+ entries.append(entry)
86
+ return entries
87
+
88
+
89
+ def _check_entry_is_downloable(entry):
90
+ url = entry._download_url
91
+ expected_size = entry.size
92
+
93
+ try:
94
+ resp = requests.head(url, allow_redirects=True)
95
+ except Exception as e:
96
+ logger.error(f"Problem validating {entry=}, {e=}")
97
+ return False
98
+
99
+ if resp.status_code != 200:
100
+ return False
101
+
102
+ actual_size = resp.headers.get("content-length", -1)
103
+ actual_size = int(actual_size)
104
+
105
+ if expected_size == actual_size:
106
+ return True
107
+
108
+ return False