wandb 0.16.3__py3-none-any.whl → 0.16.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (90) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/agents/pyagent.py +1 -1
  3. wandb/apis/importers/__init__.py +1 -4
  4. wandb/apis/importers/internals/internal.py +386 -0
  5. wandb/apis/importers/internals/protocols.py +125 -0
  6. wandb/apis/importers/internals/util.py +78 -0
  7. wandb/apis/importers/mlflow.py +125 -88
  8. wandb/apis/importers/validation.py +108 -0
  9. wandb/apis/importers/wandb.py +1604 -0
  10. wandb/apis/public/api.py +7 -10
  11. wandb/apis/public/artifacts.py +38 -0
  12. wandb/apis/public/files.py +11 -2
  13. wandb/apis/reports/v2/__init__.py +0 -19
  14. wandb/apis/reports/v2/expr_parsing.py +0 -1
  15. wandb/apis/reports/v2/interface.py +15 -18
  16. wandb/apis/reports/v2/internal.py +12 -45
  17. wandb/cli/cli.py +52 -55
  18. wandb/integration/gym/__init__.py +2 -1
  19. wandb/integration/keras/callbacks/model_checkpoint.py +1 -1
  20. wandb/integration/keras/keras.py +6 -4
  21. wandb/integration/kfp/kfp_patch.py +2 -2
  22. wandb/integration/openai/fine_tuning.py +1 -2
  23. wandb/integration/ultralytics/callback.py +0 -1
  24. wandb/proto/v3/wandb_internal_pb2.py +332 -312
  25. wandb/proto/v3/wandb_settings_pb2.py +13 -3
  26. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  27. wandb/proto/v4/wandb_internal_pb2.py +316 -312
  28. wandb/proto/v4/wandb_settings_pb2.py +5 -3
  29. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  30. wandb/sdk/artifacts/artifact.py +75 -31
  31. wandb/sdk/artifacts/artifact_manifest.py +5 -2
  32. wandb/sdk/artifacts/artifact_manifest_entry.py +6 -1
  33. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +8 -2
  34. wandb/sdk/artifacts/artifact_saver.py +19 -47
  35. wandb/sdk/artifacts/storage_handler.py +2 -1
  36. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +22 -9
  37. wandb/sdk/artifacts/storage_policy.py +4 -1
  38. wandb/sdk/data_types/base_types/wb_value.py +1 -1
  39. wandb/sdk/data_types/image.py +2 -2
  40. wandb/sdk/interface/interface.py +49 -13
  41. wandb/sdk/interface/interface_shared.py +17 -11
  42. wandb/sdk/internal/file_stream.py +20 -1
  43. wandb/sdk/internal/handler.py +1 -4
  44. wandb/sdk/internal/internal_api.py +3 -1
  45. wandb/sdk/internal/job_builder.py +49 -19
  46. wandb/sdk/internal/profiler.py +1 -1
  47. wandb/sdk/internal/sender.py +96 -124
  48. wandb/sdk/internal/sender_config.py +197 -0
  49. wandb/sdk/internal/settings_static.py +9 -0
  50. wandb/sdk/internal/system/system_info.py +5 -3
  51. wandb/sdk/internal/update.py +1 -1
  52. wandb/sdk/launch/_launch.py +3 -3
  53. wandb/sdk/launch/_launch_add.py +28 -29
  54. wandb/sdk/launch/_project_spec.py +148 -136
  55. wandb/sdk/launch/agent/agent.py +3 -7
  56. wandb/sdk/launch/agent/config.py +0 -27
  57. wandb/sdk/launch/builder/build.py +54 -28
  58. wandb/sdk/launch/builder/docker_builder.py +4 -15
  59. wandb/sdk/launch/builder/kaniko_builder.py +72 -45
  60. wandb/sdk/launch/create_job.py +6 -40
  61. wandb/sdk/launch/loader.py +10 -0
  62. wandb/sdk/launch/registry/anon.py +29 -0
  63. wandb/sdk/launch/registry/local_registry.py +4 -1
  64. wandb/sdk/launch/runner/kubernetes_runner.py +20 -2
  65. wandb/sdk/launch/runner/local_container.py +15 -10
  66. wandb/sdk/launch/runner/sagemaker_runner.py +1 -1
  67. wandb/sdk/launch/sweeps/scheduler.py +11 -3
  68. wandb/sdk/launch/utils.py +14 -0
  69. wandb/sdk/lib/__init__.py +2 -5
  70. wandb/sdk/lib/_settings_toposort_generated.py +4 -1
  71. wandb/sdk/lib/apikey.py +0 -5
  72. wandb/sdk/lib/config_util.py +0 -31
  73. wandb/sdk/lib/filesystem.py +11 -1
  74. wandb/sdk/lib/run_moment.py +72 -0
  75. wandb/sdk/service/service.py +7 -2
  76. wandb/sdk/service/streams.py +1 -6
  77. wandb/sdk/verify/verify.py +2 -1
  78. wandb/sdk/wandb_init.py +12 -1
  79. wandb/sdk/wandb_login.py +43 -26
  80. wandb/sdk/wandb_run.py +164 -110
  81. wandb/sdk/wandb_settings.py +58 -16
  82. wandb/testing/relay.py +5 -6
  83. wandb/util.py +50 -7
  84. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/METADATA +8 -1
  85. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/RECORD +89 -82
  86. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/WHEEL +1 -1
  87. wandb/apis/importers/base.py +0 -400
  88. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/LICENSE +0 -0
  89. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/entry_points.txt +0 -0
  90. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1604 @@
1
+ """Tooling for the W&B Importer."""
2
+ import itertools
3
+ import json
4
+ import logging
5
+ import numbers
6
+ import os
7
+ import re
8
+ import shutil
9
+ from dataclasses import dataclass, field
10
+ from datetime import datetime as dt
11
+ from pathlib import Path
12
+ from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple
13
+ from unittest.mock import patch
14
+
15
+ import filelock
16
+ import polars as pl
17
+ import requests
18
+ import urllib3
19
+ import yaml
20
+ from wandb_gql import gql
21
+
22
+ import wandb
23
+ import wandb.apis.reports as wr
24
+ from wandb.apis.public import ArtifactCollection, Run
25
+ from wandb.apis.public.files import File
26
+ from wandb.apis.reports import Report
27
+ from wandb.util import coalesce, remove_keys_with_none_values
28
+
29
+ from . import validation
30
+ from .internals import internal
31
+ from .internals.protocols import PathStr, Policy
32
+ from .internals.util import Namespace, for_each
33
+
34
+ Artifact = wandb.Artifact
35
+ Api = wandb.Api
36
+ Project = wandb.apis.public.Project
37
+
38
+ ARTIFACT_ERRORS_FNAME = "artifact_errors.jsonl"
39
+ ARTIFACT_SUCCESSES_FNAME = "artifact_successes.jsonl"
40
+ RUN_ERRORS_FNAME = "run_errors.jsonl"
41
+ RUN_SUCCESSES_FNAME = "run_successes.jsonl"
42
+
43
+ ART_SEQUENCE_DUMMY_PLACEHOLDER = "__ART_SEQUENCE_DUMMY_PLACEHOLDER__"
44
+ RUN_DUMMY_PLACEHOLDER = "__RUN_DUMMY_PLACEHOLDER__"
45
+ ART_DUMMY_PLACEHOLDER_PATH = "__importer_temp__"
46
+ ART_DUMMY_PLACEHOLDER_TYPE = "__temp__"
47
+
48
+ SRC_ART_PATH = "./artifacts/src"
49
+ DST_ART_PATH = "./artifacts/dst"
50
+
51
+
52
+ logger = logging.getLogger(__name__)
53
+ logger.setLevel(logging.INFO)
54
+
55
+ if os.getenv("WANDB_IMPORTER_ENABLE_RICH_LOGGING"):
56
+ from rich.logging import RichHandler
57
+
58
+ logger.addHandler(RichHandler(rich_tracebacks=True, tracebacks_show_locals=True))
59
+ else:
60
+ console_handler = logging.StreamHandler()
61
+ console_handler.setLevel(logging.INFO)
62
+
63
+ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
64
+ console_handler.setFormatter(formatter)
65
+
66
+ logger.addHandler(console_handler)
67
+
68
+
69
+ @dataclass
70
+ class ArtifactSequence:
71
+ artifacts: Iterable[wandb.Artifact]
72
+ entity: str
73
+ project: str
74
+ type_: str
75
+ name: str
76
+
77
+ def __iter__(self) -> Iterator:
78
+ return iter(self.artifacts)
79
+
80
+ def __repr__(self) -> str:
81
+ return f"ArtifactSequence({self.identifier})"
82
+
83
+ @property
84
+ def identifier(self) -> str:
85
+ return "/".join([self.entity, self.project, self.type_, self.name])
86
+
87
+ @classmethod
88
+ def from_collection(cls, collection: ArtifactCollection):
89
+ arts = collection.artifacts()
90
+ arts = sorted(arts, key=lambda a: int(a.version.lstrip("v")))
91
+ return ArtifactSequence(
92
+ arts,
93
+ collection.entity,
94
+ collection.project,
95
+ collection.type,
96
+ collection.name,
97
+ )
98
+
99
+
100
+ class WandbRun:
101
+ def __init__(
102
+ self,
103
+ run: Run,
104
+ *,
105
+ src_base_url: str,
106
+ src_api_key: str,
107
+ dst_base_url: str,
108
+ dst_api_key: str,
109
+ ) -> None:
110
+ self.run = run
111
+ self.api = wandb.Api(
112
+ api_key=src_api_key,
113
+ overrides={"base_url": src_base_url},
114
+ )
115
+ self.dst_api = wandb.Api(
116
+ api_key=dst_api_key,
117
+ overrides={"base_url": dst_base_url},
118
+ )
119
+
120
+ # For caching
121
+ self._files: Optional[Iterable[Tuple[str, str]]] = None
122
+ self._artifacts: Optional[Iterable[Artifact]] = None
123
+ self._used_artifacts: Optional[Iterable[Artifact]] = None
124
+ self._parquet_history_paths: Optional[Iterable[str]] = None
125
+
126
+ def __repr__(self) -> str:
127
+ s = os.path.join(self.entity(), self.project(), self.run_id())
128
+ return f"WandbRun({s})"
129
+
130
+ def run_id(self) -> str:
131
+ return self.run.id
132
+
133
+ def entity(self) -> str:
134
+ return self.run.entity
135
+
136
+ def project(self) -> str:
137
+ return self.run.project
138
+
139
+ def config(self) -> Dict[str, Any]:
140
+ return self.run.config
141
+
142
+ def summary(self) -> Dict[str, float]:
143
+ s = self.run.summary
144
+ return s
145
+
146
+ def metrics(self) -> Iterable[Dict[str, float]]:
147
+ if self._parquet_history_paths is None:
148
+ self._parquet_history_paths = list(self._get_parquet_history_paths())
149
+
150
+ if self._parquet_history_paths:
151
+ rows = self._get_rows_from_parquet_history_paths()
152
+ else:
153
+ logger.warn(
154
+ "No parquet files detected; using scan history (this may not be reliable)"
155
+ )
156
+ rows = self.run.scan_history()
157
+
158
+ for row in rows:
159
+ row = remove_keys_with_none_values(row)
160
+ yield row
161
+
162
+ def run_group(self) -> Optional[str]:
163
+ return self.run.group
164
+
165
+ def job_type(self) -> Optional[str]:
166
+ return self.run.job_type
167
+
168
+ def display_name(self) -> str:
169
+ return self.run.display_name
170
+
171
+ def notes(self) -> Optional[str]:
172
+ # Notes includes the previous notes and serves as a catch-all for things we missed or can't add back
173
+ previous_link = f"Imported from: {self.run.url}"
174
+ previous_author = f"Author: {self.run.user.username}"
175
+
176
+ header = [previous_link, previous_author]
177
+ previous_notes = self.run.notes or ""
178
+
179
+ return "\n".join(header) + "\n---\n" + previous_notes
180
+
181
+ def tags(self) -> Optional[List[str]]:
182
+ return self.run.tags
183
+
184
+ def artifacts(self) -> Optional[Iterable[Artifact]]:
185
+ if self._artifacts is None:
186
+ _artifacts = []
187
+ for art in self.run.logged_artifacts():
188
+ a = _clone_art(art)
189
+ _artifacts.append(a)
190
+ self._artifacts = _artifacts
191
+
192
+ yield from self._artifacts
193
+
194
+ def used_artifacts(self) -> Optional[Iterable[Artifact]]:
195
+ if self._used_artifacts is None:
196
+ _used_artifacts = []
197
+ for art in self.run.used_artifacts():
198
+ a = _clone_art(art)
199
+ _used_artifacts.append(a)
200
+ self._used_artifacts = _used_artifacts
201
+
202
+ yield from self._used_artifacts
203
+
204
+ def os_version(self) -> Optional[str]:
205
+ ... # pragma: no cover
206
+
207
+ def python_version(self) -> Optional[str]:
208
+ return self._metadata_file().get("python")
209
+
210
+ def cuda_version(self) -> Optional[str]:
211
+ ... # pragma: no cover
212
+
213
+ def program(self) -> Optional[str]:
214
+ ... # pragma: no cover
215
+
216
+ def host(self) -> Optional[str]:
217
+ return self._metadata_file().get("host")
218
+
219
+ def username(self) -> Optional[str]:
220
+ ... # pragma: no cover
221
+
222
+ def executable(self) -> Optional[str]:
223
+ ... # pragma: no cover
224
+
225
+ def gpus_used(self) -> Optional[str]:
226
+ ... # pragma: no cover
227
+
228
+ def cpus_used(self) -> Optional[int]: # can we get the model?
229
+ ... # pragma: no cover
230
+
231
+ def memory_used(self) -> Optional[int]:
232
+ ... # pragma: no cover
233
+
234
+ def runtime(self) -> Optional[int]:
235
+ wandb_runtime = self.run.summary.get("_wandb", {}).get("runtime")
236
+ base_runtime = self.run.summary.get("_runtime")
237
+
238
+ if (t := coalesce(wandb_runtime, base_runtime)) is None:
239
+ return t
240
+ return int(t)
241
+
242
+ def start_time(self) -> Optional[int]:
243
+ t = dt.fromisoformat(self.run.created_at).timestamp() * 1000
244
+ return int(t)
245
+
246
+ def code_path(self) -> Optional[str]:
247
+ path = self._metadata_file().get("codePath", "")
248
+ return f"code/{path}"
249
+
250
+ def cli_version(self) -> Optional[str]:
251
+ return self._config_file().get("_wandb", {}).get("value", {}).get("cli_version")
252
+
253
+ def files(self) -> Optional[Iterable[Tuple[PathStr, Policy]]]:
254
+ if self._files is None:
255
+ files_dir = f"{internal.ROOT_DIR}/{self.run_id()}/files"
256
+ _files = []
257
+ for f in self.run.files():
258
+ f: File
259
+ # These optimizations are intended to avoid rate limiting when importing many runs in parallel
260
+ # Don't carry over empty files
261
+ if f.size == 0:
262
+ continue
263
+ # Skip deadlist to avoid overloading S3
264
+ if "wandb_manifest.json.deadlist" in f.name:
265
+ continue
266
+
267
+ result = f.download(files_dir, exist_ok=True, api=self.api)
268
+ file_and_policy = (result.name, "end")
269
+ _files.append(file_and_policy)
270
+ self._files = _files
271
+
272
+ yield from self._files
273
+
274
+ def logs(self) -> Optional[Iterable[str]]:
275
+ if (fname := self._find_in_files("output.log")) is None:
276
+ return
277
+
278
+ with open(fname) as f:
279
+ yield from f.readlines()
280
+
281
+ def _metadata_file(self) -> Dict[str, Any]:
282
+ if (fname := self._find_in_files("wandb-metadata.json")) is None:
283
+ return {}
284
+
285
+ with open(fname) as f:
286
+ return json.loads(f.read())
287
+
288
+ def _config_file(self) -> Dict[str, Any]:
289
+ if (fname := self._find_in_files("config.yaml")) is None:
290
+ return {}
291
+
292
+ with open(fname) as f:
293
+ return yaml.safe_load(f) or {}
294
+
295
+ def _get_rows_from_parquet_history_paths(self) -> Iterable[Dict[str, Any]]:
296
+ # Unfortunately, it's not feasible to validate non-parquet history
297
+ if not (paths := self._get_parquet_history_paths()):
298
+ yield {}
299
+ return
300
+
301
+ # Collect and merge parquet history
302
+ dfs = [
303
+ pl.read_parquet(p) for path in paths for p in Path(path).glob("*.parquet")
304
+ ]
305
+ if "_step" in (df := _merge_dfs(dfs)):
306
+ df = df.with_columns(pl.col("_step").cast(pl.Int64))
307
+ yield from df.iter_rows(named=True)
308
+
309
+ def _get_parquet_history_paths(self) -> Iterable[str]:
310
+ if self._parquet_history_paths is None:
311
+ paths = []
312
+ # self.artifacts() returns a copy of the artifacts; use this to get raw
313
+ for art in self.run.logged_artifacts():
314
+ if art.type != "wandb-history":
315
+ continue
316
+ if (
317
+ path := _download_art(art, root=f"{SRC_ART_PATH}/{art.name}")
318
+ ) is None:
319
+ continue
320
+ paths.append(path)
321
+ self._parquet_history_paths = paths
322
+
323
+ yield from self._parquet_history_paths
324
+
325
+ def _find_in_files(self, name: str) -> Optional[str]:
326
+ if files := self.files():
327
+ for path, _ in files:
328
+ if name in path:
329
+ return path
330
+ return None
331
+
332
+
333
+ class WandbImporter:
334
+ """Transfers runs, reports, and artifact sequences between W&B instances."""
335
+
336
+ def __init__(
337
+ self,
338
+ src_base_url: str,
339
+ src_api_key: str,
340
+ dst_base_url: str,
341
+ dst_api_key: str,
342
+ *,
343
+ custom_api_kwargs: Optional[Dict[str, Any]] = None,
344
+ ) -> None:
345
+ self.src_base_url = src_base_url
346
+ self.src_api_key = src_api_key
347
+ self.dst_base_url = dst_base_url
348
+ self.dst_api_key = dst_api_key
349
+
350
+ if custom_api_kwargs is None:
351
+ custom_api_kwargs = {"timeout": 600}
352
+
353
+ self.src_api = wandb.Api(
354
+ api_key=src_api_key,
355
+ overrides={"base_url": src_base_url},
356
+ **custom_api_kwargs,
357
+ )
358
+ self.dst_api = wandb.Api(
359
+ api_key=dst_api_key,
360
+ overrides={"base_url": dst_base_url},
361
+ **custom_api_kwargs,
362
+ )
363
+
364
+ self.run_api_kwargs = {
365
+ "src_base_url": src_base_url,
366
+ "src_api_key": src_api_key,
367
+ "dst_base_url": dst_base_url,
368
+ "dst_api_key": dst_api_key,
369
+ }
370
+
371
+ def __repr__(self):
372
+ return f"<WandbImporter src={self.src_base_url}, dst={self.dst_base_url}>" # pragma: no cover
373
+
374
+ def _import_run(
375
+ self,
376
+ run: WandbRun,
377
+ *,
378
+ namespace: Optional[Namespace] = None,
379
+ config: Optional[internal.SendManagerConfig] = None,
380
+ ) -> None:
381
+ """Import one WandbRun.
382
+
383
+ Use `namespace` to specify alternate settings like where the run should be uploaded
384
+ """
385
+ if namespace is None:
386
+ namespace = Namespace(run.entity(), run.project())
387
+
388
+ if config is None:
389
+ config = internal.SendManagerConfig(
390
+ metadata=True,
391
+ files=True,
392
+ media=True,
393
+ code=True,
394
+ history=True,
395
+ summary=True,
396
+ terminal_output=True,
397
+ )
398
+
399
+ settings_override = {
400
+ "api_key": self.dst_api_key,
401
+ "base_url": self.dst_base_url,
402
+ "resume": "true",
403
+ "resumed": True,
404
+ }
405
+
406
+ # Send run with base config
407
+ logger.debug(f"Importing run, {run=}")
408
+ internal.send_run(
409
+ run,
410
+ overrides=namespace.send_manager_overrides,
411
+ settings_override=settings_override,
412
+ config=config,
413
+ )
414
+
415
+ if config.history:
416
+ # Send run again with history artifacts in case config history=True, artifacts=False
417
+ # The history artifact must come with the actual history data
418
+
419
+ logger.debug(f"Collecting history artifacts, {run=}")
420
+ history_arts = []
421
+ for art in run.run.logged_artifacts():
422
+ if art.type != "wandb-history":
423
+ continue
424
+ logger.debug(f"Collecting history artifact {art.name=}")
425
+ new_art = _clone_art(art)
426
+ history_arts.append(new_art)
427
+
428
+ logger.debug(f"Importing history artifacts, {run=}")
429
+ internal.send_run(
430
+ run,
431
+ extra_arts=history_arts,
432
+ overrides=namespace.send_manager_overrides,
433
+ settings_override=settings_override,
434
+ config=config,
435
+ )
436
+
437
+ def _delete_collection_in_dst(
438
+ self,
439
+ seq: ArtifactSequence,
440
+ namespace: Optional[Namespace] = None,
441
+ ):
442
+ """Deletes the equivalent artifact collection in destination.
443
+
444
+ Intended to clear the destination when an uploaded artifact does not pass validation.
445
+ """
446
+ entity = coalesce(namespace.entity, seq.entity)
447
+ project = coalesce(namespace.project, seq.project)
448
+ art_type = f"{entity}/{project}/{seq.type_}"
449
+ art_name = seq.name
450
+
451
+ logger.info(
452
+ f"Deleting collection {entity=}, {project=}, {art_type=}, {art_name=}"
453
+ )
454
+ try:
455
+ dst_collection = self.dst_api.artifact_collection(art_type, art_name)
456
+ except (wandb.CommError, ValueError):
457
+ logger.warn(f"Collection doesn't exist {art_type=}, {art_name=}")
458
+ return
459
+
460
+ try:
461
+ dst_collection.delete()
462
+ except (wandb.CommError, ValueError) as e:
463
+ logger.warn(f"Collection can't be deleted, {art_type=}, {art_name=}, {e=}")
464
+ return
465
+
466
+ def _import_artifact_sequence(
467
+ self,
468
+ seq: ArtifactSequence,
469
+ *,
470
+ namespace: Optional[Namespace] = None,
471
+ ) -> None:
472
+ """Import one artifact sequence.
473
+
474
+ Use `namespace` to specify alternate settings like where the artifact sequence should be uploaded
475
+ """
476
+ if not seq.artifacts:
477
+ # The artifact sequence has no versions. This usually means all artifacts versions were deleted intentionally,
478
+ # but it can also happen if the sequence represents run history and that run was deleted.
479
+ logger.warn(f"Artifact {seq=} has no artifacts, skipping.")
480
+ return
481
+
482
+ if namespace is None:
483
+ namespace = Namespace(seq.entity, seq.project)
484
+
485
+ settings_override = {
486
+ "api_key": self.dst_api_key,
487
+ "base_url": self.dst_base_url,
488
+ "resume": "true",
489
+ "resumed": True,
490
+ }
491
+
492
+ send_manager_config = internal.SendManagerConfig(log_artifacts=True)
493
+
494
+ # Delete any existing artifact sequence, otherwise versions will be out of order
495
+ # Unfortunately, you can't delete only part of the sequence because versions are "remembered" even after deletion
496
+ self._delete_collection_in_dst(seq, namespace)
497
+
498
+ # Get a placeholder run for dummy artifacts we'll upload later
499
+ art = seq.artifacts[0]
500
+ run_or_dummy: Optional[Run] = _get_run_or_dummy_from_art(art, self.src_api)
501
+
502
+ # Each `group_of_artifacts` is either:
503
+ # 1. A single "real" artifact in a list; or
504
+ # 2. A list of dummy artifacts that are uploaded together.
505
+ # This guarantees the real artifacts have the correct version numbers while allowing for parallel upload of dummies.
506
+ groups_of_artifacts = list(_make_groups_of_artifacts(seq))
507
+ for i, group in enumerate(groups_of_artifacts, 1):
508
+ art = group[0]
509
+ if art.description == ART_SEQUENCE_DUMMY_PLACEHOLDER:
510
+ run = WandbRun(run_or_dummy, **self.run_api_kwargs)
511
+ else:
512
+ try:
513
+ wandb_run = art.logged_by()
514
+ except ValueError:
515
+ # The run used to exist but has since been deleted
516
+ # wandb_run = None
517
+ pass
518
+
519
+ # Could be logged by None (rare) or ValueError
520
+ if wandb_run is None:
521
+ logger.warn(
522
+ f"Run for {art.name=} does not exist (deleted?), using {run_or_dummy=}"
523
+ )
524
+ wandb_run = run_or_dummy
525
+
526
+ new_art = _clone_art(art)
527
+ group = [new_art]
528
+ run = WandbRun(wandb_run, **self.run_api_kwargs)
529
+
530
+ logger.info(
531
+ f"Uploading partial artifact {seq=}, {i}/{len(groups_of_artifacts)}"
532
+ )
533
+ internal.send_run(
534
+ run,
535
+ extra_arts=group,
536
+ overrides=namespace.send_manager_overrides,
537
+ settings_override=settings_override,
538
+ config=send_manager_config,
539
+ )
540
+ logger.info(f"Finished uploading {seq=}")
541
+
542
+ # query it back and remove placeholders
543
+ self._remove_placeholders(seq)
544
+
545
+ def _remove_placeholders(self, seq: ArtifactSequence) -> None:
546
+ try:
547
+ retry_arts_func = internal.exp_retry(self._dst_api.artifacts)
548
+ dst_arts = list(retry_arts_func(seq.type_, seq.name))
549
+ except wandb.CommError:
550
+ logger.warn(f"{seq=} does not exist in dst. Has it already been deleted?")
551
+ return
552
+ except TypeError as e:
553
+ logger.error(f"Problem getting dst versions (try again later) {e=}")
554
+ return
555
+
556
+ for art in dst_arts:
557
+ if art.description != ART_SEQUENCE_DUMMY_PLACEHOLDER:
558
+ continue
559
+ if art.type in ("wandb-history", "job"):
560
+ continue
561
+
562
+ try:
563
+ art.delete(delete_aliases=True)
564
+ except wandb.CommError as e:
565
+ if "cannot delete system managed artifact" in str(e):
566
+ logger.warn("Cannot delete system managed artifact")
567
+ else:
568
+ raise e
569
+
570
+ def _get_dst_art(
571
+ self, src_art: Run, entity: Optional[str] = None, project: Optional[str] = None
572
+ ) -> Artifact:
573
+ entity = coalesce(entity, src_art.entity)
574
+ project = coalesce(project, src_art.project)
575
+ name = src_art.name
576
+
577
+ return self.dst_api.artifact(f"{entity}/{project}/{name}")
578
+
579
+ def _get_run_problems(
580
+ self, src_run: Run, dst_run: Run, force_retry: bool = False
581
+ ) -> List[dict]:
582
+ problems = []
583
+
584
+ if force_retry:
585
+ problems.append("__force_retry__")
586
+
587
+ if non_matching_metadata := self._compare_run_metadata(src_run, dst_run):
588
+ problems.append("metadata:" + str(non_matching_metadata))
589
+
590
+ if non_matching_summary := self._compare_run_summary(src_run, dst_run):
591
+ problems.append("summary:" + str(non_matching_summary))
592
+
593
+ # TODO: Compare files?
594
+
595
+ return problems
596
+
597
+ def _compare_run_metadata(self, src_run: Run, dst_run: Run) -> dict:
598
+ fname = "wandb-metadata.json"
599
+ # problems = {}
600
+
601
+ src_f = src_run.file(fname)
602
+ if src_f.size == 0:
603
+ # the src was corrupted so no comparisons here will ever work
604
+ return {}
605
+
606
+ dst_f = dst_run.file(fname)
607
+ try:
608
+ contents = wandb.util.download_file_into_memory(
609
+ dst_f.url, self.dst_api.api_key
610
+ )
611
+ except urllib3.exceptions.ReadTimeoutError:
612
+ return {"Error checking": "Timeout"}
613
+ except requests.HTTPError as e:
614
+ if e.response.status_code == 404:
615
+ return {"Bad upload": f"File not found: {fname}"}
616
+ return {"http problem": f"{fname}: ({e})"}
617
+
618
+ dst_meta = wandb.wandb_sdk.lib.json_util.loads(contents)
619
+
620
+ non_matching = {}
621
+ if src_run.metadata:
622
+ for k, src_v in src_run.metadata.items():
623
+ if k not in dst_meta:
624
+ non_matching[k] = {"src": src_v, "dst": "KEY NOT FOUND"}
625
+ continue
626
+ dst_v = dst_meta[k]
627
+ if src_v != dst_v:
628
+ non_matching[k] = {"src": src_v, "dst": dst_v}
629
+
630
+ return non_matching
631
+
632
+ def _compare_run_summary(self, src_run: Run, dst_run: Run) -> dict:
633
+ non_matching = {}
634
+ for k, src_v in src_run.summary.items():
635
+ # These won't match between systems and that's ok
636
+ if isinstance(src_v, str) and src_v.startswith("wandb-client-artifact://"):
637
+ continue
638
+ if k in ("_wandb", "_runtime"):
639
+ continue
640
+
641
+ src_v = _recursive_cast_to_dict(src_v)
642
+
643
+ dst_v = dst_run.summary.get(k)
644
+ dst_v = _recursive_cast_to_dict(dst_v)
645
+
646
+ if isinstance(src_v, dict) and isinstance(dst_v, dict):
647
+ for kk, sv in src_v.items():
648
+ # These won't match between systems and that's ok
649
+ if isinstance(sv, str) and sv.startswith(
650
+ "wandb-client-artifact://"
651
+ ):
652
+ continue
653
+ dv = dst_v.get(kk)
654
+ if not _almost_equal(sv, dv):
655
+ non_matching[f"{k}-{kk}"] = {"src": sv, "dst": dv}
656
+ else:
657
+ if not _almost_equal(src_v, dst_v):
658
+ non_matching[k] = {"src": src_v, "dst": dst_v}
659
+
660
+ return non_matching
661
+
662
+ def _collect_failed_artifact_sequences(self) -> Iterable[ArtifactSequence]:
663
+ if (df := _read_ndjson(ARTIFACT_ERRORS_FNAME)) is None:
664
+ logger.debug(f"{ARTIFACT_ERRORS_FNAME=} is empty, returning nothing")
665
+ return
666
+
667
+ unique_failed_sequences = df[
668
+ ["src_entity", "src_project", "name", "type"]
669
+ ].unique()
670
+
671
+ for row in unique_failed_sequences.iter_rows(named=True):
672
+ entity = row["src_entity"]
673
+ project = row["src_project"]
674
+ name = row["name"]
675
+ _type = row["type"]
676
+
677
+ art_name = f"{entity}/{project}/{name}"
678
+ arts = self.src_api.artifacts(_type, art_name)
679
+ arts = sorted(arts, key=lambda a: int(a.version.lstrip("v")))
680
+ arts = sorted(arts, key=lambda a: a.type)
681
+
682
+ yield ArtifactSequence(arts, entity, project, _type, name)
683
+
684
+ def _cleanup_dummy_runs(
685
+ self,
686
+ *,
687
+ namespaces: Optional[Iterable[Namespace]] = None,
688
+ api: Optional[Api] = None,
689
+ remapping: Optional[Dict[Namespace, Namespace]] = None,
690
+ ) -> None:
691
+ api = coalesce(api, self.dst_api)
692
+ namespaces = coalesce(namespaces, self._all_namespaces())
693
+
694
+ for ns in namespaces:
695
+ if remapping and ns in remapping:
696
+ ns = remapping[ns]
697
+
698
+ logger.debug(f"Cleaning up, {ns=}")
699
+ try:
700
+ runs = list(
701
+ api.runs(ns.path, filters={"displayName": RUN_DUMMY_PLACEHOLDER})
702
+ )
703
+ except ValueError as e:
704
+ if "Could not find project" in str(e):
705
+ logger.error("Could not find project, does it exist?")
706
+ continue
707
+
708
+ for run in runs:
709
+ logger.debug(f"Deleting dummy {run=}")
710
+ run.delete(delete_artifacts=False)
711
+
712
+ def _import_report(
713
+ self, report: Report, *, namespace: Optional[Namespace] = None
714
+ ) -> None:
715
+ """Import one wandb.Report.
716
+
717
+ Use `namespace` to specify alternate settings like where the report should be uploaded
718
+ """
719
+ if namespace is None:
720
+ namespace = Namespace(report.entity, report.project)
721
+
722
+ entity = coalesce(namespace.entity, report.entity)
723
+ project = coalesce(namespace.project, report.project)
724
+ name = report.name
725
+ title = report.title
726
+ description = report.description
727
+
728
+ api = self.dst_api
729
+
730
+ # We shouldn't need to upsert the project for every report
731
+ logger.debug(f"Upserting {entity=}/{project=}")
732
+ try:
733
+ api.create_project(project, entity)
734
+ except requests.exceptions.HTTPError as e:
735
+ if e.response.status_code != 409:
736
+ logger.warn(f"Issue upserting {entity=}/{project=}, {e=}")
737
+
738
+ logger.debug(f"Upserting report {entity=}, {project=}, {name=}, {title=}")
739
+ api.client.execute(
740
+ wr.report.UPSERT_VIEW,
741
+ variable_values={
742
+ "id": None, # Is there any benefit for this to be the same as default report?
743
+ "name": name,
744
+ "entityName": entity,
745
+ "projectName": project,
746
+ "description": description,
747
+ "displayName": title,
748
+ "type": "runs",
749
+ "spec": json.dumps(report.spec),
750
+ },
751
+ )
752
+
753
+ def _use_artifact_sequence(
754
+ self,
755
+ sequence: ArtifactSequence,
756
+ *,
757
+ namespace: Optional[Namespace] = None,
758
+ ):
759
+ if namespace is None:
760
+ namespace = Namespace(sequence.entity, sequence.project)
761
+
762
+ settings_override = {
763
+ "api_key": self.dst_api_key,
764
+ "base_url": self.dst_base_url,
765
+ "resume": "true",
766
+ "resumed": True,
767
+ }
768
+ logger.debug(f"Using artifact sequence with {settings_override=}, {namespace=}")
769
+
770
+ send_manager_config = internal.SendManagerConfig(use_artifacts=True)
771
+
772
+ for art in sequence:
773
+ if (used_by := art.used_by()) is None:
774
+ continue
775
+
776
+ for wandb_run in used_by:
777
+ run = WandbRun(wandb_run, **self.run_api_kwargs)
778
+
779
+ internal.send_run(
780
+ run,
781
+ overrides=namespace.send_manager_overrides,
782
+ settings_override=settings_override,
783
+ config=send_manager_config,
784
+ )
785
+
786
+ def import_runs(
787
+ self,
788
+ *,
789
+ namespaces: Optional[Iterable[Namespace]] = None,
790
+ remapping: Optional[Dict[Namespace, Namespace]] = None,
791
+ parallel: bool = True,
792
+ incremental: bool = True,
793
+ max_workers: Optional[int] = None,
794
+ limit: Optional[int] = None,
795
+ metadata: bool = True,
796
+ files: bool = True,
797
+ media: bool = True,
798
+ code: bool = True,
799
+ history: bool = True,
800
+ summary: bool = True,
801
+ terminal_output: bool = True,
802
+ ):
803
+ logger.info("START: Import runs")
804
+
805
+ logger.info("Setting up for import")
806
+ _create_files_if_not_exists()
807
+ _clear_fname(RUN_ERRORS_FNAME)
808
+
809
+ logger.info("Collecting runs")
810
+ runs = list(self._collect_runs(namespaces=namespaces, limit=limit))
811
+
812
+ logger.info(f"Validating runs, {len(runs)=}")
813
+ self._validate_runs(
814
+ runs,
815
+ skip_previously_validated=incremental,
816
+ remapping=remapping,
817
+ )
818
+
819
+ logger.info("Collecting failed runs")
820
+ runs = list(self._collect_failed_runs())
821
+
822
+ logger.info(f"Importing runs, {len(runs)=}")
823
+
824
+ def _import_run_wrapped(run):
825
+ namespace = Namespace(run.entity(), run.project())
826
+ if remapping is not None and namespace in remapping:
827
+ namespace = remapping[namespace]
828
+
829
+ config = internal.SendManagerConfig(
830
+ metadata=metadata,
831
+ files=files,
832
+ media=media,
833
+ code=code,
834
+ history=history,
835
+ summary=summary,
836
+ terminal_output=terminal_output,
837
+ )
838
+
839
+ logger.debug(f"Importing {run=}, {namespace=}, {config=}")
840
+ self._import_run(run, namespace=namespace, config=config)
841
+ logger.debug(f"Finished importing {run=}, {namespace=}, {config=}")
842
+
843
+ for_each(_import_run_wrapped, runs, max_workers=max_workers, parallel=parallel)
844
+ logger.info("END: Importing runs")
845
+
846
+ def import_reports(
847
+ self,
848
+ *,
849
+ namespaces: Optional[Iterable[Namespace]] = None,
850
+ limit: Optional[int] = None,
851
+ remapping: Optional[Dict[Namespace, Namespace]] = None,
852
+ ):
853
+ logger.info("START: Importing reports")
854
+
855
+ logger.info("Collecting reports")
856
+ reports = self._collect_reports(namespaces=namespaces, limit=limit)
857
+
858
+ logger.info("Importing reports")
859
+
860
+ def _import_report_wrapped(report):
861
+ namespace = Namespace(report.entity, report.project)
862
+ if remapping is not None and namespace in remapping:
863
+ namespace = remapping[namespace]
864
+
865
+ logger.debug(f"Importing {report=}, {namespace=}")
866
+ self._import_report(report, namespace=namespace)
867
+ logger.debug(f"Finished importing {report=}, {namespace=}")
868
+
869
+ for_each(_import_report_wrapped, reports)
870
+
871
+ logger.info("END: Importing reports")
872
+
873
+ def import_artifact_sequences(
874
+ self,
875
+ *,
876
+ namespaces: Optional[Iterable[Namespace]] = None,
877
+ incremental: bool = True,
878
+ max_workers: Optional[int] = None,
879
+ remapping: Optional[Dict[Namespace, Namespace]] = None,
880
+ ):
881
+ """Import all artifact sequences from `namespaces`.
882
+
883
+ Note: There is a known bug with the AWS backend where artifacts > 2048MB will fail to upload. This seems to be related to multipart uploads, but we don't have a fix yet.
884
+ """
885
+ logger.info("START: Importing artifact sequences")
886
+ _clear_fname(ARTIFACT_ERRORS_FNAME)
887
+
888
+ logger.info("Collecting artifact sequences")
889
+ seqs = list(self._collect_artifact_sequences(namespaces=namespaces))
890
+
891
+ logger.info("Validating artifact sequences")
892
+ self._validate_artifact_sequences(
893
+ seqs,
894
+ incremental=incremental,
895
+ remapping=remapping,
896
+ )
897
+
898
+ logger.info("Collecting failed artifact sequences")
899
+ seqs = list(self._collect_failed_artifact_sequences())
900
+
901
+ logger.info(f"Importing artifact sequences, {len(seqs)=}")
902
+
903
+ def _import_artifact_sequence_wrapped(seq):
904
+ namespace = Namespace(seq.entity, seq.project)
905
+ if remapping is not None and namespace in remapping:
906
+ namespace = remapping[namespace]
907
+
908
+ logger.debug(f"Importing artifact sequence {seq=}, {namespace=}")
909
+ self._import_artifact_sequence(seq, namespace=namespace)
910
+ logger.debug(f"Finished importing artifact sequence {seq=}, {namespace=}")
911
+
912
+ for_each(_import_artifact_sequence_wrapped, seqs, max_workers=max_workers)
913
+
914
+ # it's safer to just use artifact on all seqs to make sure we don't miss anything
915
+ # For seqs that have already been used, this is a no-op.
916
+ logger.debug(f"Using artifact sequences, {len(seqs)=}")
917
+
918
+ def _use_artifact_sequence_wrapped(seq):
919
+ namespace = Namespace(seq.entity, seq.project)
920
+ if remapping is not None and namespace in remapping:
921
+ namespace = remapping[namespace]
922
+
923
+ logger.debug(f"Using artifact sequence {seq=}, {namespace=}")
924
+ self._use_artifact_sequence(seq, namespace=namespace)
925
+ logger.debug(f"Finished using artifact sequence {seq=}, {namespace=}")
926
+
927
+ for_each(_use_artifact_sequence_wrapped, seqs, max_workers=max_workers)
928
+
929
+ # Artifacts whose parent runs have been deleted should have that run deleted in the
930
+ # destination as well
931
+
932
+ logger.info("Cleaning up dummy runs")
933
+ self._cleanup_dummy_runs(
934
+ namespaces=namespaces,
935
+ remapping=remapping,
936
+ )
937
+
938
+ logger.info("END: Importing artifact sequences")
939
+
940
+ def import_all(
941
+ self,
942
+ *,
943
+ runs: bool = True,
944
+ artifacts: bool = True,
945
+ reports: bool = True,
946
+ namespaces: Optional[Iterable[Namespace]] = None,
947
+ incremental: bool = True,
948
+ remapping: Optional[Dict[Namespace, Namespace]] = None,
949
+ ):
950
+ logger.info(f"START: Importing all, {runs=}, {artifacts=}, {reports=}")
951
+ if runs:
952
+ self.import_runs(
953
+ namespaces=namespaces,
954
+ incremental=incremental,
955
+ remapping=remapping,
956
+ )
957
+
958
+ if reports:
959
+ self.import_reports(
960
+ namespaces=namespaces,
961
+ remapping=remapping,
962
+ )
963
+
964
+ if artifacts:
965
+ self.import_artifact_sequences(
966
+ namespaces=namespaces,
967
+ incremental=incremental,
968
+ remapping=remapping,
969
+ )
970
+
971
+ logger.info("END: Importing all")
972
+
973
+ def _validate_run(
974
+ self,
975
+ src_run: Run,
976
+ *,
977
+ remapping: Optional[Dict[Namespace, Namespace]] = None,
978
+ ) -> None:
979
+ namespace = Namespace(src_run.entity, src_run.project)
980
+ if remapping is not None and namespace in remapping:
981
+ namespace = remapping[namespace]
982
+
983
+ dst_entity = namespace.entity
984
+ dst_project = namespace.project
985
+ run_id = src_run.id
986
+
987
+ try:
988
+ dst_run = self.dst_api.run(f"{dst_entity}/{dst_project}/{run_id}")
989
+ except wandb.CommError:
990
+ problems = [f"run does not exist in dst at {dst_entity=}/{dst_project=}"]
991
+ else:
992
+ problems = self._get_run_problems(src_run, dst_run)
993
+
994
+ d = {
995
+ "src_entity": src_run.entity,
996
+ "src_project": src_run.project,
997
+ "dst_entity": dst_entity,
998
+ "dst_project": dst_project,
999
+ "run_id": run_id,
1000
+ }
1001
+ if problems:
1002
+ d["problems"] = problems
1003
+ fname = RUN_ERRORS_FNAME
1004
+ else:
1005
+ fname = RUN_SUCCESSES_FNAME
1006
+
1007
+ with filelock.FileLock("runs.lock"):
1008
+ with open(fname, "a") as f:
1009
+ f.write(json.dumps(d) + "\n")
1010
+
1011
+ def _filter_previously_checked_runs(
1012
+ self,
1013
+ runs: Iterable[Run],
1014
+ *,
1015
+ remapping: Optional[Dict[Namespace, Namespace]] = None,
1016
+ ) -> Iterable[Run]:
1017
+ if (df := _read_ndjson(RUN_SUCCESSES_FNAME)) is None:
1018
+ logger.debug(f"{RUN_SUCCESSES_FNAME=} is empty, yielding all runs")
1019
+ yield from runs
1020
+ return
1021
+
1022
+ data = []
1023
+ for r in runs:
1024
+ namespace = Namespace(r.entity, r.project)
1025
+ if remapping is not None and namespace in remapping:
1026
+ namespace = remapping[namespace]
1027
+
1028
+ data.append(
1029
+ {
1030
+ "src_entity": r.entity,
1031
+ "src_project": r.project,
1032
+ "dst_entity": namespace.entity,
1033
+ "dst_project": namespace.project,
1034
+ "run_id": r.id,
1035
+ "data": r,
1036
+ }
1037
+ )
1038
+ df2 = pl.DataFrame(data)
1039
+ logger.debug(f"Starting with {len(runs)=} in namespaces")
1040
+
1041
+ results = df2.join(
1042
+ df,
1043
+ how="anti",
1044
+ on=["src_entity", "src_project", "dst_entity", "dst_project", "run_id"],
1045
+ )
1046
+ logger.debug(f"After filtering out already successful runs, {len(results)=}")
1047
+
1048
+ if not results.is_empty():
1049
+ results = results.filter(~results["run_id"].is_null())
1050
+ results = results.unique(
1051
+ ["src_entity", "src_project", "dst_entity", "dst_project", "run_id"]
1052
+ )
1053
+
1054
+ for r in results.iter_rows(named=True):
1055
+ yield r["data"]
1056
+
1057
+ def _validate_artifact(
1058
+ self,
1059
+ src_art: Artifact,
1060
+ dst_entity: str,
1061
+ dst_project: str,
1062
+ download_files_and_compare: bool = False,
1063
+ check_entries_are_downloadable: bool = True,
1064
+ ):
1065
+ problems = []
1066
+
1067
+ # These patterns of artifacts are special and should not be validated
1068
+ ignore_patterns = [
1069
+ r"^job-(.*?)\.py(:v\d+)?$",
1070
+ # r"^run-.*-history(?:\:v\d+)?$$",
1071
+ ]
1072
+ for pattern in ignore_patterns:
1073
+ if re.search(pattern, src_art.name):
1074
+ return (src_art, dst_entity, dst_project, problems)
1075
+
1076
+ try:
1077
+ dst_art = self._get_dst_art(src_art, dst_entity, dst_project)
1078
+ except Exception:
1079
+ problems.append("destination artifact not found")
1080
+ return (src_art, dst_entity, dst_project, problems)
1081
+
1082
+ try:
1083
+ logger.debug("Comparing artifact manifests")
1084
+ except Exception as e:
1085
+ problems.append(
1086
+ f"Problem getting problems! problem with {src_art.entity=}, {src_art.project=}, {src_art.name=} {e=}"
1087
+ )
1088
+ else:
1089
+ problems += validation._compare_artifact_manifests(src_art, dst_art)
1090
+
1091
+ if check_entries_are_downloadable:
1092
+ # validation._check_entries_are_downloadable(src_art)
1093
+ validation._check_entries_are_downloadable(dst_art)
1094
+
1095
+ if download_files_and_compare:
1096
+ logger.debug(f"Downloading {src_art=}")
1097
+ try:
1098
+ src_dir = _download_art(src_art, root=f"{SRC_ART_PATH}/{src_art.name}")
1099
+ except requests.HTTPError as e:
1100
+ problems.append(
1101
+ f"Invalid download link for src {src_art.entity=}, {src_art.project=}, {src_art.name=}, {e}"
1102
+ )
1103
+
1104
+ logger.debug(f"Downloading {dst_art=}")
1105
+ try:
1106
+ dst_dir = _download_art(dst_art, root=f"{DST_ART_PATH}/{dst_art.name}")
1107
+ except requests.HTTPError as e:
1108
+ problems.append(
1109
+ f"Invalid download link for dst {dst_art.entity=}, {dst_art.project=}, {dst_art.name=}, {e}"
1110
+ )
1111
+ else:
1112
+ logger.debug(f"Comparing artifact dirs {src_dir=}, {dst_dir=}")
1113
+ if problem := validation._compare_artifact_dirs(src_dir, dst_dir):
1114
+ problems.append(problem)
1115
+
1116
+ return (src_art, dst_entity, dst_project, problems)
1117
+
1118
+ def _validate_runs(
1119
+ self,
1120
+ runs: Iterable[WandbRun],
1121
+ *,
1122
+ skip_previously_validated: bool = True,
1123
+ remapping: Optional[Dict[Namespace, Namespace]] = None,
1124
+ ):
1125
+ base_runs = [r.run for r in runs]
1126
+ if skip_previously_validated:
1127
+ base_runs = list(
1128
+ self._filter_previously_checked_runs(
1129
+ base_runs,
1130
+ remapping=remapping,
1131
+ )
1132
+ )
1133
+
1134
+ def _validate_run(run):
1135
+ logger.debug(f"Validating {run=}")
1136
+ self._validate_run(run, remapping=remapping)
1137
+ logger.debug(f"Finished validating {run=}")
1138
+
1139
+ for_each(_validate_run, base_runs)
1140
+
1141
+ def _collect_failed_runs(self):
1142
+ if (df := _read_ndjson(RUN_ERRORS_FNAME)) is None:
1143
+ logger.debug(f"{RUN_ERRORS_FNAME=} is empty, returning nothing")
1144
+ return
1145
+
1146
+ unique_failed_runs = df[
1147
+ ["src_entity", "src_project", "dst_entity", "dst_project", "run_id"]
1148
+ ].unique()
1149
+
1150
+ for row in unique_failed_runs.iter_rows(named=True):
1151
+ src_entity = row["src_entity"]
1152
+ src_project = row["src_project"]
1153
+ # dst_entity = row["dst_entity"]
1154
+ # dst_project = row["dst_project"]
1155
+ run_id = row["run_id"]
1156
+
1157
+ run = self.src_api.run(f"{src_entity}/{src_project}/{run_id}")
1158
+ yield WandbRun(run, **self.run_api_kwargs)
1159
+
1160
+ def _filter_previously_checked_artifacts(self, seqs: Iterable[ArtifactSequence]):
1161
+ if (df := _read_ndjson(ARTIFACT_SUCCESSES_FNAME)) is None:
1162
+ logger.info(
1163
+ f"{ARTIFACT_SUCCESSES_FNAME=} is empty, yielding all artifact sequences"
1164
+ )
1165
+ for seq in seqs:
1166
+ yield from seq.artifacts
1167
+ return
1168
+
1169
+ for seq in seqs:
1170
+ for art in seq:
1171
+ try:
1172
+ logged_by = _get_run_or_dummy_from_art(art, self.src_api)
1173
+ except requests.HTTPError as e:
1174
+ logger.error(f"Failed to get run, skipping: {art=}, {e=}")
1175
+ continue
1176
+
1177
+ if art.type == "wandb-history" and isinstance(logged_by, _DummyRun):
1178
+ logger.debug(f"Skipping history artifact {art=}")
1179
+ # We can never upload valid history for a deleted run, so skip it
1180
+ continue
1181
+
1182
+ entity = art.entity
1183
+ project = art.project
1184
+ _type = art.type
1185
+ name, ver = _get_art_name_ver(art)
1186
+
1187
+ filtered_df = df.filter(
1188
+ (df["src_entity"] == entity)
1189
+ & (df["src_project"] == project)
1190
+ & (df["name"] == name)
1191
+ & (df["version"] == ver)
1192
+ & (df["type"] == _type)
1193
+ )
1194
+
1195
+ # not in file, so not verified yet, don't filter out
1196
+ if len(filtered_df) == 0:
1197
+ yield art
1198
+
1199
+ def _validate_artifact_sequences(
1200
+ self,
1201
+ seqs: Iterable[ArtifactSequence],
1202
+ *,
1203
+ incremental: bool = True,
1204
+ download_files_and_compare: bool = False,
1205
+ check_entries_are_downloadable: bool = True,
1206
+ remapping: Optional[Dict[Namespace, Namespace]] = None,
1207
+ ):
1208
+ if incremental:
1209
+ logger.info("Validating in incremental mode")
1210
+
1211
+ def filtered_sequences():
1212
+ for seq in seqs:
1213
+ if not seq.artifacts:
1214
+ continue
1215
+
1216
+ art = seq.artifacts[0]
1217
+ try:
1218
+ logged_by = _get_run_or_dummy_from_art(art, self.src_api)
1219
+ except requests.HTTPError as e:
1220
+ logger.error(
1221
+ f"Validate Artifact http error: {art.entity=}, {art.project=}, {art.name=}, {e=}"
1222
+ )
1223
+ continue
1224
+
1225
+ if art.type == "wandb-history" and isinstance(logged_by, _DummyRun):
1226
+ # We can never upload valid history for a deleted run, so skip it
1227
+ continue
1228
+
1229
+ yield seq
1230
+
1231
+ artifacts = self._filter_previously_checked_artifacts(filtered_sequences())
1232
+ else:
1233
+ logger.info("Validating in non-incremental mode")
1234
+ artifacts = [art for seq in seqs for art in seq.artifacts]
1235
+
1236
+ def _validate_artifact_wrapped(args):
1237
+ art, entity, project = args
1238
+ if (
1239
+ remapping is not None
1240
+ and (namespace := Namespace(entity, project)) in remapping
1241
+ ):
1242
+ remapped_ns = remapping[namespace]
1243
+ entity = remapped_ns.entity
1244
+ project = remapped_ns.project
1245
+
1246
+ logger.debug(f"Validating {art=}, {entity=}, {project=}")
1247
+ result = self._validate_artifact(
1248
+ art,
1249
+ entity,
1250
+ project,
1251
+ download_files_and_compare=download_files_and_compare,
1252
+ check_entries_are_downloadable=check_entries_are_downloadable,
1253
+ )
1254
+ logger.debug(f"Finished validating {art=}, {entity=}, {project=}")
1255
+ return result
1256
+
1257
+ args = ((art, art.entity, art.project) for art in artifacts)
1258
+ art_problems = for_each(_validate_artifact_wrapped, args)
1259
+ for art, dst_entity, dst_project, problems in art_problems:
1260
+ name, ver = _get_art_name_ver(art)
1261
+ d = {
1262
+ "src_entity": art.entity,
1263
+ "src_project": art.project,
1264
+ "dst_entity": dst_entity,
1265
+ "dst_project": dst_project,
1266
+ "name": name,
1267
+ "version": ver,
1268
+ "type": art.type,
1269
+ }
1270
+
1271
+ if problems:
1272
+ d["problems"] = problems
1273
+ fname = ARTIFACT_ERRORS_FNAME
1274
+ else:
1275
+ fname = ARTIFACT_SUCCESSES_FNAME
1276
+
1277
+ with open(fname, "a") as f:
1278
+ f.write(json.dumps(d) + "\n")
1279
+
1280
+ def _collect_runs(
1281
+ self,
1282
+ *,
1283
+ namespaces: Optional[Iterable[Namespace]] = None,
1284
+ limit: Optional[int] = None,
1285
+ skip_ids: Optional[List[str]] = None,
1286
+ start_date: Optional[str] = None,
1287
+ api: Optional[Api] = None,
1288
+ ) -> Iterable[WandbRun]:
1289
+ api = coalesce(api, self.src_api)
1290
+ namespaces = coalesce(namespaces, self._all_namespaces())
1291
+
1292
+ filters: Dict[str, Any] = {}
1293
+ if skip_ids is not None:
1294
+ filters["name"] = {"$nin": skip_ids}
1295
+ if start_date is not None:
1296
+ filters["createdAt"] = {"$gte": start_date}
1297
+
1298
+ def _runs():
1299
+ for ns in namespaces:
1300
+ logger.debug(f"Collecting runs from {ns=}")
1301
+ for run in api.runs(ns.path, filters=filters):
1302
+ yield WandbRun(run, **self.run_api_kwargs)
1303
+
1304
+ runs = itertools.islice(_runs(), limit)
1305
+ yield from runs
1306
+
1307
+ def _all_namespaces(
1308
+ self, *, entity: Optional[str] = None, api: Optional[Api] = None
1309
+ ):
1310
+ api = coalesce(api, self.src_api)
1311
+ entity = coalesce(entity, api.default_entity)
1312
+ projects = api.projects(entity)
1313
+ for p in projects:
1314
+ yield Namespace(p.entity, p.name)
1315
+
1316
+ def _collect_reports(
1317
+ self,
1318
+ *,
1319
+ namespaces: Optional[Iterable[Namespace]] = None,
1320
+ limit: Optional[int] = None,
1321
+ api: Optional[Api] = None,
1322
+ ):
1323
+ api = coalesce(api, self.src_api)
1324
+ namespaces = coalesce(namespaces, self._all_namespaces())
1325
+
1326
+ wandb.login(key=self.src_api_key, host=self.src_base_url)
1327
+
1328
+ def reports():
1329
+ for ns in namespaces:
1330
+ for r in api.reports(ns.path):
1331
+ yield wr.Report.from_url(r.url, api=api)
1332
+
1333
+ yield from itertools.islice(reports(), limit)
1334
+
1335
+ def _collect_artifact_sequences(
1336
+ self,
1337
+ *,
1338
+ namespaces: Optional[Iterable[Namespace]] = None,
1339
+ limit: Optional[int] = None,
1340
+ api: Optional[Api] = None,
1341
+ ):
1342
+ api = coalesce(api, self.src_api)
1343
+ namespaces = coalesce(namespaces, self._all_namespaces())
1344
+
1345
+ def artifact_sequences():
1346
+ for ns in namespaces:
1347
+ logger.debug(f"Collecting artifact sequences from {ns=}")
1348
+ types = []
1349
+ try:
1350
+ types = [t for t in api.artifact_types(ns.path)]
1351
+ except Exception as e:
1352
+ logger.error(f"Failed to get artifact types {e=}")
1353
+
1354
+ for t in types:
1355
+ collections = []
1356
+
1357
+ # Skip history because it's really for run history
1358
+ if t.name == "wandb-history":
1359
+ continue
1360
+
1361
+ try:
1362
+ collections = t.collections()
1363
+ except Exception as e:
1364
+ logger.error(f"Failed to get artifact collections {e=}")
1365
+
1366
+ for c in collections:
1367
+ if c.is_sequence():
1368
+ yield ArtifactSequence.from_collection(c)
1369
+
1370
+ seqs = itertools.islice(artifact_sequences(), limit)
1371
+ unique_sequences = {seq.identifier: seq for seq in seqs}
1372
+ yield from unique_sequences.values()
1373
+
1374
+
1375
+ def _get_art_name_ver(art: Artifact) -> Tuple[str, int]:
1376
+ name, ver = art.name.split(":v")
1377
+ return name, int(ver)
1378
+
1379
+
1380
+ def _make_dummy_art(name: str, _type: str, ver: int):
1381
+ art = Artifact(name, ART_DUMMY_PLACEHOLDER_TYPE)
1382
+ art._type = _type
1383
+ art._description = ART_SEQUENCE_DUMMY_PLACEHOLDER
1384
+
1385
+ p = Path(ART_DUMMY_PLACEHOLDER_PATH)
1386
+ p.mkdir(parents=True, exist_ok=True)
1387
+
1388
+ # dummy file with different name to prevent dedupe
1389
+ fname = p / str(ver)
1390
+ with open(fname, "w"):
1391
+ pass
1392
+ art.add_file(fname)
1393
+
1394
+ return art
1395
+
1396
+
1397
+ def _make_groups_of_artifacts(seq: ArtifactSequence, start: int = 0):
1398
+ prev_ver = start - 1
1399
+ for art in seq:
1400
+ name, ver = _get_art_name_ver(art)
1401
+
1402
+ # If there's a gap between versions, fill with dummy artifacts
1403
+ if ver - prev_ver > 1:
1404
+ yield [_make_dummy_art(name, art.type, v) for v in range(prev_ver + 1, ver)]
1405
+
1406
+ # Then yield the actual artifact
1407
+ # Must always be a list of one artifact to guarantee ordering
1408
+ yield [art]
1409
+ prev_ver = ver
1410
+
1411
+
1412
+ def _recursive_cast_to_dict(obj):
1413
+ if isinstance(obj, list):
1414
+ return [_recursive_cast_to_dict(item) for item in obj]
1415
+ elif isinstance(obj, dict) or hasattr(obj, "items"):
1416
+ new_dict = {}
1417
+ for key, value in obj.items():
1418
+ new_dict[key] = _recursive_cast_to_dict(value)
1419
+ return new_dict
1420
+ else:
1421
+ return obj
1422
+
1423
+
1424
+ def _almost_equal(x, y, eps=1e-6):
1425
+ if isinstance(x, dict) and isinstance(y, dict):
1426
+ if x.keys() != y.keys():
1427
+ return False
1428
+ return all(_almost_equal(x[k], y[k], eps) for k in x)
1429
+
1430
+ if isinstance(x, numbers.Number) and isinstance(y, numbers.Number):
1431
+ return abs(x - y) < eps
1432
+
1433
+ if type(x) is not type(y):
1434
+ return False
1435
+
1436
+ return x == y
1437
+
1438
+
1439
+ @dataclass
1440
+ class _DummyUser:
1441
+ username: str = ""
1442
+
1443
+
1444
+ @dataclass
1445
+ class _DummyRun:
1446
+ entity: str = ""
1447
+ project: str = ""
1448
+ run_id: str = RUN_DUMMY_PLACEHOLDER
1449
+ id: str = RUN_DUMMY_PLACEHOLDER
1450
+ display_name: str = RUN_DUMMY_PLACEHOLDER
1451
+ notes: str = ""
1452
+ url: str = ""
1453
+ group: str = ""
1454
+ created_at: str = "2000-01-01"
1455
+ user: _DummyUser = field(default_factory=_DummyUser)
1456
+ tags: list = field(default_factory=list)
1457
+ summary: dict = field(default_factory=dict)
1458
+ config: dict = field(default_factory=dict)
1459
+
1460
+ def files(self):
1461
+ return []
1462
+
1463
+
1464
+ def _read_ndjson(fname: str) -> Optional[pl.DataFrame]:
1465
+ try:
1466
+ df = pl.read_ndjson(fname)
1467
+ except FileNotFoundError:
1468
+ return None
1469
+ except RuntimeError as e:
1470
+ # No runs previously checked
1471
+ if "empty string is not a valid JSON value" in str(e):
1472
+ return None
1473
+ if "error parsing ndjson" in str(e):
1474
+ return None
1475
+ raise e
1476
+
1477
+ return df
1478
+
1479
+
1480
+ def _get_run_or_dummy_from_art(art: Artifact, api=None):
1481
+ run = None
1482
+
1483
+ try:
1484
+ run = art.logged_by()
1485
+ except ValueError as e:
1486
+ logger.warn(
1487
+ f"Can't log artifact because run does't exist, {art=}, {run=}, {e=}"
1488
+ )
1489
+
1490
+ if run is not None:
1491
+ return run
1492
+
1493
+ query = gql(
1494
+ """
1495
+ query ArtifactCreatedBy(
1496
+ $id: ID!
1497
+ ) {
1498
+ artifact(id: $id) {
1499
+ createdBy {
1500
+ ... on Run {
1501
+ name
1502
+ project {
1503
+ name
1504
+ entityName
1505
+ }
1506
+ }
1507
+ }
1508
+ }
1509
+ }
1510
+ """
1511
+ )
1512
+ response = api.client.execute(query, variable_values={"id": art.id})
1513
+ creator = response.get("artifact", {}).get("createdBy", {})
1514
+ run = _DummyRun(
1515
+ entity=art.entity,
1516
+ project=art.project,
1517
+ run_id=creator.get("name", RUN_DUMMY_PLACEHOLDER),
1518
+ id=creator.get("name", RUN_DUMMY_PLACEHOLDER),
1519
+ )
1520
+ return run
1521
+
1522
+
1523
+ def _clear_fname(fname: str) -> None:
1524
+ old_fname = f"{internal.ROOT_DIR}/{fname}"
1525
+ new_fname = f"{internal.ROOT_DIR}/prev_{fname}"
1526
+
1527
+ logger.debug(f"Moving {old_fname=} to {new_fname=}")
1528
+ try:
1529
+ shutil.copy2(old_fname, new_fname)
1530
+ except FileNotFoundError:
1531
+ # this is just to make a copy of the last iteration, so its ok if the src doesn't exist
1532
+ pass
1533
+
1534
+ with open(fname, "w"):
1535
+ pass
1536
+
1537
+
1538
+ def _download_art(art: Artifact, root: str) -> Optional[str]:
1539
+ try:
1540
+ with patch("click.echo"):
1541
+ return art.download(root=root, skip_cache=True)
1542
+ except Exception as e:
1543
+ logger.error(f"Error downloading artifact {art=}, {e=}")
1544
+
1545
+
1546
+ def _clone_art(art: Artifact, root: Optional[str] = None):
1547
+ if root is None:
1548
+ # Currently, we would only ever clone a src artifact to move it to dst.
1549
+ root = f"{SRC_ART_PATH}/{art.name}"
1550
+
1551
+ if (path := _download_art(art, root=root)) is None:
1552
+ raise ValueError(f"Problem downloading {art=}")
1553
+
1554
+ name, _ = art.name.split(":v")
1555
+
1556
+ # Hack: skip naming validation check for wandb-* types
1557
+ new_art = Artifact(name, ART_DUMMY_PLACEHOLDER_TYPE)
1558
+ new_art._type = art.type
1559
+ new_art._created_at = art.created_at
1560
+
1561
+ new_art._aliases = art.aliases
1562
+ new_art._description = art.description
1563
+
1564
+ with patch("click.echo"):
1565
+ new_art.add_dir(path)
1566
+
1567
+ return new_art
1568
+
1569
+
1570
+ def _create_files_if_not_exists() -> None:
1571
+ fnames = [
1572
+ ARTIFACT_ERRORS_FNAME,
1573
+ ARTIFACT_SUCCESSES_FNAME,
1574
+ RUN_ERRORS_FNAME,
1575
+ RUN_SUCCESSES_FNAME,
1576
+ ]
1577
+
1578
+ for fname in fnames:
1579
+ logger.debug(f"Creating {fname=} if not exists")
1580
+ with open(fname, "a"):
1581
+ pass
1582
+
1583
+
1584
+ def _merge_dfs(dfs: List[pl.DataFrame]) -> pl.DataFrame:
1585
+ # Ensure there are DataFrames in the list
1586
+ if len(dfs) == 0:
1587
+ return pl.DataFrame()
1588
+
1589
+ if len(dfs) == 1:
1590
+ return dfs[0]
1591
+
1592
+ merged_df = dfs[0]
1593
+ for df in dfs[1:]:
1594
+ merged_df = merged_df.join(df, how="outer", on=["_step"])
1595
+ col_pairs = [
1596
+ (c, f"{c}_right")
1597
+ for c in merged_df.columns
1598
+ if f"{c}_right" in merged_df.columns
1599
+ ]
1600
+ for col, right in col_pairs:
1601
+ new_col = merged_df[col].fill_null(merged_df[right])
1602
+ merged_df = merged_df.with_columns(new_col).drop(right)
1603
+
1604
+ return merged_df