wandb 0.16.5__py3-none-any.whl → 0.17.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- package_readme.md +95 -0
- wandb/__init__.py +2 -3
- wandb/agents/pyagent.py +0 -1
- wandb/analytics/sentry.py +2 -1
- wandb/apis/importers/internals/internal.py +0 -1
- wandb/apis/importers/internals/protocols.py +30 -56
- wandb/apis/importers/mlflow.py +13 -26
- wandb/apis/importers/wandb.py +8 -14
- wandb/apis/internal.py +0 -3
- wandb/apis/public/api.py +55 -3
- wandb/apis/public/artifacts.py +1 -0
- wandb/apis/public/files.py +1 -0
- wandb/apis/public/history.py +1 -0
- wandb/apis/public/jobs.py +17 -4
- wandb/apis/public/projects.py +1 -0
- wandb/apis/public/reports.py +1 -0
- wandb/apis/public/runs.py +15 -17
- wandb/apis/public/sweeps.py +1 -0
- wandb/apis/public/teams.py +1 -0
- wandb/apis/public/users.py +1 -0
- wandb/apis/reports/v1/_blocks.py +3 -7
- wandb/apis/reports/v2/gql.py +1 -0
- wandb/apis/reports/v2/interface.py +3 -4
- wandb/apis/reports/v2/internal.py +5 -8
- wandb/cli/cli.py +95 -22
- wandb/data_types.py +9 -6
- wandb/docker/__init__.py +1 -1
- wandb/env.py +38 -8
- wandb/errors/__init__.py +5 -0
- wandb/errors/term.py +10 -2
- wandb/filesync/step_checksum.py +1 -4
- wandb/filesync/step_prepare.py +4 -24
- wandb/filesync/step_upload.py +4 -106
- wandb/filesync/upload_job.py +0 -76
- wandb/integration/catboost/catboost.py +1 -1
- wandb/integration/fastai/__init__.py +1 -0
- wandb/integration/huggingface/resolver.py +2 -2
- wandb/integration/keras/__init__.py +1 -0
- wandb/integration/keras/callbacks/metrics_logger.py +1 -1
- wandb/integration/keras/keras.py +7 -7
- wandb/integration/langchain/wandb_tracer.py +1 -0
- wandb/integration/lightning/fabric/logger.py +1 -3
- wandb/integration/metaflow/metaflow.py +41 -6
- wandb/integration/openai/fine_tuning.py +77 -40
- wandb/integration/prodigy/prodigy.py +1 -1
- wandb/old/summary.py +1 -1
- wandb/plot/confusion_matrix.py +1 -1
- wandb/plot/pr_curve.py +2 -1
- wandb/plot/roc_curve.py +2 -1
- wandb/{plots → plot}/utils.py +13 -25
- wandb/proto/v3/wandb_internal_pb2.py +364 -332
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +322 -316
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/wandb_deprecated.py +7 -1
- wandb/proto/wandb_internal_codegen.py +3 -29
- wandb/sdk/artifacts/artifact.py +51 -20
- wandb/sdk/artifacts/artifact_download_logger.py +1 -0
- wandb/sdk/artifacts/artifact_file_cache.py +18 -4
- wandb/sdk/artifacts/artifact_instance_cache.py +1 -0
- wandb/sdk/artifacts/artifact_manifest.py +1 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +7 -3
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
- wandb/sdk/artifacts/artifact_saver.py +18 -27
- wandb/sdk/artifacts/artifact_state.py +1 -0
- wandb/sdk/artifacts/artifact_ttl.py +1 -0
- wandb/sdk/artifacts/exceptions.py +1 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +13 -18
- wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +5 -3
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +3 -42
- wandb/sdk/artifacts/storage_policy.py +2 -12
- wandb/sdk/data_types/_dtypes.py +8 -8
- wandb/sdk/data_types/base_types/media.py +3 -6
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +3 -1
- wandb/sdk/data_types/image.py +1 -1
- wandb/sdk/data_types/video.py +1 -1
- wandb/sdk/integration_utils/auto_logging.py +5 -6
- wandb/sdk/integration_utils/data_logging.py +10 -6
- wandb/sdk/interface/interface.py +86 -38
- wandb/sdk/interface/interface_shared.py +7 -13
- wandb/sdk/internal/datastore.py +1 -1
- wandb/sdk/internal/file_pusher.py +2 -5
- wandb/sdk/internal/file_stream.py +5 -18
- wandb/sdk/internal/handler.py +18 -2
- wandb/sdk/internal/internal.py +0 -1
- wandb/sdk/internal/internal_api.py +1 -129
- wandb/sdk/internal/internal_util.py +0 -1
- wandb/sdk/internal/job_builder.py +159 -45
- wandb/sdk/internal/profiler.py +1 -0
- wandb/sdk/internal/progress.py +0 -28
- wandb/sdk/internal/run.py +1 -0
- wandb/sdk/internal/sender.py +1 -2
- wandb/sdk/internal/system/assets/gpu_amd.py +44 -44
- wandb/sdk/internal/system/assets/gpu_apple.py +56 -11
- wandb/sdk/internal/system/assets/interfaces.py +6 -8
- wandb/sdk/internal/system/assets/open_metrics.py +2 -2
- wandb/sdk/internal/system/assets/trainium.py +1 -3
- wandb/sdk/launch/__init__.py +9 -1
- wandb/sdk/launch/_launch.py +9 -24
- wandb/sdk/launch/_launch_add.py +1 -3
- wandb/sdk/launch/_project_spec.py +188 -241
- wandb/sdk/launch/agent/agent.py +115 -48
- wandb/sdk/launch/agent/config.py +80 -14
- wandb/sdk/launch/builder/abstract.py +69 -1
- wandb/sdk/launch/builder/build.py +156 -555
- wandb/sdk/launch/builder/context_manager.py +235 -0
- wandb/sdk/launch/builder/docker_builder.py +8 -23
- wandb/sdk/launch/builder/kaniko_builder.py +161 -159
- wandb/sdk/launch/builder/noop.py +1 -0
- wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
- wandb/sdk/launch/create_job.py +68 -63
- wandb/sdk/launch/environment/abstract.py +1 -0
- wandb/sdk/launch/environment/gcp_environment.py +1 -0
- wandb/sdk/launch/environment/local_environment.py +1 -0
- wandb/sdk/launch/inputs/files.py +148 -0
- wandb/sdk/launch/inputs/internal.py +217 -0
- wandb/sdk/launch/inputs/manage.py +95 -0
- wandb/sdk/launch/loader.py +1 -0
- wandb/sdk/launch/registry/abstract.py +1 -0
- wandb/sdk/launch/registry/azure_container_registry.py +1 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +1 -0
- wandb/sdk/launch/registry/google_artifact_registry.py +2 -1
- wandb/sdk/launch/registry/local_registry.py +1 -0
- wandb/sdk/launch/runner/abstract.py +1 -0
- wandb/sdk/launch/runner/kubernetes_monitor.py +4 -1
- wandb/sdk/launch/runner/kubernetes_runner.py +9 -10
- wandb/sdk/launch/runner/local_container.py +2 -3
- wandb/sdk/launch/runner/local_process.py +8 -29
- wandb/sdk/launch/runner/sagemaker_runner.py +21 -20
- wandb/sdk/launch/runner/vertex_runner.py +8 -7
- wandb/sdk/launch/sweeps/scheduler.py +7 -4
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +3 -3
- wandb/sdk/launch/utils.py +33 -140
- wandb/sdk/lib/_settings_toposort_generated.py +1 -5
- wandb/sdk/lib/fsm.py +8 -12
- wandb/sdk/lib/gitlib.py +4 -4
- wandb/sdk/lib/import_hooks.py +1 -1
- wandb/sdk/lib/lazyloader.py +0 -1
- wandb/sdk/lib/proto_util.py +23 -2
- wandb/sdk/lib/redirect.py +19 -14
- wandb/sdk/lib/retry.py +3 -2
- wandb/sdk/lib/run_moment.py +7 -1
- wandb/sdk/lib/tracelog.py +1 -1
- wandb/sdk/service/service.py +19 -16
- wandb/sdk/verify/verify.py +2 -1
- wandb/sdk/wandb_init.py +16 -63
- wandb/sdk/wandb_manager.py +2 -2
- wandb/sdk/wandb_require.py +5 -0
- wandb/sdk/wandb_run.py +164 -90
- wandb/sdk/wandb_settings.py +2 -48
- wandb/sdk/wandb_setup.py +1 -1
- wandb/sklearn/__init__.py +1 -0
- wandb/sklearn/plot/__init__.py +1 -0
- wandb/sklearn/plot/classifier.py +11 -12
- wandb/sklearn/plot/clusterer.py +2 -1
- wandb/sklearn/plot/regressor.py +1 -0
- wandb/sklearn/plot/shared.py +1 -0
- wandb/sklearn/utils.py +1 -0
- wandb/testing/relay.py +4 -4
- wandb/trigger.py +1 -0
- wandb/util.py +67 -54
- wandb/wandb_controller.py +2 -3
- wandb/wandb_torch.py +1 -2
- {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/METADATA +67 -70
- {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/RECORD +178 -188
- {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/WHEEL +1 -2
- wandb/bin/apple_gpu_stats +0 -0
- wandb/catboost/__init__.py +0 -9
- wandb/fastai/__init__.py +0 -9
- wandb/keras/__init__.py +0 -18
- wandb/lightgbm/__init__.py +0 -9
- wandb/plots/__init__.py +0 -6
- wandb/plots/explain_text.py +0 -36
- wandb/plots/heatmap.py +0 -81
- wandb/plots/named_entity.py +0 -43
- wandb/plots/part_of_speech.py +0 -50
- wandb/plots/plot_definitions.py +0 -768
- wandb/plots/precision_recall.py +0 -121
- wandb/plots/roc.py +0 -103
- wandb/sacred/__init__.py +0 -3
- wandb/xgboost/__init__.py +0 -9
- wandb-0.16.5.dist-info/top_level.txt +0 -1
- {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info/licenses}/LICENSE +0 -0
wandb/sdk/internal/profiler.py
CHANGED
wandb/sdk/internal/progress.py
CHANGED
@@ -81,31 +81,3 @@ class Progress:
|
|
81
81
|
return self.len
|
82
82
|
|
83
83
|
next = __next__
|
84
|
-
|
85
|
-
|
86
|
-
class AsyncProgress:
|
87
|
-
"""Wrapper around Progress, to make it async iterable.
|
88
|
-
|
89
|
-
httpx, for streaming uploads, requires the data source to be an async iterable.
|
90
|
-
If we pass in a sync iterable (like a bare `Progress` instance), httpx will
|
91
|
-
get confused, think we're trying to make a synchronous request, and raise.
|
92
|
-
So we need this wrapper class to be an async iterable but *not* a sync iterable.
|
93
|
-
"""
|
94
|
-
|
95
|
-
def __init__(self, progress: Progress) -> None:
|
96
|
-
self._progress = progress
|
97
|
-
|
98
|
-
def __aiter__(self):
|
99
|
-
return self
|
100
|
-
|
101
|
-
async def __anext__(self):
|
102
|
-
try:
|
103
|
-
return next(self._progress)
|
104
|
-
except StopIteration:
|
105
|
-
raise StopAsyncIteration
|
106
|
-
|
107
|
-
def __len__(self):
|
108
|
-
return len(self._progress)
|
109
|
-
|
110
|
-
def rewind(self) -> None:
|
111
|
-
self._progress.rewind()
|
wandb/sdk/internal/run.py
CHANGED
wandb/sdk/internal/sender.py
CHANGED
@@ -327,7 +327,6 @@ class SendManager:
|
|
327
327
|
# ignore_globs=(),
|
328
328
|
_sync=True,
|
329
329
|
disable_job_creation=False,
|
330
|
-
_async_upload_concurrency_limit=None,
|
331
330
|
_file_stream_timeout_seconds=0,
|
332
331
|
)
|
333
332
|
record_q: Queue[Record] = queue.Queue()
|
@@ -910,7 +909,7 @@ class SendManager:
|
|
910
909
|
is_wandb_init = self._run is None
|
911
910
|
|
912
911
|
# save start time of a run
|
913
|
-
self._start_time = run.start_time.ToMicroseconds() // 1e6
|
912
|
+
self._start_time = int(run.start_time.ToMicroseconds() // 1e6)
|
914
913
|
|
915
914
|
# update telemetry
|
916
915
|
if run.telemetry:
|
@@ -28,14 +28,6 @@ logger = logging.getLogger(__name__)
|
|
28
28
|
ROCM_SMI_CMD: Final[str] = shutil.which("rocm-smi") or "/usr/bin/rocm-smi"
|
29
29
|
|
30
30
|
|
31
|
-
def get_rocm_smi_stats() -> Dict[str, Any]:
|
32
|
-
command = [str(ROCM_SMI_CMD), "-a", "--json"]
|
33
|
-
output = subprocess.check_output(command, universal_newlines=True).strip()
|
34
|
-
if "No AMD GPUs specified" in output:
|
35
|
-
return {}
|
36
|
-
return json.loads(output.split("\n")[0]) # type: ignore
|
37
|
-
|
38
|
-
|
39
31
|
_StatsKeys = Literal[
|
40
32
|
"gpu",
|
41
33
|
"memoryAllocated",
|
@@ -49,6 +41,48 @@ _Stats = Dict[_StatsKeys, float]
|
|
49
41
|
_InfoDict = Dict[str, Union[int, List[Dict[str, Any]]]]
|
50
42
|
|
51
43
|
|
44
|
+
def get_rocm_smi_stats() -> Dict[str, Any]:
|
45
|
+
command = [str(ROCM_SMI_CMD), "-a", "--json"]
|
46
|
+
output = subprocess.check_output(command, universal_newlines=True).strip()
|
47
|
+
if "No AMD GPUs specified" in output:
|
48
|
+
return {}
|
49
|
+
return json.loads(output.split("\n")[0]) # type: ignore
|
50
|
+
|
51
|
+
|
52
|
+
def parse_stats(stats: Dict[str, str]) -> _Stats:
|
53
|
+
"""Parse stats from rocm-smi output."""
|
54
|
+
parsed_stats: _Stats = {}
|
55
|
+
|
56
|
+
try:
|
57
|
+
parsed_stats["gpu"] = float(stats.get("GPU use (%)")) # type: ignore
|
58
|
+
except (TypeError, ValueError):
|
59
|
+
logger.warning("Could not parse GPU usage as float")
|
60
|
+
try:
|
61
|
+
parsed_stats["memoryAllocated"] = float(stats.get("GPU memory use (%)")) # type: ignore
|
62
|
+
except (TypeError, ValueError):
|
63
|
+
logger.warning("Could not parse GPU memory allocation as float")
|
64
|
+
try:
|
65
|
+
parsed_stats["temp"] = float(stats.get("Temperature (Sensor memory) (C)")) # type: ignore
|
66
|
+
except (TypeError, ValueError):
|
67
|
+
logger.warning("Could not parse GPU temperature as float")
|
68
|
+
try:
|
69
|
+
parsed_stats["powerWatts"] = float(
|
70
|
+
stats.get("Average Graphics Package Power (W)") # type: ignore
|
71
|
+
)
|
72
|
+
except (TypeError, ValueError):
|
73
|
+
logger.warning("Could not parse GPU power as float")
|
74
|
+
try:
|
75
|
+
parsed_stats["powerPercent"] = (
|
76
|
+
float(stats.get("Average Graphics Package Power (W)")) # type: ignore
|
77
|
+
/ float(stats.get("Max Graphics Package Power (W)")) # type: ignore
|
78
|
+
* 100
|
79
|
+
)
|
80
|
+
except (TypeError, ValueError):
|
81
|
+
logger.warning("Could not parse GPU average/max power as float")
|
82
|
+
|
83
|
+
return parsed_stats
|
84
|
+
|
85
|
+
|
52
86
|
class GPUAMDStats:
|
53
87
|
"""Stats for AMD GPU devices."""
|
54
88
|
|
@@ -58,40 +92,6 @@ class GPUAMDStats:
|
|
58
92
|
def __init__(self) -> None:
|
59
93
|
self.samples = deque()
|
60
94
|
|
61
|
-
@staticmethod
|
62
|
-
def parse_stats(stats: Dict[str, str]) -> _Stats:
|
63
|
-
"""Parse stats from rocm-smi output."""
|
64
|
-
parsed_stats: _Stats = {}
|
65
|
-
|
66
|
-
try:
|
67
|
-
parsed_stats["gpu"] = float(stats.get("GPU use (%)")) # type: ignore
|
68
|
-
except (TypeError, ValueError):
|
69
|
-
logger.warning("Could not parse GPU usage as float")
|
70
|
-
try:
|
71
|
-
parsed_stats["memoryAllocated"] = float(stats.get("GPU memory use (%)")) # type: ignore
|
72
|
-
except (TypeError, ValueError):
|
73
|
-
logger.warning("Could not parse GPU memory allocation as float")
|
74
|
-
try:
|
75
|
-
parsed_stats["temp"] = float(stats.get("Temperature (Sensor memory) (C)")) # type: ignore
|
76
|
-
except (TypeError, ValueError):
|
77
|
-
logger.warning("Could not parse GPU temperature as float")
|
78
|
-
try:
|
79
|
-
parsed_stats["powerWatts"] = float(
|
80
|
-
stats.get("Average Graphics Package Power (W)") # type: ignore
|
81
|
-
)
|
82
|
-
except (TypeError, ValueError):
|
83
|
-
logger.warning("Could not parse GPU power as float")
|
84
|
-
try:
|
85
|
-
parsed_stats["powerPercent"] = (
|
86
|
-
float(stats.get("Average Graphics Package Power (W)")) # type: ignore
|
87
|
-
/ float(stats.get("Max Graphics Package Power (W)")) # type: ignore
|
88
|
-
* 100
|
89
|
-
)
|
90
|
-
except (TypeError, ValueError):
|
91
|
-
logger.warning("Could not parse GPU average/max power as float")
|
92
|
-
|
93
|
-
return parsed_stats
|
94
|
-
|
95
95
|
def sample(self) -> None:
|
96
96
|
try:
|
97
97
|
raw_stats = get_rocm_smi_stats()
|
@@ -103,7 +103,7 @@ class GPUAMDStats:
|
|
103
103
|
|
104
104
|
for card_key in card_keys:
|
105
105
|
card_stats = raw_stats[card_key]
|
106
|
-
stats =
|
106
|
+
stats = parse_stats(card_stats)
|
107
107
|
if stats:
|
108
108
|
cards.append(stats)
|
109
109
|
|
@@ -183,7 +183,7 @@ class GPUAMD:
|
|
183
183
|
|
184
184
|
can_read_rocm_smi = False
|
185
185
|
try:
|
186
|
-
if get_rocm_smi_stats():
|
186
|
+
if parse_stats(get_rocm_smi_stats()):
|
187
187
|
can_read_rocm_smi = True
|
188
188
|
except Exception:
|
189
189
|
pass
|
@@ -37,6 +37,12 @@ class _Stats(TypedDict):
|
|
37
37
|
# cpuWaitMs: float
|
38
38
|
|
39
39
|
|
40
|
+
def get_apple_gpu_path() -> pathlib.Path:
|
41
|
+
return (
|
42
|
+
pathlib.Path(sys.modules["wandb"].__path__[0]) / "bin" / "apple_gpu_stats"
|
43
|
+
).resolve()
|
44
|
+
|
45
|
+
|
40
46
|
class GPUAppleStats:
|
41
47
|
"""Apple GPU stats available on Arm Macs."""
|
42
48
|
|
@@ -49,9 +55,7 @@ class GPUAppleStats:
|
|
49
55
|
|
50
56
|
def __init__(self) -> None:
|
51
57
|
self.samples = deque()
|
52
|
-
self.binary_path = (
|
53
|
-
pathlib.Path(sys.modules["wandb"].__path__[0]) / "bin" / "apple_gpu_stats"
|
54
|
-
).resolve()
|
58
|
+
self.binary_path = get_apple_gpu_path()
|
55
59
|
|
56
60
|
def sample(self) -> None:
|
57
61
|
try:
|
@@ -63,22 +67,47 @@ class GPUAppleStats:
|
|
63
67
|
)[0]
|
64
68
|
raw_stats = json.loads(output)
|
65
69
|
|
70
|
+
temp_keys = [
|
71
|
+
"m1Gpu1",
|
72
|
+
"m1Gpu2",
|
73
|
+
"m1Gpu3",
|
74
|
+
"m1Gpu4",
|
75
|
+
"m2Gpu1",
|
76
|
+
"m2Gpu2",
|
77
|
+
"m3Gpu1",
|
78
|
+
"m3Gpu2",
|
79
|
+
"m3Gpu3",
|
80
|
+
"m3Gpu4",
|
81
|
+
"m3Gpu5",
|
82
|
+
"m3Gpu6",
|
83
|
+
"m3Gpu7",
|
84
|
+
"m3Gpu8",
|
85
|
+
]
|
86
|
+
temp, count = 0, 0
|
87
|
+
for k in temp_keys:
|
88
|
+
if raw_stats.get(k, 0) > 0:
|
89
|
+
temp += raw_stats[k]
|
90
|
+
count += 1
|
91
|
+
|
66
92
|
stats: _Stats = {
|
67
93
|
"gpu": raw_stats["utilization"],
|
68
|
-
"memoryAllocated":
|
69
|
-
|
70
|
-
|
71
|
-
|
94
|
+
"memoryAllocated": (
|
95
|
+
raw_stats["inUseSystemMemory"]
|
96
|
+
/ raw_stats["allocatedSystemMemory"]
|
97
|
+
* 100
|
98
|
+
),
|
99
|
+
"powerWatts": raw_stats["systemPower"],
|
100
|
+
"powerPercent": (raw_stats["systemPower"] / self.MAX_POWER_WATTS) * 100,
|
101
|
+
"temp": temp / count if count > 0 else 0,
|
72
102
|
# TODO: this stat could be useful eventually, it was consistently
|
73
103
|
# 0 in my experimentation and requires a frontend change
|
74
104
|
# so leaving it out for now.
|
75
105
|
# "cpuWaitMs": raw_stats["cpu_wait_ms"],
|
76
106
|
}
|
77
|
-
|
78
107
|
self.samples.append(stats)
|
79
108
|
|
80
109
|
except (OSError, ValueError, TypeError, subprocess.CalledProcessError) as e:
|
81
|
-
logger.exception(
|
110
|
+
logger.exception("GPU stats error: %s", e)
|
82
111
|
|
83
112
|
def clear(self) -> None:
|
84
113
|
self.samples.clear()
|
@@ -116,6 +145,7 @@ class GPUApple:
|
|
116
145
|
telemetry_record = telemetry.TelemetryRecord()
|
117
146
|
telemetry_record.env.m1_gpu = True
|
118
147
|
interface._publish_telemetry(telemetry_record)
|
148
|
+
self.binary_path = get_apple_gpu_path()
|
119
149
|
|
120
150
|
@classmethod
|
121
151
|
def is_available(cls) -> bool:
|
@@ -128,5 +158,20 @@ class GPUApple:
|
|
128
158
|
self.metrics_monitor.finish()
|
129
159
|
|
130
160
|
def probe(self) -> dict:
|
131
|
-
|
132
|
-
|
161
|
+
try:
|
162
|
+
command = [str(self.binary_path), "--json"]
|
163
|
+
output = (
|
164
|
+
subprocess.check_output(command, universal_newlines=True)
|
165
|
+
.strip()
|
166
|
+
.split("\n")
|
167
|
+
)[0]
|
168
|
+
raw_stats = json.loads(output)
|
169
|
+
return {
|
170
|
+
self.name: {
|
171
|
+
"type": raw_stats["name"],
|
172
|
+
"vendor": raw_stats["vendor"],
|
173
|
+
}
|
174
|
+
}
|
175
|
+
except (OSError, ValueError, TypeError, subprocess.CalledProcessError) as e:
|
176
|
+
logger.exception("GPU stats error: %s", e)
|
177
|
+
return {self.name: {"type": "arm", "vendor": "Apple"}}
|
@@ -68,8 +68,7 @@ class Asset(Protocol):
|
|
68
68
|
metrics: List[Metric]
|
69
69
|
metrics_monitor: "MetricsMonitor"
|
70
70
|
|
71
|
-
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
72
|
-
... # pragma: no cover
|
71
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None: ... # pragma: no cover
|
73
72
|
|
74
73
|
@classmethod
|
75
74
|
def is_available(cls) -> bool:
|
@@ -90,14 +89,13 @@ class Asset(Protocol):
|
|
90
89
|
|
91
90
|
|
92
91
|
class Interface(Protocol):
|
93
|
-
def publish_stats(self, stats: dict) -> None:
|
94
|
-
... # pragma: no cover
|
92
|
+
def publish_stats(self, stats: dict) -> None: ... # pragma: no cover
|
95
93
|
|
96
|
-
def _publish_telemetry(
|
97
|
-
|
94
|
+
def _publish_telemetry(
|
95
|
+
self, telemetry: "TelemetryRecord"
|
96
|
+
) -> None: ... # pragma: no cover
|
98
97
|
|
99
|
-
def publish_files(self, files_dict: "FilesDict") -> None:
|
100
|
-
... # pragma: no cover
|
98
|
+
def publish_files(self, files_dict: "FilesDict") -> None: ... # pragma: no cover
|
101
99
|
|
102
100
|
|
103
101
|
class MetricsMonitor:
|
@@ -65,13 +65,13 @@ def _setup_requests_session() -> requests.Session:
|
|
65
65
|
|
66
66
|
|
67
67
|
def _nested_dict_to_tuple(
|
68
|
-
nested_dict: Mapping[str, Mapping[str, str]]
|
68
|
+
nested_dict: Mapping[str, Mapping[str, str]],
|
69
69
|
) -> Tuple[Tuple[str, Tuple[str, str]], ...]:
|
70
70
|
return tuple((k, *v.items()) for k, v in nested_dict.items()) # type: ignore
|
71
71
|
|
72
72
|
|
73
73
|
def _tuple_to_nested_dict(
|
74
|
-
nested_tuple: Tuple[Tuple[str, Tuple[str, str]], ...]
|
74
|
+
nested_tuple: Tuple[Tuple[str, Tuple[str, str]], ...],
|
75
75
|
) -> Dict[str, Dict[str, str]]:
|
76
76
|
return {k: dict(v) for k, *v in nested_tuple}
|
77
77
|
|
@@ -197,9 +197,7 @@ class NeuronCoreStats:
|
|
197
197
|
entry["report"]
|
198
198
|
for entry in raw_stats["neuron_runtime_data"]
|
199
199
|
if self._is_matching_entry(entry)
|
200
|
-
][
|
201
|
-
0
|
202
|
-
] # there should be only one entry with the pid
|
200
|
+
][0] # there should be only one entry with the pid
|
203
201
|
|
204
202
|
neuroncores_in_use = neuron_runtime_data["neuroncore_counters"][
|
205
203
|
"neuroncores_in_use"
|
wandb/sdk/launch/__init__.py
CHANGED
@@ -1,6 +1,14 @@
|
|
1
1
|
from ._launch import launch
|
2
2
|
from ._launch_add import launch_add
|
3
3
|
from .agent.agent import LaunchAgent
|
4
|
+
from .inputs.manage import manage_config_file, manage_wandb_config
|
4
5
|
from .utils import load_wandb_config
|
5
6
|
|
6
|
-
__all__ = [
|
7
|
+
__all__ = [
|
8
|
+
"LaunchAgent",
|
9
|
+
"launch",
|
10
|
+
"launch_add",
|
11
|
+
"load_wandb_config",
|
12
|
+
"manage_config_file",
|
13
|
+
"manage_wandb_config",
|
14
|
+
]
|
wandb/sdk/launch/_launch.py
CHANGED
@@ -12,13 +12,12 @@ from wandb.apis.internal import Api
|
|
12
12
|
from . import loader
|
13
13
|
from ._project_spec import LaunchProject
|
14
14
|
from .agent import LaunchAgent
|
15
|
-
from .
|
15
|
+
from .agent.agent import construct_agent_configs
|
16
16
|
from .environment.local_environment import LocalEnvironment
|
17
17
|
from .errors import ExecutionError, LaunchError
|
18
18
|
from .runner.abstract import AbstractRun
|
19
19
|
from .utils import (
|
20
20
|
LAUNCH_CONFIG_FILE,
|
21
|
-
LAUNCH_DEFAULT_PROJECT,
|
22
21
|
PROJECT_SYNCHRONOUS,
|
23
22
|
construct_launch_spec,
|
24
23
|
validate_launch_spec_source,
|
@@ -58,33 +57,32 @@ def set_launch_logfile(logfile: str) -> None:
|
|
58
57
|
|
59
58
|
def resolve_agent_config( # noqa: C901
|
60
59
|
entity: Optional[str],
|
61
|
-
project: Optional[str],
|
62
60
|
max_jobs: Optional[int],
|
63
61
|
queues: Optional[Tuple[str]],
|
64
62
|
config: Optional[str],
|
63
|
+
verbosity: Optional[int],
|
65
64
|
) -> Tuple[Dict[str, Any], Api]:
|
66
65
|
"""Resolve the agent config.
|
67
66
|
|
68
67
|
Arguments:
|
69
68
|
api (Api): The api.
|
70
69
|
entity (str): The entity.
|
71
|
-
project (str): The project.
|
72
70
|
max_jobs (int): The max number of jobs.
|
73
71
|
queues (Tuple[str]): The queues.
|
74
72
|
config (str): The config.
|
73
|
+
verbosity (int): How verbose to print, 0 or None = default, 1 = print status every 20 seconds, 2 = also print debugging information
|
75
74
|
|
76
75
|
Returns:
|
77
76
|
Tuple[Dict[str, Any], Api]: The resolved config and api.
|
78
77
|
"""
|
79
78
|
defaults = {
|
80
|
-
"project": LAUNCH_DEFAULT_PROJECT,
|
81
79
|
"max_jobs": 1,
|
82
80
|
"max_schedulers": 1,
|
83
81
|
"queues": [],
|
84
82
|
"registry": {},
|
85
83
|
"builder": {},
|
84
|
+
"verbosity": 0,
|
86
85
|
}
|
87
|
-
user_set_project = False
|
88
86
|
resolved_config: Dict[str, Any] = defaults
|
89
87
|
config_path = config or os.path.expanduser(LAUNCH_CONFIG_FILE)
|
90
88
|
if os.path.isfile(config_path):
|
@@ -97,16 +95,11 @@ def resolve_agent_config( # noqa: C901
|
|
97
95
|
launch_config = {} # type: ignore
|
98
96
|
except yaml.YAMLError as e:
|
99
97
|
raise LaunchError(f"Invalid launch agent config: {e}")
|
100
|
-
if launch_config.get("project") is not None:
|
101
|
-
user_set_project = True
|
102
98
|
resolved_config.update(launch_config.items())
|
103
99
|
elif config is not None:
|
104
100
|
raise LaunchError(
|
105
101
|
f"Could not find use specified launch config file: {config_path}"
|
106
102
|
)
|
107
|
-
if os.environ.get("WANDB_PROJECT") is not None:
|
108
|
-
resolved_config.update({"project": os.environ.get("WANDB_PROJECT")})
|
109
|
-
user_set_project = True
|
110
103
|
if os.environ.get("WANDB_ENTITY") is not None:
|
111
104
|
resolved_config.update({"entity": os.environ.get("WANDB_ENTITY")})
|
112
105
|
if os.environ.get("WANDB_LAUNCH_MAX_JOBS") is not None:
|
@@ -114,15 +107,14 @@ def resolve_agent_config( # noqa: C901
|
|
114
107
|
{"max_jobs": int(os.environ.get("WANDB_LAUNCH_MAX_JOBS", 1))}
|
115
108
|
)
|
116
109
|
|
117
|
-
if project is not None:
|
118
|
-
resolved_config.update({"project": project})
|
119
|
-
user_set_project = True
|
120
110
|
if entity is not None:
|
121
111
|
resolved_config.update({"entity": entity})
|
122
112
|
if max_jobs is not None:
|
123
113
|
resolved_config.update({"max_jobs": int(max_jobs)})
|
124
114
|
if queues:
|
125
115
|
resolved_config.update({"queues": list(queues)})
|
116
|
+
if verbosity:
|
117
|
+
resolved_config.update({"verbosity": int(verbosity)})
|
126
118
|
# queue -> queues
|
127
119
|
if resolved_config.get("queue"):
|
128
120
|
if isinstance(resolved_config.get("queue"), str):
|
@@ -133,7 +125,7 @@ def resolve_agent_config( # noqa: C901
|
|
133
125
|
+ " (expected str). Specify multiple queues with the 'queues' key"
|
134
126
|
)
|
135
127
|
|
136
|
-
keys = ["
|
128
|
+
keys = ["entity"]
|
137
129
|
settings = {
|
138
130
|
k: resolved_config.get(k) for k in keys if resolved_config.get(k) is not None
|
139
131
|
}
|
@@ -142,10 +134,6 @@ def resolve_agent_config( # noqa: C901
|
|
142
134
|
|
143
135
|
if resolved_config.get("entity") is None:
|
144
136
|
resolved_config.update({"entity": api.default_entity})
|
145
|
-
if user_set_project:
|
146
|
-
wandb.termwarn(
|
147
|
-
"Specifying a project for the launch agent is deprecated. Please use queues found in the Launch application at https://wandb.ai/launch."
|
148
|
-
)
|
149
137
|
|
150
138
|
return resolved_config, api
|
151
139
|
|
@@ -183,7 +171,6 @@ def create_and_run_agent(
|
|
183
171
|
|
184
172
|
async def _launch(
|
185
173
|
api: Api,
|
186
|
-
uri: Optional[str] = None,
|
187
174
|
job: Optional[str] = None,
|
188
175
|
name: Optional[str] = None,
|
189
176
|
project: Optional[str] = None,
|
@@ -204,7 +191,7 @@ async def _launch(
|
|
204
191
|
if resource is None:
|
205
192
|
resource = "local-container"
|
206
193
|
launch_spec = construct_launch_spec(
|
207
|
-
|
194
|
+
None,
|
208
195
|
job,
|
209
196
|
api,
|
210
197
|
name,
|
@@ -223,7 +210,7 @@ async def _launch(
|
|
223
210
|
validate_launch_spec_source(launch_spec)
|
224
211
|
launch_project = LaunchProject.from_spec(launch_spec, api)
|
225
212
|
launch_project.fetch_and_validate_project()
|
226
|
-
entrypoint = launch_project.
|
213
|
+
entrypoint = launch_project.get_job_entry_point()
|
227
214
|
image_uri = launch_project.docker_image # Either set by user or None.
|
228
215
|
|
229
216
|
# construct runner config.
|
@@ -321,8 +308,6 @@ def launch(
|
|
321
308
|
"""
|
322
309
|
submitted_run_obj = asyncio.run(
|
323
310
|
_launch(
|
324
|
-
# TODO: fully deprecate URI path
|
325
|
-
uri=None,
|
326
311
|
job=job,
|
327
312
|
name=name,
|
328
313
|
project=project,
|
wandb/sdk/launch/_launch_add.py
CHANGED
@@ -109,7 +109,6 @@ def launch_add(
|
|
109
109
|
|
110
110
|
return _launch_add(
|
111
111
|
api,
|
112
|
-
uri,
|
113
112
|
job,
|
114
113
|
config,
|
115
114
|
template_variables,
|
@@ -134,7 +133,6 @@ def launch_add(
|
|
134
133
|
|
135
134
|
def _launch_add(
|
136
135
|
api: Api,
|
137
|
-
uri: Optional[str],
|
138
136
|
job: Optional[str],
|
139
137
|
config: Optional[Dict[str, Any]],
|
140
138
|
template_variables: Optional[dict],
|
@@ -156,7 +154,7 @@ def _launch_add(
|
|
156
154
|
priority: Optional[int] = None,
|
157
155
|
) -> "public.QueuedRun":
|
158
156
|
launch_spec = construct_launch_spec(
|
159
|
-
|
157
|
+
None,
|
160
158
|
job,
|
161
159
|
api,
|
162
160
|
name,
|