wandb 0.16.5__py3-none-any.whl → 0.17.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- package_readme.md +95 -0
- wandb/__init__.py +2 -3
- wandb/agents/pyagent.py +0 -1
- wandb/analytics/sentry.py +2 -1
- wandb/apis/importers/internals/internal.py +0 -1
- wandb/apis/importers/internals/protocols.py +30 -56
- wandb/apis/importers/mlflow.py +13 -26
- wandb/apis/importers/wandb.py +8 -14
- wandb/apis/internal.py +0 -3
- wandb/apis/public/api.py +55 -3
- wandb/apis/public/artifacts.py +1 -0
- wandb/apis/public/files.py +1 -0
- wandb/apis/public/history.py +1 -0
- wandb/apis/public/jobs.py +17 -4
- wandb/apis/public/projects.py +1 -0
- wandb/apis/public/reports.py +1 -0
- wandb/apis/public/runs.py +15 -17
- wandb/apis/public/sweeps.py +1 -0
- wandb/apis/public/teams.py +1 -0
- wandb/apis/public/users.py +1 -0
- wandb/apis/reports/v1/_blocks.py +3 -7
- wandb/apis/reports/v2/gql.py +1 -0
- wandb/apis/reports/v2/interface.py +3 -4
- wandb/apis/reports/v2/internal.py +5 -8
- wandb/cli/cli.py +95 -22
- wandb/data_types.py +9 -6
- wandb/docker/__init__.py +1 -1
- wandb/env.py +38 -8
- wandb/errors/__init__.py +5 -0
- wandb/errors/term.py +10 -2
- wandb/filesync/step_checksum.py +1 -4
- wandb/filesync/step_prepare.py +4 -24
- wandb/filesync/step_upload.py +4 -106
- wandb/filesync/upload_job.py +0 -76
- wandb/integration/catboost/catboost.py +1 -1
- wandb/integration/fastai/__init__.py +1 -0
- wandb/integration/huggingface/resolver.py +2 -2
- wandb/integration/keras/__init__.py +1 -0
- wandb/integration/keras/callbacks/metrics_logger.py +1 -1
- wandb/integration/keras/keras.py +7 -7
- wandb/integration/langchain/wandb_tracer.py +1 -0
- wandb/integration/lightning/fabric/logger.py +1 -3
- wandb/integration/metaflow/metaflow.py +41 -6
- wandb/integration/openai/fine_tuning.py +77 -40
- wandb/integration/prodigy/prodigy.py +1 -1
- wandb/old/summary.py +1 -1
- wandb/plot/confusion_matrix.py +1 -1
- wandb/plot/pr_curve.py +2 -1
- wandb/plot/roc_curve.py +2 -1
- wandb/{plots → plot}/utils.py +13 -25
- wandb/proto/v3/wandb_internal_pb2.py +364 -332
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +322 -316
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/wandb_deprecated.py +7 -1
- wandb/proto/wandb_internal_codegen.py +3 -29
- wandb/sdk/artifacts/artifact.py +51 -20
- wandb/sdk/artifacts/artifact_download_logger.py +1 -0
- wandb/sdk/artifacts/artifact_file_cache.py +18 -4
- wandb/sdk/artifacts/artifact_instance_cache.py +1 -0
- wandb/sdk/artifacts/artifact_manifest.py +1 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +7 -3
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
- wandb/sdk/artifacts/artifact_saver.py +18 -27
- wandb/sdk/artifacts/artifact_state.py +1 -0
- wandb/sdk/artifacts/artifact_ttl.py +1 -0
- wandb/sdk/artifacts/exceptions.py +1 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +13 -18
- wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +5 -3
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +3 -42
- wandb/sdk/artifacts/storage_policy.py +2 -12
- wandb/sdk/data_types/_dtypes.py +8 -8
- wandb/sdk/data_types/base_types/media.py +3 -6
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +3 -1
- wandb/sdk/data_types/image.py +1 -1
- wandb/sdk/data_types/video.py +1 -1
- wandb/sdk/integration_utils/auto_logging.py +5 -6
- wandb/sdk/integration_utils/data_logging.py +10 -6
- wandb/sdk/interface/interface.py +86 -38
- wandb/sdk/interface/interface_shared.py +7 -13
- wandb/sdk/internal/datastore.py +1 -1
- wandb/sdk/internal/file_pusher.py +2 -5
- wandb/sdk/internal/file_stream.py +5 -18
- wandb/sdk/internal/handler.py +18 -2
- wandb/sdk/internal/internal.py +0 -1
- wandb/sdk/internal/internal_api.py +1 -129
- wandb/sdk/internal/internal_util.py +0 -1
- wandb/sdk/internal/job_builder.py +159 -45
- wandb/sdk/internal/profiler.py +1 -0
- wandb/sdk/internal/progress.py +0 -28
- wandb/sdk/internal/run.py +1 -0
- wandb/sdk/internal/sender.py +1 -2
- wandb/sdk/internal/system/assets/gpu_amd.py +44 -44
- wandb/sdk/internal/system/assets/gpu_apple.py +56 -11
- wandb/sdk/internal/system/assets/interfaces.py +6 -8
- wandb/sdk/internal/system/assets/open_metrics.py +2 -2
- wandb/sdk/internal/system/assets/trainium.py +1 -3
- wandb/sdk/launch/__init__.py +9 -1
- wandb/sdk/launch/_launch.py +9 -24
- wandb/sdk/launch/_launch_add.py +1 -3
- wandb/sdk/launch/_project_spec.py +188 -241
- wandb/sdk/launch/agent/agent.py +115 -48
- wandb/sdk/launch/agent/config.py +80 -14
- wandb/sdk/launch/builder/abstract.py +69 -1
- wandb/sdk/launch/builder/build.py +156 -555
- wandb/sdk/launch/builder/context_manager.py +235 -0
- wandb/sdk/launch/builder/docker_builder.py +8 -23
- wandb/sdk/launch/builder/kaniko_builder.py +161 -159
- wandb/sdk/launch/builder/noop.py +1 -0
- wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
- wandb/sdk/launch/create_job.py +68 -63
- wandb/sdk/launch/environment/abstract.py +1 -0
- wandb/sdk/launch/environment/gcp_environment.py +1 -0
- wandb/sdk/launch/environment/local_environment.py +1 -0
- wandb/sdk/launch/inputs/files.py +148 -0
- wandb/sdk/launch/inputs/internal.py +217 -0
- wandb/sdk/launch/inputs/manage.py +95 -0
- wandb/sdk/launch/loader.py +1 -0
- wandb/sdk/launch/registry/abstract.py +1 -0
- wandb/sdk/launch/registry/azure_container_registry.py +1 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +1 -0
- wandb/sdk/launch/registry/google_artifact_registry.py +2 -1
- wandb/sdk/launch/registry/local_registry.py +1 -0
- wandb/sdk/launch/runner/abstract.py +1 -0
- wandb/sdk/launch/runner/kubernetes_monitor.py +4 -1
- wandb/sdk/launch/runner/kubernetes_runner.py +9 -10
- wandb/sdk/launch/runner/local_container.py +2 -3
- wandb/sdk/launch/runner/local_process.py +8 -29
- wandb/sdk/launch/runner/sagemaker_runner.py +21 -20
- wandb/sdk/launch/runner/vertex_runner.py +8 -7
- wandb/sdk/launch/sweeps/scheduler.py +7 -4
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +3 -3
- wandb/sdk/launch/utils.py +33 -140
- wandb/sdk/lib/_settings_toposort_generated.py +1 -5
- wandb/sdk/lib/fsm.py +8 -12
- wandb/sdk/lib/gitlib.py +4 -4
- wandb/sdk/lib/import_hooks.py +1 -1
- wandb/sdk/lib/lazyloader.py +0 -1
- wandb/sdk/lib/proto_util.py +23 -2
- wandb/sdk/lib/redirect.py +19 -14
- wandb/sdk/lib/retry.py +3 -2
- wandb/sdk/lib/run_moment.py +7 -1
- wandb/sdk/lib/tracelog.py +1 -1
- wandb/sdk/service/service.py +19 -16
- wandb/sdk/verify/verify.py +2 -1
- wandb/sdk/wandb_init.py +16 -63
- wandb/sdk/wandb_manager.py +2 -2
- wandb/sdk/wandb_require.py +5 -0
- wandb/sdk/wandb_run.py +164 -90
- wandb/sdk/wandb_settings.py +2 -48
- wandb/sdk/wandb_setup.py +1 -1
- wandb/sklearn/__init__.py +1 -0
- wandb/sklearn/plot/__init__.py +1 -0
- wandb/sklearn/plot/classifier.py +11 -12
- wandb/sklearn/plot/clusterer.py +2 -1
- wandb/sklearn/plot/regressor.py +1 -0
- wandb/sklearn/plot/shared.py +1 -0
- wandb/sklearn/utils.py +1 -0
- wandb/testing/relay.py +4 -4
- wandb/trigger.py +1 -0
- wandb/util.py +67 -54
- wandb/wandb_controller.py +2 -3
- wandb/wandb_torch.py +1 -2
- {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/METADATA +67 -70
- {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/RECORD +178 -188
- {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/WHEEL +1 -2
- wandb/bin/apple_gpu_stats +0 -0
- wandb/catboost/__init__.py +0 -9
- wandb/fastai/__init__.py +0 -9
- wandb/keras/__init__.py +0 -18
- wandb/lightgbm/__init__.py +0 -9
- wandb/plots/__init__.py +0 -6
- wandb/plots/explain_text.py +0 -36
- wandb/plots/heatmap.py +0 -81
- wandb/plots/named_entity.py +0 -43
- wandb/plots/part_of_speech.py +0 -50
- wandb/plots/plot_definitions.py +0 -768
- wandb/plots/precision_recall.py +0 -121
- wandb/plots/roc.py +0 -103
- wandb/sacred/__init__.py +0 -3
- wandb/xgboost/__init__.py +0 -9
- wandb-0.16.5.dist-info/top_level.txt +0 -1
- {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info/licenses}/LICENSE +0 -0
wandb/sdk/launch/utils.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1
|
-
# heavily inspired by https://github.com/mlflow/mlflow/blob/master/mlflow/projects/utils.py
|
2
1
|
import asyncio
|
3
2
|
import json
|
4
3
|
import logging
|
@@ -16,7 +15,6 @@ import wandb
|
|
16
15
|
import wandb.docker as docker
|
17
16
|
from wandb import util
|
18
17
|
from wandb.apis.internal import Api
|
19
|
-
from wandb.errors import CommError
|
20
18
|
from wandb.sdk.launch.errors import LaunchError
|
21
19
|
from wandb.sdk.launch.git_reference import GitReference
|
22
20
|
from wandb.sdk.launch.wandb_reference import WandbReference
|
@@ -32,7 +30,6 @@ FAILED_PACKAGES_REGEX = re.compile(
|
|
32
30
|
)
|
33
31
|
|
34
32
|
if TYPE_CHECKING: # pragma: no cover
|
35
|
-
from wandb.sdk.artifacts.artifact import Artifact
|
36
33
|
from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
|
37
34
|
|
38
35
|
|
@@ -57,15 +54,15 @@ API_KEY_REGEX = r"WANDB_API_KEY=\w+(-\w+)?"
|
|
57
54
|
MACRO_REGEX = re.compile(r"\$\{(\w+)\}")
|
58
55
|
|
59
56
|
AZURE_CONTAINER_REGISTRY_URI_REGEX = re.compile(
|
60
|
-
r"(?:https://)?([\w]+)\.azurecr\.io/([\w\-]+):?(
|
57
|
+
r"^(?:https://)?([\w]+)\.azurecr\.io/(?P<repository>[\w\-]+):?(?P<tag>.*)"
|
61
58
|
)
|
62
59
|
|
63
60
|
ELASTIC_CONTAINER_REGISTRY_URI_REGEX = re.compile(
|
64
|
-
r"^(?P<account
|
61
|
+
r"^(?:https://)?(?P<account>[\w-]+)\.dkr\.ecr\.(?P<region>[\w-]+)\.amazonaws\.com/(?P<repository>[\w-]+):?(?P<tag>.*)$"
|
65
62
|
)
|
66
63
|
|
67
64
|
GCP_ARTIFACT_REGISTRY_URI_REGEX = re.compile(
|
68
|
-
r"^(?P<region>[\w-]+)-docker\.pkg\.dev/(?P<project>[\w-]+)/(?P<repository>[\w-]+)
|
65
|
+
r"^(?:https://)?(?P<region>[\w-]+)-docker\.pkg\.dev/(?P<project>[\w-]+)/(?P<repository>[\w-]+)/?(?P<image_name>[\w-]+)?(?P<tag>:.*)?$",
|
69
66
|
re.IGNORECASE,
|
70
67
|
)
|
71
68
|
|
@@ -316,16 +313,13 @@ def construct_launch_spec(
|
|
316
313
|
|
317
314
|
|
318
315
|
def validate_launch_spec_source(launch_spec: Dict[str, Any]) -> None:
|
319
|
-
uri = launch_spec.get("uri")
|
320
316
|
job = launch_spec.get("job")
|
321
317
|
docker_image = launch_spec.get("docker", {}).get("docker_image")
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
elif sum(map(bool, [uri, job, docker_image])) > 1:
|
328
|
-
raise LaunchError("Must specify exactly one of uri, job or image")
|
318
|
+
if bool(job) == bool(docker_image):
|
319
|
+
raise LaunchError(
|
320
|
+
"Exactly one of job or docker_image must be specified in the launch "
|
321
|
+
"spec."
|
322
|
+
)
|
329
323
|
|
330
324
|
|
331
325
|
def parse_wandb_uri(uri: str) -> Tuple[str, str, str]:
|
@@ -336,77 +330,6 @@ def parse_wandb_uri(uri: str) -> Tuple[str, str, str]:
|
|
336
330
|
return (ref.entity, ref.project, ref.run_id)
|
337
331
|
|
338
332
|
|
339
|
-
def is_bare_wandb_uri(uri: str) -> bool:
|
340
|
-
"""Check that a wandb uri is valid.
|
341
|
-
|
342
|
-
URI must be in the format
|
343
|
-
`/<entity>/<project>/runs/<run_name>[other stuff]`
|
344
|
-
or
|
345
|
-
`/<entity>/<project>/artifacts/job/<job_name>[other stuff]`.
|
346
|
-
"""
|
347
|
-
_logger.info(f"Checking if uri {uri} is bare...")
|
348
|
-
return uri.startswith("/") and WandbReference.is_uri_job_or_run(uri)
|
349
|
-
|
350
|
-
|
351
|
-
def fetch_wandb_project_run_info(
|
352
|
-
entity: str, project: str, run_name: str, api: Api
|
353
|
-
) -> Any:
|
354
|
-
_logger.info("Fetching run info...")
|
355
|
-
try:
|
356
|
-
result = api.get_run_info(entity, project, run_name)
|
357
|
-
except CommError:
|
358
|
-
result = None
|
359
|
-
if result is None:
|
360
|
-
raise LaunchError(
|
361
|
-
f"Run info is invalid or doesn't exist for {api.settings('base_url')}/{entity}/{project}/runs/{run_name}"
|
362
|
-
)
|
363
|
-
if result.get("codePath") is None:
|
364
|
-
# TODO: we don't currently expose codePath in the runInfo endpoint, this downloads
|
365
|
-
# it from wandb-metadata.json if we can.
|
366
|
-
metadata = api.download_url(
|
367
|
-
project, "wandb-metadata.json", run=run_name, entity=entity
|
368
|
-
)
|
369
|
-
if metadata is not None:
|
370
|
-
_, response = api.download_file(metadata["url"])
|
371
|
-
data = response.json()
|
372
|
-
result["codePath"] = data.get("codePath")
|
373
|
-
result["cudaVersion"] = data.get("cuda", None)
|
374
|
-
|
375
|
-
return result
|
376
|
-
|
377
|
-
|
378
|
-
def download_entry_point(
|
379
|
-
entity: str, project: str, run_name: str, api: Api, entry_point: str, dir: str
|
380
|
-
) -> bool:
|
381
|
-
metadata = api.download_url(
|
382
|
-
project, f"code/{entry_point}", run=run_name, entity=entity
|
383
|
-
)
|
384
|
-
if metadata is not None:
|
385
|
-
_, response = api.download_file(metadata["url"])
|
386
|
-
with util.fsync_open(os.path.join(dir, entry_point), "wb") as file:
|
387
|
-
for data in response.iter_content(chunk_size=1024):
|
388
|
-
file.write(data)
|
389
|
-
return True
|
390
|
-
return False
|
391
|
-
|
392
|
-
|
393
|
-
def download_wandb_python_deps(
|
394
|
-
entity: str, project: str, run_name: str, api: Api, dir: str
|
395
|
-
) -> Optional[str]:
|
396
|
-
reqs = api.download_url(project, "requirements.txt", run=run_name, entity=entity)
|
397
|
-
if reqs is not None:
|
398
|
-
_logger.info("Downloading python dependencies")
|
399
|
-
_, response = api.download_file(reqs["url"])
|
400
|
-
|
401
|
-
with util.fsync_open(
|
402
|
-
os.path.join(dir, "requirements.frozen.txt"), "wb"
|
403
|
-
) as file:
|
404
|
-
for data in response.iter_content(chunk_size=1024):
|
405
|
-
file.write(data)
|
406
|
-
return "requirements.frozen.txt"
|
407
|
-
return None
|
408
|
-
|
409
|
-
|
410
333
|
def get_local_python_deps(
|
411
334
|
dir: str, filename: str = "requirements.local.txt"
|
412
335
|
) -> Optional[str]:
|
@@ -498,19 +421,6 @@ def validate_wandb_python_deps(
|
|
498
421
|
_logger.warning("Unable to validate local python dependencies")
|
499
422
|
|
500
423
|
|
501
|
-
def fetch_project_diff(
|
502
|
-
entity: str, project: str, run_name: str, api: Api
|
503
|
-
) -> Optional[str]:
|
504
|
-
"""Fetches project diff from wandb servers."""
|
505
|
-
_logger.info("Searching for diff.patch")
|
506
|
-
patch = None
|
507
|
-
try:
|
508
|
-
(_, _, patch, _) = api.run_config(project, run_name, entity)
|
509
|
-
except CommError:
|
510
|
-
pass
|
511
|
-
return patch
|
512
|
-
|
513
|
-
|
514
424
|
def apply_patch(patch_string: str, dst_dir: str) -> None:
|
515
425
|
"""Applies a patch file to a directory."""
|
516
426
|
_logger.info("Applying diff.patch")
|
@@ -531,17 +441,6 @@ def apply_patch(patch_string: str, dst_dir: str) -> None:
|
|
531
441
|
raise wandb.Error("Failed to apply diff.patch associated with run.")
|
532
442
|
|
533
443
|
|
534
|
-
def _make_refspec_from_version(version: Optional[str]) -> List[str]:
|
535
|
-
"""Create a refspec that checks for the existence of origin/main and the version."""
|
536
|
-
if version:
|
537
|
-
return [f"+{version}"]
|
538
|
-
|
539
|
-
return [
|
540
|
-
"+refs/heads/main*:refs/remotes/origin/main*",
|
541
|
-
"+refs/heads/master*:refs/remotes/origin/master*",
|
542
|
-
]
|
543
|
-
|
544
|
-
|
545
444
|
def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> Optional[str]:
|
546
445
|
"""Clones the git repo at ``uri`` into ``dst_dir``.
|
547
446
|
|
@@ -561,13 +460,6 @@ def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> Optional[
|
|
561
460
|
return version
|
562
461
|
|
563
462
|
|
564
|
-
def merge_parameters(
|
565
|
-
higher_priority_params: Dict[str, Any], lower_priority_params: Dict[str, Any]
|
566
|
-
) -> Dict[str, Any]:
|
567
|
-
"""Merge the contents of two dicts, keeping values from higher_priority_params if there are conflicts."""
|
568
|
-
return {**lower_priority_params, **higher_priority_params}
|
569
|
-
|
570
|
-
|
571
463
|
def convert_jupyter_notebook_to_script(fname: str, project_dir: str) -> str:
|
572
464
|
nbconvert = wandb.util.get_module(
|
573
465
|
"nbconvert", "nbformat and nbconvert are required to use launch with notebooks"
|
@@ -597,25 +489,6 @@ def convert_jupyter_notebook_to_script(fname: str, project_dir: str) -> str:
|
|
597
489
|
return new_name
|
598
490
|
|
599
491
|
|
600
|
-
def check_and_download_code_artifacts(
|
601
|
-
entity: str, project: str, run_name: str, internal_api: Api, project_dir: str
|
602
|
-
) -> Optional["Artifact"]:
|
603
|
-
_logger.info("Checking for code artifacts")
|
604
|
-
public_api = wandb.PublicApi(
|
605
|
-
overrides={"base_url": internal_api.settings("base_url")}
|
606
|
-
)
|
607
|
-
|
608
|
-
run = public_api.run(f"{entity}/{project}/{run_name}")
|
609
|
-
run_artifacts = run.logged_artifacts()
|
610
|
-
|
611
|
-
for artifact in run_artifacts:
|
612
|
-
if hasattr(artifact, "type") and artifact.type == "code":
|
613
|
-
artifact.download(project_dir)
|
614
|
-
return artifact # type: ignore
|
615
|
-
|
616
|
-
return None
|
617
|
-
|
618
|
-
|
619
492
|
def to_camel_case(maybe_snake_str: str) -> str:
|
620
493
|
if "_" not in maybe_snake_str:
|
621
494
|
return maybe_snake_str
|
@@ -623,11 +496,6 @@ def to_camel_case(maybe_snake_str: str) -> str:
|
|
623
496
|
return "".join(x.title() if x else "_" for x in components)
|
624
497
|
|
625
498
|
|
626
|
-
def run_shell(args: List[str]) -> Tuple[str, str]:
|
627
|
-
out = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
628
|
-
return out.stdout.decode("utf-8").strip(), out.stderr.decode("utf-8").strip()
|
629
|
-
|
630
|
-
|
631
499
|
def validate_build_and_registry_configs(
|
632
500
|
build_config: Dict[str, Any], registry_config: Dict[str, Any]
|
633
501
|
) -> None:
|
@@ -846,3 +714,28 @@ def fetch_and_validate_template_variables(
|
|
846
714
|
raise LaunchError(f"Value for {key} must be of type {field_type}.")
|
847
715
|
template_variables[key] = val
|
848
716
|
return template_variables
|
717
|
+
|
718
|
+
|
719
|
+
def get_entrypoint_file(entrypoint: List[str]) -> Optional[str]:
|
720
|
+
"""Get the entrypoint file from the given command.
|
721
|
+
|
722
|
+
Args:
|
723
|
+
entrypoint (List[str]): List of command and arguments.
|
724
|
+
|
725
|
+
Returns:
|
726
|
+
Optional[str]: The entrypoint file if found, otherwise None.
|
727
|
+
"""
|
728
|
+
if not entrypoint:
|
729
|
+
return None
|
730
|
+
if entrypoint[0].endswith(".py") or entrypoint[0].endswith(".sh"):
|
731
|
+
return entrypoint[0]
|
732
|
+
if len(entrypoint) < 2:
|
733
|
+
return None
|
734
|
+
return entrypoint[1]
|
735
|
+
|
736
|
+
|
737
|
+
def get_current_python_version() -> Tuple[str, str]:
|
738
|
+
full_version = sys.version.split()[0].split(".")
|
739
|
+
major = full_version[0]
|
740
|
+
version = ".".join(full_version[:2]) if len(full_version) >= 2 else major + ".0"
|
741
|
+
return version, major
|
@@ -13,7 +13,6 @@ else:
|
|
13
13
|
_Setting = Literal[
|
14
14
|
"_args",
|
15
15
|
"_aws_lambda",
|
16
|
-
"_async_upload_concurrency_limit",
|
17
16
|
"_cli_only_mode",
|
18
17
|
"_code_path_local",
|
19
18
|
"_colab",
|
@@ -22,9 +21,9 @@ _Setting = Literal[
|
|
22
21
|
"_disable_service",
|
23
22
|
"_disable_setproctitle",
|
24
23
|
"_disable_stats",
|
24
|
+
"_disable_update_check",
|
25
25
|
"_disable_viewer",
|
26
26
|
"_disable_machine_info",
|
27
|
-
"_except_exit",
|
28
27
|
"_executable",
|
29
28
|
"_extra_http_headers",
|
30
29
|
"_file_stream_retry_max",
|
@@ -125,7 +124,6 @@ _Setting = Literal[
|
|
125
124
|
"login_timeout",
|
126
125
|
"mode",
|
127
126
|
"notebook_name",
|
128
|
-
"problem",
|
129
127
|
"program",
|
130
128
|
"program_abspath",
|
131
129
|
"program_relpath",
|
@@ -178,7 +176,6 @@ _Setting = Literal[
|
|
178
176
|
]
|
179
177
|
|
180
178
|
SETTINGS_TOPOLOGICALLY_SORTED: Final[Tuple[_Setting, ...]] = (
|
181
|
-
"_async_upload_concurrency_limit",
|
182
179
|
"_service_wait",
|
183
180
|
"_stats_sample_rate_seconds",
|
184
181
|
"_stats_samples_to_average",
|
@@ -188,7 +185,6 @@ SETTINGS_TOPOLOGICALLY_SORTED: Final[Tuple[_Setting, ...]] = (
|
|
188
185
|
"console",
|
189
186
|
"job_source",
|
190
187
|
"mode",
|
191
|
-
"problem",
|
192
188
|
"project",
|
193
189
|
"run_id",
|
194
190
|
"start_method",
|
wandb/sdk/lib/fsm.py
CHANGED
@@ -52,43 +52,39 @@ T_FsmContext_contra = TypeVar("T_FsmContext_contra", contravariant=True)
|
|
52
52
|
@runtime_checkable
|
53
53
|
class FsmStateCheck(Protocol[T_FsmInputs]):
|
54
54
|
@abstractmethod
|
55
|
-
def on_check(self, inputs: T_FsmInputs) -> None:
|
56
|
-
... # pragma: no cover
|
55
|
+
def on_check(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
|
57
56
|
|
58
57
|
|
59
58
|
@runtime_checkable
|
60
59
|
class FsmStateOutput(Protocol[T_FsmInputs]):
|
61
60
|
@abstractmethod
|
62
|
-
def on_state(self, inputs: T_FsmInputs) -> None:
|
63
|
-
... # pragma: no cover
|
61
|
+
def on_state(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
|
64
62
|
|
65
63
|
|
66
64
|
@runtime_checkable
|
67
65
|
class FsmStateEnter(Protocol[T_FsmInputs]):
|
68
66
|
@abstractmethod
|
69
|
-
def on_enter(self, inputs: T_FsmInputs) -> None:
|
70
|
-
... # pragma: no cover
|
67
|
+
def on_enter(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
|
71
68
|
|
72
69
|
|
73
70
|
@runtime_checkable
|
74
71
|
class FsmStateEnterWithContext(Protocol[T_FsmInputs, T_FsmContext_contra]):
|
75
72
|
@abstractmethod
|
76
|
-
def on_enter(
|
77
|
-
|
73
|
+
def on_enter(
|
74
|
+
self, inputs: T_FsmInputs, context: T_FsmContext_contra
|
75
|
+
) -> None: ... # pragma: no cover
|
78
76
|
|
79
77
|
|
80
78
|
@runtime_checkable
|
81
79
|
class FsmStateStay(Protocol[T_FsmInputs]):
|
82
80
|
@abstractmethod
|
83
|
-
def on_stay(self, inputs: T_FsmInputs) -> None:
|
84
|
-
... # pragma: no cover
|
81
|
+
def on_stay(self, inputs: T_FsmInputs) -> None: ... # pragma: no cover
|
85
82
|
|
86
83
|
|
87
84
|
@runtime_checkable
|
88
85
|
class FsmStateExit(Protocol[T_FsmInputs, T_FsmContext_cov]):
|
89
86
|
@abstractmethod
|
90
|
-
def on_exit(self, inputs: T_FsmInputs) -> T_FsmContext_cov:
|
91
|
-
... # pragma: no cover
|
87
|
+
def on_exit(self, inputs: T_FsmInputs) -> T_FsmContext_cov: ... # pragma: no cover
|
92
88
|
|
93
89
|
|
94
90
|
# It would be nice if python provided optional protocol members, but it doesnt as described here:
|
wandb/sdk/lib/gitlib.py
CHANGED
@@ -14,7 +14,7 @@ try:
|
|
14
14
|
Repo,
|
15
15
|
)
|
16
16
|
except ImportError:
|
17
|
-
Repo = None
|
17
|
+
Repo = None # type: ignore
|
18
18
|
|
19
19
|
if TYPE_CHECKING:
|
20
20
|
from git import Repo
|
@@ -121,7 +121,7 @@ class GitRepo:
|
|
121
121
|
# TODO: Saw a user getting a Unicode decode error when parsing refs,
|
122
122
|
# more details on implementing a real fix in [WB-4064]
|
123
123
|
try:
|
124
|
-
if len(self.repo.refs) > 0:
|
124
|
+
if len(self.repo.refs) > 0: # type: ignore[arg-type]
|
125
125
|
return self.repo.head.commit.hexsha
|
126
126
|
else:
|
127
127
|
return self.repo.git.show_ref("--head").split(" ")[0]
|
@@ -140,7 +140,7 @@ class GitRepo:
|
|
140
140
|
if not self.repo:
|
141
141
|
return None
|
142
142
|
try:
|
143
|
-
return self.repo.remotes[self.remote_name]
|
143
|
+
return self.repo.remotes[self.remote_name] # type: ignore[index]
|
144
144
|
except IndexError:
|
145
145
|
return None
|
146
146
|
|
@@ -200,7 +200,7 @@ class GitRepo:
|
|
200
200
|
possible_relatives.append(tracking_branch.commit)
|
201
201
|
|
202
202
|
if not possible_relatives:
|
203
|
-
for branch in self.repo.branches:
|
203
|
+
for branch in self.repo.branches: # type: ignore[attr-defined]
|
204
204
|
tracking_branch = branch.tracking_branch()
|
205
205
|
if tracking_branch is not None:
|
206
206
|
possible_relatives.append(tracking_branch.commit)
|
wandb/sdk/lib/import_hooks.py
CHANGED
@@ -143,7 +143,7 @@ class _ImportHookChainedLoader:
|
|
143
143
|
# None, so handle None as well. The module may not support attribute
|
144
144
|
# assignment, in which case we simply skip it. Note that we also deal
|
145
145
|
# with __loader__ not existing at all. This is to future proof things
|
146
|
-
# due to proposal to remove the
|
146
|
+
# due to proposal to remove the attribute as described in the GitHub
|
147
147
|
# issue at https://github.com/python/cpython/issues/77458. Also prior
|
148
148
|
# to Python 3.3, the __loader__ attribute was only set if a custom
|
149
149
|
# module loader was used. It isn't clear whether the attribute still
|
wandb/sdk/lib/lazyloader.py
CHANGED
wandb/sdk/lib/proto_util.py
CHANGED
@@ -12,7 +12,28 @@ if TYPE_CHECKING: # pragma: no cover
|
|
12
12
|
|
13
13
|
|
14
14
|
def dict_from_proto_list(obj_list: "RepeatedCompositeFieldContainer") -> Dict[str, Any]:
|
15
|
-
|
15
|
+
result: Dict[str, Any] = {}
|
16
|
+
|
17
|
+
for item in obj_list:
|
18
|
+
# Start from the root of the result dict
|
19
|
+
current_level = result
|
20
|
+
|
21
|
+
if len(item.nested_key) > 0:
|
22
|
+
keys = list(item.nested_key)
|
23
|
+
else:
|
24
|
+
keys = [item.key]
|
25
|
+
|
26
|
+
for key in keys[:-1]:
|
27
|
+
if key not in current_level:
|
28
|
+
current_level[key] = {}
|
29
|
+
# Move the reference deeper into the nested dictionary
|
30
|
+
current_level = current_level[key]
|
31
|
+
|
32
|
+
# Set the value at the final key location, parsing JSON from the value_json field
|
33
|
+
final_key = keys[-1]
|
34
|
+
current_level[final_key] = json.loads(item.value_json)
|
35
|
+
|
36
|
+
return result
|
16
37
|
|
17
38
|
|
18
39
|
def _result_from_record(record: "pb.Record") -> "pb.Result":
|
@@ -29,7 +50,7 @@ def _assign_end_offset(record: "pb.Record", end_offset: int) -> None:
|
|
29
50
|
|
30
51
|
|
31
52
|
def proto_encode_to_dict(
|
32
|
-
pb_obj: Union["tpb.TelemetryRecord", "pb.MetricRecord"]
|
53
|
+
pb_obj: Union["tpb.TelemetryRecord", "pb.MetricRecord"],
|
33
54
|
) -> Dict[int, Any]:
|
34
55
|
data: Dict[int, Any] = dict()
|
35
56
|
fields = pb_obj.ListFields()
|
wandb/sdk/lib/redirect.py
CHANGED
@@ -224,7 +224,7 @@ class TerminalEmulator:
|
|
224
224
|
def carriage_return(self):
|
225
225
|
self.cursor.x = 0
|
226
226
|
|
227
|
-
def
|
227
|
+
def cursor_position(self, line, column):
|
228
228
|
self.cursor.x = min(column, 1) - 1
|
229
229
|
self.cursor.y = min(line, 1) - 1
|
230
230
|
|
@@ -393,25 +393,30 @@ class TerminalEmulator:
|
|
393
393
|
p = (int(p[0]), 1)
|
394
394
|
else:
|
395
395
|
p = (1, 1)
|
396
|
-
self.
|
396
|
+
self.cursor_position(*p)
|
397
397
|
except Exception:
|
398
398
|
pass
|
399
399
|
|
400
400
|
def _get_line(self, n):
|
401
401
|
line = self.buffer[n]
|
402
402
|
line_len = self._get_line_len(n)
|
403
|
-
# We have to loop through each character in the line and check if foreground,
|
404
|
-
# other attributes (italics, bold, underline, etc) of the ith
|
405
|
-
# (i-1)th character. If different, the
|
406
|
-
#
|
407
|
-
#
|
408
|
-
#
|
409
|
-
|
410
|
-
#
|
411
|
-
#
|
412
|
-
|
413
|
-
#
|
414
|
-
#
|
403
|
+
# We have to loop through each character in the line and check if foreground,
|
404
|
+
# background and other attributes (italics, bold, underline, etc) of the ith
|
405
|
+
# character are different from those of the (i-1)th character. If different, the
|
406
|
+
# appropriate ascii character for switching the color/attribute should be
|
407
|
+
# appended to the output string before appending the actual character. This loop
|
408
|
+
# and subsequent checks can be expensive, especially because 99% of terminal
|
409
|
+
# output use default colors and formatting. Even in outputs that do contain
|
410
|
+
# colors and styles, its unlikely that they will change on a per character
|
411
|
+
# basis.
|
412
|
+
|
413
|
+
# So instead we create a character list without any ascii codes (`out`), and a
|
414
|
+
# list of all the foregrounds in the line (`fgs`) on which we call np.diff() and
|
415
|
+
# np.where() to find the indices where the foreground change, and insert the
|
416
|
+
# ascii characters in the output list (`out`) on those indices. All of this is
|
417
|
+
# the done only if there are more than 1 foreground color in the line in the
|
418
|
+
# first place (`if len(set(fgs)) > 1 else None`). Same logic is repeated for
|
419
|
+
# background colors and other attributes.
|
415
420
|
|
416
421
|
out = [line[i].data for i in range(line_len)]
|
417
422
|
|
wandb/sdk/lib/retry.py
CHANGED
@@ -248,8 +248,9 @@ class ExponentialBackoff(Backoff):
|
|
248
248
|
if self._timeout_at is not None and NOW_FN() > self._timeout_at:
|
249
249
|
raise exc
|
250
250
|
|
251
|
-
result, self._next_sleep =
|
252
|
-
self.
|
251
|
+
result, self._next_sleep = (
|
252
|
+
self._next_sleep,
|
253
|
+
min(self._max_sleep, self._next_sleep * (1 + random.random())),
|
253
254
|
)
|
254
255
|
|
255
256
|
return result
|
wandb/sdk/lib/run_moment.py
CHANGED
@@ -1,7 +1,13 @@
|
|
1
|
+
import sys
|
1
2
|
from dataclasses import dataclass
|
2
|
-
from typing import
|
3
|
+
from typing import Union, cast
|
3
4
|
from urllib import parse
|
4
5
|
|
6
|
+
if sys.version_info >= (3, 8):
|
7
|
+
from typing import Literal
|
8
|
+
else:
|
9
|
+
from typing_extensions import Literal
|
10
|
+
|
5
11
|
_STEP = Literal["_step"]
|
6
12
|
|
7
13
|
|
wandb/sdk/lib/tracelog.py
CHANGED
@@ -6,7 +6,7 @@ Functions:
|
|
6
6
|
log_message_send - message sent to socket
|
7
7
|
log_message_recv - message received from socket
|
8
8
|
log_message_process - message processed by thread
|
9
|
-
log_message_link - message linked to another
|
9
|
+
log_message_link - message linked to another message
|
10
10
|
log_message_assert - message encountered problem
|
11
11
|
|
12
12
|
"""
|
wandb/sdk/service/service.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
Backend server process can be connected to using tcp sockets transport.
|
4
4
|
"""
|
5
|
+
|
5
6
|
import datetime
|
6
7
|
import os
|
7
8
|
import pathlib
|
@@ -14,8 +15,8 @@ import time
|
|
14
15
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
15
16
|
|
16
17
|
from wandb import _sentry, termlog
|
17
|
-
from wandb.env import
|
18
|
-
from wandb.errors import Error
|
18
|
+
from wandb.env import core_debug, core_error_reporting_enabled, is_require_core
|
19
|
+
from wandb.errors import Error, WandbCoreNotAvailableError
|
19
20
|
from wandb.sdk.lib.wburls import wburls
|
20
21
|
from wandb.util import get_core_path, get_module
|
21
22
|
|
@@ -109,7 +110,8 @@ class _Service:
|
|
109
110
|
f"The wandb service process exited with {proc.returncode}. "
|
110
111
|
"Ensure that `sys.executable` is a valid python interpreter. "
|
111
112
|
"You can override it with the `_executable` setting "
|
112
|
-
"or with the `WANDB__EXECUTABLE` environment variable."
|
113
|
+
"or with the `WANDB__EXECUTABLE` environment variable."
|
114
|
+
f"\n{context}",
|
113
115
|
context=context,
|
114
116
|
)
|
115
117
|
if not os.path.isfile(fname):
|
@@ -161,28 +163,29 @@ class _Service:
|
|
161
163
|
exec_cmd_list += ["coverage", "run", "-m"]
|
162
164
|
|
163
165
|
service_args = []
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
166
|
+
|
167
|
+
if is_require_core():
|
168
|
+
try:
|
169
|
+
core_path = get_core_path()
|
170
|
+
except WandbCoreNotAvailableError as e:
|
171
|
+
_sentry.reraise(e)
|
172
|
+
|
171
173
|
service_args.extend([core_path])
|
172
|
-
|
174
|
+
|
175
|
+
if not core_error_reporting_enabled(default="True"):
|
173
176
|
service_args.append("--no-observability")
|
174
|
-
|
177
|
+
|
178
|
+
if core_debug(default="False"):
|
175
179
|
service_args.append("--debug")
|
180
|
+
|
176
181
|
trace_filename = os.environ.get("_WANDB_TRACE")
|
177
182
|
if trace_filename is not None:
|
178
183
|
service_args.extend(["--trace", trace_filename])
|
179
184
|
|
180
185
|
exec_cmd_list = []
|
181
|
-
# TODO: remove this after the wandb-core GA release
|
182
|
-
wandb_core = get_module("wandb_core")
|
183
186
|
termlog(
|
184
|
-
|
185
|
-
f"Please refer to {wburls.get('wandb_core')} for more information.",
|
187
|
+
"Using wandb-core as the SDK backend."
|
188
|
+
f" Please refer to {wburls.get('wandb_core')} for more information.",
|
186
189
|
repeat=False,
|
187
190
|
)
|
188
191
|
else:
|
wandb/sdk/verify/verify.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
"""Utilities for wandb verify."""
|
2
|
+
|
2
3
|
import getpass
|
3
4
|
import os
|
4
5
|
import time
|
@@ -20,7 +21,7 @@ PROJECT_NAME = "verify"
|
|
20
21
|
GET_RUN_MAX_TIME = 10
|
21
22
|
MIN_RETRYS = 3
|
22
23
|
CHECKMARK = "\u2705"
|
23
|
-
RED_X = "\
|
24
|
+
RED_X = "\u274c"
|
24
25
|
ID_PREFIX = runid.generate_id()
|
25
26
|
|
26
27
|
|