wandb 0.16.3__py3-none-any.whl → 0.16.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/agents/pyagent.py +1 -1
  3. wandb/apis/importers/__init__.py +1 -4
  4. wandb/apis/importers/internals/internal.py +386 -0
  5. wandb/apis/importers/internals/protocols.py +125 -0
  6. wandb/apis/importers/internals/util.py +78 -0
  7. wandb/apis/importers/mlflow.py +125 -88
  8. wandb/apis/importers/validation.py +108 -0
  9. wandb/apis/importers/wandb.py +1604 -0
  10. wandb/apis/public/api.py +7 -10
  11. wandb/apis/public/artifacts.py +38 -0
  12. wandb/apis/public/files.py +11 -2
  13. wandb/apis/reports/v2/__init__.py +0 -19
  14. wandb/apis/reports/v2/expr_parsing.py +0 -1
  15. wandb/apis/reports/v2/interface.py +15 -18
  16. wandb/apis/reports/v2/internal.py +12 -45
  17. wandb/cli/cli.py +52 -55
  18. wandb/integration/gym/__init__.py +2 -1
  19. wandb/integration/keras/callbacks/model_checkpoint.py +1 -1
  20. wandb/integration/keras/keras.py +6 -4
  21. wandb/integration/kfp/kfp_patch.py +2 -2
  22. wandb/integration/openai/fine_tuning.py +1 -2
  23. wandb/integration/ultralytics/callback.py +0 -1
  24. wandb/proto/v3/wandb_internal_pb2.py +332 -312
  25. wandb/proto/v3/wandb_settings_pb2.py +13 -3
  26. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  27. wandb/proto/v4/wandb_internal_pb2.py +316 -312
  28. wandb/proto/v4/wandb_settings_pb2.py +5 -3
  29. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  30. wandb/sdk/artifacts/artifact.py +75 -31
  31. wandb/sdk/artifacts/artifact_manifest.py +5 -2
  32. wandb/sdk/artifacts/artifact_manifest_entry.py +6 -1
  33. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +8 -2
  34. wandb/sdk/artifacts/artifact_saver.py +19 -47
  35. wandb/sdk/artifacts/storage_handler.py +2 -1
  36. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +22 -9
  37. wandb/sdk/artifacts/storage_policy.py +4 -1
  38. wandb/sdk/data_types/base_types/wb_value.py +1 -1
  39. wandb/sdk/data_types/image.py +2 -2
  40. wandb/sdk/interface/interface.py +49 -13
  41. wandb/sdk/interface/interface_shared.py +17 -11
  42. wandb/sdk/internal/file_stream.py +20 -1
  43. wandb/sdk/internal/handler.py +1 -4
  44. wandb/sdk/internal/internal_api.py +3 -1
  45. wandb/sdk/internal/job_builder.py +49 -19
  46. wandb/sdk/internal/profiler.py +1 -1
  47. wandb/sdk/internal/sender.py +96 -124
  48. wandb/sdk/internal/sender_config.py +197 -0
  49. wandb/sdk/internal/settings_static.py +9 -0
  50. wandb/sdk/internal/system/system_info.py +5 -3
  51. wandb/sdk/internal/update.py +1 -1
  52. wandb/sdk/launch/_launch.py +3 -3
  53. wandb/sdk/launch/_launch_add.py +28 -29
  54. wandb/sdk/launch/_project_spec.py +148 -136
  55. wandb/sdk/launch/agent/agent.py +3 -7
  56. wandb/sdk/launch/agent/config.py +0 -27
  57. wandb/sdk/launch/builder/build.py +54 -28
  58. wandb/sdk/launch/builder/docker_builder.py +4 -15
  59. wandb/sdk/launch/builder/kaniko_builder.py +72 -45
  60. wandb/sdk/launch/create_job.py +6 -40
  61. wandb/sdk/launch/loader.py +10 -0
  62. wandb/sdk/launch/registry/anon.py +29 -0
  63. wandb/sdk/launch/registry/local_registry.py +4 -1
  64. wandb/sdk/launch/runner/kubernetes_runner.py +20 -2
  65. wandb/sdk/launch/runner/local_container.py +15 -10
  66. wandb/sdk/launch/runner/sagemaker_runner.py +1 -1
  67. wandb/sdk/launch/sweeps/scheduler.py +11 -3
  68. wandb/sdk/launch/utils.py +14 -0
  69. wandb/sdk/lib/__init__.py +2 -5
  70. wandb/sdk/lib/_settings_toposort_generated.py +4 -1
  71. wandb/sdk/lib/apikey.py +0 -5
  72. wandb/sdk/lib/config_util.py +0 -31
  73. wandb/sdk/lib/filesystem.py +11 -1
  74. wandb/sdk/lib/run_moment.py +72 -0
  75. wandb/sdk/service/service.py +7 -2
  76. wandb/sdk/service/streams.py +1 -6
  77. wandb/sdk/verify/verify.py +2 -1
  78. wandb/sdk/wandb_init.py +12 -1
  79. wandb/sdk/wandb_login.py +43 -26
  80. wandb/sdk/wandb_run.py +164 -110
  81. wandb/sdk/wandb_settings.py +58 -16
  82. wandb/testing/relay.py +5 -6
  83. wandb/util.py +50 -7
  84. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/METADATA +8 -1
  85. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/RECORD +89 -82
  86. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/WHEEL +1 -1
  87. wandb/apis/importers/base.py +0 -400
  88. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/LICENSE +0 -0
  89. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/entry_points.txt +0 -0
  90. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1604 @@
1
+ """Tooling for the W&B Importer."""
2
+ import itertools
3
+ import json
4
+ import logging
5
+ import numbers
6
+ import os
7
+ import re
8
+ import shutil
9
+ from dataclasses import dataclass, field
10
+ from datetime import datetime as dt
11
+ from pathlib import Path
12
+ from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple
13
+ from unittest.mock import patch
14
+
15
+ import filelock
16
+ import polars as pl
17
+ import requests
18
+ import urllib3
19
+ import yaml
20
+ from wandb_gql import gql
21
+
22
+ import wandb
23
+ import wandb.apis.reports as wr
24
+ from wandb.apis.public import ArtifactCollection, Run
25
+ from wandb.apis.public.files import File
26
+ from wandb.apis.reports import Report
27
+ from wandb.util import coalesce, remove_keys_with_none_values
28
+
29
+ from . import validation
30
+ from .internals import internal
31
+ from .internals.protocols import PathStr, Policy
32
+ from .internals.util import Namespace, for_each
33
+
34
+ Artifact = wandb.Artifact
35
+ Api = wandb.Api
36
+ Project = wandb.apis.public.Project
37
+
38
+ ARTIFACT_ERRORS_FNAME = "artifact_errors.jsonl"
39
+ ARTIFACT_SUCCESSES_FNAME = "artifact_successes.jsonl"
40
+ RUN_ERRORS_FNAME = "run_errors.jsonl"
41
+ RUN_SUCCESSES_FNAME = "run_successes.jsonl"
42
+
43
+ ART_SEQUENCE_DUMMY_PLACEHOLDER = "__ART_SEQUENCE_DUMMY_PLACEHOLDER__"
44
+ RUN_DUMMY_PLACEHOLDER = "__RUN_DUMMY_PLACEHOLDER__"
45
+ ART_DUMMY_PLACEHOLDER_PATH = "__importer_temp__"
46
+ ART_DUMMY_PLACEHOLDER_TYPE = "__temp__"
47
+
48
+ SRC_ART_PATH = "./artifacts/src"
49
+ DST_ART_PATH = "./artifacts/dst"
50
+
51
+
52
+ logger = logging.getLogger(__name__)
53
+ logger.setLevel(logging.INFO)
54
+
55
+ if os.getenv("WANDB_IMPORTER_ENABLE_RICH_LOGGING"):
56
+ from rich.logging import RichHandler
57
+
58
+ logger.addHandler(RichHandler(rich_tracebacks=True, tracebacks_show_locals=True))
59
+ else:
60
+ console_handler = logging.StreamHandler()
61
+ console_handler.setLevel(logging.INFO)
62
+
63
+ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
64
+ console_handler.setFormatter(formatter)
65
+
66
+ logger.addHandler(console_handler)
67
+
68
+
69
+ @dataclass
70
+ class ArtifactSequence:
71
+ artifacts: Iterable[wandb.Artifact]
72
+ entity: str
73
+ project: str
74
+ type_: str
75
+ name: str
76
+
77
+ def __iter__(self) -> Iterator:
78
+ return iter(self.artifacts)
79
+
80
+ def __repr__(self) -> str:
81
+ return f"ArtifactSequence({self.identifier})"
82
+
83
+ @property
84
+ def identifier(self) -> str:
85
+ return "/".join([self.entity, self.project, self.type_, self.name])
86
+
87
+ @classmethod
88
+ def from_collection(cls, collection: ArtifactCollection):
89
+ arts = collection.artifacts()
90
+ arts = sorted(arts, key=lambda a: int(a.version.lstrip("v")))
91
+ return ArtifactSequence(
92
+ arts,
93
+ collection.entity,
94
+ collection.project,
95
+ collection.type,
96
+ collection.name,
97
+ )
98
+
99
+
100
+ class WandbRun:
101
+ def __init__(
102
+ self,
103
+ run: Run,
104
+ *,
105
+ src_base_url: str,
106
+ src_api_key: str,
107
+ dst_base_url: str,
108
+ dst_api_key: str,
109
+ ) -> None:
110
+ self.run = run
111
+ self.api = wandb.Api(
112
+ api_key=src_api_key,
113
+ overrides={"base_url": src_base_url},
114
+ )
115
+ self.dst_api = wandb.Api(
116
+ api_key=dst_api_key,
117
+ overrides={"base_url": dst_base_url},
118
+ )
119
+
120
+ # For caching
121
+ self._files: Optional[Iterable[Tuple[str, str]]] = None
122
+ self._artifacts: Optional[Iterable[Artifact]] = None
123
+ self._used_artifacts: Optional[Iterable[Artifact]] = None
124
+ self._parquet_history_paths: Optional[Iterable[str]] = None
125
+
126
+ def __repr__(self) -> str:
127
+ s = os.path.join(self.entity(), self.project(), self.run_id())
128
+ return f"WandbRun({s})"
129
+
130
+ def run_id(self) -> str:
131
+ return self.run.id
132
+
133
+ def entity(self) -> str:
134
+ return self.run.entity
135
+
136
+ def project(self) -> str:
137
+ return self.run.project
138
+
139
+ def config(self) -> Dict[str, Any]:
140
+ return self.run.config
141
+
142
+ def summary(self) -> Dict[str, float]:
143
+ s = self.run.summary
144
+ return s
145
+
146
+ def metrics(self) -> Iterable[Dict[str, float]]:
147
+ if self._parquet_history_paths is None:
148
+ self._parquet_history_paths = list(self._get_parquet_history_paths())
149
+
150
+ if self._parquet_history_paths:
151
+ rows = self._get_rows_from_parquet_history_paths()
152
+ else:
153
+ logger.warn(
154
+ "No parquet files detected; using scan history (this may not be reliable)"
155
+ )
156
+ rows = self.run.scan_history()
157
+
158
+ for row in rows:
159
+ row = remove_keys_with_none_values(row)
160
+ yield row
161
+
162
+ def run_group(self) -> Optional[str]:
163
+ return self.run.group
164
+
165
+ def job_type(self) -> Optional[str]:
166
+ return self.run.job_type
167
+
168
+ def display_name(self) -> str:
169
+ return self.run.display_name
170
+
171
+ def notes(self) -> Optional[str]:
172
+ # Notes includes the previous notes and serves as a catch-all for things we missed or can't add back
173
+ previous_link = f"Imported from: {self.run.url}"
174
+ previous_author = f"Author: {self.run.user.username}"
175
+
176
+ header = [previous_link, previous_author]
177
+ previous_notes = self.run.notes or ""
178
+
179
+ return "\n".join(header) + "\n---\n" + previous_notes
180
+
181
+ def tags(self) -> Optional[List[str]]:
182
+ return self.run.tags
183
+
184
+ def artifacts(self) -> Optional[Iterable[Artifact]]:
185
+ if self._artifacts is None:
186
+ _artifacts = []
187
+ for art in self.run.logged_artifacts():
188
+ a = _clone_art(art)
189
+ _artifacts.append(a)
190
+ self._artifacts = _artifacts
191
+
192
+ yield from self._artifacts
193
+
194
+ def used_artifacts(self) -> Optional[Iterable[Artifact]]:
195
+ if self._used_artifacts is None:
196
+ _used_artifacts = []
197
+ for art in self.run.used_artifacts():
198
+ a = _clone_art(art)
199
+ _used_artifacts.append(a)
200
+ self._used_artifacts = _used_artifacts
201
+
202
+ yield from self._used_artifacts
203
+
204
+ def os_version(self) -> Optional[str]:
205
+ ... # pragma: no cover
206
+
207
+ def python_version(self) -> Optional[str]:
208
+ return self._metadata_file().get("python")
209
+
210
+ def cuda_version(self) -> Optional[str]:
211
+ ... # pragma: no cover
212
+
213
+ def program(self) -> Optional[str]:
214
+ ... # pragma: no cover
215
+
216
+ def host(self) -> Optional[str]:
217
+ return self._metadata_file().get("host")
218
+
219
+ def username(self) -> Optional[str]:
220
+ ... # pragma: no cover
221
+
222
+ def executable(self) -> Optional[str]:
223
+ ... # pragma: no cover
224
+
225
+ def gpus_used(self) -> Optional[str]:
226
+ ... # pragma: no cover
227
+
228
+ def cpus_used(self) -> Optional[int]: # can we get the model?
229
+ ... # pragma: no cover
230
+
231
+ def memory_used(self) -> Optional[int]:
232
+ ... # pragma: no cover
233
+
234
+ def runtime(self) -> Optional[int]:
235
+ wandb_runtime = self.run.summary.get("_wandb", {}).get("runtime")
236
+ base_runtime = self.run.summary.get("_runtime")
237
+
238
+ if (t := coalesce(wandb_runtime, base_runtime)) is None:
239
+ return t
240
+ return int(t)
241
+
242
+ def start_time(self) -> Optional[int]:
243
+ t = dt.fromisoformat(self.run.created_at).timestamp() * 1000
244
+ return int(t)
245
+
246
+ def code_path(self) -> Optional[str]:
247
+ path = self._metadata_file().get("codePath", "")
248
+ return f"code/{path}"
249
+
250
+ def cli_version(self) -> Optional[str]:
251
+ return self._config_file().get("_wandb", {}).get("value", {}).get("cli_version")
252
+
253
+ def files(self) -> Optional[Iterable[Tuple[PathStr, Policy]]]:
254
+ if self._files is None:
255
+ files_dir = f"{internal.ROOT_DIR}/{self.run_id()}/files"
256
+ _files = []
257
+ for f in self.run.files():
258
+ f: File
259
+ # These optimizations are intended to avoid rate limiting when importing many runs in parallel
260
+ # Don't carry over empty files
261
+ if f.size == 0:
262
+ continue
263
+ # Skip deadlist to avoid overloading S3
264
+ if "wandb_manifest.json.deadlist" in f.name:
265
+ continue
266
+
267
+ result = f.download(files_dir, exist_ok=True, api=self.api)
268
+ file_and_policy = (result.name, "end")
269
+ _files.append(file_and_policy)
270
+ self._files = _files
271
+
272
+ yield from self._files
273
+
274
+ def logs(self) -> Optional[Iterable[str]]:
275
+ if (fname := self._find_in_files("output.log")) is None:
276
+ return
277
+
278
+ with open(fname) as f:
279
+ yield from f.readlines()
280
+
281
+ def _metadata_file(self) -> Dict[str, Any]:
282
+ if (fname := self._find_in_files("wandb-metadata.json")) is None:
283
+ return {}
284
+
285
+ with open(fname) as f:
286
+ return json.loads(f.read())
287
+
288
+ def _config_file(self) -> Dict[str, Any]:
289
+ if (fname := self._find_in_files("config.yaml")) is None:
290
+ return {}
291
+
292
+ with open(fname) as f:
293
+ return yaml.safe_load(f) or {}
294
+
295
+ def _get_rows_from_parquet_history_paths(self) -> Iterable[Dict[str, Any]]:
296
+ # Unfortunately, it's not feasible to validate non-parquet history
297
+ if not (paths := self._get_parquet_history_paths()):
298
+ yield {}
299
+ return
300
+
301
+ # Collect and merge parquet history
302
+ dfs = [
303
+ pl.read_parquet(p) for path in paths for p in Path(path).glob("*.parquet")
304
+ ]
305
+ if "_step" in (df := _merge_dfs(dfs)):
306
+ df = df.with_columns(pl.col("_step").cast(pl.Int64))
307
+ yield from df.iter_rows(named=True)
308
+
309
+ def _get_parquet_history_paths(self) -> Iterable[str]:
310
+ if self._parquet_history_paths is None:
311
+ paths = []
312
+ # self.artifacts() returns a copy of the artifacts; use this to get raw
313
+ for art in self.run.logged_artifacts():
314
+ if art.type != "wandb-history":
315
+ continue
316
+ if (
317
+ path := _download_art(art, root=f"{SRC_ART_PATH}/{art.name}")
318
+ ) is None:
319
+ continue
320
+ paths.append(path)
321
+ self._parquet_history_paths = paths
322
+
323
+ yield from self._parquet_history_paths
324
+
325
+ def _find_in_files(self, name: str) -> Optional[str]:
326
+ if files := self.files():
327
+ for path, _ in files:
328
+ if name in path:
329
+ return path
330
+ return None
331
+
332
+
333
+ class WandbImporter:
334
+ """Transfers runs, reports, and artifact sequences between W&B instances."""
335
+
336
+ def __init__(
337
+ self,
338
+ src_base_url: str,
339
+ src_api_key: str,
340
+ dst_base_url: str,
341
+ dst_api_key: str,
342
+ *,
343
+ custom_api_kwargs: Optional[Dict[str, Any]] = None,
344
+ ) -> None:
345
+ self.src_base_url = src_base_url
346
+ self.src_api_key = src_api_key
347
+ self.dst_base_url = dst_base_url
348
+ self.dst_api_key = dst_api_key
349
+
350
+ if custom_api_kwargs is None:
351
+ custom_api_kwargs = {"timeout": 600}
352
+
353
+ self.src_api = wandb.Api(
354
+ api_key=src_api_key,
355
+ overrides={"base_url": src_base_url},
356
+ **custom_api_kwargs,
357
+ )
358
+ self.dst_api = wandb.Api(
359
+ api_key=dst_api_key,
360
+ overrides={"base_url": dst_base_url},
361
+ **custom_api_kwargs,
362
+ )
363
+
364
+ self.run_api_kwargs = {
365
+ "src_base_url": src_base_url,
366
+ "src_api_key": src_api_key,
367
+ "dst_base_url": dst_base_url,
368
+ "dst_api_key": dst_api_key,
369
+ }
370
+
371
+ def __repr__(self):
372
+ return f"<WandbImporter src={self.src_base_url}, dst={self.dst_base_url}>" # pragma: no cover
373
+
374
+ def _import_run(
375
+ self,
376
+ run: WandbRun,
377
+ *,
378
+ namespace: Optional[Namespace] = None,
379
+ config: Optional[internal.SendManagerConfig] = None,
380
+ ) -> None:
381
+ """Import one WandbRun.
382
+
383
+ Use `namespace` to specify alternate settings like where the run should be uploaded
384
+ """
385
+ if namespace is None:
386
+ namespace = Namespace(run.entity(), run.project())
387
+
388
+ if config is None:
389
+ config = internal.SendManagerConfig(
390
+ metadata=True,
391
+ files=True,
392
+ media=True,
393
+ code=True,
394
+ history=True,
395
+ summary=True,
396
+ terminal_output=True,
397
+ )
398
+
399
+ settings_override = {
400
+ "api_key": self.dst_api_key,
401
+ "base_url": self.dst_base_url,
402
+ "resume": "true",
403
+ "resumed": True,
404
+ }
405
+
406
+ # Send run with base config
407
+ logger.debug(f"Importing run, {run=}")
408
+ internal.send_run(
409
+ run,
410
+ overrides=namespace.send_manager_overrides,
411
+ settings_override=settings_override,
412
+ config=config,
413
+ )
414
+
415
+ if config.history:
416
+ # Send run again with history artifacts in case config history=True, artifacts=False
417
+ # The history artifact must come with the actual history data
418
+
419
+ logger.debug(f"Collecting history artifacts, {run=}")
420
+ history_arts = []
421
+ for art in run.run.logged_artifacts():
422
+ if art.type != "wandb-history":
423
+ continue
424
+ logger.debug(f"Collecting history artifact {art.name=}")
425
+ new_art = _clone_art(art)
426
+ history_arts.append(new_art)
427
+
428
+ logger.debug(f"Importing history artifacts, {run=}")
429
+ internal.send_run(
430
+ run,
431
+ extra_arts=history_arts,
432
+ overrides=namespace.send_manager_overrides,
433
+ settings_override=settings_override,
434
+ config=config,
435
+ )
436
+
437
+ def _delete_collection_in_dst(
438
+ self,
439
+ seq: ArtifactSequence,
440
+ namespace: Optional[Namespace] = None,
441
+ ):
442
+ """Deletes the equivalent artifact collection in destination.
443
+
444
+ Intended to clear the destination when an uploaded artifact does not pass validation.
445
+ """
446
+ entity = coalesce(namespace.entity, seq.entity)
447
+ project = coalesce(namespace.project, seq.project)
448
+ art_type = f"{entity}/{project}/{seq.type_}"
449
+ art_name = seq.name
450
+
451
+ logger.info(
452
+ f"Deleting collection {entity=}, {project=}, {art_type=}, {art_name=}"
453
+ )
454
+ try:
455
+ dst_collection = self.dst_api.artifact_collection(art_type, art_name)
456
+ except (wandb.CommError, ValueError):
457
+ logger.warn(f"Collection doesn't exist {art_type=}, {art_name=}")
458
+ return
459
+
460
+ try:
461
+ dst_collection.delete()
462
+ except (wandb.CommError, ValueError) as e:
463
+ logger.warn(f"Collection can't be deleted, {art_type=}, {art_name=}, {e=}")
464
+ return
465
+
466
+ def _import_artifact_sequence(
467
+ self,
468
+ seq: ArtifactSequence,
469
+ *,
470
+ namespace: Optional[Namespace] = None,
471
+ ) -> None:
472
+ """Import one artifact sequence.
473
+
474
+ Use `namespace` to specify alternate settings like where the artifact sequence should be uploaded
475
+ """
476
+ if not seq.artifacts:
477
+ # The artifact sequence has no versions. This usually means all artifacts versions were deleted intentionally,
478
+ # but it can also happen if the sequence represents run history and that run was deleted.
479
+ logger.warn(f"Artifact {seq=} has no artifacts, skipping.")
480
+ return
481
+
482
+ if namespace is None:
483
+ namespace = Namespace(seq.entity, seq.project)
484
+
485
+ settings_override = {
486
+ "api_key": self.dst_api_key,
487
+ "base_url": self.dst_base_url,
488
+ "resume": "true",
489
+ "resumed": True,
490
+ }
491
+
492
+ send_manager_config = internal.SendManagerConfig(log_artifacts=True)
493
+
494
+ # Delete any existing artifact sequence, otherwise versions will be out of order
495
+ # Unfortunately, you can't delete only part of the sequence because versions are "remembered" even after deletion
496
+ self._delete_collection_in_dst(seq, namespace)
497
+
498
+ # Get a placeholder run for dummy artifacts we'll upload later
499
+ art = seq.artifacts[0]
500
+ run_or_dummy: Optional[Run] = _get_run_or_dummy_from_art(art, self.src_api)
501
+
502
+ # Each `group_of_artifacts` is either:
503
+ # 1. A single "real" artifact in a list; or
504
+ # 2. A list of dummy artifacts that are uploaded together.
505
+ # This guarantees the real artifacts have the correct version numbers while allowing for parallel upload of dummies.
506
+ groups_of_artifacts = list(_make_groups_of_artifacts(seq))
507
+ for i, group in enumerate(groups_of_artifacts, 1):
508
+ art = group[0]
509
+ if art.description == ART_SEQUENCE_DUMMY_PLACEHOLDER:
510
+ run = WandbRun(run_or_dummy, **self.run_api_kwargs)
511
+ else:
512
+ try:
513
+ wandb_run = art.logged_by()
514
+ except ValueError:
515
+ # The run used to exist but has since been deleted
516
+ # wandb_run = None
517
+ pass
518
+
519
+ # Could be logged by None (rare) or ValueError
520
+ if wandb_run is None:
521
+ logger.warn(
522
+ f"Run for {art.name=} does not exist (deleted?), using {run_or_dummy=}"
523
+ )
524
+ wandb_run = run_or_dummy
525
+
526
+ new_art = _clone_art(art)
527
+ group = [new_art]
528
+ run = WandbRun(wandb_run, **self.run_api_kwargs)
529
+
530
+ logger.info(
531
+ f"Uploading partial artifact {seq=}, {i}/{len(groups_of_artifacts)}"
532
+ )
533
+ internal.send_run(
534
+ run,
535
+ extra_arts=group,
536
+ overrides=namespace.send_manager_overrides,
537
+ settings_override=settings_override,
538
+ config=send_manager_config,
539
+ )
540
+ logger.info(f"Finished uploading {seq=}")
541
+
542
+ # query it back and remove placeholders
543
+ self._remove_placeholders(seq)
544
+
545
+ def _remove_placeholders(self, seq: ArtifactSequence) -> None:
546
+ try:
547
+ retry_arts_func = internal.exp_retry(self._dst_api.artifacts)
548
+ dst_arts = list(retry_arts_func(seq.type_, seq.name))
549
+ except wandb.CommError:
550
+ logger.warn(f"{seq=} does not exist in dst. Has it already been deleted?")
551
+ return
552
+ except TypeError as e:
553
+ logger.error(f"Problem getting dst versions (try again later) {e=}")
554
+ return
555
+
556
+ for art in dst_arts:
557
+ if art.description != ART_SEQUENCE_DUMMY_PLACEHOLDER:
558
+ continue
559
+ if art.type in ("wandb-history", "job"):
560
+ continue
561
+
562
+ try:
563
+ art.delete(delete_aliases=True)
564
+ except wandb.CommError as e:
565
+ if "cannot delete system managed artifact" in str(e):
566
+ logger.warn("Cannot delete system managed artifact")
567
+ else:
568
+ raise e
569
+
570
+ def _get_dst_art(
571
+ self, src_art: Run, entity: Optional[str] = None, project: Optional[str] = None
572
+ ) -> Artifact:
573
+ entity = coalesce(entity, src_art.entity)
574
+ project = coalesce(project, src_art.project)
575
+ name = src_art.name
576
+
577
+ return self.dst_api.artifact(f"{entity}/{project}/{name}")
578
+
579
+ def _get_run_problems(
580
+ self, src_run: Run, dst_run: Run, force_retry: bool = False
581
+ ) -> List[dict]:
582
+ problems = []
583
+
584
+ if force_retry:
585
+ problems.append("__force_retry__")
586
+
587
+ if non_matching_metadata := self._compare_run_metadata(src_run, dst_run):
588
+ problems.append("metadata:" + str(non_matching_metadata))
589
+
590
+ if non_matching_summary := self._compare_run_summary(src_run, dst_run):
591
+ problems.append("summary:" + str(non_matching_summary))
592
+
593
+ # TODO: Compare files?
594
+
595
+ return problems
596
+
597
+ def _compare_run_metadata(self, src_run: Run, dst_run: Run) -> dict:
598
+ fname = "wandb-metadata.json"
599
+ # problems = {}
600
+
601
+ src_f = src_run.file(fname)
602
+ if src_f.size == 0:
603
+ # the src was corrupted so no comparisons here will ever work
604
+ return {}
605
+
606
+ dst_f = dst_run.file(fname)
607
+ try:
608
+ contents = wandb.util.download_file_into_memory(
609
+ dst_f.url, self.dst_api.api_key
610
+ )
611
+ except urllib3.exceptions.ReadTimeoutError:
612
+ return {"Error checking": "Timeout"}
613
+ except requests.HTTPError as e:
614
+ if e.response.status_code == 404:
615
+ return {"Bad upload": f"File not found: {fname}"}
616
+ return {"http problem": f"{fname}: ({e})"}
617
+
618
+ dst_meta = wandb.wandb_sdk.lib.json_util.loads(contents)
619
+
620
+ non_matching = {}
621
+ if src_run.metadata:
622
+ for k, src_v in src_run.metadata.items():
623
+ if k not in dst_meta:
624
+ non_matching[k] = {"src": src_v, "dst": "KEY NOT FOUND"}
625
+ continue
626
+ dst_v = dst_meta[k]
627
+ if src_v != dst_v:
628
+ non_matching[k] = {"src": src_v, "dst": dst_v}
629
+
630
+ return non_matching
631
+
632
+ def _compare_run_summary(self, src_run: Run, dst_run: Run) -> dict:
633
+ non_matching = {}
634
+ for k, src_v in src_run.summary.items():
635
+ # These won't match between systems and that's ok
636
+ if isinstance(src_v, str) and src_v.startswith("wandb-client-artifact://"):
637
+ continue
638
+ if k in ("_wandb", "_runtime"):
639
+ continue
640
+
641
+ src_v = _recursive_cast_to_dict(src_v)
642
+
643
+ dst_v = dst_run.summary.get(k)
644
+ dst_v = _recursive_cast_to_dict(dst_v)
645
+
646
+ if isinstance(src_v, dict) and isinstance(dst_v, dict):
647
+ for kk, sv in src_v.items():
648
+ # These won't match between systems and that's ok
649
+ if isinstance(sv, str) and sv.startswith(
650
+ "wandb-client-artifact://"
651
+ ):
652
+ continue
653
+ dv = dst_v.get(kk)
654
+ if not _almost_equal(sv, dv):
655
+ non_matching[f"{k}-{kk}"] = {"src": sv, "dst": dv}
656
+ else:
657
+ if not _almost_equal(src_v, dst_v):
658
+ non_matching[k] = {"src": src_v, "dst": dst_v}
659
+
660
+ return non_matching
661
+
662
+ def _collect_failed_artifact_sequences(self) -> Iterable[ArtifactSequence]:
663
+ if (df := _read_ndjson(ARTIFACT_ERRORS_FNAME)) is None:
664
+ logger.debug(f"{ARTIFACT_ERRORS_FNAME=} is empty, returning nothing")
665
+ return
666
+
667
+ unique_failed_sequences = df[
668
+ ["src_entity", "src_project", "name", "type"]
669
+ ].unique()
670
+
671
+ for row in unique_failed_sequences.iter_rows(named=True):
672
+ entity = row["src_entity"]
673
+ project = row["src_project"]
674
+ name = row["name"]
675
+ _type = row["type"]
676
+
677
+ art_name = f"{entity}/{project}/{name}"
678
+ arts = self.src_api.artifacts(_type, art_name)
679
+ arts = sorted(arts, key=lambda a: int(a.version.lstrip("v")))
680
+ arts = sorted(arts, key=lambda a: a.type)
681
+
682
+ yield ArtifactSequence(arts, entity, project, _type, name)
683
+
684
+ def _cleanup_dummy_runs(
685
+ self,
686
+ *,
687
+ namespaces: Optional[Iterable[Namespace]] = None,
688
+ api: Optional[Api] = None,
689
+ remapping: Optional[Dict[Namespace, Namespace]] = None,
690
+ ) -> None:
691
+ api = coalesce(api, self.dst_api)
692
+ namespaces = coalesce(namespaces, self._all_namespaces())
693
+
694
+ for ns in namespaces:
695
+ if remapping and ns in remapping:
696
+ ns = remapping[ns]
697
+
698
+ logger.debug(f"Cleaning up, {ns=}")
699
+ try:
700
+ runs = list(
701
+ api.runs(ns.path, filters={"displayName": RUN_DUMMY_PLACEHOLDER})
702
+ )
703
+ except ValueError as e:
704
+ if "Could not find project" in str(e):
705
+ logger.error("Could not find project, does it exist?")
706
+ continue
707
+
708
+ for run in runs:
709
+ logger.debug(f"Deleting dummy {run=}")
710
+ run.delete(delete_artifacts=False)
711
+
712
+ def _import_report(
713
+ self, report: Report, *, namespace: Optional[Namespace] = None
714
+ ) -> None:
715
+ """Import one wandb.Report.
716
+
717
+ Use `namespace` to specify alternate settings like where the report should be uploaded
718
+ """
719
+ if namespace is None:
720
+ namespace = Namespace(report.entity, report.project)
721
+
722
+ entity = coalesce(namespace.entity, report.entity)
723
+ project = coalesce(namespace.project, report.project)
724
+ name = report.name
725
+ title = report.title
726
+ description = report.description
727
+
728
+ api = self.dst_api
729
+
730
+ # We shouldn't need to upsert the project for every report
731
+ logger.debug(f"Upserting {entity=}/{project=}")
732
+ try:
733
+ api.create_project(project, entity)
734
+ except requests.exceptions.HTTPError as e:
735
+ if e.response.status_code != 409:
736
+ logger.warn(f"Issue upserting {entity=}/{project=}, {e=}")
737
+
738
+ logger.debug(f"Upserting report {entity=}, {project=}, {name=}, {title=}")
739
+ api.client.execute(
740
+ wr.report.UPSERT_VIEW,
741
+ variable_values={
742
+ "id": None, # Is there any benefit for this to be the same as default report?
743
+ "name": name,
744
+ "entityName": entity,
745
+ "projectName": project,
746
+ "description": description,
747
+ "displayName": title,
748
+ "type": "runs",
749
+ "spec": json.dumps(report.spec),
750
+ },
751
+ )
752
+
753
+ def _use_artifact_sequence(
754
+ self,
755
+ sequence: ArtifactSequence,
756
+ *,
757
+ namespace: Optional[Namespace] = None,
758
+ ):
759
+ if namespace is None:
760
+ namespace = Namespace(sequence.entity, sequence.project)
761
+
762
+ settings_override = {
763
+ "api_key": self.dst_api_key,
764
+ "base_url": self.dst_base_url,
765
+ "resume": "true",
766
+ "resumed": True,
767
+ }
768
+ logger.debug(f"Using artifact sequence with {settings_override=}, {namespace=}")
769
+
770
+ send_manager_config = internal.SendManagerConfig(use_artifacts=True)
771
+
772
+ for art in sequence:
773
+ if (used_by := art.used_by()) is None:
774
+ continue
775
+
776
+ for wandb_run in used_by:
777
+ run = WandbRun(wandb_run, **self.run_api_kwargs)
778
+
779
+ internal.send_run(
780
+ run,
781
+ overrides=namespace.send_manager_overrides,
782
+ settings_override=settings_override,
783
+ config=send_manager_config,
784
+ )
785
+
786
+ def import_runs(
787
+ self,
788
+ *,
789
+ namespaces: Optional[Iterable[Namespace]] = None,
790
+ remapping: Optional[Dict[Namespace, Namespace]] = None,
791
+ parallel: bool = True,
792
+ incremental: bool = True,
793
+ max_workers: Optional[int] = None,
794
+ limit: Optional[int] = None,
795
+ metadata: bool = True,
796
+ files: bool = True,
797
+ media: bool = True,
798
+ code: bool = True,
799
+ history: bool = True,
800
+ summary: bool = True,
801
+ terminal_output: bool = True,
802
+ ):
803
+ logger.info("START: Import runs")
804
+
805
+ logger.info("Setting up for import")
806
+ _create_files_if_not_exists()
807
+ _clear_fname(RUN_ERRORS_FNAME)
808
+
809
+ logger.info("Collecting runs")
810
+ runs = list(self._collect_runs(namespaces=namespaces, limit=limit))
811
+
812
+ logger.info(f"Validating runs, {len(runs)=}")
813
+ self._validate_runs(
814
+ runs,
815
+ skip_previously_validated=incremental,
816
+ remapping=remapping,
817
+ )
818
+
819
+ logger.info("Collecting failed runs")
820
+ runs = list(self._collect_failed_runs())
821
+
822
+ logger.info(f"Importing runs, {len(runs)=}")
823
+
824
+ def _import_run_wrapped(run):
825
+ namespace = Namespace(run.entity(), run.project())
826
+ if remapping is not None and namespace in remapping:
827
+ namespace = remapping[namespace]
828
+
829
+ config = internal.SendManagerConfig(
830
+ metadata=metadata,
831
+ files=files,
832
+ media=media,
833
+ code=code,
834
+ history=history,
835
+ summary=summary,
836
+ terminal_output=terminal_output,
837
+ )
838
+
839
+ logger.debug(f"Importing {run=}, {namespace=}, {config=}")
840
+ self._import_run(run, namespace=namespace, config=config)
841
+ logger.debug(f"Finished importing {run=}, {namespace=}, {config=}")
842
+
843
+ for_each(_import_run_wrapped, runs, max_workers=max_workers, parallel=parallel)
844
+ logger.info("END: Importing runs")
845
+
846
+ def import_reports(
847
+ self,
848
+ *,
849
+ namespaces: Optional[Iterable[Namespace]] = None,
850
+ limit: Optional[int] = None,
851
+ remapping: Optional[Dict[Namespace, Namespace]] = None,
852
+ ):
853
+ logger.info("START: Importing reports")
854
+
855
+ logger.info("Collecting reports")
856
+ reports = self._collect_reports(namespaces=namespaces, limit=limit)
857
+
858
+ logger.info("Importing reports")
859
+
860
+ def _import_report_wrapped(report):
861
+ namespace = Namespace(report.entity, report.project)
862
+ if remapping is not None and namespace in remapping:
863
+ namespace = remapping[namespace]
864
+
865
+ logger.debug(f"Importing {report=}, {namespace=}")
866
+ self._import_report(report, namespace=namespace)
867
+ logger.debug(f"Finished importing {report=}, {namespace=}")
868
+
869
+ for_each(_import_report_wrapped, reports)
870
+
871
+ logger.info("END: Importing reports")
872
+
873
+ def import_artifact_sequences(
874
+ self,
875
+ *,
876
+ namespaces: Optional[Iterable[Namespace]] = None,
877
+ incremental: bool = True,
878
+ max_workers: Optional[int] = None,
879
+ remapping: Optional[Dict[Namespace, Namespace]] = None,
880
+ ):
881
+ """Import all artifact sequences from `namespaces`.
882
+
883
+ Note: There is a known bug with the AWS backend where artifacts > 2048MB will fail to upload. This seems to be related to multipart uploads, but we don't have a fix yet.
884
+ """
885
+ logger.info("START: Importing artifact sequences")
886
+ _clear_fname(ARTIFACT_ERRORS_FNAME)
887
+
888
+ logger.info("Collecting artifact sequences")
889
+ seqs = list(self._collect_artifact_sequences(namespaces=namespaces))
890
+
891
+ logger.info("Validating artifact sequences")
892
+ self._validate_artifact_sequences(
893
+ seqs,
894
+ incremental=incremental,
895
+ remapping=remapping,
896
+ )
897
+
898
+ logger.info("Collecting failed artifact sequences")
899
+ seqs = list(self._collect_failed_artifact_sequences())
900
+
901
+ logger.info(f"Importing artifact sequences, {len(seqs)=}")
902
+
903
+ def _import_artifact_sequence_wrapped(seq):
904
+ namespace = Namespace(seq.entity, seq.project)
905
+ if remapping is not None and namespace in remapping:
906
+ namespace = remapping[namespace]
907
+
908
+ logger.debug(f"Importing artifact sequence {seq=}, {namespace=}")
909
+ self._import_artifact_sequence(seq, namespace=namespace)
910
+ logger.debug(f"Finished importing artifact sequence {seq=}, {namespace=}")
911
+
912
+ for_each(_import_artifact_sequence_wrapped, seqs, max_workers=max_workers)
913
+
914
+ # it's safer to just use artifact on all seqs to make sure we don't miss anything
915
+ # For seqs that have already been used, this is a no-op.
916
+ logger.debug(f"Using artifact sequences, {len(seqs)=}")
917
+
918
+ def _use_artifact_sequence_wrapped(seq):
919
+ namespace = Namespace(seq.entity, seq.project)
920
+ if remapping is not None and namespace in remapping:
921
+ namespace = remapping[namespace]
922
+
923
+ logger.debug(f"Using artifact sequence {seq=}, {namespace=}")
924
+ self._use_artifact_sequence(seq, namespace=namespace)
925
+ logger.debug(f"Finished using artifact sequence {seq=}, {namespace=}")
926
+
927
+ for_each(_use_artifact_sequence_wrapped, seqs, max_workers=max_workers)
928
+
929
+ # Artifacts whose parent runs have been deleted should have that run deleted in the
930
+ # destination as well
931
+
932
+ logger.info("Cleaning up dummy runs")
933
+ self._cleanup_dummy_runs(
934
+ namespaces=namespaces,
935
+ remapping=remapping,
936
+ )
937
+
938
+ logger.info("END: Importing artifact sequences")
939
+
940
+ def import_all(
941
+ self,
942
+ *,
943
+ runs: bool = True,
944
+ artifacts: bool = True,
945
+ reports: bool = True,
946
+ namespaces: Optional[Iterable[Namespace]] = None,
947
+ incremental: bool = True,
948
+ remapping: Optional[Dict[Namespace, Namespace]] = None,
949
+ ):
950
+ logger.info(f"START: Importing all, {runs=}, {artifacts=}, {reports=}")
951
+ if runs:
952
+ self.import_runs(
953
+ namespaces=namespaces,
954
+ incremental=incremental,
955
+ remapping=remapping,
956
+ )
957
+
958
+ if reports:
959
+ self.import_reports(
960
+ namespaces=namespaces,
961
+ remapping=remapping,
962
+ )
963
+
964
+ if artifacts:
965
+ self.import_artifact_sequences(
966
+ namespaces=namespaces,
967
+ incremental=incremental,
968
+ remapping=remapping,
969
+ )
970
+
971
+ logger.info("END: Importing all")
972
+
973
+ def _validate_run(
974
+ self,
975
+ src_run: Run,
976
+ *,
977
+ remapping: Optional[Dict[Namespace, Namespace]] = None,
978
+ ) -> None:
979
+ namespace = Namespace(src_run.entity, src_run.project)
980
+ if remapping is not None and namespace in remapping:
981
+ namespace = remapping[namespace]
982
+
983
+ dst_entity = namespace.entity
984
+ dst_project = namespace.project
985
+ run_id = src_run.id
986
+
987
+ try:
988
+ dst_run = self.dst_api.run(f"{dst_entity}/{dst_project}/{run_id}")
989
+ except wandb.CommError:
990
+ problems = [f"run does not exist in dst at {dst_entity=}/{dst_project=}"]
991
+ else:
992
+ problems = self._get_run_problems(src_run, dst_run)
993
+
994
+ d = {
995
+ "src_entity": src_run.entity,
996
+ "src_project": src_run.project,
997
+ "dst_entity": dst_entity,
998
+ "dst_project": dst_project,
999
+ "run_id": run_id,
1000
+ }
1001
+ if problems:
1002
+ d["problems"] = problems
1003
+ fname = RUN_ERRORS_FNAME
1004
+ else:
1005
+ fname = RUN_SUCCESSES_FNAME
1006
+
1007
+ with filelock.FileLock("runs.lock"):
1008
+ with open(fname, "a") as f:
1009
+ f.write(json.dumps(d) + "\n")
1010
+
1011
+ def _filter_previously_checked_runs(
1012
+ self,
1013
+ runs: Iterable[Run],
1014
+ *,
1015
+ remapping: Optional[Dict[Namespace, Namespace]] = None,
1016
+ ) -> Iterable[Run]:
1017
+ if (df := _read_ndjson(RUN_SUCCESSES_FNAME)) is None:
1018
+ logger.debug(f"{RUN_SUCCESSES_FNAME=} is empty, yielding all runs")
1019
+ yield from runs
1020
+ return
1021
+
1022
+ data = []
1023
+ for r in runs:
1024
+ namespace = Namespace(r.entity, r.project)
1025
+ if remapping is not None and namespace in remapping:
1026
+ namespace = remapping[namespace]
1027
+
1028
+ data.append(
1029
+ {
1030
+ "src_entity": r.entity,
1031
+ "src_project": r.project,
1032
+ "dst_entity": namespace.entity,
1033
+ "dst_project": namespace.project,
1034
+ "run_id": r.id,
1035
+ "data": r,
1036
+ }
1037
+ )
1038
+ df2 = pl.DataFrame(data)
1039
+ logger.debug(f"Starting with {len(runs)=} in namespaces")
1040
+
1041
+ results = df2.join(
1042
+ df,
1043
+ how="anti",
1044
+ on=["src_entity", "src_project", "dst_entity", "dst_project", "run_id"],
1045
+ )
1046
+ logger.debug(f"After filtering out already successful runs, {len(results)=}")
1047
+
1048
+ if not results.is_empty():
1049
+ results = results.filter(~results["run_id"].is_null())
1050
+ results = results.unique(
1051
+ ["src_entity", "src_project", "dst_entity", "dst_project", "run_id"]
1052
+ )
1053
+
1054
+ for r in results.iter_rows(named=True):
1055
+ yield r["data"]
1056
+
1057
+ def _validate_artifact(
1058
+ self,
1059
+ src_art: Artifact,
1060
+ dst_entity: str,
1061
+ dst_project: str,
1062
+ download_files_and_compare: bool = False,
1063
+ check_entries_are_downloadable: bool = True,
1064
+ ):
1065
+ problems = []
1066
+
1067
+ # These patterns of artifacts are special and should not be validated
1068
+ ignore_patterns = [
1069
+ r"^job-(.*?)\.py(:v\d+)?$",
1070
+ # r"^run-.*-history(?:\:v\d+)?$$",
1071
+ ]
1072
+ for pattern in ignore_patterns:
1073
+ if re.search(pattern, src_art.name):
1074
+ return (src_art, dst_entity, dst_project, problems)
1075
+
1076
+ try:
1077
+ dst_art = self._get_dst_art(src_art, dst_entity, dst_project)
1078
+ except Exception:
1079
+ problems.append("destination artifact not found")
1080
+ return (src_art, dst_entity, dst_project, problems)
1081
+
1082
+ try:
1083
+ logger.debug("Comparing artifact manifests")
1084
+ except Exception as e:
1085
+ problems.append(
1086
+ f"Problem getting problems! problem with {src_art.entity=}, {src_art.project=}, {src_art.name=} {e=}"
1087
+ )
1088
+ else:
1089
+ problems += validation._compare_artifact_manifests(src_art, dst_art)
1090
+
1091
+ if check_entries_are_downloadable:
1092
+ # validation._check_entries_are_downloadable(src_art)
1093
+ validation._check_entries_are_downloadable(dst_art)
1094
+
1095
+ if download_files_and_compare:
1096
+ logger.debug(f"Downloading {src_art=}")
1097
+ try:
1098
+ src_dir = _download_art(src_art, root=f"{SRC_ART_PATH}/{src_art.name}")
1099
+ except requests.HTTPError as e:
1100
+ problems.append(
1101
+ f"Invalid download link for src {src_art.entity=}, {src_art.project=}, {src_art.name=}, {e}"
1102
+ )
1103
+
1104
+ logger.debug(f"Downloading {dst_art=}")
1105
+ try:
1106
+ dst_dir = _download_art(dst_art, root=f"{DST_ART_PATH}/{dst_art.name}")
1107
+ except requests.HTTPError as e:
1108
+ problems.append(
1109
+ f"Invalid download link for dst {dst_art.entity=}, {dst_art.project=}, {dst_art.name=}, {e}"
1110
+ )
1111
+ else:
1112
+ logger.debug(f"Comparing artifact dirs {src_dir=}, {dst_dir=}")
1113
+ if problem := validation._compare_artifact_dirs(src_dir, dst_dir):
1114
+ problems.append(problem)
1115
+
1116
+ return (src_art, dst_entity, dst_project, problems)
1117
+
1118
+ def _validate_runs(
1119
+ self,
1120
+ runs: Iterable[WandbRun],
1121
+ *,
1122
+ skip_previously_validated: bool = True,
1123
+ remapping: Optional[Dict[Namespace, Namespace]] = None,
1124
+ ):
1125
+ base_runs = [r.run for r in runs]
1126
+ if skip_previously_validated:
1127
+ base_runs = list(
1128
+ self._filter_previously_checked_runs(
1129
+ base_runs,
1130
+ remapping=remapping,
1131
+ )
1132
+ )
1133
+
1134
+ def _validate_run(run):
1135
+ logger.debug(f"Validating {run=}")
1136
+ self._validate_run(run, remapping=remapping)
1137
+ logger.debug(f"Finished validating {run=}")
1138
+
1139
+ for_each(_validate_run, base_runs)
1140
+
1141
+ def _collect_failed_runs(self):
1142
+ if (df := _read_ndjson(RUN_ERRORS_FNAME)) is None:
1143
+ logger.debug(f"{RUN_ERRORS_FNAME=} is empty, returning nothing")
1144
+ return
1145
+
1146
+ unique_failed_runs = df[
1147
+ ["src_entity", "src_project", "dst_entity", "dst_project", "run_id"]
1148
+ ].unique()
1149
+
1150
+ for row in unique_failed_runs.iter_rows(named=True):
1151
+ src_entity = row["src_entity"]
1152
+ src_project = row["src_project"]
1153
+ # dst_entity = row["dst_entity"]
1154
+ # dst_project = row["dst_project"]
1155
+ run_id = row["run_id"]
1156
+
1157
+ run = self.src_api.run(f"{src_entity}/{src_project}/{run_id}")
1158
+ yield WandbRun(run, **self.run_api_kwargs)
1159
+
1160
+ def _filter_previously_checked_artifacts(self, seqs: Iterable[ArtifactSequence]):
1161
+ if (df := _read_ndjson(ARTIFACT_SUCCESSES_FNAME)) is None:
1162
+ logger.info(
1163
+ f"{ARTIFACT_SUCCESSES_FNAME=} is empty, yielding all artifact sequences"
1164
+ )
1165
+ for seq in seqs:
1166
+ yield from seq.artifacts
1167
+ return
1168
+
1169
+ for seq in seqs:
1170
+ for art in seq:
1171
+ try:
1172
+ logged_by = _get_run_or_dummy_from_art(art, self.src_api)
1173
+ except requests.HTTPError as e:
1174
+ logger.error(f"Failed to get run, skipping: {art=}, {e=}")
1175
+ continue
1176
+
1177
+ if art.type == "wandb-history" and isinstance(logged_by, _DummyRun):
1178
+ logger.debug(f"Skipping history artifact {art=}")
1179
+ # We can never upload valid history for a deleted run, so skip it
1180
+ continue
1181
+
1182
+ entity = art.entity
1183
+ project = art.project
1184
+ _type = art.type
1185
+ name, ver = _get_art_name_ver(art)
1186
+
1187
+ filtered_df = df.filter(
1188
+ (df["src_entity"] == entity)
1189
+ & (df["src_project"] == project)
1190
+ & (df["name"] == name)
1191
+ & (df["version"] == ver)
1192
+ & (df["type"] == _type)
1193
+ )
1194
+
1195
+ # not in file, so not verified yet, don't filter out
1196
+ if len(filtered_df) == 0:
1197
+ yield art
1198
+
1199
+ def _validate_artifact_sequences(
1200
+ self,
1201
+ seqs: Iterable[ArtifactSequence],
1202
+ *,
1203
+ incremental: bool = True,
1204
+ download_files_and_compare: bool = False,
1205
+ check_entries_are_downloadable: bool = True,
1206
+ remapping: Optional[Dict[Namespace, Namespace]] = None,
1207
+ ):
1208
+ if incremental:
1209
+ logger.info("Validating in incremental mode")
1210
+
1211
+ def filtered_sequences():
1212
+ for seq in seqs:
1213
+ if not seq.artifacts:
1214
+ continue
1215
+
1216
+ art = seq.artifacts[0]
1217
+ try:
1218
+ logged_by = _get_run_or_dummy_from_art(art, self.src_api)
1219
+ except requests.HTTPError as e:
1220
+ logger.error(
1221
+ f"Validate Artifact http error: {art.entity=}, {art.project=}, {art.name=}, {e=}"
1222
+ )
1223
+ continue
1224
+
1225
+ if art.type == "wandb-history" and isinstance(logged_by, _DummyRun):
1226
+ # We can never upload valid history for a deleted run, so skip it
1227
+ continue
1228
+
1229
+ yield seq
1230
+
1231
+ artifacts = self._filter_previously_checked_artifacts(filtered_sequences())
1232
+ else:
1233
+ logger.info("Validating in non-incremental mode")
1234
+ artifacts = [art for seq in seqs for art in seq.artifacts]
1235
+
1236
+ def _validate_artifact_wrapped(args):
1237
+ art, entity, project = args
1238
+ if (
1239
+ remapping is not None
1240
+ and (namespace := Namespace(entity, project)) in remapping
1241
+ ):
1242
+ remapped_ns = remapping[namespace]
1243
+ entity = remapped_ns.entity
1244
+ project = remapped_ns.project
1245
+
1246
+ logger.debug(f"Validating {art=}, {entity=}, {project=}")
1247
+ result = self._validate_artifact(
1248
+ art,
1249
+ entity,
1250
+ project,
1251
+ download_files_and_compare=download_files_and_compare,
1252
+ check_entries_are_downloadable=check_entries_are_downloadable,
1253
+ )
1254
+ logger.debug(f"Finished validating {art=}, {entity=}, {project=}")
1255
+ return result
1256
+
1257
+ args = ((art, art.entity, art.project) for art in artifacts)
1258
+ art_problems = for_each(_validate_artifact_wrapped, args)
1259
+ for art, dst_entity, dst_project, problems in art_problems:
1260
+ name, ver = _get_art_name_ver(art)
1261
+ d = {
1262
+ "src_entity": art.entity,
1263
+ "src_project": art.project,
1264
+ "dst_entity": dst_entity,
1265
+ "dst_project": dst_project,
1266
+ "name": name,
1267
+ "version": ver,
1268
+ "type": art.type,
1269
+ }
1270
+
1271
+ if problems:
1272
+ d["problems"] = problems
1273
+ fname = ARTIFACT_ERRORS_FNAME
1274
+ else:
1275
+ fname = ARTIFACT_SUCCESSES_FNAME
1276
+
1277
+ with open(fname, "a") as f:
1278
+ f.write(json.dumps(d) + "\n")
1279
+
1280
+ def _collect_runs(
1281
+ self,
1282
+ *,
1283
+ namespaces: Optional[Iterable[Namespace]] = None,
1284
+ limit: Optional[int] = None,
1285
+ skip_ids: Optional[List[str]] = None,
1286
+ start_date: Optional[str] = None,
1287
+ api: Optional[Api] = None,
1288
+ ) -> Iterable[WandbRun]:
1289
+ api = coalesce(api, self.src_api)
1290
+ namespaces = coalesce(namespaces, self._all_namespaces())
1291
+
1292
+ filters: Dict[str, Any] = {}
1293
+ if skip_ids is not None:
1294
+ filters["name"] = {"$nin": skip_ids}
1295
+ if start_date is not None:
1296
+ filters["createdAt"] = {"$gte": start_date}
1297
+
1298
+ def _runs():
1299
+ for ns in namespaces:
1300
+ logger.debug(f"Collecting runs from {ns=}")
1301
+ for run in api.runs(ns.path, filters=filters):
1302
+ yield WandbRun(run, **self.run_api_kwargs)
1303
+
1304
+ runs = itertools.islice(_runs(), limit)
1305
+ yield from runs
1306
+
1307
+ def _all_namespaces(
1308
+ self, *, entity: Optional[str] = None, api: Optional[Api] = None
1309
+ ):
1310
+ api = coalesce(api, self.src_api)
1311
+ entity = coalesce(entity, api.default_entity)
1312
+ projects = api.projects(entity)
1313
+ for p in projects:
1314
+ yield Namespace(p.entity, p.name)
1315
+
1316
+ def _collect_reports(
1317
+ self,
1318
+ *,
1319
+ namespaces: Optional[Iterable[Namespace]] = None,
1320
+ limit: Optional[int] = None,
1321
+ api: Optional[Api] = None,
1322
+ ):
1323
+ api = coalesce(api, self.src_api)
1324
+ namespaces = coalesce(namespaces, self._all_namespaces())
1325
+
1326
+ wandb.login(key=self.src_api_key, host=self.src_base_url)
1327
+
1328
+ def reports():
1329
+ for ns in namespaces:
1330
+ for r in api.reports(ns.path):
1331
+ yield wr.Report.from_url(r.url, api=api)
1332
+
1333
+ yield from itertools.islice(reports(), limit)
1334
+
1335
+ def _collect_artifact_sequences(
1336
+ self,
1337
+ *,
1338
+ namespaces: Optional[Iterable[Namespace]] = None,
1339
+ limit: Optional[int] = None,
1340
+ api: Optional[Api] = None,
1341
+ ):
1342
+ api = coalesce(api, self.src_api)
1343
+ namespaces = coalesce(namespaces, self._all_namespaces())
1344
+
1345
+ def artifact_sequences():
1346
+ for ns in namespaces:
1347
+ logger.debug(f"Collecting artifact sequences from {ns=}")
1348
+ types = []
1349
+ try:
1350
+ types = [t for t in api.artifact_types(ns.path)]
1351
+ except Exception as e:
1352
+ logger.error(f"Failed to get artifact types {e=}")
1353
+
1354
+ for t in types:
1355
+ collections = []
1356
+
1357
+ # Skip history because it's really for run history
1358
+ if t.name == "wandb-history":
1359
+ continue
1360
+
1361
+ try:
1362
+ collections = t.collections()
1363
+ except Exception as e:
1364
+ logger.error(f"Failed to get artifact collections {e=}")
1365
+
1366
+ for c in collections:
1367
+ if c.is_sequence():
1368
+ yield ArtifactSequence.from_collection(c)
1369
+
1370
+ seqs = itertools.islice(artifact_sequences(), limit)
1371
+ unique_sequences = {seq.identifier: seq for seq in seqs}
1372
+ yield from unique_sequences.values()
1373
+
1374
+
1375
+ def _get_art_name_ver(art: Artifact) -> Tuple[str, int]:
1376
+ name, ver = art.name.split(":v")
1377
+ return name, int(ver)
1378
+
1379
+
1380
+ def _make_dummy_art(name: str, _type: str, ver: int):
1381
+ art = Artifact(name, ART_DUMMY_PLACEHOLDER_TYPE)
1382
+ art._type = _type
1383
+ art._description = ART_SEQUENCE_DUMMY_PLACEHOLDER
1384
+
1385
+ p = Path(ART_DUMMY_PLACEHOLDER_PATH)
1386
+ p.mkdir(parents=True, exist_ok=True)
1387
+
1388
+ # dummy file with different name to prevent dedupe
1389
+ fname = p / str(ver)
1390
+ with open(fname, "w"):
1391
+ pass
1392
+ art.add_file(fname)
1393
+
1394
+ return art
1395
+
1396
+
1397
+ def _make_groups_of_artifacts(seq: ArtifactSequence, start: int = 0):
1398
+ prev_ver = start - 1
1399
+ for art in seq:
1400
+ name, ver = _get_art_name_ver(art)
1401
+
1402
+ # If there's a gap between versions, fill with dummy artifacts
1403
+ if ver - prev_ver > 1:
1404
+ yield [_make_dummy_art(name, art.type, v) for v in range(prev_ver + 1, ver)]
1405
+
1406
+ # Then yield the actual artifact
1407
+ # Must always be a list of one artifact to guarantee ordering
1408
+ yield [art]
1409
+ prev_ver = ver
1410
+
1411
+
1412
+ def _recursive_cast_to_dict(obj):
1413
+ if isinstance(obj, list):
1414
+ return [_recursive_cast_to_dict(item) for item in obj]
1415
+ elif isinstance(obj, dict) or hasattr(obj, "items"):
1416
+ new_dict = {}
1417
+ for key, value in obj.items():
1418
+ new_dict[key] = _recursive_cast_to_dict(value)
1419
+ return new_dict
1420
+ else:
1421
+ return obj
1422
+
1423
+
1424
+ def _almost_equal(x, y, eps=1e-6):
1425
+ if isinstance(x, dict) and isinstance(y, dict):
1426
+ if x.keys() != y.keys():
1427
+ return False
1428
+ return all(_almost_equal(x[k], y[k], eps) for k in x)
1429
+
1430
+ if isinstance(x, numbers.Number) and isinstance(y, numbers.Number):
1431
+ return abs(x - y) < eps
1432
+
1433
+ if type(x) is not type(y):
1434
+ return False
1435
+
1436
+ return x == y
1437
+
1438
+
1439
+ @dataclass
1440
+ class _DummyUser:
1441
+ username: str = ""
1442
+
1443
+
1444
+ @dataclass
1445
+ class _DummyRun:
1446
+ entity: str = ""
1447
+ project: str = ""
1448
+ run_id: str = RUN_DUMMY_PLACEHOLDER
1449
+ id: str = RUN_DUMMY_PLACEHOLDER
1450
+ display_name: str = RUN_DUMMY_PLACEHOLDER
1451
+ notes: str = ""
1452
+ url: str = ""
1453
+ group: str = ""
1454
+ created_at: str = "2000-01-01"
1455
+ user: _DummyUser = field(default_factory=_DummyUser)
1456
+ tags: list = field(default_factory=list)
1457
+ summary: dict = field(default_factory=dict)
1458
+ config: dict = field(default_factory=dict)
1459
+
1460
+ def files(self):
1461
+ return []
1462
+
1463
+
1464
+ def _read_ndjson(fname: str) -> Optional[pl.DataFrame]:
1465
+ try:
1466
+ df = pl.read_ndjson(fname)
1467
+ except FileNotFoundError:
1468
+ return None
1469
+ except RuntimeError as e:
1470
+ # No runs previously checked
1471
+ if "empty string is not a valid JSON value" in str(e):
1472
+ return None
1473
+ if "error parsing ndjson" in str(e):
1474
+ return None
1475
+ raise e
1476
+
1477
+ return df
1478
+
1479
+
1480
+ def _get_run_or_dummy_from_art(art: Artifact, api=None):
1481
+ run = None
1482
+
1483
+ try:
1484
+ run = art.logged_by()
1485
+ except ValueError as e:
1486
+ logger.warn(
1487
+ f"Can't log artifact because run does't exist, {art=}, {run=}, {e=}"
1488
+ )
1489
+
1490
+ if run is not None:
1491
+ return run
1492
+
1493
+ query = gql(
1494
+ """
1495
+ query ArtifactCreatedBy(
1496
+ $id: ID!
1497
+ ) {
1498
+ artifact(id: $id) {
1499
+ createdBy {
1500
+ ... on Run {
1501
+ name
1502
+ project {
1503
+ name
1504
+ entityName
1505
+ }
1506
+ }
1507
+ }
1508
+ }
1509
+ }
1510
+ """
1511
+ )
1512
+ response = api.client.execute(query, variable_values={"id": art.id})
1513
+ creator = response.get("artifact", {}).get("createdBy", {})
1514
+ run = _DummyRun(
1515
+ entity=art.entity,
1516
+ project=art.project,
1517
+ run_id=creator.get("name", RUN_DUMMY_PLACEHOLDER),
1518
+ id=creator.get("name", RUN_DUMMY_PLACEHOLDER),
1519
+ )
1520
+ return run
1521
+
1522
+
1523
+ def _clear_fname(fname: str) -> None:
1524
+ old_fname = f"{internal.ROOT_DIR}/{fname}"
1525
+ new_fname = f"{internal.ROOT_DIR}/prev_{fname}"
1526
+
1527
+ logger.debug(f"Moving {old_fname=} to {new_fname=}")
1528
+ try:
1529
+ shutil.copy2(old_fname, new_fname)
1530
+ except FileNotFoundError:
1531
+ # this is just to make a copy of the last iteration, so its ok if the src doesn't exist
1532
+ pass
1533
+
1534
+ with open(fname, "w"):
1535
+ pass
1536
+
1537
+
1538
+ def _download_art(art: Artifact, root: str) -> Optional[str]:
1539
+ try:
1540
+ with patch("click.echo"):
1541
+ return art.download(root=root, skip_cache=True)
1542
+ except Exception as e:
1543
+ logger.error(f"Error downloading artifact {art=}, {e=}")
1544
+
1545
+
1546
+ def _clone_art(art: Artifact, root: Optional[str] = None):
1547
+ if root is None:
1548
+ # Currently, we would only ever clone a src artifact to move it to dst.
1549
+ root = f"{SRC_ART_PATH}/{art.name}"
1550
+
1551
+ if (path := _download_art(art, root=root)) is None:
1552
+ raise ValueError(f"Problem downloading {art=}")
1553
+
1554
+ name, _ = art.name.split(":v")
1555
+
1556
+ # Hack: skip naming validation check for wandb-* types
1557
+ new_art = Artifact(name, ART_DUMMY_PLACEHOLDER_TYPE)
1558
+ new_art._type = art.type
1559
+ new_art._created_at = art.created_at
1560
+
1561
+ new_art._aliases = art.aliases
1562
+ new_art._description = art.description
1563
+
1564
+ with patch("click.echo"):
1565
+ new_art.add_dir(path)
1566
+
1567
+ return new_art
1568
+
1569
+
1570
+ def _create_files_if_not_exists() -> None:
1571
+ fnames = [
1572
+ ARTIFACT_ERRORS_FNAME,
1573
+ ARTIFACT_SUCCESSES_FNAME,
1574
+ RUN_ERRORS_FNAME,
1575
+ RUN_SUCCESSES_FNAME,
1576
+ ]
1577
+
1578
+ for fname in fnames:
1579
+ logger.debug(f"Creating {fname=} if not exists")
1580
+ with open(fname, "a"):
1581
+ pass
1582
+
1583
+
1584
+ def _merge_dfs(dfs: List[pl.DataFrame]) -> pl.DataFrame:
1585
+ # Ensure there are DataFrames in the list
1586
+ if len(dfs) == 0:
1587
+ return pl.DataFrame()
1588
+
1589
+ if len(dfs) == 1:
1590
+ return dfs[0]
1591
+
1592
+ merged_df = dfs[0]
1593
+ for df in dfs[1:]:
1594
+ merged_df = merged_df.join(df, how="outer", on=["_step"])
1595
+ col_pairs = [
1596
+ (c, f"{c}_right")
1597
+ for c in merged_df.columns
1598
+ if f"{c}_right" in merged_df.columns
1599
+ ]
1600
+ for col, right in col_pairs:
1601
+ new_col = merged_df[col].fill_null(merged_df[right])
1602
+ merged_df = merged_df.with_columns(new_col).drop(right)
1603
+
1604
+ return merged_df