wandb 0.16.3__py3-none-any.whl → 0.16.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +2 -2
- wandb/agents/pyagent.py +1 -1
- wandb/apis/importers/__init__.py +1 -4
- wandb/apis/importers/internals/internal.py +386 -0
- wandb/apis/importers/internals/protocols.py +125 -0
- wandb/apis/importers/internals/util.py +78 -0
- wandb/apis/importers/mlflow.py +125 -88
- wandb/apis/importers/validation.py +108 -0
- wandb/apis/importers/wandb.py +1604 -0
- wandb/apis/public/api.py +7 -10
- wandb/apis/public/artifacts.py +38 -0
- wandb/apis/public/files.py +11 -2
- wandb/apis/reports/v2/__init__.py +0 -19
- wandb/apis/reports/v2/expr_parsing.py +0 -1
- wandb/apis/reports/v2/interface.py +15 -18
- wandb/apis/reports/v2/internal.py +12 -45
- wandb/cli/cli.py +52 -55
- wandb/integration/gym/__init__.py +2 -1
- wandb/integration/keras/callbacks/model_checkpoint.py +1 -1
- wandb/integration/keras/keras.py +6 -4
- wandb/integration/kfp/kfp_patch.py +2 -2
- wandb/integration/openai/fine_tuning.py +1 -2
- wandb/integration/ultralytics/callback.py +0 -1
- wandb/proto/v3/wandb_internal_pb2.py +332 -312
- wandb/proto/v3/wandb_settings_pb2.py +13 -3
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +316 -312
- wandb/proto/v4/wandb_settings_pb2.py +5 -3
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/artifacts/artifact.py +75 -31
- wandb/sdk/artifacts/artifact_manifest.py +5 -2
- wandb/sdk/artifacts/artifact_manifest_entry.py +6 -1
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +8 -2
- wandb/sdk/artifacts/artifact_saver.py +19 -47
- wandb/sdk/artifacts/storage_handler.py +2 -1
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +22 -9
- wandb/sdk/artifacts/storage_policy.py +4 -1
- wandb/sdk/data_types/base_types/wb_value.py +1 -1
- wandb/sdk/data_types/image.py +2 -2
- wandb/sdk/interface/interface.py +49 -13
- wandb/sdk/interface/interface_shared.py +17 -11
- wandb/sdk/internal/file_stream.py +20 -1
- wandb/sdk/internal/handler.py +1 -4
- wandb/sdk/internal/internal_api.py +3 -1
- wandb/sdk/internal/job_builder.py +49 -19
- wandb/sdk/internal/profiler.py +1 -1
- wandb/sdk/internal/sender.py +96 -124
- wandb/sdk/internal/sender_config.py +197 -0
- wandb/sdk/internal/settings_static.py +9 -0
- wandb/sdk/internal/system/system_info.py +5 -3
- wandb/sdk/internal/update.py +1 -1
- wandb/sdk/launch/_launch.py +3 -3
- wandb/sdk/launch/_launch_add.py +28 -29
- wandb/sdk/launch/_project_spec.py +148 -136
- wandb/sdk/launch/agent/agent.py +3 -7
- wandb/sdk/launch/agent/config.py +0 -27
- wandb/sdk/launch/builder/build.py +54 -28
- wandb/sdk/launch/builder/docker_builder.py +4 -15
- wandb/sdk/launch/builder/kaniko_builder.py +72 -45
- wandb/sdk/launch/create_job.py +6 -40
- wandb/sdk/launch/loader.py +10 -0
- wandb/sdk/launch/registry/anon.py +29 -0
- wandb/sdk/launch/registry/local_registry.py +4 -1
- wandb/sdk/launch/runner/kubernetes_runner.py +20 -2
- wandb/sdk/launch/runner/local_container.py +15 -10
- wandb/sdk/launch/runner/sagemaker_runner.py +1 -1
- wandb/sdk/launch/sweeps/scheduler.py +11 -3
- wandb/sdk/launch/utils.py +14 -0
- wandb/sdk/lib/__init__.py +2 -5
- wandb/sdk/lib/_settings_toposort_generated.py +4 -1
- wandb/sdk/lib/apikey.py +0 -5
- wandb/sdk/lib/config_util.py +0 -31
- wandb/sdk/lib/filesystem.py +11 -1
- wandb/sdk/lib/run_moment.py +72 -0
- wandb/sdk/service/service.py +7 -2
- wandb/sdk/service/streams.py +1 -6
- wandb/sdk/verify/verify.py +2 -1
- wandb/sdk/wandb_init.py +12 -1
- wandb/sdk/wandb_login.py +43 -26
- wandb/sdk/wandb_run.py +164 -110
- wandb/sdk/wandb_settings.py +58 -16
- wandb/testing/relay.py +5 -6
- wandb/util.py +50 -7
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/METADATA +8 -1
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/RECORD +89 -82
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/WHEEL +1 -1
- wandb/apis/importers/base.py +0 -400
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/LICENSE +0 -0
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/entry_points.txt +0 -0
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/top_level.txt +0 -0
wandb/apis/importers/mlflow.py
CHANGED
@@ -1,33 +1,27 @@
|
|
1
|
+
import itertools
|
2
|
+
import logging
|
1
3
|
import re
|
2
4
|
from collections import defaultdict
|
3
|
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
4
5
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
5
|
-
from unittest.mock import patch
|
6
6
|
|
7
|
+
import mlflow
|
7
8
|
from packaging.version import Version # type: ignore
|
8
|
-
from tqdm.auto import tqdm
|
9
9
|
|
10
10
|
import wandb
|
11
11
|
from wandb import Artifact
|
12
|
-
from wandb.util import coalesce, get_module
|
13
12
|
|
14
|
-
from .
|
15
|
-
|
16
|
-
with patch("click.echo"):
|
17
|
-
from wandb.apis.reports import Report
|
18
|
-
|
19
|
-
mlflow = get_module(
|
20
|
-
"mlflow",
|
21
|
-
required="To use the MlflowImporter, please install mlflow: `pip install mlflow`",
|
22
|
-
)
|
13
|
+
from .internals import internal
|
14
|
+
from .internals.util import Namespace, for_each
|
23
15
|
|
24
16
|
mlflow_version = Version(mlflow.__version__)
|
25
17
|
|
18
|
+
logger = logging.getLogger("import_logger")
|
19
|
+
|
26
20
|
|
27
21
|
class MlflowRun:
|
28
22
|
def __init__(self, run, mlflow_client):
|
29
23
|
self.run = run
|
30
|
-
self.mlflow_client = mlflow_client
|
24
|
+
self.mlflow_client: mlflow.MlflowClient = mlflow_client
|
31
25
|
|
32
26
|
def run_id(self) -> str:
|
33
27
|
return self.run.info.run_id
|
@@ -39,7 +33,13 @@ class MlflowRun:
|
|
39
33
|
return "imported-from-mlflow"
|
40
34
|
|
41
35
|
def config(self) -> Dict[str, Any]:
|
42
|
-
|
36
|
+
conf = self.run.data.params
|
37
|
+
|
38
|
+
# Add tags here since mlflow supports very long tag names but we only support up to 64 chars
|
39
|
+
tags = {
|
40
|
+
k: v for k, v in self.run.data.tags.items() if not k.startswith("mlflow.")
|
41
|
+
}
|
42
|
+
return {**conf, "imported_mlflow_tags": tags}
|
43
43
|
|
44
44
|
def summary(self) -> Dict[str, float]:
|
45
45
|
return self.run.data.metrics
|
@@ -71,19 +71,22 @@ class MlflowRun:
|
|
71
71
|
return self.run.data.tags.get("mlflow.note.content")
|
72
72
|
|
73
73
|
def tags(self) -> Optional[List[str]]:
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
74
|
+
...
|
75
|
+
|
76
|
+
# W&B tags are different than mlflow tags.
|
77
|
+
# The full mlflow tags are added to config under key `imported_mlflow_tags` instead
|
78
78
|
|
79
79
|
def artifacts(self) -> Optional[Iterable[Artifact]]: # type: ignore
|
80
80
|
if mlflow_version < Version("2.0.0"):
|
81
81
|
dir_path = self.mlflow_client.download_artifacts(
|
82
|
-
run_id=self.run.info.run_id,
|
82
|
+
run_id=self.run.info.run_id,
|
83
|
+
path="",
|
83
84
|
)
|
84
85
|
else:
|
85
86
|
dir_path = mlflow.artifacts.download_artifacts(run_id=self.run.info.run_id)
|
86
87
|
|
88
|
+
# Since mlflow doesn't have extra metadata about the artifacts,
|
89
|
+
# we just lump them all together into a single wandb.Artifact
|
87
90
|
artifact_name = self._handle_incompatible_strings(self.display_name())
|
88
91
|
art = wandb.Artifact(artifact_name, "imported-artifacts")
|
89
92
|
art.add_dir(dir_path)
|
@@ -91,37 +94,37 @@ class MlflowRun:
|
|
91
94
|
return [art]
|
92
95
|
|
93
96
|
def used_artifacts(self) -> Optional[Iterable[Artifact]]: # type: ignore
|
94
|
-
...
|
97
|
+
... # pragma: no cover
|
95
98
|
|
96
99
|
def os_version(self) -> Optional[str]:
|
97
|
-
...
|
100
|
+
... # pragma: no cover
|
98
101
|
|
99
102
|
def python_version(self) -> Optional[str]:
|
100
|
-
...
|
103
|
+
... # pragma: no cover
|
101
104
|
|
102
105
|
def cuda_version(self) -> Optional[str]:
|
103
|
-
...
|
106
|
+
... # pragma: no cover
|
104
107
|
|
105
108
|
def program(self) -> Optional[str]:
|
106
|
-
...
|
109
|
+
... # pragma: no cover
|
107
110
|
|
108
111
|
def host(self) -> Optional[str]:
|
109
|
-
...
|
112
|
+
... # pragma: no cover
|
110
113
|
|
111
114
|
def username(self) -> Optional[str]:
|
112
|
-
...
|
115
|
+
... # pragma: no cover
|
113
116
|
|
114
117
|
def executable(self) -> Optional[str]:
|
115
|
-
...
|
118
|
+
... # pragma: no cover
|
116
119
|
|
117
120
|
def gpus_used(self) -> Optional[str]:
|
118
|
-
...
|
121
|
+
... # pragma: no cover
|
119
122
|
|
120
123
|
def cpus_used(self) -> Optional[int]: # can we get the model?
|
121
|
-
...
|
124
|
+
... # pragma: no cover
|
122
125
|
|
123
126
|
def memory_used(self) -> Optional[int]:
|
124
|
-
...
|
127
|
+
... # pragma: no cover
|
125
128
|
|
126
129
|
def runtime(self) -> Optional[int]:
|
127
130
|
end_time = (
|
@@ -135,16 +138,16 @@ class MlflowRun:
|
|
135
138
|
return self.run.info.start_time // 1000
|
136
139
|
|
137
140
|
def code_path(self) -> Optional[str]:
|
138
|
-
...
|
141
|
+
... # pragma: no cover
|
139
142
|
|
140
143
|
def cli_version(self) -> Optional[str]:
|
141
|
-
...
|
144
|
+
... # pragma: no cover
|
142
145
|
|
143
146
|
def files(self) -> Optional[Iterable[Tuple[str, str]]]:
|
144
|
-
...
|
147
|
+
... # pragma: no cover
|
145
148
|
|
146
149
|
def logs(self) -> Optional[Iterable[str]]:
|
147
|
-
...
|
150
|
+
... # pragma: no cover
|
148
151
|
|
149
152
|
@staticmethod
|
150
153
|
def _handle_incompatible_strings(s: str) -> str:
|
@@ -155,76 +158,110 @@ class MlflowRun:
|
|
155
158
|
|
156
159
|
|
157
160
|
class MlflowImporter:
|
158
|
-
def __init__(
|
159
|
-
self
|
161
|
+
def __init__(
|
162
|
+
self,
|
163
|
+
dst_base_url: str,
|
164
|
+
dst_api_key: str,
|
165
|
+
mlflow_tracking_uri: str,
|
166
|
+
mlflow_registry_uri: Optional[str] = None,
|
167
|
+
*,
|
168
|
+
custom_api_kwargs: Optional[Dict[str, Any]] = None,
|
169
|
+
) -> None:
|
170
|
+
self.dst_base_url = dst_base_url
|
171
|
+
self.dst_api_key = dst_api_key
|
172
|
+
|
173
|
+
if custom_api_kwargs is None:
|
174
|
+
custom_api_kwargs = {"timeout": 600}
|
160
175
|
|
176
|
+
self.dst_api = wandb.Api(
|
177
|
+
api_key=dst_api_key,
|
178
|
+
overrides={"base_url": dst_base_url},
|
179
|
+
**custom_api_kwargs,
|
180
|
+
)
|
181
|
+
self.mlflow_tracking_uri = mlflow_tracking_uri
|
161
182
|
mlflow.set_tracking_uri(self.mlflow_tracking_uri)
|
183
|
+
|
162
184
|
if mlflow_registry_uri:
|
163
185
|
mlflow.set_registry_uri(mlflow_registry_uri)
|
186
|
+
|
164
187
|
self.mlflow_client = mlflow.tracking.MlflowClient(mlflow_tracking_uri)
|
165
188
|
|
166
|
-
def
|
189
|
+
def __repr__(self):
|
190
|
+
return f"<MlflowImporter src={self.mlflow_tracking_uri}>"
|
191
|
+
|
192
|
+
def collect_runs(self, *, limit: Optional[int] = None) -> Iterable[MlflowRun]:
|
167
193
|
if mlflow_version < Version("1.28.0"):
|
168
194
|
experiments = self.mlflow_client.list_experiments()
|
169
195
|
else:
|
170
196
|
experiments = self.mlflow_client.search_experiments()
|
171
197
|
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
break
|
180
|
-
yield MlflowRun(run, self.mlflow_client)
|
198
|
+
def _runs():
|
199
|
+
for exp in experiments:
|
200
|
+
for run in self.mlflow_client.search_runs(exp.experiment_id):
|
201
|
+
yield MlflowRun(run, self.mlflow_client)
|
202
|
+
|
203
|
+
runs = itertools.islice(_runs(), limit)
|
204
|
+
yield from runs
|
181
205
|
|
182
|
-
def
|
206
|
+
def _import_run(
|
183
207
|
self,
|
184
|
-
run:
|
185
|
-
|
208
|
+
run: MlflowRun,
|
209
|
+
*,
|
210
|
+
artifacts: bool = True,
|
211
|
+
namespace: Optional[Namespace] = None,
|
212
|
+
config: Optional[internal.SendManagerConfig] = None,
|
186
213
|
) -> None:
|
214
|
+
if namespace is None:
|
215
|
+
namespace = Namespace(run.entity(), run.project())
|
216
|
+
|
217
|
+
if config is None:
|
218
|
+
config = internal.SendManagerConfig(
|
219
|
+
metadata=True,
|
220
|
+
files=True,
|
221
|
+
media=True,
|
222
|
+
code=True,
|
223
|
+
history=True,
|
224
|
+
summary=True,
|
225
|
+
terminal_output=True,
|
226
|
+
)
|
227
|
+
|
228
|
+
settings_override = {
|
229
|
+
"api_key": self.dst_api_key,
|
230
|
+
"base_url": self.dst_base_url,
|
231
|
+
"resume": "true",
|
232
|
+
"resumed": True,
|
233
|
+
}
|
234
|
+
|
187
235
|
mlflow.set_tracking_uri(self.mlflow_tracking_uri)
|
188
|
-
|
236
|
+
internal.send_run(
|
237
|
+
run,
|
238
|
+
overrides=namespace.send_manager_overrides,
|
239
|
+
settings_override=settings_override,
|
240
|
+
config=config,
|
241
|
+
)
|
242
|
+
|
243
|
+
# in mlflow, the artifacts come with the runs, so import them together
|
244
|
+
if artifacts:
|
245
|
+
arts = list(run.artifacts())
|
246
|
+
logger.debug(f"Importing history artifacts, {run=}")
|
247
|
+
internal.send_run(
|
248
|
+
run,
|
249
|
+
extra_arts=arts,
|
250
|
+
overrides=namespace.send_manager_overrides,
|
251
|
+
settings_override=settings_override,
|
252
|
+
config=internal.SendManagerConfig(log_artifacts=True),
|
253
|
+
)
|
189
254
|
|
190
255
|
def import_runs(
|
191
256
|
self,
|
192
|
-
runs: Iterable[
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
runs = list(self.collect_runs())
|
199
|
-
|
200
|
-
with ThreadPoolExecutor(**_pool_kwargs) as exc:
|
201
|
-
futures = {
|
202
|
-
exc.submit(self.import_run, run, overrides=_overrides): run
|
203
|
-
for run in runs
|
204
|
-
}
|
205
|
-
with tqdm(desc="Importing runs", total=len(futures), unit="run") as pbar:
|
206
|
-
for future in as_completed(futures):
|
207
|
-
run = futures[future]
|
208
|
-
try:
|
209
|
-
future.result()
|
210
|
-
except Exception as e:
|
211
|
-
wandb.termerror(f"Failed to import {run.display_name()}: {e}")
|
212
|
-
raise e
|
213
|
-
else:
|
214
|
-
pbar.set_description(
|
215
|
-
f"Imported Run: {run.run_group()} {run.display_name()}"
|
216
|
-
)
|
217
|
-
finally:
|
218
|
-
pbar.update(1)
|
219
|
-
|
220
|
-
def import_all_runs(
|
221
|
-
self,
|
222
|
-
limit: Optional[int] = None,
|
223
|
-
overrides: Optional[Dict[str, Any]] = None,
|
224
|
-
pool_kwargs: Optional[Dict[str, Any]] = None,
|
257
|
+
runs: Iterable[MlflowRun],
|
258
|
+
*,
|
259
|
+
artifacts: bool = True,
|
260
|
+
namespace: Optional[Namespace] = None,
|
261
|
+
parallel: bool = True,
|
262
|
+
max_workers: Optional[int] = None,
|
225
263
|
) -> None:
|
226
|
-
|
227
|
-
|
264
|
+
def _import_run_wrapped(run):
|
265
|
+
self._import_run(run, namespace=namespace, artifacts=artifacts)
|
228
266
|
|
229
|
-
|
230
|
-
raise NotImplementedError("MLFlow does not have a reports concept")
|
267
|
+
for_each(_import_run_wrapped, runs, parallel=parallel, max_workers=max_workers)
|
@@ -0,0 +1,108 @@
|
|
1
|
+
import filecmp
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
|
5
|
+
import requests
|
6
|
+
|
7
|
+
import wandb
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
logger.setLevel(logging.INFO)
|
11
|
+
|
12
|
+
|
13
|
+
def _compare_artifact_manifests(
|
14
|
+
src_art: wandb.Artifact, dst_art: wandb.Artifact
|
15
|
+
) -> list:
|
16
|
+
problems = []
|
17
|
+
if isinstance(dst_art, wandb.CommError):
|
18
|
+
return ["commError"]
|
19
|
+
|
20
|
+
if src_art.digest != dst_art.digest:
|
21
|
+
problems.append(f"digest mismatch {src_art.digest=}, {dst_art.digest=}")
|
22
|
+
|
23
|
+
for name, src_entry in src_art.manifest.entries.items():
|
24
|
+
dst_entry = dst_art.manifest.entries.get(name)
|
25
|
+
if dst_entry is None:
|
26
|
+
problems.append(f"missing manifest entry {name=}, {src_entry=}")
|
27
|
+
continue
|
28
|
+
|
29
|
+
for attr in ["path", "digest", "size"]:
|
30
|
+
if getattr(src_entry, attr) != getattr(dst_entry, attr):
|
31
|
+
problems.append(
|
32
|
+
f"manifest entry mismatch {attr=}, {getattr(src_entry, attr)=}, {getattr(dst_entry, attr)=}"
|
33
|
+
)
|
34
|
+
|
35
|
+
return problems
|
36
|
+
|
37
|
+
|
38
|
+
def _compare_artifact_dirs(src_dir, dst_dir) -> list:
|
39
|
+
def compare(src_dir, dst_dir):
|
40
|
+
comparison = filecmp.dircmp(src_dir, dst_dir)
|
41
|
+
differences = {
|
42
|
+
"left_only": comparison.left_only,
|
43
|
+
"right_only": comparison.right_only,
|
44
|
+
"diff_files": comparison.diff_files,
|
45
|
+
"subdir_differences": {},
|
46
|
+
}
|
47
|
+
|
48
|
+
# Recursively find differences in subdirectories
|
49
|
+
for subdir in comparison.subdirs:
|
50
|
+
subdir_src = os.path.join(src_dir, subdir)
|
51
|
+
subdir_dst = os.path.join(dst_dir, subdir)
|
52
|
+
subdir_differences = compare(subdir_src, subdir_dst)
|
53
|
+
# If there are differences, add them to the result
|
54
|
+
if subdir_differences and any(subdir_differences.values()):
|
55
|
+
differences["subdir_differences"][subdir] = subdir_differences
|
56
|
+
|
57
|
+
if all(not diff for diff in differences.values()):
|
58
|
+
return None
|
59
|
+
|
60
|
+
return differences
|
61
|
+
|
62
|
+
return compare(src_dir, dst_dir)
|
63
|
+
|
64
|
+
|
65
|
+
def _check_entries_are_downloadable(art):
|
66
|
+
entries = _collect_entries(art)
|
67
|
+
for entry in entries:
|
68
|
+
if not _check_entry_is_downloable(entry):
|
69
|
+
return False
|
70
|
+
return True
|
71
|
+
|
72
|
+
|
73
|
+
def _collect_entries(art):
|
74
|
+
has_next_page = True
|
75
|
+
cursor = None
|
76
|
+
entries = []
|
77
|
+
while has_next_page:
|
78
|
+
attrs = art._fetch_file_urls(cursor)
|
79
|
+
has_next_page = attrs["pageInfo"]["hasNextPage"]
|
80
|
+
cursor = attrs["pageInfo"]["endCursor"]
|
81
|
+
for edge in attrs["edges"]:
|
82
|
+
name = edge["node"]["name"]
|
83
|
+
entry = art.get_entry(name)
|
84
|
+
entry._download_url = edge["node"]["directUrl"]
|
85
|
+
entries.append(entry)
|
86
|
+
return entries
|
87
|
+
|
88
|
+
|
89
|
+
def _check_entry_is_downloable(entry):
|
90
|
+
url = entry._download_url
|
91
|
+
expected_size = entry.size
|
92
|
+
|
93
|
+
try:
|
94
|
+
resp = requests.head(url, allow_redirects=True)
|
95
|
+
except Exception as e:
|
96
|
+
logger.error(f"Problem validating {entry=}, {e=}")
|
97
|
+
return False
|
98
|
+
|
99
|
+
if resp.status_code != 200:
|
100
|
+
return False
|
101
|
+
|
102
|
+
actual_size = resp.headers.get("content-length", -1)
|
103
|
+
actual_size = int(actual_size)
|
104
|
+
|
105
|
+
if expected_size == actual_size:
|
106
|
+
return True
|
107
|
+
|
108
|
+
return False
|