wandb 0.13.11__py3-none-any.whl → 0.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/apis/importers/__init__.py +4 -0
  3. wandb/apis/importers/base.py +312 -0
  4. wandb/apis/importers/mlflow.py +113 -0
  5. wandb/apis/internal.py +9 -0
  6. wandb/apis/public.py +0 -2
  7. wandb/cli/cli.py +100 -72
  8. wandb/docker/__init__.py +33 -5
  9. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  10. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  11. wandb/sdk/internal/internal_api.py +85 -9
  12. wandb/sdk/launch/_project_spec.py +45 -55
  13. wandb/sdk/launch/agent/agent.py +80 -18
  14. wandb/sdk/launch/builder/build.py +16 -74
  15. wandb/sdk/launch/builder/docker_builder.py +36 -8
  16. wandb/sdk/launch/builder/kaniko_builder.py +78 -37
  17. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +68 -18
  18. wandb/sdk/launch/environment/aws_environment.py +4 -0
  19. wandb/sdk/launch/launch.py +1 -6
  20. wandb/sdk/launch/launch_add.py +0 -5
  21. wandb/sdk/launch/registry/abstract.py +12 -0
  22. wandb/sdk/launch/registry/elastic_container_registry.py +31 -1
  23. wandb/sdk/launch/registry/google_artifact_registry.py +32 -0
  24. wandb/sdk/launch/registry/local_registry.py +15 -1
  25. wandb/sdk/launch/runner/abstract.py +0 -14
  26. wandb/sdk/launch/runner/kubernetes_runner.py +25 -19
  27. wandb/sdk/launch/runner/local_container.py +7 -8
  28. wandb/sdk/launch/runner/local_process.py +0 -3
  29. wandb/sdk/launch/runner/sagemaker_runner.py +0 -3
  30. wandb/sdk/launch/runner/vertex_runner.py +0 -2
  31. wandb/sdk/launch/sweeps/scheduler.py +39 -10
  32. wandb/sdk/launch/utils.py +52 -4
  33. wandb/sdk/wandb_run.py +3 -10
  34. wandb/sync/sync.py +1 -0
  35. wandb/util.py +1 -0
  36. {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/METADATA +1 -1
  37. {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/RECORD +41 -38
  38. {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/WHEEL +1 -1
  39. {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/LICENSE +0 -0
  40. {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/entry_points.txt +0 -0
  41. {wandb-0.13.11.dist-info → wandb-0.14.0.dist-info}/top_level.txt +0 -0
wandb/__init__.py CHANGED
@@ -11,7 +11,7 @@ For scripts and interactive notebooks, see https://github.com/wandb/examples.
11
11
 
12
12
  For reference documentation, see https://docs.wandb.com/ref/python.
13
13
  """
14
- __version__ = "0.13.11"
14
+ __version__ = "0.14.0"
15
15
 
16
16
  # Used with pypi checks and other messages related to pip
17
17
  _wandb_module = "wandb"
@@ -0,0 +1,4 @@
1
+ from wandb.util import get_module
2
+
3
+ if get_module("mlflow"):
4
+ from .mlflow import MlflowImporter, MlflowRun # noqa: F401
@@ -0,0 +1,312 @@
1
+ import json
2
+ import platform
3
+ from abc import ABC, abstractmethod
4
+ from concurrent.futures import ProcessPoolExecutor, as_completed
5
+ from contextlib import contextmanager
6
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
7
+
8
+ from tqdm.auto import tqdm
9
+
10
+ import wandb
11
+ from wandb.proto import wandb_internal_pb2 as pb
12
+ from wandb.proto import wandb_telemetry_pb2 as telem_pb
13
+ from wandb.sdk.interface.interface import file_policy_to_enum
14
+ from wandb.sdk.interface.interface_queue import InterfaceQueue
15
+ from wandb.sdk.internal.sender import SendManager
16
+
17
+ Name = str
18
+ Path = str
19
+
20
+
21
+ def coalesce(*arg: Any) -> Any:
22
+ """Return the first non-none value in the list of arguments. Similar to ?? in C#."""
23
+ return next((a for a in arg if a is not None), None)
24
+
25
+
26
+ @contextmanager
27
+ def send_manager(root_dir):
28
+ sm = SendManager.setup(root_dir, resume=False)
29
+ try:
30
+ yield sm
31
+ finally:
32
+ # flush any remaining records
33
+ while sm:
34
+ data = next(sm)
35
+ sm.send(data)
36
+ sm.finish()
37
+
38
+
39
+ class ImporterRun:
40
+ def __init__(self) -> None:
41
+ self.interface = InterfaceQueue()
42
+ self.run_dir = f"./wandb-importer/{self.run_id()}"
43
+
44
+ def run_id(self) -> str:
45
+ _id = wandb.util.generate_id()
46
+ wandb.termwarn(f"`run_id` not specified. Autogenerating id: {_id}")
47
+ return _id
48
+
49
+ def entity(self) -> str:
50
+ _entity = "unspecified-entity"
51
+ wandb.termwarn(f"`entity` not specified. Defaulting to: {_entity}")
52
+ return _entity
53
+
54
+ def project(self) -> str:
55
+ _project = "unspecified-project"
56
+ wandb.termwarn(f"`project` not specified. Defaulting to: {_project}")
57
+ return _project
58
+
59
+ def config(self) -> Dict[str, Any]:
60
+ return {}
61
+
62
+ def summary(self) -> Dict[str, float]:
63
+ return {}
64
+
65
+ def metrics(self) -> List[Dict[str, float]]:
66
+ """Metrics for the run.
67
+
68
+ We expect metrics in this shape:
69
+
70
+ [
71
+ {'metric1': 1, 'metric2': 1, '_step': 0},
72
+ {'metric1': 2, 'metric2': 4, '_step': 1},
73
+ {'metric1': 3, 'metric2': 9, '_step': 2},
74
+ ...
75
+ ]
76
+
77
+ You can also submit metrics in this shape:
78
+ [
79
+ {'metric1': 1, '_step': 0},
80
+ {'metric2': 1, '_step': 0},
81
+ {'metric1': 2, '_step': 1},
82
+ {'metric2': 4, '_step': 1},
83
+ ...
84
+ ]
85
+ """
86
+ return []
87
+
88
+ def run_group(self) -> Optional[str]:
89
+ ...
90
+
91
+ def job_type(self) -> Optional[str]:
92
+ ...
93
+
94
+ def display_name(self) -> str:
95
+ return self.run_id()
96
+
97
+ def notes(self) -> Optional[str]:
98
+ ...
99
+
100
+ def tags(self) -> Optional[List[str]]:
101
+ ...
102
+
103
+ def artifacts(self) -> Optional[Iterable[Tuple[Name, Path]]]:
104
+ ...
105
+
106
+ def os_version(self) -> Optional[str]:
107
+ ...
108
+
109
+ def python_version(self) -> Optional[str]:
110
+ ...
111
+
112
+ def cuda_version(self) -> Optional[str]:
113
+ ...
114
+
115
+ def program(self) -> Optional[str]:
116
+ ...
117
+
118
+ def host(self) -> Optional[str]:
119
+ ...
120
+
121
+ def username(self) -> Optional[str]:
122
+ ...
123
+
124
+ def executable(self) -> Optional[str]:
125
+ ...
126
+
127
+ def gpus_used(self) -> Optional[str]:
128
+ ...
129
+
130
+ def cpus_used(self) -> Optional[int]: # can we get the model?
131
+ ...
132
+
133
+ def memory_used(self) -> Optional[int]:
134
+ ...
135
+
136
+ def runtime(self) -> Optional[int]:
137
+ ...
138
+
139
+ def start_time(self) -> Optional[int]:
140
+ ...
141
+
142
+ def _make_run_record(self) -> pb.Record:
143
+ run = pb.RunRecord()
144
+ run.run_id = self.run_id()
145
+ run.entity = self.entity()
146
+ run.project = self.project()
147
+ run.display_name = coalesce(self.display_name())
148
+ run.notes = coalesce(self.notes(), "")
149
+ run.tags.extend(coalesce(self.tags(), list()))
150
+ # run.start_time.FromMilliseconds(self.start_time())
151
+ # run.runtime = self.runtime()
152
+ run_group = self.run_group()
153
+ if run_group is not None:
154
+ run.run_group = run_group
155
+ self.interface._make_config(
156
+ data=self.config(),
157
+ obj=run.config,
158
+ ) # is there a better way?
159
+ return self.interface._make_record(run=run)
160
+
161
+ def _make_summary_record(self) -> pb.Record:
162
+ d: dict = {
163
+ **self.summary(),
164
+ "_runtime": self.runtime(), # quirk of runtime -- it has to be here!
165
+ # '_timestamp': self.start_time()/1000,
166
+ }
167
+ summary = self.interface._make_summary_from_dict(d)
168
+ return self.interface._make_record(summary=summary)
169
+
170
+ def _make_history_records(self) -> Iterable[pb.Record]:
171
+ for _, metrics in enumerate(self.metrics()):
172
+ history = pb.HistoryRecord()
173
+ for k, v in metrics.items():
174
+ item = history.item.add()
175
+ item.key = k
176
+ item.value_json = json.dumps(v)
177
+ yield self.interface._make_record(history=history)
178
+
179
+ def _make_files_record(self, files_dict) -> pb.Record:
180
+ # when making the metadata file, it captures most things correctly
181
+ # but notably it doesn't capture the start time!
182
+ files_record = pb.FilesRecord()
183
+ for path, policy in files_dict["files"]:
184
+ f = files_record.files.add()
185
+ f.path = path
186
+ f.policy = file_policy_to_enum(policy) # is this always "end"?
187
+ return self.interface._make_record(files=files_record)
188
+
189
+ def _make_metadata_files_record(self) -> pb.Record:
190
+ self._make_metadata_file(self.run_dir)
191
+ return self._make_files_record(
192
+ {"files": [[f"{self.run_dir}/files/wandb-metadata.json", "end"]]}
193
+ )
194
+
195
+ def _make_artifact_record(self) -> pb.Record:
196
+ art = wandb.Artifact(self.display_name(), "imported-artifacts")
197
+ artifacts = self.artifacts()
198
+ if artifacts is not None:
199
+ for name, path in artifacts:
200
+ art.add_file(path, name)
201
+ proto = self.interface._make_artifact(art)
202
+ proto.run_id = self.run_id()
203
+ proto.project = self.project()
204
+ proto.entity = self.entity()
205
+ proto.user_created = False
206
+ proto.use_after_commit = False
207
+ proto.finalize = True
208
+ for tag in ["latest", "imported"]:
209
+ proto.aliases.append(tag)
210
+ return self.interface._make_record(artifact=proto)
211
+
212
+ def _make_telem_record(self) -> pb.Record:
213
+ feature = telem_pb.Feature()
214
+ feature.importer_mlflow = True
215
+
216
+ telem = telem_pb.TelemetryRecord()
217
+ telem.feature.CopyFrom(feature)
218
+ telem.python_version = platform.python_version() # importer's python version
219
+ telem.cli_version = wandb.__version__
220
+ return self.interface._make_record(telemetry=telem)
221
+
222
+ def _make_metadata_file(self, run_dir: str) -> None:
223
+ missing_text = "MLFlow did not capture this info."
224
+
225
+ d = {}
226
+ if self.os_version() is not None:
227
+ d["os"] = self.os_version()
228
+ else:
229
+ d["os"] = missing_text
230
+
231
+ if self.python_version() is not None:
232
+ d["python"] = self.python_version()
233
+ else:
234
+ d["python"] = missing_text
235
+
236
+ if self.program() is not None:
237
+ d["program"] = self.program()
238
+ else:
239
+ d["program"] = missing_text
240
+
241
+ if self.cuda_version() is not None:
242
+ d["cuda"] = self.cuda_version()
243
+ if self.host() is not None:
244
+ d["host"] = self.host()
245
+ if self.username() is not None:
246
+ d["username"] = self.username()
247
+ if self.executable() is not None:
248
+ d["executable"] = self.executable()
249
+ gpus_used = self.gpus_used()
250
+ if gpus_used is not None:
251
+ d["gpu_devices"] = json.dumps(gpus_used)
252
+ d["gpu_count"] = json.dumps(len(gpus_used))
253
+ cpus_used = self.cpus_used()
254
+ if cpus_used is not None:
255
+ d["cpu_count"] = json.dumps(self.cpus_used())
256
+ mem_used = self.memory_used()
257
+ if mem_used is not None:
258
+ d["memory"] = json.dumps({"total": self.memory_used()})
259
+
260
+ with open(f"{run_dir}/files/wandb-metadata.json", "w") as f:
261
+ f.write(json.dumps(d))
262
+
263
+
264
+ class Importer(ABC):
265
+ @abstractmethod
266
+ def download_all_runs(self) -> Iterable[ImporterRun]:
267
+ ...
268
+
269
+ def import_all(self, overrides: Optional[Dict[str, Any]] = None) -> None:
270
+ for run in tqdm(self.download_all_runs(), desc="Sending runs"):
271
+ self.import_one(run, overrides)
272
+
273
+ def import_all_parallel(
274
+ self, overrides: Optional[Dict[str, Any]] = None, **pool_kwargs: Any
275
+ ) -> None:
276
+ runs = list(self.download_all_runs())
277
+ with tqdm(total=len(runs)) as pbar:
278
+ with ProcessPoolExecutor(**pool_kwargs) as exc:
279
+ futures = {
280
+ exc.submit(self.import_one, run, overrides=overrides): run
281
+ for run in runs
282
+ }
283
+ for future in as_completed(futures):
284
+ run = futures[future]
285
+ pbar.update(1)
286
+ pbar.set_description(
287
+ f"Imported Run: {run.run_group()} {run.display_name()}"
288
+ )
289
+
290
+ def import_one(
291
+ self,
292
+ run: ImporterRun,
293
+ overrides: Optional[Dict[str, Any]] = None,
294
+ ) -> None:
295
+ # does this need to be here for pmap?
296
+ if overrides:
297
+ for k, v in overrides.items():
298
+ # `lambda: v` won't work!
299
+ # https://stackoverflow.com/questions/10802002/why-deepcopy-doesnt-create-new-references-to-lambda-function
300
+ setattr(run, k, lambda v=v: v)
301
+ self._import_one(run)
302
+
303
+ def _import_one(self, run: ImporterRun) -> None:
304
+ with send_manager(run.run_dir) as sm:
305
+ sm.send(run._make_run_record())
306
+ sm.send(run._make_summary_record())
307
+ sm.send(run._make_metadata_files_record())
308
+ for history_record in run._make_history_records():
309
+ sm.send(history_record)
310
+ if run.artifacts() is not None:
311
+ sm.send(run._make_artifact_record())
312
+ sm.send(run._make_telem_record())
@@ -0,0 +1,113 @@
1
+ from typing import Any, Dict, Iterable, Optional
2
+
3
+ from wandb.util import get_module
4
+
5
+ from .base import Importer, ImporterRun
6
+
7
+ mlflow = get_module(
8
+ "mlflow",
9
+ required="To use the MlflowImporter, please install mlflow: `pip install mlflow`",
10
+ )
11
+
12
+
13
+ class MlflowRun(ImporterRun):
14
+ def __init__(self, run, mlflow_client):
15
+ self.run = run
16
+ self.mlflow_client = mlflow_client
17
+ super().__init__()
18
+
19
+ def run_id(self):
20
+ return self.run.info.run_id
21
+
22
+ def entity(self):
23
+ return self.run.info.user_id
24
+
25
+ def project(self):
26
+ return "imported-from-mlflow"
27
+
28
+ def config(self):
29
+ return self.run.data.params
30
+
31
+ def summary(self):
32
+ return self.run.data.metrics
33
+
34
+ def metrics(self):
35
+ def wandbify(metrics):
36
+ for step, t in enumerate(metrics):
37
+ d = {m.key: m.value for m in t}
38
+ d["_step"] = step
39
+ yield d
40
+
41
+ metrics = [
42
+ self.mlflow_client.get_metric_history(self.run.info.run_id, k)
43
+ for k in self.run.data.metrics.keys()
44
+ ]
45
+ metrics = zip(*metrics) # transpose
46
+ return wandbify(metrics)
47
+
48
+ # Alternate: Might be slower but use less mem
49
+ # Can't make this a generator. See mlflow get_metric_history internals
50
+ # https://github.com/mlflow/mlflow/blob/master/mlflow/tracking/_tracking_service/client.py#L74-L93
51
+ # for k in self.run.data.metrics.keys():
52
+ # history = self.mlflow_client.get_metric_history(self.run.info.run_id, k)
53
+ # yield wandbify(history)
54
+
55
+ def run_group(self):
56
+ # this is nesting? Parent at `run.info.tags.get("mlflow.parentRunId")`
57
+ return f"Experiment {self.run.info.experiment_id}"
58
+
59
+ def job_type(self):
60
+ # Is this the right approach?
61
+ return f"User {self.run.info.user_id}"
62
+
63
+ def display_name(self):
64
+ return self.run.info.run_name
65
+
66
+ def notes(self):
67
+ return self.run.data.tags.get("mlflow.note.content")
68
+
69
+ def tags(self):
70
+ return {
71
+ k: v for k, v in self.run.data.tags.items() if not k.startswith("mlflow.")
72
+ }
73
+
74
+ def start_time(self):
75
+ return self.run.info.start_time // 1000
76
+
77
+ def runtime(self):
78
+ return self.run.info.end_time // 1_000 - self.start_time()
79
+
80
+ def git(self):
81
+ ...
82
+
83
+ def artifacts(self):
84
+ for f in self.mlflow_client.list_artifacts(self.run.info.run_id):
85
+ dir_path = mlflow.artifacts.download_artifacts(run_id=self.run.info.run_id)
86
+ full_path = dir_path + f.path
87
+ yield (f.path, full_path)
88
+
89
+
90
+ class MlflowImporter(Importer):
91
+ def __init__(
92
+ self, mlflow_tracking_uri, mlflow_registry_uri=None, wandb_base_url=None
93
+ ) -> None:
94
+ super().__init__()
95
+ self.mlflow_tracking_uri = mlflow_tracking_uri
96
+
97
+ mlflow.set_tracking_uri(self.mlflow_tracking_uri)
98
+ if mlflow_registry_uri:
99
+ mlflow.set_registry_uri(mlflow_registry_uri)
100
+ self.mlflow_client = mlflow.tracking.MlflowClient(mlflow_tracking_uri)
101
+
102
+ def import_one(
103
+ self,
104
+ run: ImporterRun,
105
+ overrides: Optional[Dict[str, Any]] = None,
106
+ ) -> None:
107
+ mlflow.set_tracking_uri(self.mlflow_tracking_uri)
108
+ super().import_one(run, overrides)
109
+
110
+ def download_all_runs(self) -> Iterable[MlflowRun]:
111
+ for exp in self.mlflow_client.search_experiments():
112
+ for run in self.mlflow_client.search_runs(exp.experiment_id):
113
+ yield MlflowRun(run, self.mlflow_client)
wandb/apis/internal.py CHANGED
@@ -161,6 +161,9 @@ class Api:
161
161
  def get_run_state(self, *args, **kwargs):
162
162
  return self.api.get_run_state(*args, **kwargs)
163
163
 
164
+ def entity_is_team(self, *args, **kwargs):
165
+ return self.api.entity_is_team(*args, **kwargs)
166
+
164
167
  def get_project_run_queues(self, *args, **kwargs):
165
168
  return self.api.get_project_run_queues(*args, **kwargs)
166
169
 
@@ -182,6 +185,12 @@ class Api:
182
185
  def launch_agent_introspection(self, *args, **kwargs):
183
186
  return self.api.launch_agent_introspection(*args, **kwargs)
184
187
 
188
+ def fail_run_queue_item_introspection(self, *args, **kwargs):
189
+ return self.api.fail_run_queue_item_introspection(*args, **kwargs)
190
+
191
+ def fail_run_queue_item(self, *args, **kwargs):
192
+ return self.api.fail_run_queue_item(*args, **kwargs)
193
+
185
194
  def get_launch_agent(self, *args, **kwargs):
186
195
  return self.api.get_launch_agent(*args, **kwargs)
187
196
 
wandb/apis/public.py CHANGED
@@ -5562,7 +5562,6 @@ class Job:
5562
5562
  queue=None,
5563
5563
  resource="local-container",
5564
5564
  resource_args=None,
5565
- cuda=False,
5566
5565
  project_queue=None,
5567
5566
  ):
5568
5567
  from wandb.sdk.launch import launch_add
@@ -5589,6 +5588,5 @@ class Job:
5589
5588
  resource=resource,
5590
5589
  project_queue=project_queue,
5591
5590
  resource_args=resource_args,
5592
- cuda=cuda,
5593
5591
  )
5594
5592
  return queued_run