wandb 0.16.3__py3-none-any.whl → 0.16.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +2 -2
- wandb/agents/pyagent.py +1 -1
- wandb/apis/importers/__init__.py +1 -4
- wandb/apis/importers/internals/internal.py +386 -0
- wandb/apis/importers/internals/protocols.py +125 -0
- wandb/apis/importers/internals/util.py +78 -0
- wandb/apis/importers/mlflow.py +125 -88
- wandb/apis/importers/validation.py +108 -0
- wandb/apis/importers/wandb.py +1604 -0
- wandb/apis/public/api.py +7 -10
- wandb/apis/public/artifacts.py +38 -0
- wandb/apis/public/files.py +11 -2
- wandb/apis/reports/v2/__init__.py +0 -19
- wandb/apis/reports/v2/expr_parsing.py +0 -1
- wandb/apis/reports/v2/interface.py +15 -18
- wandb/apis/reports/v2/internal.py +12 -45
- wandb/cli/cli.py +52 -55
- wandb/integration/gym/__init__.py +2 -1
- wandb/integration/keras/callbacks/model_checkpoint.py +1 -1
- wandb/integration/keras/keras.py +6 -4
- wandb/integration/kfp/kfp_patch.py +2 -2
- wandb/integration/openai/fine_tuning.py +1 -2
- wandb/integration/ultralytics/callback.py +0 -1
- wandb/proto/v3/wandb_internal_pb2.py +332 -312
- wandb/proto/v3/wandb_settings_pb2.py +13 -3
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +316 -312
- wandb/proto/v4/wandb_settings_pb2.py +5 -3
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/artifacts/artifact.py +75 -31
- wandb/sdk/artifacts/artifact_manifest.py +5 -2
- wandb/sdk/artifacts/artifact_manifest_entry.py +6 -1
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +8 -2
- wandb/sdk/artifacts/artifact_saver.py +19 -47
- wandb/sdk/artifacts/storage_handler.py +2 -1
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +22 -9
- wandb/sdk/artifacts/storage_policy.py +4 -1
- wandb/sdk/data_types/base_types/wb_value.py +1 -1
- wandb/sdk/data_types/image.py +2 -2
- wandb/sdk/interface/interface.py +49 -13
- wandb/sdk/interface/interface_shared.py +17 -11
- wandb/sdk/internal/file_stream.py +20 -1
- wandb/sdk/internal/handler.py +1 -4
- wandb/sdk/internal/internal_api.py +3 -1
- wandb/sdk/internal/job_builder.py +49 -19
- wandb/sdk/internal/profiler.py +1 -1
- wandb/sdk/internal/sender.py +96 -124
- wandb/sdk/internal/sender_config.py +197 -0
- wandb/sdk/internal/settings_static.py +9 -0
- wandb/sdk/internal/system/system_info.py +5 -3
- wandb/sdk/internal/update.py +1 -1
- wandb/sdk/launch/_launch.py +3 -3
- wandb/sdk/launch/_launch_add.py +28 -29
- wandb/sdk/launch/_project_spec.py +148 -136
- wandb/sdk/launch/agent/agent.py +3 -7
- wandb/sdk/launch/agent/config.py +0 -27
- wandb/sdk/launch/builder/build.py +54 -28
- wandb/sdk/launch/builder/docker_builder.py +4 -15
- wandb/sdk/launch/builder/kaniko_builder.py +72 -45
- wandb/sdk/launch/create_job.py +6 -40
- wandb/sdk/launch/loader.py +10 -0
- wandb/sdk/launch/registry/anon.py +29 -0
- wandb/sdk/launch/registry/local_registry.py +4 -1
- wandb/sdk/launch/runner/kubernetes_runner.py +20 -2
- wandb/sdk/launch/runner/local_container.py +15 -10
- wandb/sdk/launch/runner/sagemaker_runner.py +1 -1
- wandb/sdk/launch/sweeps/scheduler.py +11 -3
- wandb/sdk/launch/utils.py +14 -0
- wandb/sdk/lib/__init__.py +2 -5
- wandb/sdk/lib/_settings_toposort_generated.py +4 -1
- wandb/sdk/lib/apikey.py +0 -5
- wandb/sdk/lib/config_util.py +0 -31
- wandb/sdk/lib/filesystem.py +11 -1
- wandb/sdk/lib/run_moment.py +72 -0
- wandb/sdk/service/service.py +7 -2
- wandb/sdk/service/streams.py +1 -6
- wandb/sdk/verify/verify.py +2 -1
- wandb/sdk/wandb_init.py +12 -1
- wandb/sdk/wandb_login.py +43 -26
- wandb/sdk/wandb_run.py +164 -110
- wandb/sdk/wandb_settings.py +58 -16
- wandb/testing/relay.py +5 -6
- wandb/util.py +50 -7
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/METADATA +8 -1
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/RECORD +89 -82
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/WHEEL +1 -1
- wandb/apis/importers/base.py +0 -400
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/LICENSE +0 -0
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/entry_points.txt +0 -0
- {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/top_level.txt +0 -0
wandb/apis/public/api.py
CHANGED
@@ -88,10 +88,7 @@ class RetryingClient:
|
|
88
88
|
return self._server_info
|
89
89
|
|
90
90
|
def version_supported(self, min_version):
|
91
|
-
|
92
|
-
from packaging.version import Version as parse_version # noqa: N813
|
93
|
-
except ImportError:
|
94
|
-
from pkg_resources import parse_version
|
91
|
+
from wandb.util import parse_version
|
95
92
|
|
96
93
|
return parse_version(min_version) <= parse_version(
|
97
94
|
self.server_info["cliVersionInfo"]["max_cli_version"]
|
@@ -316,15 +313,15 @@ class Api:
|
|
316
313
|
entity: (str) Optional name of the entity to create the queue. If None, will use the configured or default entity.
|
317
314
|
prioritization_mode: (str) Optional version of prioritization to use. Either "V0" or None
|
318
315
|
config: (dict) Optional default resource configuration to be used for the queue. Use handlebars (eg. "{{var}}") to specify template variables.
|
319
|
-
template_variables (dict)
|
316
|
+
template_variables: (dict) A dictionary of template variable schemas to be used with the config. Expected format of:
|
320
317
|
{
|
321
318
|
"var-name": {
|
322
319
|
"schema": {
|
323
|
-
"type": "
|
324
|
-
"default":
|
325
|
-
"minimum":
|
326
|
-
"maximum":
|
327
|
-
"enum": [..."
|
320
|
+
"type": ("string", "number", or "integer"),
|
321
|
+
"default": (optional value),
|
322
|
+
"minimum": (optional minimum),
|
323
|
+
"maximum": (optional maximum),
|
324
|
+
"enum": [..."(options)"]
|
328
325
|
}
|
329
326
|
}
|
330
327
|
}
|
wandb/apis/public/artifacts.py
CHANGED
@@ -8,6 +8,7 @@ import wandb
|
|
8
8
|
from wandb.apis import public
|
9
9
|
from wandb.apis.normalize import normalize_exceptions
|
10
10
|
from wandb.apis.paginator import Paginator
|
11
|
+
from wandb.errors.term import termlog
|
11
12
|
|
12
13
|
if TYPE_CHECKING:
|
13
14
|
from wandb.apis.public import RetryingClient, Run
|
@@ -402,6 +403,43 @@ class ArtifactCollection:
|
|
402
403
|
self._attrs = response["project"]["artifactType"]["artifactCollection"]
|
403
404
|
return self._attrs
|
404
405
|
|
406
|
+
def change_type(self, new_type: str) -> None:
|
407
|
+
"""Change the type of the artifact collection.
|
408
|
+
|
409
|
+
Arguments:
|
410
|
+
new_type: The new collection type to use, freeform string.
|
411
|
+
"""
|
412
|
+
if not self.is_sequence():
|
413
|
+
raise ValueError("Artifact collection needs to be a sequence")
|
414
|
+
termlog(f"Changing artifact collection type of " f"{self.type} to {new_type}")
|
415
|
+
template = """
|
416
|
+
mutation MoveArtifactCollection(
|
417
|
+
$artifactSequenceID: ID!
|
418
|
+
$destinationArtifactTypeName: String!
|
419
|
+
) {
|
420
|
+
moveArtifactSequence(
|
421
|
+
input: {
|
422
|
+
artifactSequenceID: $artifactSequenceID
|
423
|
+
destinationArtifactTypeName: $destinationArtifactTypeName
|
424
|
+
}
|
425
|
+
) {
|
426
|
+
artifactCollection {
|
427
|
+
id
|
428
|
+
name
|
429
|
+
description
|
430
|
+
__typename
|
431
|
+
}
|
432
|
+
}
|
433
|
+
}
|
434
|
+
"""
|
435
|
+
variable_values = {
|
436
|
+
"artifactSequenceID": self.id,
|
437
|
+
"destinationArtifactTypeName": new_type,
|
438
|
+
}
|
439
|
+
mutation = gql(template)
|
440
|
+
self.client.execute(mutation, variable_values=variable_values)
|
441
|
+
self.type = new_type
|
442
|
+
|
405
443
|
@normalize_exceptions
|
406
444
|
def is_sequence(self) -> bool:
|
407
445
|
"""Return True if this is a sequence."""
|
wandb/apis/public/files.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
"""Public API: files."""
|
2
2
|
import io
|
3
3
|
import os
|
4
|
+
from typing import Optional
|
4
5
|
|
5
6
|
import requests
|
6
7
|
from wandb_gql import gql
|
@@ -11,6 +12,7 @@ from wandb import util
|
|
11
12
|
from wandb.apis.attrs import Attrs
|
12
13
|
from wandb.apis.normalize import normalize_exceptions
|
13
14
|
from wandb.apis.paginator import Paginator
|
15
|
+
from wandb.apis.public.api import Api
|
14
16
|
from wandb.apis.public.const import RETRY_TIMEDELTA
|
15
17
|
from wandb.sdk.lib import retry
|
16
18
|
|
@@ -136,7 +138,11 @@ class File(Attrs):
|
|
136
138
|
retryable_exceptions=(RetryError, requests.RequestException),
|
137
139
|
)
|
138
140
|
def download(
|
139
|
-
self,
|
141
|
+
self,
|
142
|
+
root: str = ".",
|
143
|
+
replace: bool = False,
|
144
|
+
exist_ok: bool = False,
|
145
|
+
api: Optional[Api] = None,
|
140
146
|
) -> io.TextIOWrapper:
|
141
147
|
"""Downloads a file previously saved by a run from the wandb server.
|
142
148
|
|
@@ -150,6 +156,9 @@ class File(Attrs):
|
|
150
156
|
Raises:
|
151
157
|
`ValueError` if file already exists, replace=False and exist_ok=False.
|
152
158
|
"""
|
159
|
+
if api is None:
|
160
|
+
api = wandb.Api()
|
161
|
+
|
153
162
|
path = os.path.join(root, self.name)
|
154
163
|
if os.path.exists(path) and not replace:
|
155
164
|
if exist_ok:
|
@@ -159,7 +168,7 @@ class File(Attrs):
|
|
159
168
|
"File already exists, pass replace=True to overwrite or exist_ok=True to leave it as is and don't error."
|
160
169
|
)
|
161
170
|
|
162
|
-
util.download_file_from_url(path, self.url,
|
171
|
+
util.download_file_from_url(path, self.url, api.api_key)
|
163
172
|
return open(path)
|
164
173
|
|
165
174
|
@normalize_exceptions
|
@@ -18,22 +18,3 @@ from .interface import (
|
|
18
18
|
)
|
19
19
|
from .metrics import * # noqa
|
20
20
|
from .panels import * # noqa
|
21
|
-
|
22
|
-
|
23
|
-
def show_welcome_message():
|
24
|
-
if os.getenv("WANDB_REPORT_API_DISABLE_MESSAGE"):
|
25
|
-
return
|
26
|
-
|
27
|
-
termlog(
|
28
|
-
cleandoc(
|
29
|
-
"""
|
30
|
-
Thanks for trying out Report API v2!
|
31
|
-
See a tutorial and the changes here: http://wandb.me/report-api-quickstart
|
32
|
-
For bugs/feature requests, please create an issue on github: https://github.com/wandb/wandb/issues
|
33
|
-
You can disable this message by setting the env var WANDB_REPORT_API_DISABLE_MESSAGE=True
|
34
|
-
"""
|
35
|
-
)
|
36
|
-
)
|
37
|
-
|
38
|
-
|
39
|
-
show_welcome_message()
|
@@ -4,6 +4,8 @@ from datetime import datetime
|
|
4
4
|
from typing import Dict, Iterable, Optional, Tuple, Union
|
5
5
|
from typing import List as LList
|
6
6
|
|
7
|
+
from annotated_types import Annotated, Ge, Le
|
8
|
+
|
7
9
|
try:
|
8
10
|
from typing import Literal
|
9
11
|
except ImportError:
|
@@ -758,14 +760,8 @@ block_mapping = {
|
|
758
760
|
|
759
761
|
@dataclass(config=dataclass_config)
|
760
762
|
class GradientPoint(Base):
|
761
|
-
color: str
|
762
|
-
offset: float
|
763
|
-
|
764
|
-
@validator("color")
|
765
|
-
def validate_color(cls, v): # noqa: N805
|
766
|
-
if not internal.is_valid_color(v):
|
767
|
-
raise ValueError("invalid color, value should be hex, rgb, or rgba")
|
768
|
-
return v
|
763
|
+
color: Annotated[str, internal.ColorStrConstraints]
|
764
|
+
offset: Annotated[float, Ge(0), Le(100)] = 0
|
769
765
|
|
770
766
|
def to_model(self):
|
771
767
|
return internal.GradientPoint(color=self.color, offset=self.offset)
|
@@ -1419,8 +1415,7 @@ class Report(Base):
|
|
1419
1415
|
description: str = ""
|
1420
1416
|
blocks: LList[BlockTypes] = Field(default_factory=list)
|
1421
1417
|
|
1422
|
-
id: str = Field(default_factory=lambda: "", init=False, repr=False)
|
1423
|
-
|
1418
|
+
id: str = Field(default_factory=lambda: "", init=False, repr=False, kw_only=True)
|
1424
1419
|
_discussion_threads: list = Field(default_factory=list, init=False, repr=False)
|
1425
1420
|
_ref: dict = Field(default_factory=dict, init=False, repr=False)
|
1426
1421
|
_panel_settings: dict = Field(default_factory=dict, init=False, repr=False)
|
@@ -1470,20 +1465,22 @@ class Report(Base):
|
|
1470
1465
|
if blocks[-1] == internal.Paragraph():
|
1471
1466
|
blocks = blocks[:-1]
|
1472
1467
|
|
1473
|
-
|
1468
|
+
obj = cls(
|
1474
1469
|
title=model.display_name,
|
1475
1470
|
description=model.description,
|
1476
1471
|
entity=model.project.entity_name,
|
1477
1472
|
project=model.project.name,
|
1478
|
-
id=model.id,
|
1479
1473
|
blocks=[_lookup(b) for b in blocks],
|
1480
|
-
_discussion_threads=model.spec.discussion_threads,
|
1481
|
-
_panel_settings=model.spec.panel_settings,
|
1482
|
-
_ref=model.spec.ref,
|
1483
|
-
_authors=model.spec.authors,
|
1484
|
-
_created_at=model.created_at,
|
1485
|
-
_updated_at=model.updated_at,
|
1486
1474
|
)
|
1475
|
+
obj.id = model.id
|
1476
|
+
obj._discussion_threads = model.spec.discussion_threads
|
1477
|
+
obj._panel_settings = model.spec.panel_settings
|
1478
|
+
obj._ref = model.spec.ref
|
1479
|
+
obj._authors = model.spec.authors
|
1480
|
+
obj._created_at = model.created_at
|
1481
|
+
obj._updated_at = model.updated_at
|
1482
|
+
|
1483
|
+
return obj
|
1487
1484
|
|
1488
1485
|
@property
|
1489
1486
|
def url(self):
|
@@ -1,18 +1,19 @@
|
|
1
1
|
"""JSONSchema for internal types. Hopefully this is auto-generated one day!"""
|
2
2
|
import json
|
3
3
|
import random
|
4
|
-
import re
|
5
4
|
from copy import deepcopy
|
6
5
|
from datetime import datetime
|
7
6
|
from typing import Any, Dict, Optional, Tuple, Union
|
8
7
|
from typing import List as LList
|
9
8
|
|
9
|
+
from annotated_types import Annotated, Ge, Le
|
10
|
+
|
10
11
|
try:
|
11
12
|
from typing import Literal
|
12
13
|
except ImportError:
|
13
14
|
from typing_extensions import Literal
|
14
15
|
|
15
|
-
from pydantic import BaseModel, ConfigDict, Field, validator
|
16
|
+
from pydantic import BaseModel, ConfigDict, Field, StringConstraints, validator
|
16
17
|
from pydantic.alias_generators import to_camel
|
17
18
|
|
18
19
|
|
@@ -48,6 +49,13 @@ def _generate_name(length: int = 12) -> str:
|
|
48
49
|
return rand36.lower()[:length]
|
49
50
|
|
50
51
|
|
52
|
+
hex_pattern = r"^#(?:[0-9a-fA-F]{3}){1,2}$"
|
53
|
+
rgb_pattern = r"^rgb\(\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*\)$"
|
54
|
+
rgba_pattern = r"^rgba\(\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(1|0|0?\.\d+)\s*\)$"
|
55
|
+
ColorStrConstraints = StringConstraints(
|
56
|
+
pattern=f"{hex_pattern}|{rgb_pattern}|{rgba_pattern}"
|
57
|
+
)
|
58
|
+
|
51
59
|
LinePlotStyle = Literal["line", "stacked-area", "pct-area"]
|
52
60
|
BarPlotStyle = Literal["bar", "boxplot", "violin"]
|
53
61
|
FontSize = Literal["small", "medium", "large", "auto"]
|
@@ -609,14 +617,8 @@ class LinePlot(Panel):
|
|
609
617
|
|
610
618
|
|
611
619
|
class GradientPoint(ReportAPIBaseModel):
|
612
|
-
color: str
|
613
|
-
offset: float
|
614
|
-
|
615
|
-
@validator("color")
|
616
|
-
def validate_color(cls, v): # noqa: N805
|
617
|
-
if not is_valid_color(v):
|
618
|
-
raise ValueError("invalid color, value should be hex, rgb, or rgba")
|
619
|
-
return v
|
620
|
+
color: Annotated[str, ColorStrConstraints]
|
621
|
+
offset: Annotated[float, Ge(0), Le(100)] = 0
|
620
622
|
|
621
623
|
|
622
624
|
class ScatterPlotConfig(ReportAPIBaseModel):
|
@@ -866,38 +868,3 @@ block_type_mapping = {
|
|
866
868
|
"table-of-contents": TableOfContents,
|
867
869
|
"block-quote": BlockQuote,
|
868
870
|
}
|
869
|
-
|
870
|
-
|
871
|
-
def is_valid_color(color_str: str) -> bool:
|
872
|
-
# Regular expression for hex color validation
|
873
|
-
hex_color_pattern = r"^#(?:[0-9a-fA-F]{3}){1,2}$"
|
874
|
-
|
875
|
-
# Check if it's a valid hex color
|
876
|
-
if re.match(hex_color_pattern, color_str):
|
877
|
-
return True
|
878
|
-
|
879
|
-
# Try parsing it as an RGB or RGBA tuple
|
880
|
-
try:
|
881
|
-
# Strip 'rgb(' or 'rgba(' and the closing ')'
|
882
|
-
if color_str.startswith("rgb(") and color_str.endswith(")"):
|
883
|
-
parts = color_str[4:-1].split(",")
|
884
|
-
elif color_str.startswith("rgba(") and color_str.endswith(")"):
|
885
|
-
parts = color_str[5:-1].split(",")
|
886
|
-
else:
|
887
|
-
return False
|
888
|
-
|
889
|
-
# Convert parts to integers and validate ranges
|
890
|
-
parts = [int(p.strip()) for p in parts]
|
891
|
-
if len(parts) == 3 and all(0 <= p <= 255 for p in parts):
|
892
|
-
return True # Valid RGB
|
893
|
-
if (
|
894
|
-
len(parts) == 4
|
895
|
-
and all(0 <= p <= 255 for p in parts[:-1])
|
896
|
-
and 0 <= parts[-1] <= 1
|
897
|
-
):
|
898
|
-
return True # Valid RGBA
|
899
|
-
|
900
|
-
except ValueError:
|
901
|
-
pass
|
902
|
-
|
903
|
-
return False
|
wandb/cli/cli.py
CHANGED
@@ -433,10 +433,14 @@ def beta():
|
|
433
433
|
from wandb.util import get_core_path
|
434
434
|
|
435
435
|
if not get_core_path():
|
436
|
-
click.
|
437
|
-
|
436
|
+
click.secho(
|
437
|
+
(
|
438
|
+
"wandb beta commands require wandb-core, please install with"
|
439
|
+
" `pip install wandb-core`"
|
440
|
+
),
|
441
|
+
fg="red",
|
442
|
+
err=True,
|
438
443
|
)
|
439
|
-
sys.exit(1)
|
440
444
|
|
441
445
|
|
442
446
|
@beta.command(
|
@@ -1596,29 +1600,28 @@ def launch(
|
|
1596
1600
|
sys.exit(0)
|
1597
1601
|
else:
|
1598
1602
|
try:
|
1599
|
-
|
1600
|
-
|
1601
|
-
|
1602
|
-
|
1603
|
-
|
1604
|
-
|
1605
|
-
|
1606
|
-
|
1607
|
-
|
1608
|
-
|
1609
|
-
|
1610
|
-
|
1611
|
-
|
1612
|
-
|
1613
|
-
|
1614
|
-
|
1615
|
-
|
1616
|
-
|
1617
|
-
|
1618
|
-
|
1619
|
-
priority=priority,
|
1620
|
-
)
|
1603
|
+
_launch_add(
|
1604
|
+
api,
|
1605
|
+
uri,
|
1606
|
+
job,
|
1607
|
+
config,
|
1608
|
+
template_variables,
|
1609
|
+
project,
|
1610
|
+
entity,
|
1611
|
+
queue,
|
1612
|
+
resource,
|
1613
|
+
entry_point,
|
1614
|
+
name,
|
1615
|
+
git_version,
|
1616
|
+
docker_image,
|
1617
|
+
project_queue,
|
1618
|
+
resource_args,
|
1619
|
+
build=build,
|
1620
|
+
run_id=run_id,
|
1621
|
+
repository=repository,
|
1622
|
+
priority=priority,
|
1621
1623
|
)
|
1624
|
+
|
1622
1625
|
except Exception as e:
|
1623
1626
|
wandb._sentry.exception(e)
|
1624
1627
|
raise e
|
@@ -2342,8 +2345,29 @@ def artifact():
|
|
2342
2345
|
default=None,
|
2343
2346
|
help="Resume the last run from your current directory.",
|
2344
2347
|
)
|
2348
|
+
@click.option(
|
2349
|
+
"--skip_cache",
|
2350
|
+
is_flag=True,
|
2351
|
+
default=False,
|
2352
|
+
help="Skip caching while uploading artifact files.",
|
2353
|
+
)
|
2354
|
+
@click.option(
|
2355
|
+
"--policy",
|
2356
|
+
default="mutable",
|
2357
|
+
help="Set the storage policy while uploading artifact files.",
|
2358
|
+
)
|
2345
2359
|
@display_error
|
2346
|
-
def put(
|
2360
|
+
def put(
|
2361
|
+
path,
|
2362
|
+
name,
|
2363
|
+
description,
|
2364
|
+
type,
|
2365
|
+
alias,
|
2366
|
+
run_id,
|
2367
|
+
resume,
|
2368
|
+
skip_cache,
|
2369
|
+
policy,
|
2370
|
+
):
|
2347
2371
|
if name is None:
|
2348
2372
|
name = os.path.basename(path)
|
2349
2373
|
public_api = PublicApi()
|
@@ -2358,10 +2382,10 @@ def put(path, name, description, type, alias, run_id, resume):
|
|
2358
2382
|
artifact_path = f"{entity}/{project}/{artifact_name}:{alias[0]}"
|
2359
2383
|
if os.path.isdir(path):
|
2360
2384
|
wandb.termlog(f'Uploading directory {path} to: "{artifact_path}" ({type})')
|
2361
|
-
artifact.add_dir(path)
|
2385
|
+
artifact.add_dir(path, skip_cache=skip_cache, policy=policy)
|
2362
2386
|
elif os.path.isfile(path):
|
2363
2387
|
wandb.termlog(f'Uploading file {path} to: "{artifact_path}" ({type})')
|
2364
|
-
artifact.add_file(path)
|
2388
|
+
artifact.add_file(path, skip_cache=skip_cache, policy=policy)
|
2365
2389
|
elif "://" in path:
|
2366
2390
|
wandb.termlog(
|
2367
2391
|
f'Logging reference artifact from {path} to: "{artifact_path}" ({type})'
|
@@ -2861,30 +2885,3 @@ def verify(host):
|
|
2861
2885
|
and url_success
|
2862
2886
|
):
|
2863
2887
|
sys.exit(1)
|
2864
|
-
|
2865
|
-
|
2866
|
-
@cli.group("import", help="Commands for importing data from other systems")
|
2867
|
-
def importer():
|
2868
|
-
pass
|
2869
|
-
|
2870
|
-
|
2871
|
-
@importer.command("mlflow", help="Import from MLFlow")
|
2872
|
-
@click.option("--mlflow-tracking-uri", help="MLFlow Tracking URI")
|
2873
|
-
@click.option(
|
2874
|
-
"--target-entity", required=True, help="Override default entity to import data into"
|
2875
|
-
)
|
2876
|
-
@click.option(
|
2877
|
-
"--target-project",
|
2878
|
-
required=True,
|
2879
|
-
help="Override default project to import data into",
|
2880
|
-
)
|
2881
|
-
def mlflow(mlflow_tracking_uri, target_entity, target_project):
|
2882
|
-
from wandb.apis.importers import MlflowImporter
|
2883
|
-
|
2884
|
-
importer = MlflowImporter(mlflow_tracking_uri=mlflow_tracking_uri)
|
2885
|
-
overrides = {
|
2886
|
-
"entity": target_entity,
|
2887
|
-
"project": target_project,
|
2888
|
-
}
|
2889
|
-
|
2890
|
-
importer.import_all_parallel(overrides=overrides)
|
@@ -47,7 +47,8 @@ def monitor():
|
|
47
47
|
import gym
|
48
48
|
else:
|
49
49
|
import gymnasium as gym # type: ignore
|
50
|
-
|
50
|
+
|
51
|
+
from wandb.util import parse_version
|
51
52
|
|
52
53
|
if parse_version(gym.__version__) < parse_version("0.26.0"):
|
53
54
|
_gym_version_lt_0_26 = True
|
@@ -187,7 +187,7 @@ class WandbModelCheckpoint(callbacks.ModelCheckpoint):
|
|
187
187
|
@property
|
188
188
|
def is_old_tf_keras_version(self) -> Optional[bool]:
|
189
189
|
if self._is_old_tf_keras_version is None:
|
190
|
-
from
|
190
|
+
from wandb.util import parse_version
|
191
191
|
|
192
192
|
try:
|
193
193
|
if parse_version(tf.keras.__version__) < parse_version("2.6.0"):
|
wandb/integration/keras/keras.py
CHANGED
@@ -19,7 +19,8 @@ from wandb.util import add_import_hook
|
|
19
19
|
|
20
20
|
def _check_keras_version():
|
21
21
|
from keras import __version__ as keras_version
|
22
|
-
|
22
|
+
|
23
|
+
from wandb.util import parse_version
|
23
24
|
|
24
25
|
if parse_version(keras_version) < parse_version("2.4.0"):
|
25
26
|
wandb.termwarn(
|
@@ -29,7 +30,7 @@ def _check_keras_version():
|
|
29
30
|
|
30
31
|
def _can_compute_flops() -> bool:
|
31
32
|
"""FLOPS computation is restricted to TF 2.x as it requires tf.compat.v1."""
|
32
|
-
from
|
33
|
+
from wandb.util import parse_version
|
33
34
|
|
34
35
|
if parse_version(tf.__version__) >= parse_version("2.0.0"):
|
35
36
|
return True
|
@@ -73,9 +74,10 @@ def is_generator_like(data):
|
|
73
74
|
|
74
75
|
|
75
76
|
def patch_tf_keras(): # noqa: C901
|
76
|
-
from pkg_resources import parse_version
|
77
77
|
from tensorflow.python.eager import context
|
78
78
|
|
79
|
+
from wandb.util import parse_version
|
80
|
+
|
79
81
|
if (
|
80
82
|
parse_version("2.6.0")
|
81
83
|
<= parse_version(tf.__version__)
|
@@ -235,7 +237,7 @@ patch_tf_keras()
|
|
235
237
|
|
236
238
|
|
237
239
|
def _get_custom_optimizer_parent_class():
|
238
|
-
from
|
240
|
+
from wandb.util import parse_version
|
239
241
|
|
240
242
|
if parse_version(tf.__version__) >= parse_version("2.9.0"):
|
241
243
|
custom_optimizer_parent_class = tf.keras.optimizers.legacy.Optimizer
|
@@ -3,8 +3,6 @@ import itertools
|
|
3
3
|
import textwrap
|
4
4
|
from typing import Callable, List, Mapping, Optional
|
5
5
|
|
6
|
-
from pkg_resources import parse_version
|
7
|
-
|
8
6
|
import wandb
|
9
7
|
|
10
8
|
try:
|
@@ -13,6 +11,8 @@ try:
|
|
13
11
|
from kfp.components._components import _create_task_factory_from_component_spec
|
14
12
|
from kfp.components._python_op import _func_to_component_spec
|
15
13
|
|
14
|
+
from wandb.util import parse_version
|
15
|
+
|
16
16
|
MIN_KFP_VERSION = "1.6.1"
|
17
17
|
|
18
18
|
if parse_version(kfp_version) < parse_version(MIN_KFP_VERSION):
|
@@ -5,13 +5,12 @@ import re
|
|
5
5
|
import time
|
6
6
|
from typing import Any, Dict, Optional, Tuple
|
7
7
|
|
8
|
-
from pkg_resources import parse_version
|
9
|
-
|
10
8
|
import wandb
|
11
9
|
from wandb import util
|
12
10
|
from wandb.data_types import Table
|
13
11
|
from wandb.sdk.lib import telemetry
|
14
12
|
from wandb.sdk.wandb_run import Run
|
13
|
+
from wandb.util import parse_version
|
15
14
|
|
16
15
|
openai = util.get_module(
|
17
16
|
name="openai",
|