wandb 0.16.3__py3-none-any.whl → 0.16.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +2 -2
- wandb/agents/pyagent.py +1 -1
- wandb/apis/importers/__init__.py +1 -4
- wandb/apis/importers/internals/internal.py +386 -0
- wandb/apis/importers/internals/protocols.py +125 -0
- wandb/apis/importers/internals/util.py +78 -0
- wandb/apis/importers/mlflow.py +125 -88
- wandb/apis/importers/validation.py +108 -0
- wandb/apis/importers/wandb.py +1604 -0
- wandb/apis/public/api.py +7 -10
- wandb/apis/public/artifacts.py +38 -0
- wandb/apis/public/files.py +11 -2
- wandb/apis/reports/v2/__init__.py +0 -19
- wandb/apis/reports/v2/expr_parsing.py +0 -1
- wandb/apis/reports/v2/interface.py +15 -18
- wandb/apis/reports/v2/internal.py +12 -45
- wandb/cli/cli.py +52 -55
- wandb/integration/gym/__init__.py +2 -1
- wandb/integration/keras/callbacks/model_checkpoint.py +1 -1
- wandb/integration/keras/keras.py +6 -4
- wandb/integration/kfp/kfp_patch.py +2 -2
- wandb/integration/openai/fine_tuning.py +1 -2
- wandb/integration/ultralytics/callback.py +0 -1
- wandb/proto/v3/wandb_internal_pb2.py +332 -312
- wandb/proto/v3/wandb_settings_pb2.py +13 -3
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +316 -312
- wandb/proto/v4/wandb_settings_pb2.py +5 -3
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/artifacts/artifact.py +75 -31
- wandb/sdk/artifacts/artifact_manifest.py +5 -2
- wandb/sdk/artifacts/artifact_manifest_entry.py +6 -1
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +8 -2
- wandb/sdk/artifacts/artifact_saver.py +19 -47
- wandb/sdk/artifacts/storage_handler.py +2 -1
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +22 -9
- wandb/sdk/artifacts/storage_policy.py +4 -1
- wandb/sdk/data_types/base_types/wb_value.py +1 -1
- wandb/sdk/data_types/image.py +2 -2
- wandb/sdk/interface/interface.py +49 -13
- wandb/sdk/interface/interface_shared.py +17 -11
- wandb/sdk/internal/file_stream.py +20 -1
- wandb/sdk/internal/handler.py +1 -4
- wandb/sdk/internal/internal_api.py +3 -1
- wandb/sdk/internal/job_builder.py +49 -19
- wandb/sdk/internal/profiler.py +1 -1
- wandb/sdk/internal/sender.py +96 -124
- wandb/sdk/internal/sender_config.py +197 -0
- wandb/sdk/internal/settings_static.py +9 -0
- wandb/sdk/internal/system/system_info.py +5 -3
- wandb/sdk/internal/update.py +1 -1
- wandb/sdk/launch/_launch.py +3 -3
- wandb/sdk/launch/_launch_add.py +28 -29
- wandb/sdk/launch/_project_spec.py +148 -136
- wandb/sdk/launch/agent/agent.py +3 -7
- wandb/sdk/launch/agent/config.py +0 -27
- wandb/sdk/launch/builder/build.py +54 -28
- wandb/sdk/launch/builder/docker_builder.py +4 -15
- wandb/sdk/launch/builder/kaniko_builder.py +72 -45
- wandb/sdk/launch/create_job.py +6 -40
- wandb/sdk/launch/loader.py +10 -0
- wandb/sdk/launch/registry/anon.py +29 -0
- wandb/sdk/launch/registry/local_registry.py +4 -1
- wandb/sdk/launch/runner/kubernetes_runner.py +20 -2
- wandb/sdk/launch/runner/local_container.py +15 -10
- wandb/sdk/launch/runner/sagemaker_runner.py +1 -1
- wandb/sdk/launch/sweeps/scheduler.py +11 -3
- wandb/sdk/launch/utils.py +14 -0
- wandb/sdk/lib/__init__.py +2 -5
- wandb/sdk/lib/_settings_toposort_generated.py +4 -1
- wandb/sdk/lib/apikey.py +0 -5
- wandb/sdk/lib/config_util.py +0 -31
- wandb/sdk/lib/filesystem.py +11 -1
- wandb/sdk/lib/run_moment.py +72 -0
- wandb/sdk/service/service.py +7 -2
- wandb/sdk/service/streams.py +1 -6
- wandb/sdk/verify/verify.py +2 -1
- wandb/sdk/wandb_init.py +12 -1
- wandb/sdk/wandb_login.py +43 -26
- wandb/sdk/wandb_run.py +164 -110
- wandb/sdk/wandb_settings.py +58 -16
- wandb/testing/relay.py +5 -6
- wandb/util.py +50 -7
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/METADATA +8 -1
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/RECORD +89 -82
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/WHEEL +1 -1
- wandb/apis/importers/base.py +0 -400
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/LICENSE +0 -0
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/entry_points.txt +0 -0
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/top_level.txt +0 -0
wandb/apis/importers/mlflow.py
CHANGED
@@ -1,33 +1,27 @@
|
|
1
|
+
import itertools
|
2
|
+
import logging
|
1
3
|
import re
|
2
4
|
from collections import defaultdict
|
3
|
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
4
5
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
5
|
-
from unittest.mock import patch
|
6
6
|
|
7
|
+
import mlflow
|
7
8
|
from packaging.version import Version # type: ignore
|
8
|
-
from tqdm.auto import tqdm
|
9
9
|
|
10
10
|
import wandb
|
11
11
|
from wandb import Artifact
|
12
|
-
from wandb.util import coalesce, get_module
|
13
12
|
|
14
|
-
from .
|
15
|
-
|
16
|
-
with patch("click.echo"):
|
17
|
-
from wandb.apis.reports import Report
|
18
|
-
|
19
|
-
mlflow = get_module(
|
20
|
-
"mlflow",
|
21
|
-
required="To use the MlflowImporter, please install mlflow: `pip install mlflow`",
|
22
|
-
)
|
13
|
+
from .internals import internal
|
14
|
+
from .internals.util import Namespace, for_each
|
23
15
|
|
24
16
|
mlflow_version = Version(mlflow.__version__)
|
25
17
|
|
18
|
+
logger = logging.getLogger("import_logger")
|
19
|
+
|
26
20
|
|
27
21
|
class MlflowRun:
|
28
22
|
def __init__(self, run, mlflow_client):
|
29
23
|
self.run = run
|
30
|
-
self.mlflow_client = mlflow_client
|
24
|
+
self.mlflow_client: mlflow.MlflowClient = mlflow_client
|
31
25
|
|
32
26
|
def run_id(self) -> str:
|
33
27
|
return self.run.info.run_id
|
@@ -39,7 +33,13 @@ class MlflowRun:
|
|
39
33
|
return "imported-from-mlflow"
|
40
34
|
|
41
35
|
def config(self) -> Dict[str, Any]:
|
42
|
-
|
36
|
+
conf = self.run.data.params
|
37
|
+
|
38
|
+
# Add tags here since mlflow supports very long tag names but we only support up to 64 chars
|
39
|
+
tags = {
|
40
|
+
k: v for k, v in self.run.data.tags.items() if not k.startswith("mlflow.")
|
41
|
+
}
|
42
|
+
return {**conf, "imported_mlflow_tags": tags}
|
43
43
|
|
44
44
|
def summary(self) -> Dict[str, float]:
|
45
45
|
return self.run.data.metrics
|
@@ -71,19 +71,22 @@ class MlflowRun:
|
|
71
71
|
return self.run.data.tags.get("mlflow.note.content")
|
72
72
|
|
73
73
|
def tags(self) -> Optional[List[str]]:
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
74
|
+
...
|
75
|
+
|
76
|
+
# W&B tags are different than mlflow tags.
|
77
|
+
# The full mlflow tags are added to config under key `imported_mlflow_tags` instead
|
78
78
|
|
79
79
|
def artifacts(self) -> Optional[Iterable[Artifact]]: # type: ignore
|
80
80
|
if mlflow_version < Version("2.0.0"):
|
81
81
|
dir_path = self.mlflow_client.download_artifacts(
|
82
|
-
run_id=self.run.info.run_id,
|
82
|
+
run_id=self.run.info.run_id,
|
83
|
+
path="",
|
83
84
|
)
|
84
85
|
else:
|
85
86
|
dir_path = mlflow.artifacts.download_artifacts(run_id=self.run.info.run_id)
|
86
87
|
|
88
|
+
# Since mlflow doesn't have extra metadata about the artifacts,
|
89
|
+
# we just lump them all together into a single wandb.Artifact
|
87
90
|
artifact_name = self._handle_incompatible_strings(self.display_name())
|
88
91
|
art = wandb.Artifact(artifact_name, "imported-artifacts")
|
89
92
|
art.add_dir(dir_path)
|
@@ -91,37 +94,37 @@ class MlflowRun:
|
|
91
94
|
return [art]
|
92
95
|
|
93
96
|
def used_artifacts(self) -> Optional[Iterable[Artifact]]: # type: ignore
|
94
|
-
...
|
97
|
+
... # pragma: no cover
|
95
98
|
|
96
99
|
def os_version(self) -> Optional[str]:
|
97
|
-
...
|
100
|
+
... # pragma: no cover
|
98
101
|
|
99
102
|
def python_version(self) -> Optional[str]:
|
100
|
-
...
|
103
|
+
... # pragma: no cover
|
101
104
|
|
102
105
|
def cuda_version(self) -> Optional[str]:
|
103
|
-
...
|
106
|
+
... # pragma: no cover
|
104
107
|
|
105
108
|
def program(self) -> Optional[str]:
|
106
|
-
...
|
109
|
+
... # pragma: no cover
|
107
110
|
|
108
111
|
def host(self) -> Optional[str]:
|
109
|
-
...
|
112
|
+
... # pragma: no cover
|
110
113
|
|
111
114
|
def username(self) -> Optional[str]:
|
112
|
-
...
|
115
|
+
... # pragma: no cover
|
113
116
|
|
114
117
|
def executable(self) -> Optional[str]:
|
115
|
-
...
|
118
|
+
... # pragma: no cover
|
116
119
|
|
117
120
|
def gpus_used(self) -> Optional[str]:
|
118
|
-
...
|
121
|
+
... # pragma: no cover
|
119
122
|
|
120
123
|
def cpus_used(self) -> Optional[int]: # can we get the model?
|
121
|
-
...
|
124
|
+
... # pragma: no cover
|
122
125
|
|
123
126
|
def memory_used(self) -> Optional[int]:
|
124
|
-
...
|
127
|
+
... # pragma: no cover
|
125
128
|
|
126
129
|
def runtime(self) -> Optional[int]:
|
127
130
|
end_time = (
|
@@ -135,16 +138,16 @@ class MlflowRun:
|
|
135
138
|
return self.run.info.start_time // 1000
|
136
139
|
|
137
140
|
def code_path(self) -> Optional[str]:
|
138
|
-
...
|
141
|
+
... # pragma: no cover
|
139
142
|
|
140
143
|
def cli_version(self) -> Optional[str]:
|
141
|
-
...
|
144
|
+
... # pragma: no cover
|
142
145
|
|
143
146
|
def files(self) -> Optional[Iterable[Tuple[str, str]]]:
|
144
|
-
...
|
147
|
+
... # pragma: no cover
|
145
148
|
|
146
149
|
def logs(self) -> Optional[Iterable[str]]:
|
147
|
-
...
|
150
|
+
... # pragma: no cover
|
148
151
|
|
149
152
|
@staticmethod
|
150
153
|
def _handle_incompatible_strings(s: str) -> str:
|
@@ -155,76 +158,110 @@ class MlflowRun:
|
|
155
158
|
|
156
159
|
|
157
160
|
class MlflowImporter:
|
158
|
-
def __init__(
|
159
|
-
self
|
161
|
+
def __init__(
|
162
|
+
self,
|
163
|
+
dst_base_url: str,
|
164
|
+
dst_api_key: str,
|
165
|
+
mlflow_tracking_uri: str,
|
166
|
+
mlflow_registry_uri: Optional[str] = None,
|
167
|
+
*,
|
168
|
+
custom_api_kwargs: Optional[Dict[str, Any]] = None,
|
169
|
+
) -> None:
|
170
|
+
self.dst_base_url = dst_base_url
|
171
|
+
self.dst_api_key = dst_api_key
|
172
|
+
|
173
|
+
if custom_api_kwargs is None:
|
174
|
+
custom_api_kwargs = {"timeout": 600}
|
160
175
|
|
176
|
+
self.dst_api = wandb.Api(
|
177
|
+
api_key=dst_api_key,
|
178
|
+
overrides={"base_url": dst_base_url},
|
179
|
+
**custom_api_kwargs,
|
180
|
+
)
|
181
|
+
self.mlflow_tracking_uri = mlflow_tracking_uri
|
161
182
|
mlflow.set_tracking_uri(self.mlflow_tracking_uri)
|
183
|
+
|
162
184
|
if mlflow_registry_uri:
|
163
185
|
mlflow.set_registry_uri(mlflow_registry_uri)
|
186
|
+
|
164
187
|
self.mlflow_client = mlflow.tracking.MlflowClient(mlflow_tracking_uri)
|
165
188
|
|
166
|
-
def
|
189
|
+
def __repr__(self):
|
190
|
+
return f"<MlflowImporter src={self.mlflow_tracking_uri}>"
|
191
|
+
|
192
|
+
def collect_runs(self, *, limit: Optional[int] = None) -> Iterable[MlflowRun]:
|
167
193
|
if mlflow_version < Version("1.28.0"):
|
168
194
|
experiments = self.mlflow_client.list_experiments()
|
169
195
|
else:
|
170
196
|
experiments = self.mlflow_client.search_experiments()
|
171
197
|
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
break
|
180
|
-
yield MlflowRun(run, self.mlflow_client)
|
198
|
+
def _runs():
|
199
|
+
for exp in experiments:
|
200
|
+
for run in self.mlflow_client.search_runs(exp.experiment_id):
|
201
|
+
yield MlflowRun(run, self.mlflow_client)
|
202
|
+
|
203
|
+
runs = itertools.islice(_runs(), limit)
|
204
|
+
yield from runs
|
181
205
|
|
182
|
-
def
|
206
|
+
def _import_run(
|
183
207
|
self,
|
184
|
-
run:
|
185
|
-
|
208
|
+
run: MlflowRun,
|
209
|
+
*,
|
210
|
+
artifacts: bool = True,
|
211
|
+
namespace: Optional[Namespace] = None,
|
212
|
+
config: Optional[internal.SendManagerConfig] = None,
|
186
213
|
) -> None:
|
214
|
+
if namespace is None:
|
215
|
+
namespace = Namespace(run.entity(), run.project())
|
216
|
+
|
217
|
+
if config is None:
|
218
|
+
config = internal.SendManagerConfig(
|
219
|
+
metadata=True,
|
220
|
+
files=True,
|
221
|
+
media=True,
|
222
|
+
code=True,
|
223
|
+
history=True,
|
224
|
+
summary=True,
|
225
|
+
terminal_output=True,
|
226
|
+
)
|
227
|
+
|
228
|
+
settings_override = {
|
229
|
+
"api_key": self.dst_api_key,
|
230
|
+
"base_url": self.dst_base_url,
|
231
|
+
"resume": "true",
|
232
|
+
"resumed": True,
|
233
|
+
}
|
234
|
+
|
187
235
|
mlflow.set_tracking_uri(self.mlflow_tracking_uri)
|
188
|
-
|
236
|
+
internal.send_run(
|
237
|
+
run,
|
238
|
+
overrides=namespace.send_manager_overrides,
|
239
|
+
settings_override=settings_override,
|
240
|
+
config=config,
|
241
|
+
)
|
242
|
+
|
243
|
+
# in mlflow, the artifacts come with the runs, so import them together
|
244
|
+
if artifacts:
|
245
|
+
arts = list(run.artifacts())
|
246
|
+
logger.debug(f"Importing history artifacts, {run=}")
|
247
|
+
internal.send_run(
|
248
|
+
run,
|
249
|
+
extra_arts=arts,
|
250
|
+
overrides=namespace.send_manager_overrides,
|
251
|
+
settings_override=settings_override,
|
252
|
+
config=internal.SendManagerConfig(log_artifacts=True),
|
253
|
+
)
|
189
254
|
|
190
255
|
def import_runs(
|
191
256
|
self,
|
192
|
-
runs: Iterable[
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
runs = list(self.collect_runs())
|
199
|
-
|
200
|
-
with ThreadPoolExecutor(**_pool_kwargs) as exc:
|
201
|
-
futures = {
|
202
|
-
exc.submit(self.import_run, run, overrides=_overrides): run
|
203
|
-
for run in runs
|
204
|
-
}
|
205
|
-
with tqdm(desc="Importing runs", total=len(futures), unit="run") as pbar:
|
206
|
-
for future in as_completed(futures):
|
207
|
-
run = futures[future]
|
208
|
-
try:
|
209
|
-
future.result()
|
210
|
-
except Exception as e:
|
211
|
-
wandb.termerror(f"Failed to import {run.display_name()}: {e}")
|
212
|
-
raise e
|
213
|
-
else:
|
214
|
-
pbar.set_description(
|
215
|
-
f"Imported Run: {run.run_group()} {run.display_name()}"
|
216
|
-
)
|
217
|
-
finally:
|
218
|
-
pbar.update(1)
|
219
|
-
|
220
|
-
def import_all_runs(
|
221
|
-
self,
|
222
|
-
limit: Optional[int] = None,
|
223
|
-
overrides: Optional[Dict[str, Any]] = None,
|
224
|
-
pool_kwargs: Optional[Dict[str, Any]] = None,
|
257
|
+
runs: Iterable[MlflowRun],
|
258
|
+
*,
|
259
|
+
artifacts: bool = True,
|
260
|
+
namespace: Optional[Namespace] = None,
|
261
|
+
parallel: bool = True,
|
262
|
+
max_workers: Optional[int] = None,
|
225
263
|
) -> None:
|
226
|
-
|
227
|
-
|
264
|
+
def _import_run_wrapped(run):
|
265
|
+
self._import_run(run, namespace=namespace, artifacts=artifacts)
|
228
266
|
|
229
|
-
|
230
|
-
raise NotImplementedError("MLFlow does not have a reports concept")
|
267
|
+
for_each(_import_run_wrapped, runs, parallel=parallel, max_workers=max_workers)
|
@@ -0,0 +1,108 @@
|
|
1
|
+
import filecmp
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
|
5
|
+
import requests
|
6
|
+
|
7
|
+
import wandb
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
logger.setLevel(logging.INFO)
|
11
|
+
|
12
|
+
|
13
|
+
def _compare_artifact_manifests(
|
14
|
+
src_art: wandb.Artifact, dst_art: wandb.Artifact
|
15
|
+
) -> list:
|
16
|
+
problems = []
|
17
|
+
if isinstance(dst_art, wandb.CommError):
|
18
|
+
return ["commError"]
|
19
|
+
|
20
|
+
if src_art.digest != dst_art.digest:
|
21
|
+
problems.append(f"digest mismatch {src_art.digest=}, {dst_art.digest=}")
|
22
|
+
|
23
|
+
for name, src_entry in src_art.manifest.entries.items():
|
24
|
+
dst_entry = dst_art.manifest.entries.get(name)
|
25
|
+
if dst_entry is None:
|
26
|
+
problems.append(f"missing manifest entry {name=}, {src_entry=}")
|
27
|
+
continue
|
28
|
+
|
29
|
+
for attr in ["path", "digest", "size"]:
|
30
|
+
if getattr(src_entry, attr) != getattr(dst_entry, attr):
|
31
|
+
problems.append(
|
32
|
+
f"manifest entry mismatch {attr=}, {getattr(src_entry, attr)=}, {getattr(dst_entry, attr)=}"
|
33
|
+
)
|
34
|
+
|
35
|
+
return problems
|
36
|
+
|
37
|
+
|
38
|
+
def _compare_artifact_dirs(src_dir, dst_dir) -> list:
|
39
|
+
def compare(src_dir, dst_dir):
|
40
|
+
comparison = filecmp.dircmp(src_dir, dst_dir)
|
41
|
+
differences = {
|
42
|
+
"left_only": comparison.left_only,
|
43
|
+
"right_only": comparison.right_only,
|
44
|
+
"diff_files": comparison.diff_files,
|
45
|
+
"subdir_differences": {},
|
46
|
+
}
|
47
|
+
|
48
|
+
# Recursively find differences in subdirectories
|
49
|
+
for subdir in comparison.subdirs:
|
50
|
+
subdir_src = os.path.join(src_dir, subdir)
|
51
|
+
subdir_dst = os.path.join(dst_dir, subdir)
|
52
|
+
subdir_differences = compare(subdir_src, subdir_dst)
|
53
|
+
# If there are differences, add them to the result
|
54
|
+
if subdir_differences and any(subdir_differences.values()):
|
55
|
+
differences["subdir_differences"][subdir] = subdir_differences
|
56
|
+
|
57
|
+
if all(not diff for diff in differences.values()):
|
58
|
+
return None
|
59
|
+
|
60
|
+
return differences
|
61
|
+
|
62
|
+
return compare(src_dir, dst_dir)
|
63
|
+
|
64
|
+
|
65
|
+
def _check_entries_are_downloadable(art):
|
66
|
+
entries = _collect_entries(art)
|
67
|
+
for entry in entries:
|
68
|
+
if not _check_entry_is_downloable(entry):
|
69
|
+
return False
|
70
|
+
return True
|
71
|
+
|
72
|
+
|
73
|
+
def _collect_entries(art):
|
74
|
+
has_next_page = True
|
75
|
+
cursor = None
|
76
|
+
entries = []
|
77
|
+
while has_next_page:
|
78
|
+
attrs = art._fetch_file_urls(cursor)
|
79
|
+
has_next_page = attrs["pageInfo"]["hasNextPage"]
|
80
|
+
cursor = attrs["pageInfo"]["endCursor"]
|
81
|
+
for edge in attrs["edges"]:
|
82
|
+
name = edge["node"]["name"]
|
83
|
+
entry = art.get_entry(name)
|
84
|
+
entry._download_url = edge["node"]["directUrl"]
|
85
|
+
entries.append(entry)
|
86
|
+
return entries
|
87
|
+
|
88
|
+
|
89
|
+
def _check_entry_is_downloable(entry):
|
90
|
+
url = entry._download_url
|
91
|
+
expected_size = entry.size
|
92
|
+
|
93
|
+
try:
|
94
|
+
resp = requests.head(url, allow_redirects=True)
|
95
|
+
except Exception as e:
|
96
|
+
logger.error(f"Problem validating {entry=}, {e=}")
|
97
|
+
return False
|
98
|
+
|
99
|
+
if resp.status_code != 200:
|
100
|
+
return False
|
101
|
+
|
102
|
+
actual_size = resp.headers.get("content-length", -1)
|
103
|
+
actual_size = int(actual_size)
|
104
|
+
|
105
|
+
if expected_size == actual_size:
|
106
|
+
return True
|
107
|
+
|
108
|
+
return False
|