wandb 0.16.3__py3-none-any.whl → 0.16.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (90) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/agents/pyagent.py +1 -1
  3. wandb/apis/importers/__init__.py +1 -4
  4. wandb/apis/importers/internals/internal.py +386 -0
  5. wandb/apis/importers/internals/protocols.py +125 -0
  6. wandb/apis/importers/internals/util.py +78 -0
  7. wandb/apis/importers/mlflow.py +125 -88
  8. wandb/apis/importers/validation.py +108 -0
  9. wandb/apis/importers/wandb.py +1604 -0
  10. wandb/apis/public/api.py +7 -10
  11. wandb/apis/public/artifacts.py +38 -0
  12. wandb/apis/public/files.py +11 -2
  13. wandb/apis/reports/v2/__init__.py +0 -19
  14. wandb/apis/reports/v2/expr_parsing.py +0 -1
  15. wandb/apis/reports/v2/interface.py +15 -18
  16. wandb/apis/reports/v2/internal.py +12 -45
  17. wandb/cli/cli.py +52 -55
  18. wandb/integration/gym/__init__.py +2 -1
  19. wandb/integration/keras/callbacks/model_checkpoint.py +1 -1
  20. wandb/integration/keras/keras.py +6 -4
  21. wandb/integration/kfp/kfp_patch.py +2 -2
  22. wandb/integration/openai/fine_tuning.py +1 -2
  23. wandb/integration/ultralytics/callback.py +0 -1
  24. wandb/proto/v3/wandb_internal_pb2.py +332 -312
  25. wandb/proto/v3/wandb_settings_pb2.py +13 -3
  26. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  27. wandb/proto/v4/wandb_internal_pb2.py +316 -312
  28. wandb/proto/v4/wandb_settings_pb2.py +5 -3
  29. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  30. wandb/sdk/artifacts/artifact.py +75 -31
  31. wandb/sdk/artifacts/artifact_manifest.py +5 -2
  32. wandb/sdk/artifacts/artifact_manifest_entry.py +6 -1
  33. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +8 -2
  34. wandb/sdk/artifacts/artifact_saver.py +19 -47
  35. wandb/sdk/artifacts/storage_handler.py +2 -1
  36. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +22 -9
  37. wandb/sdk/artifacts/storage_policy.py +4 -1
  38. wandb/sdk/data_types/base_types/wb_value.py +1 -1
  39. wandb/sdk/data_types/image.py +2 -2
  40. wandb/sdk/interface/interface.py +49 -13
  41. wandb/sdk/interface/interface_shared.py +17 -11
  42. wandb/sdk/internal/file_stream.py +20 -1
  43. wandb/sdk/internal/handler.py +1 -4
  44. wandb/sdk/internal/internal_api.py +3 -1
  45. wandb/sdk/internal/job_builder.py +49 -19
  46. wandb/sdk/internal/profiler.py +1 -1
  47. wandb/sdk/internal/sender.py +96 -124
  48. wandb/sdk/internal/sender_config.py +197 -0
  49. wandb/sdk/internal/settings_static.py +9 -0
  50. wandb/sdk/internal/system/system_info.py +5 -3
  51. wandb/sdk/internal/update.py +1 -1
  52. wandb/sdk/launch/_launch.py +3 -3
  53. wandb/sdk/launch/_launch_add.py +28 -29
  54. wandb/sdk/launch/_project_spec.py +148 -136
  55. wandb/sdk/launch/agent/agent.py +3 -7
  56. wandb/sdk/launch/agent/config.py +0 -27
  57. wandb/sdk/launch/builder/build.py +54 -28
  58. wandb/sdk/launch/builder/docker_builder.py +4 -15
  59. wandb/sdk/launch/builder/kaniko_builder.py +72 -45
  60. wandb/sdk/launch/create_job.py +6 -40
  61. wandb/sdk/launch/loader.py +10 -0
  62. wandb/sdk/launch/registry/anon.py +29 -0
  63. wandb/sdk/launch/registry/local_registry.py +4 -1
  64. wandb/sdk/launch/runner/kubernetes_runner.py +20 -2
  65. wandb/sdk/launch/runner/local_container.py +15 -10
  66. wandb/sdk/launch/runner/sagemaker_runner.py +1 -1
  67. wandb/sdk/launch/sweeps/scheduler.py +11 -3
  68. wandb/sdk/launch/utils.py +14 -0
  69. wandb/sdk/lib/__init__.py +2 -5
  70. wandb/sdk/lib/_settings_toposort_generated.py +4 -1
  71. wandb/sdk/lib/apikey.py +0 -5
  72. wandb/sdk/lib/config_util.py +0 -31
  73. wandb/sdk/lib/filesystem.py +11 -1
  74. wandb/sdk/lib/run_moment.py +72 -0
  75. wandb/sdk/service/service.py +7 -2
  76. wandb/sdk/service/streams.py +1 -6
  77. wandb/sdk/verify/verify.py +2 -1
  78. wandb/sdk/wandb_init.py +12 -1
  79. wandb/sdk/wandb_login.py +43 -26
  80. wandb/sdk/wandb_run.py +164 -110
  81. wandb/sdk/wandb_settings.py +58 -16
  82. wandb/testing/relay.py +5 -6
  83. wandb/util.py +50 -7
  84. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/METADATA +8 -1
  85. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/RECORD +89 -82
  86. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/WHEEL +1 -1
  87. wandb/apis/importers/base.py +0 -400
  88. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/LICENSE +0 -0
  89. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/entry_points.txt +0 -0
  90. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/top_level.txt +0 -0
@@ -1,33 +1,27 @@
1
+ import itertools
2
+ import logging
1
3
  import re
2
4
  from collections import defaultdict
3
- from concurrent.futures import ThreadPoolExecutor, as_completed
4
5
  from typing import Any, Dict, Iterable, List, Optional, Tuple
5
- from unittest.mock import patch
6
6
 
7
+ import mlflow
7
8
  from packaging.version import Version # type: ignore
8
- from tqdm.auto import tqdm
9
9
 
10
10
  import wandb
11
11
  from wandb import Artifact
12
- from wandb.util import coalesce, get_module
13
12
 
14
- from .base import ImporterRun, send_run_with_send_manager
15
-
16
- with patch("click.echo"):
17
- from wandb.apis.reports import Report
18
-
19
- mlflow = get_module(
20
- "mlflow",
21
- required="To use the MlflowImporter, please install mlflow: `pip install mlflow`",
22
- )
13
+ from .internals import internal
14
+ from .internals.util import Namespace, for_each
23
15
 
24
16
  mlflow_version = Version(mlflow.__version__)
25
17
 
18
+ logger = logging.getLogger("import_logger")
19
+
26
20
 
27
21
  class MlflowRun:
28
22
  def __init__(self, run, mlflow_client):
29
23
  self.run = run
30
- self.mlflow_client = mlflow_client
24
+ self.mlflow_client: mlflow.MlflowClient = mlflow_client
31
25
 
32
26
  def run_id(self) -> str:
33
27
  return self.run.info.run_id
@@ -39,7 +33,13 @@ class MlflowRun:
39
33
  return "imported-from-mlflow"
40
34
 
41
35
  def config(self) -> Dict[str, Any]:
42
- return self.run.data.params
36
+ conf = self.run.data.params
37
+
38
+ # Add tags here since mlflow supports very long tag names but we only support up to 64 chars
39
+ tags = {
40
+ k: v for k, v in self.run.data.tags.items() if not k.startswith("mlflow.")
41
+ }
42
+ return {**conf, "imported_mlflow_tags": tags}
43
43
 
44
44
  def summary(self) -> Dict[str, float]:
45
45
  return self.run.data.metrics
@@ -71,19 +71,22 @@ class MlflowRun:
71
71
  return self.run.data.tags.get("mlflow.note.content")
72
72
 
73
73
  def tags(self) -> Optional[List[str]]:
74
- mlflow_tags = {
75
- k: v for k, v in self.run.data.tags.items() if not k.startswith("mlflow.")
76
- }
77
- return [f"{k}={v}" for k, v in mlflow_tags.items()]
74
+ ...
75
+
76
+ # W&B tags are different than mlflow tags.
77
+ # The full mlflow tags are added to config under key `imported_mlflow_tags` instead
78
78
 
79
79
  def artifacts(self) -> Optional[Iterable[Artifact]]: # type: ignore
80
80
  if mlflow_version < Version("2.0.0"):
81
81
  dir_path = self.mlflow_client.download_artifacts(
82
- run_id=self.run.info.run_id, path=""
82
+ run_id=self.run.info.run_id,
83
+ path="",
83
84
  )
84
85
  else:
85
86
  dir_path = mlflow.artifacts.download_artifacts(run_id=self.run.info.run_id)
86
87
 
88
+ # Since mlflow doesn't have extra metadata about the artifacts,
89
+ # we just lump them all together into a single wandb.Artifact
87
90
  artifact_name = self._handle_incompatible_strings(self.display_name())
88
91
  art = wandb.Artifact(artifact_name, "imported-artifacts")
89
92
  art.add_dir(dir_path)
@@ -91,37 +94,37 @@ class MlflowRun:
91
94
  return [art]
92
95
 
93
96
  def used_artifacts(self) -> Optional[Iterable[Artifact]]: # type: ignore
94
- ...
97
+ ... # pragma: no cover
95
98
 
96
99
  def os_version(self) -> Optional[str]:
97
- ...
100
+ ... # pragma: no cover
98
101
 
99
102
  def python_version(self) -> Optional[str]:
100
- ...
103
+ ... # pragma: no cover
101
104
 
102
105
  def cuda_version(self) -> Optional[str]:
103
- ...
106
+ ... # pragma: no cover
104
107
 
105
108
  def program(self) -> Optional[str]:
106
- ...
109
+ ... # pragma: no cover
107
110
 
108
111
  def host(self) -> Optional[str]:
109
- ...
112
+ ... # pragma: no cover
110
113
 
111
114
  def username(self) -> Optional[str]:
112
- ...
115
+ ... # pragma: no cover
113
116
 
114
117
  def executable(self) -> Optional[str]:
115
- ...
118
+ ... # pragma: no cover
116
119
 
117
120
  def gpus_used(self) -> Optional[str]:
118
- ...
121
+ ... # pragma: no cover
119
122
 
120
123
  def cpus_used(self) -> Optional[int]: # can we get the model?
121
- ...
124
+ ... # pragma: no cover
122
125
 
123
126
  def memory_used(self) -> Optional[int]:
124
- ...
127
+ ... # pragma: no cover
125
128
 
126
129
  def runtime(self) -> Optional[int]:
127
130
  end_time = (
@@ -135,16 +138,16 @@ class MlflowRun:
135
138
  return self.run.info.start_time // 1000
136
139
 
137
140
  def code_path(self) -> Optional[str]:
138
- ...
141
+ ... # pragma: no cover
139
142
 
140
143
  def cli_version(self) -> Optional[str]:
141
- ...
144
+ ... # pragma: no cover
142
145
 
143
146
  def files(self) -> Optional[Iterable[Tuple[str, str]]]:
144
- ...
147
+ ... # pragma: no cover
145
148
 
146
149
  def logs(self) -> Optional[Iterable[str]]:
147
- ...
150
+ ... # pragma: no cover
148
151
 
149
152
  @staticmethod
150
153
  def _handle_incompatible_strings(s: str) -> str:
@@ -155,76 +158,110 @@ class MlflowRun:
155
158
 
156
159
 
157
160
  class MlflowImporter:
158
- def __init__(self, mlflow_tracking_uri, mlflow_registry_uri=None) -> None:
159
- self.mlflow_tracking_uri = mlflow_tracking_uri
161
+ def __init__(
162
+ self,
163
+ dst_base_url: str,
164
+ dst_api_key: str,
165
+ mlflow_tracking_uri: str,
166
+ mlflow_registry_uri: Optional[str] = None,
167
+ *,
168
+ custom_api_kwargs: Optional[Dict[str, Any]] = None,
169
+ ) -> None:
170
+ self.dst_base_url = dst_base_url
171
+ self.dst_api_key = dst_api_key
172
+
173
+ if custom_api_kwargs is None:
174
+ custom_api_kwargs = {"timeout": 600}
160
175
 
176
+ self.dst_api = wandb.Api(
177
+ api_key=dst_api_key,
178
+ overrides={"base_url": dst_base_url},
179
+ **custom_api_kwargs,
180
+ )
181
+ self.mlflow_tracking_uri = mlflow_tracking_uri
161
182
  mlflow.set_tracking_uri(self.mlflow_tracking_uri)
183
+
162
184
  if mlflow_registry_uri:
163
185
  mlflow.set_registry_uri(mlflow_registry_uri)
186
+
164
187
  self.mlflow_client = mlflow.tracking.MlflowClient(mlflow_tracking_uri)
165
188
 
166
- def collect_runs(self, limit: Optional[int] = None) -> Iterable[MlflowRun]:
189
+ def __repr__(self):
190
+ return f"<MlflowImporter src={self.mlflow_tracking_uri}>"
191
+
192
+ def collect_runs(self, *, limit: Optional[int] = None) -> Iterable[MlflowRun]:
167
193
  if mlflow_version < Version("1.28.0"):
168
194
  experiments = self.mlflow_client.list_experiments()
169
195
  else:
170
196
  experiments = self.mlflow_client.search_experiments()
171
197
 
172
- runs = (
173
- run
174
- for exp in experiments
175
- for run in self.mlflow_client.search_runs(exp.experiment_id)
176
- )
177
- for i, run in enumerate(runs):
178
- if limit and i >= limit:
179
- break
180
- yield MlflowRun(run, self.mlflow_client)
198
+ def _runs():
199
+ for exp in experiments:
200
+ for run in self.mlflow_client.search_runs(exp.experiment_id):
201
+ yield MlflowRun(run, self.mlflow_client)
202
+
203
+ runs = itertools.islice(_runs(), limit)
204
+ yield from runs
181
205
 
182
- def import_run(
206
+ def _import_run(
183
207
  self,
184
- run: ImporterRun,
185
- overrides: Optional[Dict[str, Any]] = None,
208
+ run: MlflowRun,
209
+ *,
210
+ artifacts: bool = True,
211
+ namespace: Optional[Namespace] = None,
212
+ config: Optional[internal.SendManagerConfig] = None,
186
213
  ) -> None:
214
+ if namespace is None:
215
+ namespace = Namespace(run.entity(), run.project())
216
+
217
+ if config is None:
218
+ config = internal.SendManagerConfig(
219
+ metadata=True,
220
+ files=True,
221
+ media=True,
222
+ code=True,
223
+ history=True,
224
+ summary=True,
225
+ terminal_output=True,
226
+ )
227
+
228
+ settings_override = {
229
+ "api_key": self.dst_api_key,
230
+ "base_url": self.dst_base_url,
231
+ "resume": "true",
232
+ "resumed": True,
233
+ }
234
+
187
235
  mlflow.set_tracking_uri(self.mlflow_tracking_uri)
188
- send_run_with_send_manager(run, overrides)
236
+ internal.send_run(
237
+ run,
238
+ overrides=namespace.send_manager_overrides,
239
+ settings_override=settings_override,
240
+ config=config,
241
+ )
242
+
243
+ # in mlflow, the artifacts come with the runs, so import them together
244
+ if artifacts:
245
+ arts = list(run.artifacts())
246
+ logger.debug(f"Importing history artifacts, {run=}")
247
+ internal.send_run(
248
+ run,
249
+ extra_arts=arts,
250
+ overrides=namespace.send_manager_overrides,
251
+ settings_override=settings_override,
252
+ config=internal.SendManagerConfig(log_artifacts=True),
253
+ )
189
254
 
190
255
  def import_runs(
191
256
  self,
192
- runs: Iterable[ImporterRun],
193
- overrides: Optional[Dict[str, Any]] = None,
194
- pool_kwargs: Optional[Dict[str, Any]] = None,
195
- ) -> None:
196
- _overrides = coalesce(overrides, {})
197
- _pool_kwargs = coalesce(pool_kwargs, {})
198
- runs = list(self.collect_runs())
199
-
200
- with ThreadPoolExecutor(**_pool_kwargs) as exc:
201
- futures = {
202
- exc.submit(self.import_run, run, overrides=_overrides): run
203
- for run in runs
204
- }
205
- with tqdm(desc="Importing runs", total=len(futures), unit="run") as pbar:
206
- for future in as_completed(futures):
207
- run = futures[future]
208
- try:
209
- future.result()
210
- except Exception as e:
211
- wandb.termerror(f"Failed to import {run.display_name()}: {e}")
212
- raise e
213
- else:
214
- pbar.set_description(
215
- f"Imported Run: {run.run_group()} {run.display_name()}"
216
- )
217
- finally:
218
- pbar.update(1)
219
-
220
- def import_all_runs(
221
- self,
222
- limit: Optional[int] = None,
223
- overrides: Optional[Dict[str, Any]] = None,
224
- pool_kwargs: Optional[Dict[str, Any]] = None,
257
+ runs: Iterable[MlflowRun],
258
+ *,
259
+ artifacts: bool = True,
260
+ namespace: Optional[Namespace] = None,
261
+ parallel: bool = True,
262
+ max_workers: Optional[int] = None,
225
263
  ) -> None:
226
- runs = self.collect_runs(limit)
227
- self.import_runs(runs, overrides, pool_kwargs)
264
+ def _import_run_wrapped(run):
265
+ self._import_run(run, namespace=namespace, artifacts=artifacts)
228
266
 
229
- def import_report(self, report: Report):
230
- raise NotImplementedError("MLFlow does not have a reports concept")
267
+ for_each(_import_run_wrapped, runs, parallel=parallel, max_workers=max_workers)
@@ -0,0 +1,108 @@
1
+ import filecmp
2
+ import logging
3
+ import os
4
+
5
+ import requests
6
+
7
+ import wandb
8
+
9
+ logger = logging.getLogger(__name__)
10
+ logger.setLevel(logging.INFO)
11
+
12
+
13
+ def _compare_artifact_manifests(
14
+ src_art: wandb.Artifact, dst_art: wandb.Artifact
15
+ ) -> list:
16
+ problems = []
17
+ if isinstance(dst_art, wandb.CommError):
18
+ return ["commError"]
19
+
20
+ if src_art.digest != dst_art.digest:
21
+ problems.append(f"digest mismatch {src_art.digest=}, {dst_art.digest=}")
22
+
23
+ for name, src_entry in src_art.manifest.entries.items():
24
+ dst_entry = dst_art.manifest.entries.get(name)
25
+ if dst_entry is None:
26
+ problems.append(f"missing manifest entry {name=}, {src_entry=}")
27
+ continue
28
+
29
+ for attr in ["path", "digest", "size"]:
30
+ if getattr(src_entry, attr) != getattr(dst_entry, attr):
31
+ problems.append(
32
+ f"manifest entry mismatch {attr=}, {getattr(src_entry, attr)=}, {getattr(dst_entry, attr)=}"
33
+ )
34
+
35
+ return problems
36
+
37
+
38
+ def _compare_artifact_dirs(src_dir, dst_dir) -> list:
39
+ def compare(src_dir, dst_dir):
40
+ comparison = filecmp.dircmp(src_dir, dst_dir)
41
+ differences = {
42
+ "left_only": comparison.left_only,
43
+ "right_only": comparison.right_only,
44
+ "diff_files": comparison.diff_files,
45
+ "subdir_differences": {},
46
+ }
47
+
48
+ # Recursively find differences in subdirectories
49
+ for subdir in comparison.subdirs:
50
+ subdir_src = os.path.join(src_dir, subdir)
51
+ subdir_dst = os.path.join(dst_dir, subdir)
52
+ subdir_differences = compare(subdir_src, subdir_dst)
53
+ # If there are differences, add them to the result
54
+ if subdir_differences and any(subdir_differences.values()):
55
+ differences["subdir_differences"][subdir] = subdir_differences
56
+
57
+ if all(not diff for diff in differences.values()):
58
+ return None
59
+
60
+ return differences
61
+
62
+ return compare(src_dir, dst_dir)
63
+
64
+
65
+ def _check_entries_are_downloadable(art):
66
+ entries = _collect_entries(art)
67
+ for entry in entries:
68
+ if not _check_entry_is_downloable(entry):
69
+ return False
70
+ return True
71
+
72
+
73
+ def _collect_entries(art):
74
+ has_next_page = True
75
+ cursor = None
76
+ entries = []
77
+ while has_next_page:
78
+ attrs = art._fetch_file_urls(cursor)
79
+ has_next_page = attrs["pageInfo"]["hasNextPage"]
80
+ cursor = attrs["pageInfo"]["endCursor"]
81
+ for edge in attrs["edges"]:
82
+ name = edge["node"]["name"]
83
+ entry = art.get_entry(name)
84
+ entry._download_url = edge["node"]["directUrl"]
85
+ entries.append(entry)
86
+ return entries
87
+
88
+
89
+ def _check_entry_is_downloable(entry):
90
+ url = entry._download_url
91
+ expected_size = entry.size
92
+
93
+ try:
94
+ resp = requests.head(url, allow_redirects=True)
95
+ except Exception as e:
96
+ logger.error(f"Problem validating {entry=}, {e=}")
97
+ return False
98
+
99
+ if resp.status_code != 200:
100
+ return False
101
+
102
+ actual_size = resp.headers.get("content-length", -1)
103
+ actual_size = int(actual_size)
104
+
105
+ if expected_size == actual_size:
106
+ return True
107
+
108
+ return False