wandb 0.16.3__py3-none-any.whl → 0.16.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +2 -2
- wandb/agents/pyagent.py +1 -1
- wandb/apis/importers/__init__.py +1 -4
- wandb/apis/importers/internals/internal.py +386 -0
- wandb/apis/importers/internals/protocols.py +125 -0
- wandb/apis/importers/internals/util.py +78 -0
- wandb/apis/importers/mlflow.py +125 -88
- wandb/apis/importers/validation.py +108 -0
- wandb/apis/importers/wandb.py +1604 -0
- wandb/apis/public/api.py +7 -10
- wandb/apis/public/artifacts.py +38 -0
- wandb/apis/public/files.py +11 -2
- wandb/apis/reports/v2/__init__.py +0 -19
- wandb/apis/reports/v2/expr_parsing.py +0 -1
- wandb/apis/reports/v2/interface.py +15 -18
- wandb/apis/reports/v2/internal.py +12 -45
- wandb/cli/cli.py +52 -55
- wandb/integration/gym/__init__.py +2 -1
- wandb/integration/keras/callbacks/model_checkpoint.py +1 -1
- wandb/integration/keras/keras.py +6 -4
- wandb/integration/kfp/kfp_patch.py +2 -2
- wandb/integration/openai/fine_tuning.py +1 -2
- wandb/integration/ultralytics/callback.py +0 -1
- wandb/proto/v3/wandb_internal_pb2.py +332 -312
- wandb/proto/v3/wandb_settings_pb2.py +13 -3
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +316 -312
- wandb/proto/v4/wandb_settings_pb2.py +5 -3
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/artifacts/artifact.py +75 -31
- wandb/sdk/artifacts/artifact_manifest.py +5 -2
- wandb/sdk/artifacts/artifact_manifest_entry.py +6 -1
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +8 -2
- wandb/sdk/artifacts/artifact_saver.py +19 -47
- wandb/sdk/artifacts/storage_handler.py +2 -1
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +22 -9
- wandb/sdk/artifacts/storage_policy.py +4 -1
- wandb/sdk/data_types/base_types/wb_value.py +1 -1
- wandb/sdk/data_types/image.py +2 -2
- wandb/sdk/interface/interface.py +49 -13
- wandb/sdk/interface/interface_shared.py +17 -11
- wandb/sdk/internal/file_stream.py +20 -1
- wandb/sdk/internal/handler.py +1 -4
- wandb/sdk/internal/internal_api.py +3 -1
- wandb/sdk/internal/job_builder.py +49 -19
- wandb/sdk/internal/profiler.py +1 -1
- wandb/sdk/internal/sender.py +96 -124
- wandb/sdk/internal/sender_config.py +197 -0
- wandb/sdk/internal/settings_static.py +9 -0
- wandb/sdk/internal/system/system_info.py +5 -3
- wandb/sdk/internal/update.py +1 -1
- wandb/sdk/launch/_launch.py +3 -3
- wandb/sdk/launch/_launch_add.py +28 -29
- wandb/sdk/launch/_project_spec.py +148 -136
- wandb/sdk/launch/agent/agent.py +3 -7
- wandb/sdk/launch/agent/config.py +0 -27
- wandb/sdk/launch/builder/build.py +54 -28
- wandb/sdk/launch/builder/docker_builder.py +4 -15
- wandb/sdk/launch/builder/kaniko_builder.py +72 -45
- wandb/sdk/launch/create_job.py +6 -40
- wandb/sdk/launch/loader.py +10 -0
- wandb/sdk/launch/registry/anon.py +29 -0
- wandb/sdk/launch/registry/local_registry.py +4 -1
- wandb/sdk/launch/runner/kubernetes_runner.py +20 -2
- wandb/sdk/launch/runner/local_container.py +15 -10
- wandb/sdk/launch/runner/sagemaker_runner.py +1 -1
- wandb/sdk/launch/sweeps/scheduler.py +11 -3
- wandb/sdk/launch/utils.py +14 -0
- wandb/sdk/lib/__init__.py +2 -5
- wandb/sdk/lib/_settings_toposort_generated.py +4 -1
- wandb/sdk/lib/apikey.py +0 -5
- wandb/sdk/lib/config_util.py +0 -31
- wandb/sdk/lib/filesystem.py +11 -1
- wandb/sdk/lib/run_moment.py +72 -0
- wandb/sdk/service/service.py +7 -2
- wandb/sdk/service/streams.py +1 -6
- wandb/sdk/verify/verify.py +2 -1
- wandb/sdk/wandb_init.py +12 -1
- wandb/sdk/wandb_login.py +43 -26
- wandb/sdk/wandb_run.py +164 -110
- wandb/sdk/wandb_settings.py +58 -16
- wandb/testing/relay.py +5 -6
- wandb/util.py +50 -7
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/METADATA +8 -1
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/RECORD +89 -82
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/WHEEL +1 -1
- wandb/apis/importers/base.py +0 -400
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/LICENSE +0 -0
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/entry_points.txt +0 -0
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/top_level.txt +0 -0
wandb/apis/public/api.py
CHANGED
@@ -88,10 +88,7 @@ class RetryingClient:
|
|
88
88
|
return self._server_info
|
89
89
|
|
90
90
|
def version_supported(self, min_version):
|
91
|
-
|
92
|
-
from packaging.version import Version as parse_version # noqa: N813
|
93
|
-
except ImportError:
|
94
|
-
from pkg_resources import parse_version
|
91
|
+
from wandb.util import parse_version
|
95
92
|
|
96
93
|
return parse_version(min_version) <= parse_version(
|
97
94
|
self.server_info["cliVersionInfo"]["max_cli_version"]
|
@@ -316,15 +313,15 @@ class Api:
|
|
316
313
|
entity: (str) Optional name of the entity to create the queue. If None, will use the configured or default entity.
|
317
314
|
prioritization_mode: (str) Optional version of prioritization to use. Either "V0" or None
|
318
315
|
config: (dict) Optional default resource configuration to be used for the queue. Use handlebars (eg. "{{var}}") to specify template variables.
|
319
|
-
template_variables (dict)
|
316
|
+
template_variables: (dict) A dictionary of template variable schemas to be used with the config. Expected format of:
|
320
317
|
{
|
321
318
|
"var-name": {
|
322
319
|
"schema": {
|
323
|
-
"type": "
|
324
|
-
"default":
|
325
|
-
"minimum":
|
326
|
-
"maximum":
|
327
|
-
"enum": [..."
|
320
|
+
"type": ("string", "number", or "integer"),
|
321
|
+
"default": (optional value),
|
322
|
+
"minimum": (optional minimum),
|
323
|
+
"maximum": (optional maximum),
|
324
|
+
"enum": [..."(options)"]
|
328
325
|
}
|
329
326
|
}
|
330
327
|
}
|
wandb/apis/public/artifacts.py
CHANGED
@@ -8,6 +8,7 @@ import wandb
|
|
8
8
|
from wandb.apis import public
|
9
9
|
from wandb.apis.normalize import normalize_exceptions
|
10
10
|
from wandb.apis.paginator import Paginator
|
11
|
+
from wandb.errors.term import termlog
|
11
12
|
|
12
13
|
if TYPE_CHECKING:
|
13
14
|
from wandb.apis.public import RetryingClient, Run
|
@@ -402,6 +403,43 @@ class ArtifactCollection:
|
|
402
403
|
self._attrs = response["project"]["artifactType"]["artifactCollection"]
|
403
404
|
return self._attrs
|
404
405
|
|
406
|
+
def change_type(self, new_type: str) -> None:
|
407
|
+
"""Change the type of the artifact collection.
|
408
|
+
|
409
|
+
Arguments:
|
410
|
+
new_type: The new collection type to use, freeform string.
|
411
|
+
"""
|
412
|
+
if not self.is_sequence():
|
413
|
+
raise ValueError("Artifact collection needs to be a sequence")
|
414
|
+
termlog(f"Changing artifact collection type of " f"{self.type} to {new_type}")
|
415
|
+
template = """
|
416
|
+
mutation MoveArtifactCollection(
|
417
|
+
$artifactSequenceID: ID!
|
418
|
+
$destinationArtifactTypeName: String!
|
419
|
+
) {
|
420
|
+
moveArtifactSequence(
|
421
|
+
input: {
|
422
|
+
artifactSequenceID: $artifactSequenceID
|
423
|
+
destinationArtifactTypeName: $destinationArtifactTypeName
|
424
|
+
}
|
425
|
+
) {
|
426
|
+
artifactCollection {
|
427
|
+
id
|
428
|
+
name
|
429
|
+
description
|
430
|
+
__typename
|
431
|
+
}
|
432
|
+
}
|
433
|
+
}
|
434
|
+
"""
|
435
|
+
variable_values = {
|
436
|
+
"artifactSequenceID": self.id,
|
437
|
+
"destinationArtifactTypeName": new_type,
|
438
|
+
}
|
439
|
+
mutation = gql(template)
|
440
|
+
self.client.execute(mutation, variable_values=variable_values)
|
441
|
+
self.type = new_type
|
442
|
+
|
405
443
|
@normalize_exceptions
|
406
444
|
def is_sequence(self) -> bool:
|
407
445
|
"""Return True if this is a sequence."""
|
wandb/apis/public/files.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
"""Public API: files."""
|
2
2
|
import io
|
3
3
|
import os
|
4
|
+
from typing import Optional
|
4
5
|
|
5
6
|
import requests
|
6
7
|
from wandb_gql import gql
|
@@ -11,6 +12,7 @@ from wandb import util
|
|
11
12
|
from wandb.apis.attrs import Attrs
|
12
13
|
from wandb.apis.normalize import normalize_exceptions
|
13
14
|
from wandb.apis.paginator import Paginator
|
15
|
+
from wandb.apis.public.api import Api
|
14
16
|
from wandb.apis.public.const import RETRY_TIMEDELTA
|
15
17
|
from wandb.sdk.lib import retry
|
16
18
|
|
@@ -136,7 +138,11 @@ class File(Attrs):
|
|
136
138
|
retryable_exceptions=(RetryError, requests.RequestException),
|
137
139
|
)
|
138
140
|
def download(
|
139
|
-
self,
|
141
|
+
self,
|
142
|
+
root: str = ".",
|
143
|
+
replace: bool = False,
|
144
|
+
exist_ok: bool = False,
|
145
|
+
api: Optional[Api] = None,
|
140
146
|
) -> io.TextIOWrapper:
|
141
147
|
"""Downloads a file previously saved by a run from the wandb server.
|
142
148
|
|
@@ -150,6 +156,9 @@ class File(Attrs):
|
|
150
156
|
Raises:
|
151
157
|
`ValueError` if file already exists, replace=False and exist_ok=False.
|
152
158
|
"""
|
159
|
+
if api is None:
|
160
|
+
api = wandb.Api()
|
161
|
+
|
153
162
|
path = os.path.join(root, self.name)
|
154
163
|
if os.path.exists(path) and not replace:
|
155
164
|
if exist_ok:
|
@@ -159,7 +168,7 @@ class File(Attrs):
|
|
159
168
|
"File already exists, pass replace=True to overwrite or exist_ok=True to leave it as is and don't error."
|
160
169
|
)
|
161
170
|
|
162
|
-
util.download_file_from_url(path, self.url,
|
171
|
+
util.download_file_from_url(path, self.url, api.api_key)
|
163
172
|
return open(path)
|
164
173
|
|
165
174
|
@normalize_exceptions
|
@@ -18,22 +18,3 @@ from .interface import (
|
|
18
18
|
)
|
19
19
|
from .metrics import * # noqa
|
20
20
|
from .panels import * # noqa
|
21
|
-
|
22
|
-
|
23
|
-
def show_welcome_message():
|
24
|
-
if os.getenv("WANDB_REPORT_API_DISABLE_MESSAGE"):
|
25
|
-
return
|
26
|
-
|
27
|
-
termlog(
|
28
|
-
cleandoc(
|
29
|
-
"""
|
30
|
-
Thanks for trying out Report API v2!
|
31
|
-
See a tutorial and the changes here: http://wandb.me/report-api-quickstart
|
32
|
-
For bugs/feature requests, please create an issue on github: https://github.com/wandb/wandb/issues
|
33
|
-
You can disable this message by setting the env var WANDB_REPORT_API_DISABLE_MESSAGE=True
|
34
|
-
"""
|
35
|
-
)
|
36
|
-
)
|
37
|
-
|
38
|
-
|
39
|
-
show_welcome_message()
|
@@ -4,6 +4,8 @@ from datetime import datetime
|
|
4
4
|
from typing import Dict, Iterable, Optional, Tuple, Union
|
5
5
|
from typing import List as LList
|
6
6
|
|
7
|
+
from annotated_types import Annotated, Ge, Le
|
8
|
+
|
7
9
|
try:
|
8
10
|
from typing import Literal
|
9
11
|
except ImportError:
|
@@ -758,14 +760,8 @@ block_mapping = {
|
|
758
760
|
|
759
761
|
@dataclass(config=dataclass_config)
|
760
762
|
class GradientPoint(Base):
|
761
|
-
color: str
|
762
|
-
offset: float
|
763
|
-
|
764
|
-
@validator("color")
|
765
|
-
def validate_color(cls, v): # noqa: N805
|
766
|
-
if not internal.is_valid_color(v):
|
767
|
-
raise ValueError("invalid color, value should be hex, rgb, or rgba")
|
768
|
-
return v
|
763
|
+
color: Annotated[str, internal.ColorStrConstraints]
|
764
|
+
offset: Annotated[float, Ge(0), Le(100)] = 0
|
769
765
|
|
770
766
|
def to_model(self):
|
771
767
|
return internal.GradientPoint(color=self.color, offset=self.offset)
|
@@ -1419,8 +1415,7 @@ class Report(Base):
|
|
1419
1415
|
description: str = ""
|
1420
1416
|
blocks: LList[BlockTypes] = Field(default_factory=list)
|
1421
1417
|
|
1422
|
-
id: str = Field(default_factory=lambda: "", init=False, repr=False)
|
1423
|
-
|
1418
|
+
id: str = Field(default_factory=lambda: "", init=False, repr=False, kw_only=True)
|
1424
1419
|
_discussion_threads: list = Field(default_factory=list, init=False, repr=False)
|
1425
1420
|
_ref: dict = Field(default_factory=dict, init=False, repr=False)
|
1426
1421
|
_panel_settings: dict = Field(default_factory=dict, init=False, repr=False)
|
@@ -1470,20 +1465,22 @@ class Report(Base):
|
|
1470
1465
|
if blocks[-1] == internal.Paragraph():
|
1471
1466
|
blocks = blocks[:-1]
|
1472
1467
|
|
1473
|
-
|
1468
|
+
obj = cls(
|
1474
1469
|
title=model.display_name,
|
1475
1470
|
description=model.description,
|
1476
1471
|
entity=model.project.entity_name,
|
1477
1472
|
project=model.project.name,
|
1478
|
-
id=model.id,
|
1479
1473
|
blocks=[_lookup(b) for b in blocks],
|
1480
|
-
_discussion_threads=model.spec.discussion_threads,
|
1481
|
-
_panel_settings=model.spec.panel_settings,
|
1482
|
-
_ref=model.spec.ref,
|
1483
|
-
_authors=model.spec.authors,
|
1484
|
-
_created_at=model.created_at,
|
1485
|
-
_updated_at=model.updated_at,
|
1486
1474
|
)
|
1475
|
+
obj.id = model.id
|
1476
|
+
obj._discussion_threads = model.spec.discussion_threads
|
1477
|
+
obj._panel_settings = model.spec.panel_settings
|
1478
|
+
obj._ref = model.spec.ref
|
1479
|
+
obj._authors = model.spec.authors
|
1480
|
+
obj._created_at = model.created_at
|
1481
|
+
obj._updated_at = model.updated_at
|
1482
|
+
|
1483
|
+
return obj
|
1487
1484
|
|
1488
1485
|
@property
|
1489
1486
|
def url(self):
|
@@ -1,18 +1,19 @@
|
|
1
1
|
"""JSONSchema for internal types. Hopefully this is auto-generated one day!"""
|
2
2
|
import json
|
3
3
|
import random
|
4
|
-
import re
|
5
4
|
from copy import deepcopy
|
6
5
|
from datetime import datetime
|
7
6
|
from typing import Any, Dict, Optional, Tuple, Union
|
8
7
|
from typing import List as LList
|
9
8
|
|
9
|
+
from annotated_types import Annotated, Ge, Le
|
10
|
+
|
10
11
|
try:
|
11
12
|
from typing import Literal
|
12
13
|
except ImportError:
|
13
14
|
from typing_extensions import Literal
|
14
15
|
|
15
|
-
from pydantic import BaseModel, ConfigDict, Field, validator
|
16
|
+
from pydantic import BaseModel, ConfigDict, Field, StringConstraints, validator
|
16
17
|
from pydantic.alias_generators import to_camel
|
17
18
|
|
18
19
|
|
@@ -48,6 +49,13 @@ def _generate_name(length: int = 12) -> str:
|
|
48
49
|
return rand36.lower()[:length]
|
49
50
|
|
50
51
|
|
52
|
+
hex_pattern = r"^#(?:[0-9a-fA-F]{3}){1,2}$"
|
53
|
+
rgb_pattern = r"^rgb\(\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*\)$"
|
54
|
+
rgba_pattern = r"^rgba\(\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(1|0|0?\.\d+)\s*\)$"
|
55
|
+
ColorStrConstraints = StringConstraints(
|
56
|
+
pattern=f"{hex_pattern}|{rgb_pattern}|{rgba_pattern}"
|
57
|
+
)
|
58
|
+
|
51
59
|
LinePlotStyle = Literal["line", "stacked-area", "pct-area"]
|
52
60
|
BarPlotStyle = Literal["bar", "boxplot", "violin"]
|
53
61
|
FontSize = Literal["small", "medium", "large", "auto"]
|
@@ -609,14 +617,8 @@ class LinePlot(Panel):
|
|
609
617
|
|
610
618
|
|
611
619
|
class GradientPoint(ReportAPIBaseModel):
|
612
|
-
color: str
|
613
|
-
offset: float
|
614
|
-
|
615
|
-
@validator("color")
|
616
|
-
def validate_color(cls, v): # noqa: N805
|
617
|
-
if not is_valid_color(v):
|
618
|
-
raise ValueError("invalid color, value should be hex, rgb, or rgba")
|
619
|
-
return v
|
620
|
+
color: Annotated[str, ColorStrConstraints]
|
621
|
+
offset: Annotated[float, Ge(0), Le(100)] = 0
|
620
622
|
|
621
623
|
|
622
624
|
class ScatterPlotConfig(ReportAPIBaseModel):
|
@@ -866,38 +868,3 @@ block_type_mapping = {
|
|
866
868
|
"table-of-contents": TableOfContents,
|
867
869
|
"block-quote": BlockQuote,
|
868
870
|
}
|
869
|
-
|
870
|
-
|
871
|
-
def is_valid_color(color_str: str) -> bool:
|
872
|
-
# Regular expression for hex color validation
|
873
|
-
hex_color_pattern = r"^#(?:[0-9a-fA-F]{3}){1,2}$"
|
874
|
-
|
875
|
-
# Check if it's a valid hex color
|
876
|
-
if re.match(hex_color_pattern, color_str):
|
877
|
-
return True
|
878
|
-
|
879
|
-
# Try parsing it as an RGB or RGBA tuple
|
880
|
-
try:
|
881
|
-
# Strip 'rgb(' or 'rgba(' and the closing ')'
|
882
|
-
if color_str.startswith("rgb(") and color_str.endswith(")"):
|
883
|
-
parts = color_str[4:-1].split(",")
|
884
|
-
elif color_str.startswith("rgba(") and color_str.endswith(")"):
|
885
|
-
parts = color_str[5:-1].split(",")
|
886
|
-
else:
|
887
|
-
return False
|
888
|
-
|
889
|
-
# Convert parts to integers and validate ranges
|
890
|
-
parts = [int(p.strip()) for p in parts]
|
891
|
-
if len(parts) == 3 and all(0 <= p <= 255 for p in parts):
|
892
|
-
return True # Valid RGB
|
893
|
-
if (
|
894
|
-
len(parts) == 4
|
895
|
-
and all(0 <= p <= 255 for p in parts[:-1])
|
896
|
-
and 0 <= parts[-1] <= 1
|
897
|
-
):
|
898
|
-
return True # Valid RGBA
|
899
|
-
|
900
|
-
except ValueError:
|
901
|
-
pass
|
902
|
-
|
903
|
-
return False
|
wandb/cli/cli.py
CHANGED
@@ -433,10 +433,14 @@ def beta():
|
|
433
433
|
from wandb.util import get_core_path
|
434
434
|
|
435
435
|
if not get_core_path():
|
436
|
-
click.
|
437
|
-
|
436
|
+
click.secho(
|
437
|
+
(
|
438
|
+
"wandb beta commands require wandb-core, please install with"
|
439
|
+
" `pip install wandb-core`"
|
440
|
+
),
|
441
|
+
fg="red",
|
442
|
+
err=True,
|
438
443
|
)
|
439
|
-
sys.exit(1)
|
440
444
|
|
441
445
|
|
442
446
|
@beta.command(
|
@@ -1596,29 +1600,28 @@ def launch(
|
|
1596
1600
|
sys.exit(0)
|
1597
1601
|
else:
|
1598
1602
|
try:
|
1599
|
-
|
1600
|
-
|
1601
|
-
|
1602
|
-
|
1603
|
-
|
1604
|
-
|
1605
|
-
|
1606
|
-
|
1607
|
-
|
1608
|
-
|
1609
|
-
|
1610
|
-
|
1611
|
-
|
1612
|
-
|
1613
|
-
|
1614
|
-
|
1615
|
-
|
1616
|
-
|
1617
|
-
|
1618
|
-
|
1619
|
-
priority=priority,
|
1620
|
-
)
|
1603
|
+
_launch_add(
|
1604
|
+
api,
|
1605
|
+
uri,
|
1606
|
+
job,
|
1607
|
+
config,
|
1608
|
+
template_variables,
|
1609
|
+
project,
|
1610
|
+
entity,
|
1611
|
+
queue,
|
1612
|
+
resource,
|
1613
|
+
entry_point,
|
1614
|
+
name,
|
1615
|
+
git_version,
|
1616
|
+
docker_image,
|
1617
|
+
project_queue,
|
1618
|
+
resource_args,
|
1619
|
+
build=build,
|
1620
|
+
run_id=run_id,
|
1621
|
+
repository=repository,
|
1622
|
+
priority=priority,
|
1621
1623
|
)
|
1624
|
+
|
1622
1625
|
except Exception as e:
|
1623
1626
|
wandb._sentry.exception(e)
|
1624
1627
|
raise e
|
@@ -2342,8 +2345,29 @@ def artifact():
|
|
2342
2345
|
default=None,
|
2343
2346
|
help="Resume the last run from your current directory.",
|
2344
2347
|
)
|
2348
|
+
@click.option(
|
2349
|
+
"--skip_cache",
|
2350
|
+
is_flag=True,
|
2351
|
+
default=False,
|
2352
|
+
help="Skip caching while uploading artifact files.",
|
2353
|
+
)
|
2354
|
+
@click.option(
|
2355
|
+
"--policy",
|
2356
|
+
default="mutable",
|
2357
|
+
help="Set the storage policy while uploading artifact files.",
|
2358
|
+
)
|
2345
2359
|
@display_error
|
2346
|
-
def put(
|
2360
|
+
def put(
|
2361
|
+
path,
|
2362
|
+
name,
|
2363
|
+
description,
|
2364
|
+
type,
|
2365
|
+
alias,
|
2366
|
+
run_id,
|
2367
|
+
resume,
|
2368
|
+
skip_cache,
|
2369
|
+
policy,
|
2370
|
+
):
|
2347
2371
|
if name is None:
|
2348
2372
|
name = os.path.basename(path)
|
2349
2373
|
public_api = PublicApi()
|
@@ -2358,10 +2382,10 @@ def put(path, name, description, type, alias, run_id, resume):
|
|
2358
2382
|
artifact_path = f"{entity}/{project}/{artifact_name}:{alias[0]}"
|
2359
2383
|
if os.path.isdir(path):
|
2360
2384
|
wandb.termlog(f'Uploading directory {path} to: "{artifact_path}" ({type})')
|
2361
|
-
artifact.add_dir(path)
|
2385
|
+
artifact.add_dir(path, skip_cache=skip_cache, policy=policy)
|
2362
2386
|
elif os.path.isfile(path):
|
2363
2387
|
wandb.termlog(f'Uploading file {path} to: "{artifact_path}" ({type})')
|
2364
|
-
artifact.add_file(path)
|
2388
|
+
artifact.add_file(path, skip_cache=skip_cache, policy=policy)
|
2365
2389
|
elif "://" in path:
|
2366
2390
|
wandb.termlog(
|
2367
2391
|
f'Logging reference artifact from {path} to: "{artifact_path}" ({type})'
|
@@ -2861,30 +2885,3 @@ def verify(host):
|
|
2861
2885
|
and url_success
|
2862
2886
|
):
|
2863
2887
|
sys.exit(1)
|
2864
|
-
|
2865
|
-
|
2866
|
-
@cli.group("import", help="Commands for importing data from other systems")
|
2867
|
-
def importer():
|
2868
|
-
pass
|
2869
|
-
|
2870
|
-
|
2871
|
-
@importer.command("mlflow", help="Import from MLFlow")
|
2872
|
-
@click.option("--mlflow-tracking-uri", help="MLFlow Tracking URI")
|
2873
|
-
@click.option(
|
2874
|
-
"--target-entity", required=True, help="Override default entity to import data into"
|
2875
|
-
)
|
2876
|
-
@click.option(
|
2877
|
-
"--target-project",
|
2878
|
-
required=True,
|
2879
|
-
help="Override default project to import data into",
|
2880
|
-
)
|
2881
|
-
def mlflow(mlflow_tracking_uri, target_entity, target_project):
|
2882
|
-
from wandb.apis.importers import MlflowImporter
|
2883
|
-
|
2884
|
-
importer = MlflowImporter(mlflow_tracking_uri=mlflow_tracking_uri)
|
2885
|
-
overrides = {
|
2886
|
-
"entity": target_entity,
|
2887
|
-
"project": target_project,
|
2888
|
-
}
|
2889
|
-
|
2890
|
-
importer.import_all_parallel(overrides=overrides)
|
@@ -47,7 +47,8 @@ def monitor():
|
|
47
47
|
import gym
|
48
48
|
else:
|
49
49
|
import gymnasium as gym # type: ignore
|
50
|
-
|
50
|
+
|
51
|
+
from wandb.util import parse_version
|
51
52
|
|
52
53
|
if parse_version(gym.__version__) < parse_version("0.26.0"):
|
53
54
|
_gym_version_lt_0_26 = True
|
@@ -187,7 +187,7 @@ class WandbModelCheckpoint(callbacks.ModelCheckpoint):
|
|
187
187
|
@property
|
188
188
|
def is_old_tf_keras_version(self) -> Optional[bool]:
|
189
189
|
if self._is_old_tf_keras_version is None:
|
190
|
-
from
|
190
|
+
from wandb.util import parse_version
|
191
191
|
|
192
192
|
try:
|
193
193
|
if parse_version(tf.keras.__version__) < parse_version("2.6.0"):
|
wandb/integration/keras/keras.py
CHANGED
@@ -19,7 +19,8 @@ from wandb.util import add_import_hook
|
|
19
19
|
|
20
20
|
def _check_keras_version():
|
21
21
|
from keras import __version__ as keras_version
|
22
|
-
|
22
|
+
|
23
|
+
from wandb.util import parse_version
|
23
24
|
|
24
25
|
if parse_version(keras_version) < parse_version("2.4.0"):
|
25
26
|
wandb.termwarn(
|
@@ -29,7 +30,7 @@ def _check_keras_version():
|
|
29
30
|
|
30
31
|
def _can_compute_flops() -> bool:
|
31
32
|
"""FLOPS computation is restricted to TF 2.x as it requires tf.compat.v1."""
|
32
|
-
from
|
33
|
+
from wandb.util import parse_version
|
33
34
|
|
34
35
|
if parse_version(tf.__version__) >= parse_version("2.0.0"):
|
35
36
|
return True
|
@@ -73,9 +74,10 @@ def is_generator_like(data):
|
|
73
74
|
|
74
75
|
|
75
76
|
def patch_tf_keras(): # noqa: C901
|
76
|
-
from pkg_resources import parse_version
|
77
77
|
from tensorflow.python.eager import context
|
78
78
|
|
79
|
+
from wandb.util import parse_version
|
80
|
+
|
79
81
|
if (
|
80
82
|
parse_version("2.6.0")
|
81
83
|
<= parse_version(tf.__version__)
|
@@ -235,7 +237,7 @@ patch_tf_keras()
|
|
235
237
|
|
236
238
|
|
237
239
|
def _get_custom_optimizer_parent_class():
|
238
|
-
from
|
240
|
+
from wandb.util import parse_version
|
239
241
|
|
240
242
|
if parse_version(tf.__version__) >= parse_version("2.9.0"):
|
241
243
|
custom_optimizer_parent_class = tf.keras.optimizers.legacy.Optimizer
|
@@ -3,8 +3,6 @@ import itertools
|
|
3
3
|
import textwrap
|
4
4
|
from typing import Callable, List, Mapping, Optional
|
5
5
|
|
6
|
-
from pkg_resources import parse_version
|
7
|
-
|
8
6
|
import wandb
|
9
7
|
|
10
8
|
try:
|
@@ -13,6 +11,8 @@ try:
|
|
13
11
|
from kfp.components._components import _create_task_factory_from_component_spec
|
14
12
|
from kfp.components._python_op import _func_to_component_spec
|
15
13
|
|
14
|
+
from wandb.util import parse_version
|
15
|
+
|
16
16
|
MIN_KFP_VERSION = "1.6.1"
|
17
17
|
|
18
18
|
if parse_version(kfp_version) < parse_version(MIN_KFP_VERSION):
|
@@ -5,13 +5,12 @@ import re
|
|
5
5
|
import time
|
6
6
|
from typing import Any, Dict, Optional, Tuple
|
7
7
|
|
8
|
-
from pkg_resources import parse_version
|
9
|
-
|
10
8
|
import wandb
|
11
9
|
from wandb import util
|
12
10
|
from wandb.data_types import Table
|
13
11
|
from wandb.sdk.lib import telemetry
|
14
12
|
from wandb.sdk.wandb_run import Run
|
13
|
+
from wandb.util import parse_version
|
15
14
|
|
16
15
|
openai = util.get_module(
|
17
16
|
name="openai",
|