wandb 0.16.3__py3-none-any.whl → 0.16.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/agents/pyagent.py +1 -1
  3. wandb/apis/importers/__init__.py +1 -4
  4. wandb/apis/importers/internals/internal.py +386 -0
  5. wandb/apis/importers/internals/protocols.py +125 -0
  6. wandb/apis/importers/internals/util.py +78 -0
  7. wandb/apis/importers/mlflow.py +125 -88
  8. wandb/apis/importers/validation.py +108 -0
  9. wandb/apis/importers/wandb.py +1604 -0
  10. wandb/apis/public/api.py +7 -10
  11. wandb/apis/public/artifacts.py +38 -0
  12. wandb/apis/public/files.py +11 -2
  13. wandb/apis/reports/v2/__init__.py +0 -19
  14. wandb/apis/reports/v2/expr_parsing.py +0 -1
  15. wandb/apis/reports/v2/interface.py +15 -18
  16. wandb/apis/reports/v2/internal.py +12 -45
  17. wandb/cli/cli.py +52 -55
  18. wandb/integration/gym/__init__.py +2 -1
  19. wandb/integration/keras/callbacks/model_checkpoint.py +1 -1
  20. wandb/integration/keras/keras.py +6 -4
  21. wandb/integration/kfp/kfp_patch.py +2 -2
  22. wandb/integration/openai/fine_tuning.py +1 -2
  23. wandb/integration/ultralytics/callback.py +0 -1
  24. wandb/proto/v3/wandb_internal_pb2.py +332 -312
  25. wandb/proto/v3/wandb_settings_pb2.py +13 -3
  26. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  27. wandb/proto/v4/wandb_internal_pb2.py +316 -312
  28. wandb/proto/v4/wandb_settings_pb2.py +5 -3
  29. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  30. wandb/sdk/artifacts/artifact.py +75 -31
  31. wandb/sdk/artifacts/artifact_manifest.py +5 -2
  32. wandb/sdk/artifacts/artifact_manifest_entry.py +6 -1
  33. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +8 -2
  34. wandb/sdk/artifacts/artifact_saver.py +19 -47
  35. wandb/sdk/artifacts/storage_handler.py +2 -1
  36. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +22 -9
  37. wandb/sdk/artifacts/storage_policy.py +4 -1
  38. wandb/sdk/data_types/base_types/wb_value.py +1 -1
  39. wandb/sdk/data_types/image.py +2 -2
  40. wandb/sdk/interface/interface.py +49 -13
  41. wandb/sdk/interface/interface_shared.py +17 -11
  42. wandb/sdk/internal/file_stream.py +20 -1
  43. wandb/sdk/internal/handler.py +1 -4
  44. wandb/sdk/internal/internal_api.py +3 -1
  45. wandb/sdk/internal/job_builder.py +49 -19
  46. wandb/sdk/internal/profiler.py +1 -1
  47. wandb/sdk/internal/sender.py +96 -124
  48. wandb/sdk/internal/sender_config.py +197 -0
  49. wandb/sdk/internal/settings_static.py +9 -0
  50. wandb/sdk/internal/system/system_info.py +5 -3
  51. wandb/sdk/internal/update.py +1 -1
  52. wandb/sdk/launch/_launch.py +3 -3
  53. wandb/sdk/launch/_launch_add.py +28 -29
  54. wandb/sdk/launch/_project_spec.py +148 -136
  55. wandb/sdk/launch/agent/agent.py +3 -7
  56. wandb/sdk/launch/agent/config.py +0 -27
  57. wandb/sdk/launch/builder/build.py +54 -28
  58. wandb/sdk/launch/builder/docker_builder.py +4 -15
  59. wandb/sdk/launch/builder/kaniko_builder.py +72 -45
  60. wandb/sdk/launch/create_job.py +6 -40
  61. wandb/sdk/launch/loader.py +10 -0
  62. wandb/sdk/launch/registry/anon.py +29 -0
  63. wandb/sdk/launch/registry/local_registry.py +4 -1
  64. wandb/sdk/launch/runner/kubernetes_runner.py +20 -2
  65. wandb/sdk/launch/runner/local_container.py +15 -10
  66. wandb/sdk/launch/runner/sagemaker_runner.py +1 -1
  67. wandb/sdk/launch/sweeps/scheduler.py +11 -3
  68. wandb/sdk/launch/utils.py +14 -0
  69. wandb/sdk/lib/__init__.py +2 -5
  70. wandb/sdk/lib/_settings_toposort_generated.py +4 -1
  71. wandb/sdk/lib/apikey.py +0 -5
  72. wandb/sdk/lib/config_util.py +0 -31
  73. wandb/sdk/lib/filesystem.py +11 -1
  74. wandb/sdk/lib/run_moment.py +72 -0
  75. wandb/sdk/service/service.py +7 -2
  76. wandb/sdk/service/streams.py +1 -6
  77. wandb/sdk/verify/verify.py +2 -1
  78. wandb/sdk/wandb_init.py +12 -1
  79. wandb/sdk/wandb_login.py +43 -26
  80. wandb/sdk/wandb_run.py +164 -110
  81. wandb/sdk/wandb_settings.py +58 -16
  82. wandb/testing/relay.py +5 -6
  83. wandb/util.py +50 -7
  84. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/METADATA +8 -1
  85. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/RECORD +89 -82
  86. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/WHEEL +1 -1
  87. wandb/apis/importers/base.py +0 -400
  88. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/LICENSE +0 -0
  89. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/entry_points.txt +0 -0
  90. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/top_level.txt +0 -0
wandb/__init__.py CHANGED
@@ -11,8 +11,8 @@ For scripts and interactive notebooks, see https://github.com/wandb/examples.
11
11
 
12
12
  For reference documentation, see https://docs.wandb.com/ref/python.
13
13
  """
14
- __version__ = "0.16.3"
15
- _minimum_core_version = "0.17.0b8"
14
+ __version__ = "0.16.5"
15
+ _minimum_core_version = "0.17.0b10"
16
16
 
17
17
  # Used with pypi checks and other messages related to pip
18
18
  _wandb_module = "wandb"
wandb/agents/pyagent.py CHANGED
@@ -347,7 +347,7 @@ def pyagent(sweep_id, function, entity=None, project=None, count=None):
347
347
  count (int, optional): the number of trials to run.
348
348
  """
349
349
  if not callable(function):
350
- raise Exception("function paramter must be callable!")
350
+ raise Exception("function parameter must be callable!")
351
351
  agent = Agent(
352
352
  sweep_id,
353
353
  function=function,
@@ -1,4 +1 @@
1
- from wandb.util import get_module
2
-
3
- if get_module("mlflow"):
4
- from .mlflow import MlflowImporter, MlflowRun # noqa: F401
1
+ from .internals.util import Namespace
@@ -0,0 +1,386 @@
1
+ import json
2
+ import logging
3
+ import math
4
+ import os
5
+ import queue
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Iterable, Optional
9
+
10
+ import numpy as np
11
+ from google.protobuf.json_format import ParseDict
12
+ from tenacity import retry, stop_after_attempt, wait_random_exponential
13
+
14
+ from wandb import Artifact
15
+ from wandb.proto import wandb_internal_pb2 as pb
16
+ from wandb.proto import wandb_settings_pb2
17
+ from wandb.proto import wandb_telemetry_pb2 as telem_pb
18
+ from wandb.sdk.interface.interface import file_policy_to_enum
19
+ from wandb.sdk.interface.interface_queue import InterfaceQueue
20
+ from wandb.sdk.internal import context
21
+ from wandb.sdk.internal.sender import SendManager
22
+ from wandb.sdk.internal.settings_static import SettingsStatic
23
+ from wandb.util import coalesce, recursive_cast_dictlike_to_dict
24
+
25
+ from .protocols import ImporterRun
26
+
27
+ ROOT_DIR = "./wandb-importer"
28
+
29
+
30
+ logger = logging.getLogger(__name__)
31
+ logger.setLevel(logging.INFO)
32
+
33
+ if os.getenv("WANDB_IMPORTER_ENABLE_RICH_LOGGING"):
34
+ from rich.logging import RichHandler
35
+
36
+ logger.addHandler(RichHandler(rich_tracebacks=True, tracebacks_show_locals=True))
37
+ else:
38
+ console_handler = logging.StreamHandler()
39
+ console_handler.setLevel(logging.INFO)
40
+
41
+ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
42
+ console_handler.setFormatter(formatter)
43
+
44
+ logger.addHandler(console_handler)
45
+
46
+
47
+ exp_retry = retry(
48
+ wait=wait_random_exponential(multiplier=1, max=10), stop=stop_after_attempt(3)
49
+ )
50
+
51
+
52
+ class AlternateSendManager(SendManager):
53
+ def __init__(self, *args, **kwargs):
54
+ super().__init__(*args, **kwargs)
55
+ self._send_artifact = exp_retry(self._send_artifact)
56
+
57
+
58
+ @dataclass(frozen=True)
59
+ class SendManagerConfig:
60
+ """Configure which parts of SendManager tooling to use."""
61
+
62
+ use_artifacts: bool = False
63
+ log_artifacts: bool = False
64
+ metadata: bool = False
65
+ files: bool = False
66
+ media: bool = False
67
+ code: bool = False
68
+ history: bool = False
69
+ summary: bool = False
70
+ terminal_output: bool = False
71
+
72
+
73
+ @dataclass
74
+ class RecordMaker:
75
+ run: ImporterRun
76
+ interface: InterfaceQueue = InterfaceQueue()
77
+
78
+ @property
79
+ def run_dir(self) -> str:
80
+ p = Path(f"{ROOT_DIR}/{self.run.run_id()}/wandb")
81
+ p.mkdir(parents=True, exist_ok=True)
82
+ return f"{ROOT_DIR}/{self.run.run_id()}"
83
+
84
+ def make_artifacts_only_records(
85
+ self,
86
+ artifacts: Optional[Iterable[Artifact]] = None,
87
+ used_artifacts: Optional[Iterable[Artifact]] = None,
88
+ ) -> Iterable[pb.Record]:
89
+ """Only make records required to upload artifacts.
90
+
91
+ Escape hatch for adding extra artifacts to a run.
92
+ """
93
+ yield self._make_run_record()
94
+
95
+ if used_artifacts:
96
+ for art in used_artifacts:
97
+ yield self._make_artifact_record(art, use_artifact=True)
98
+
99
+ if artifacts:
100
+ for art in artifacts:
101
+ yield self._make_artifact_record(art)
102
+
103
+ def make_records(
104
+ self,
105
+ config: SendManagerConfig,
106
+ ) -> Iterable[pb.Record]:
107
+ """Make all the records that constitute a run."""
108
+ yield self._make_run_record()
109
+ yield self._make_telem_record()
110
+
111
+ include_artifacts = config.log_artifacts or config.use_artifacts
112
+ yield self._make_files_record(
113
+ include_artifacts, config.files, config.media, config.code
114
+ )
115
+
116
+ if config.use_artifacts:
117
+ if (used_artifacts := self.run.used_artifacts()) is not None:
118
+ for artifact in used_artifacts:
119
+ yield self._make_artifact_record(artifact, use_artifact=True)
120
+
121
+ if config.log_artifacts:
122
+ if (artifacts := self.run.artifacts()) is not None:
123
+ for artifact in artifacts:
124
+ yield self._make_artifact_record(artifact)
125
+
126
+ if config.history:
127
+ yield from self._make_history_records()
128
+
129
+ if config.summary:
130
+ yield self._make_summary_record()
131
+
132
+ if config.terminal_output:
133
+ if (lines := self.run.logs()) is not None:
134
+ for line in lines:
135
+ yield self._make_output_record(line)
136
+
137
+ def _make_run_record(self) -> pb.Record:
138
+ run = pb.RunRecord()
139
+ run.run_id = self.run.run_id()
140
+ run.entity = self.run.entity()
141
+ run.project = self.run.project()
142
+ run.display_name = coalesce(self.run.display_name())
143
+ run.notes = coalesce(self.run.notes(), "")
144
+ run.tags.extend(coalesce(self.run.tags(), []))
145
+ run.start_time.FromMilliseconds(self.run.start_time())
146
+
147
+ host = self.run.host()
148
+ if host is not None:
149
+ run.host = host
150
+
151
+ runtime = self.run.runtime()
152
+ if runtime is not None:
153
+ run.runtime = runtime
154
+
155
+ run_group = self.run.run_group()
156
+ if run_group is not None:
157
+ run.run_group = run_group
158
+
159
+ config = self.run.config()
160
+ if "_wandb" not in config:
161
+ config["_wandb"] = {}
162
+
163
+ # how do I get this automatically?
164
+ config["_wandb"]["code_path"] = self.run.code_path()
165
+ config["_wandb"]["python_version"] = self.run.python_version()
166
+ config["_wandb"]["cli_version"] = self.run.cli_version()
167
+
168
+ self.interface._make_config(
169
+ data=config,
170
+ obj=run.config,
171
+ ) # is there a better way?
172
+ return self.interface._make_record(run=run)
173
+
174
+ def _make_output_record(self, line) -> pb.Record:
175
+ output_raw = pb.OutputRawRecord()
176
+ output_raw.output_type = pb.OutputRawRecord.OutputType.STDOUT
177
+ output_raw.line = line
178
+ return self.interface._make_record(output_raw=output_raw)
179
+
180
+ def _make_summary_record(self) -> pb.Record:
181
+ d: dict = {
182
+ **self.run.summary(),
183
+ "_runtime": self.run.runtime(), # quirk of runtime -- it has to be here!
184
+ # '_timestamp': self.run.start_time()/1000,
185
+ }
186
+ d = recursive_cast_dictlike_to_dict(d)
187
+ summary = self.interface._make_summary_from_dict(d)
188
+ return self.interface._make_record(summary=summary)
189
+
190
+ def _make_history_records(self) -> Iterable[pb.Record]:
191
+ for metrics in self.run.metrics():
192
+ history = pb.HistoryRecord()
193
+ for k, v in metrics.items():
194
+ item = history.item.add()
195
+ item.key = k
196
+ # There seems to be some conversion issue to breaks when we try to re-upload.
197
+ # np.NaN gets converted to float("nan"), which is not expected by our system.
198
+ # If this cast to string (!) is not done, the row will be dropped.
199
+ if (isinstance(v, float) and math.isnan(v)) or v == "NaN":
200
+ v = np.NaN
201
+
202
+ if isinstance(v, bytes):
203
+ # it's a json string encoded as bytes
204
+ v = v.decode("utf-8")
205
+ else:
206
+ v = json.dumps(v)
207
+
208
+ item.value_json = v
209
+ rec = self.interface._make_record(history=history)
210
+ yield rec
211
+
212
+ def _make_files_record(
213
+ self, artifacts: bool, files: bool, media: bool, code: bool
214
+ ) -> pb.Record:
215
+ run_files = self.run.files()
216
+ metadata_fname = f"{self.run_dir}/files/wandb-metadata.json"
217
+ if not files or run_files is None:
218
+ # We'll always need a metadata file even if there are no other files to upload
219
+ metadata_fname = self._make_metadata_file()
220
+ run_files = [(metadata_fname, "end")]
221
+ files_record = pb.FilesRecord()
222
+ for path, policy in run_files:
223
+ if not artifacts and path.startswith("artifact/"):
224
+ continue
225
+ if not media and path.startswith("media/"):
226
+ continue
227
+ if not code and path.startswith("code/"):
228
+ continue
229
+
230
+ # DirWatcher requires the path to start with media/ instead of the full path
231
+ if "media" in path:
232
+ p = Path(path)
233
+ path = str(p.relative_to(f"{self.run_dir}/files"))
234
+ f = files_record.files.add()
235
+ f.path = path
236
+ f.policy = file_policy_to_enum(policy)
237
+
238
+ return self.interface._make_record(files=files_record)
239
+
240
+ def _make_artifact_record(
241
+ self, artifact: Artifact, use_artifact=False
242
+ ) -> pb.Record:
243
+ proto = self.interface._make_artifact(artifact)
244
+ proto.run_id = str(self.run.run_id())
245
+ proto.project = str(self.run.project())
246
+ proto.entity = str(self.run.entity())
247
+ proto.user_created = use_artifact
248
+ proto.use_after_commit = use_artifact
249
+ proto.finalize = True
250
+
251
+ aliases = artifact._aliases
252
+ aliases += ["latest", "imported"]
253
+
254
+ for alias in aliases:
255
+ proto.aliases.append(alias)
256
+ return self.interface._make_record(artifact=proto)
257
+
258
+ def _make_telem_record(self) -> pb.Record:
259
+ telem = telem_pb.TelemetryRecord()
260
+
261
+ feature = telem_pb.Feature()
262
+ feature.importer_mlflow = True
263
+ telem.feature.CopyFrom(feature)
264
+
265
+ cli_version = self.run.cli_version()
266
+ if cli_version:
267
+ telem.cli_version = cli_version
268
+
269
+ python_version = self.run.python_version()
270
+ if python_version:
271
+ telem.python_version = python_version
272
+
273
+ return self.interface._make_record(telemetry=telem)
274
+
275
+ def _make_metadata_file(self) -> str:
276
+ missing_text = "This data was not captured"
277
+ files_dir = f"{self.run_dir}/files"
278
+ os.makedirs(files_dir, exist_ok=True)
279
+
280
+ d = {}
281
+ d["os"] = coalesce(self.run.os_version(), missing_text)
282
+ d["python"] = coalesce(self.run.python_version(), missing_text)
283
+ d["program"] = coalesce(self.run.program(), missing_text)
284
+ d["cuda"] = coalesce(self.run.cuda_version(), missing_text)
285
+ d["host"] = coalesce(self.run.host(), missing_text)
286
+ d["username"] = coalesce(self.run.username(), missing_text)
287
+ d["executable"] = coalesce(self.run.executable(), missing_text)
288
+
289
+ gpus_used = self.run.gpus_used()
290
+ if gpus_used is not None:
291
+ d["gpu_devices"] = json.dumps(gpus_used)
292
+ d["gpu_count"] = json.dumps(len(gpus_used))
293
+
294
+ cpus_used = self.run.cpus_used()
295
+ if cpus_used is not None:
296
+ d["cpu_count"] = json.dumps(self.run.cpus_used())
297
+
298
+ mem_used = self.run.memory_used()
299
+ if mem_used is not None:
300
+ d["memory"] = json.dumps({"total": self.run.memory_used()})
301
+
302
+ fname = f"{files_dir}/wandb-metadata.json"
303
+ with open(fname, "w") as f:
304
+ f.write(json.dumps(d))
305
+ return fname
306
+
307
+
308
+ def _make_settings(
309
+ root_dir: str, settings_override: Optional[Dict[str, Any]] = None
310
+ ) -> SettingsStatic:
311
+ _settings_override = coalesce(settings_override, {})
312
+
313
+ default_settings: Dict[str, Any] = {
314
+ "files_dir": os.path.join(root_dir, "files"),
315
+ "root_dir": root_dir,
316
+ "sync_file": os.path.join(root_dir, "txlog.wandb"),
317
+ "resume": "false",
318
+ "program": None,
319
+ "ignore_globs": [],
320
+ "disable_job_creation": True,
321
+ "_start_time": 0,
322
+ "_offline": None,
323
+ "_sync": True,
324
+ "_live_policy_rate_limit": 15, # matches dir_watcher
325
+ "_live_policy_wait_time": 600, # matches dir_watcher
326
+ "_async_upload_concurrency_limit": None,
327
+ "_file_stream_timeout_seconds": 60,
328
+ }
329
+
330
+ combined_settings = {**default_settings, **_settings_override}
331
+ settings_message = wandb_settings_pb2.Settings()
332
+ ParseDict(combined_settings, settings_message)
333
+
334
+ return SettingsStatic(settings_message)
335
+
336
+
337
+ def send_run(
338
+ run: ImporterRun,
339
+ *,
340
+ extra_arts: Optional[Iterable[Artifact]] = None,
341
+ extra_used_arts: Optional[Iterable[Artifact]] = None,
342
+ config: Optional[SendManagerConfig] = None,
343
+ overrides: Optional[Dict[str, Any]] = None,
344
+ settings_override: Optional[Dict[str, Any]] = None,
345
+ ) -> None:
346
+ if config is None:
347
+ config = SendManagerConfig()
348
+
349
+ # does this need to be here for pmap?
350
+ if overrides:
351
+ for k, v in overrides.items():
352
+ # `lambda: v` won't work!
353
+ # https://stackoverflow.com/questions/10802002/why-deepcopy-doesnt-create-new-references-to-lambda-function
354
+ setattr(run, k, lambda v=v: v)
355
+
356
+ rm = RecordMaker(run)
357
+ root_dir = rm.run_dir
358
+
359
+ settings = _make_settings(root_dir, settings_override)
360
+ sm_record_q = queue.Queue()
361
+ # wm_record_q = queue.Queue()
362
+ result_q = queue.Queue()
363
+ interface = InterfaceQueue(record_q=sm_record_q)
364
+ context_keeper = context.ContextKeeper()
365
+ sm = AlternateSendManager(
366
+ settings, sm_record_q, result_q, interface, context_keeper
367
+ )
368
+ # wm = WriteManager(
369
+ # settings, wm_record_q, result_q, sm_record_q, interface, context_keeper
370
+ # )
371
+
372
+ if extra_arts or extra_used_arts:
373
+ records = rm.make_artifacts_only_records(extra_arts, extra_used_arts)
374
+ else:
375
+ records = rm.make_records(config)
376
+
377
+ for r in records:
378
+ logger.debug(f"Sending {r=}")
379
+ # In a future update, it might be good to write to a transaction log and have
380
+ # incremental uploads only send the missing records.
381
+ # wm.write(r)
382
+
383
+ sm.send(r)
384
+
385
+ sm.finish()
386
+ # wm.finish()
@@ -0,0 +1,125 @@
1
+ import logging
2
+ import sys
3
+ from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
4
+
5
+ from wandb.sdk.artifacts.artifact import Artifact
6
+
7
+ if sys.version_info >= (3, 8):
8
+ from typing import Protocol, runtime_checkable
9
+ else:
10
+ from typing_extensions import Protocol, runtime_checkable
11
+
12
+ logger = logging.getLogger("import_logger")
13
+
14
+ PathStr = str
15
+ Policy = Literal["now", "end", "live"]
16
+
17
+
18
+ @runtime_checkable
19
+ class ImporterRun(Protocol):
20
+ def run_id(self) -> str:
21
+ ... # pragma: no cover
22
+
23
+ def entity(self) -> str:
24
+ ... # pragma: no cover
25
+
26
+ def project(self) -> str:
27
+ ... # pragma: no cover
28
+
29
+ def config(self) -> Dict[str, Any]:
30
+ ... # pragma: no cover
31
+
32
+ def summary(self) -> Dict[str, float]:
33
+ ... # pragma: no cover
34
+
35
+ def metrics(self) -> Iterable[Dict[str, float]]:
36
+ """Metrics for the run.
37
+
38
+ We expect metrics in this shape:
39
+
40
+ [
41
+ {'metric1': 1, 'metric2': 1, '_step': 0},
42
+ {'metric1': 2, 'metric2': 4, '_step': 1},
43
+ {'metric1': 3, 'metric2': 9, '_step': 2},
44
+ ...
45
+ ]
46
+
47
+ You can also submit metrics in this shape:
48
+ [
49
+ {'metric1': 1, '_step': 0},
50
+ {'metric2': 1, '_step': 0},
51
+ {'metric1': 2, '_step': 1},
52
+ {'metric2': 4, '_step': 1},
53
+ ...
54
+ ]
55
+ """
56
+ ... # pragma: no cover
57
+
58
+ def run_group(self) -> Optional[str]:
59
+ ... # pragma: no cover
60
+
61
+ def job_type(self) -> Optional[str]:
62
+ ... # pragma: no cover
63
+
64
+ def display_name(self) -> str:
65
+ ... # pragma: no cover
66
+
67
+ def notes(self) -> Optional[str]:
68
+ ... # pragma: no cover
69
+
70
+ def tags(self) -> Optional[List[str]]:
71
+ ... # pragma: no cover
72
+
73
+ def artifacts(self) -> Optional[Iterable[Artifact]]:
74
+ ... # pragma: no cover
75
+
76
+ def used_artifacts(self) -> Optional[Iterable[Artifact]]:
77
+ ... # pragma: no cover
78
+
79
+ def os_version(self) -> Optional[str]:
80
+ ... # pragma: no cover
81
+
82
+ def python_version(self) -> Optional[str]:
83
+ ... # pragma: no cover
84
+
85
+ def cuda_version(self) -> Optional[str]:
86
+ ... # pragma: no cover
87
+
88
+ def program(self) -> Optional[str]:
89
+ ... # pragma: no cover
90
+
91
+ def host(self) -> Optional[str]:
92
+ ... # pragma: no cover
93
+
94
+ def username(self) -> Optional[str]:
95
+ ... # pragma: no cover
96
+
97
+ def executable(self) -> Optional[str]:
98
+ ... # pragma: no cover
99
+
100
+ def gpus_used(self) -> Optional[str]:
101
+ ... # pragma: no cover
102
+
103
+ def cpus_used(self) -> Optional[int]:
104
+ ... # pragma: no cover
105
+
106
+ def memory_used(self) -> Optional[int]:
107
+ ... # pragma: no cover
108
+
109
+ def runtime(self) -> Optional[int]:
110
+ ... # pragma: no cover
111
+
112
+ def start_time(self) -> Optional[int]:
113
+ ... # pragma: no cover
114
+
115
+ def code_path(self) -> Optional[str]:
116
+ ... # pragma: no cover
117
+
118
+ def cli_version(self) -> Optional[str]:
119
+ ... # pragma: no cover
120
+
121
+ def files(self) -> Optional[Iterable[Tuple[PathStr, Policy]]]:
122
+ ... # pragma: no cover
123
+
124
+ def logs(self) -> Optional[Iterable[str]]:
125
+ ... # pragma: no cover
@@ -0,0 +1,78 @@
1
+ import logging
2
+ import sys
3
+ import traceback
4
+ from concurrent.futures import ThreadPoolExecutor, as_completed
5
+ from dataclasses import dataclass
6
+ from typing import Iterable, Optional
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class Namespace:
11
+ """Configure an alternate entity/project at the dst server your data will end up in."""
12
+
13
+ entity: str
14
+ project: str
15
+
16
+ @classmethod
17
+ def from_path(cls, path: str):
18
+ entity, project = path.split("/")
19
+ return cls(entity, project)
20
+
21
+ @property
22
+ def path(self):
23
+ return f"{self.entity}/{self.project}"
24
+
25
+ @property
26
+ def send_manager_overrides(self):
27
+ overrides = {}
28
+ if self.entity:
29
+ overrides["entity"] = self.entity
30
+ if self.project:
31
+ overrides["project"] = self.project
32
+ return overrides
33
+
34
+
35
+ logger = logging.getLogger("import_logger")
36
+
37
+
38
+ def parallelize(
39
+ func,
40
+ iterable: Iterable,
41
+ *args,
42
+ max_workers: Optional[int] = None,
43
+ raise_on_error: bool = False,
44
+ **kwargs,
45
+ ):
46
+ def safe_func(*args, **kwargs):
47
+ try:
48
+ return func(*args, **kwargs)
49
+ except Exception as e:
50
+ _, _, exc_traceback = sys.exc_info()
51
+ traceback_details = traceback.extract_tb(exc_traceback)
52
+ filename = traceback_details[-1].filename
53
+ lineno = traceback_details[-1].lineno
54
+ logger.debug(
55
+ f"Exception: {func=} {args=} {kwargs=} {e=} {filename=} {lineno=}. {traceback_details=}"
56
+ )
57
+ if raise_on_error:
58
+ raise e
59
+
60
+ results = []
61
+ with ThreadPoolExecutor(max_workers) as exc:
62
+ futures = {exc.submit(safe_func, x, *args, **kwargs): x for x in iterable}
63
+ for future in as_completed(futures):
64
+ results.append(future.result())
65
+ return results
66
+
67
+
68
+ def for_each(
69
+ func, iterable: Iterable, parallel: bool = True, max_workers: Optional[int] = None
70
+ ):
71
+ if parallel:
72
+ return parallelize(
73
+ func,
74
+ iterable,
75
+ max_workers=max_workers,
76
+ )
77
+
78
+ return [func(x) for x in iterable]