wandb 0.16.3__py3-none-any.whl → 0.16.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (90) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/agents/pyagent.py +1 -1
  3. wandb/apis/importers/__init__.py +1 -4
  4. wandb/apis/importers/internals/internal.py +386 -0
  5. wandb/apis/importers/internals/protocols.py +125 -0
  6. wandb/apis/importers/internals/util.py +78 -0
  7. wandb/apis/importers/mlflow.py +125 -88
  8. wandb/apis/importers/validation.py +108 -0
  9. wandb/apis/importers/wandb.py +1604 -0
  10. wandb/apis/public/api.py +7 -10
  11. wandb/apis/public/artifacts.py +38 -0
  12. wandb/apis/public/files.py +11 -2
  13. wandb/apis/reports/v2/__init__.py +0 -19
  14. wandb/apis/reports/v2/expr_parsing.py +0 -1
  15. wandb/apis/reports/v2/interface.py +15 -18
  16. wandb/apis/reports/v2/internal.py +12 -45
  17. wandb/cli/cli.py +52 -55
  18. wandb/integration/gym/__init__.py +2 -1
  19. wandb/integration/keras/callbacks/model_checkpoint.py +1 -1
  20. wandb/integration/keras/keras.py +6 -4
  21. wandb/integration/kfp/kfp_patch.py +2 -2
  22. wandb/integration/openai/fine_tuning.py +1 -2
  23. wandb/integration/ultralytics/callback.py +0 -1
  24. wandb/proto/v3/wandb_internal_pb2.py +332 -312
  25. wandb/proto/v3/wandb_settings_pb2.py +13 -3
  26. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  27. wandb/proto/v4/wandb_internal_pb2.py +316 -312
  28. wandb/proto/v4/wandb_settings_pb2.py +5 -3
  29. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  30. wandb/sdk/artifacts/artifact.py +75 -31
  31. wandb/sdk/artifacts/artifact_manifest.py +5 -2
  32. wandb/sdk/artifacts/artifact_manifest_entry.py +6 -1
  33. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +8 -2
  34. wandb/sdk/artifacts/artifact_saver.py +19 -47
  35. wandb/sdk/artifacts/storage_handler.py +2 -1
  36. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +22 -9
  37. wandb/sdk/artifacts/storage_policy.py +4 -1
  38. wandb/sdk/data_types/base_types/wb_value.py +1 -1
  39. wandb/sdk/data_types/image.py +2 -2
  40. wandb/sdk/interface/interface.py +49 -13
  41. wandb/sdk/interface/interface_shared.py +17 -11
  42. wandb/sdk/internal/file_stream.py +20 -1
  43. wandb/sdk/internal/handler.py +1 -4
  44. wandb/sdk/internal/internal_api.py +3 -1
  45. wandb/sdk/internal/job_builder.py +49 -19
  46. wandb/sdk/internal/profiler.py +1 -1
  47. wandb/sdk/internal/sender.py +96 -124
  48. wandb/sdk/internal/sender_config.py +197 -0
  49. wandb/sdk/internal/settings_static.py +9 -0
  50. wandb/sdk/internal/system/system_info.py +5 -3
  51. wandb/sdk/internal/update.py +1 -1
  52. wandb/sdk/launch/_launch.py +3 -3
  53. wandb/sdk/launch/_launch_add.py +28 -29
  54. wandb/sdk/launch/_project_spec.py +148 -136
  55. wandb/sdk/launch/agent/agent.py +3 -7
  56. wandb/sdk/launch/agent/config.py +0 -27
  57. wandb/sdk/launch/builder/build.py +54 -28
  58. wandb/sdk/launch/builder/docker_builder.py +4 -15
  59. wandb/sdk/launch/builder/kaniko_builder.py +72 -45
  60. wandb/sdk/launch/create_job.py +6 -40
  61. wandb/sdk/launch/loader.py +10 -0
  62. wandb/sdk/launch/registry/anon.py +29 -0
  63. wandb/sdk/launch/registry/local_registry.py +4 -1
  64. wandb/sdk/launch/runner/kubernetes_runner.py +20 -2
  65. wandb/sdk/launch/runner/local_container.py +15 -10
  66. wandb/sdk/launch/runner/sagemaker_runner.py +1 -1
  67. wandb/sdk/launch/sweeps/scheduler.py +11 -3
  68. wandb/sdk/launch/utils.py +14 -0
  69. wandb/sdk/lib/__init__.py +2 -5
  70. wandb/sdk/lib/_settings_toposort_generated.py +4 -1
  71. wandb/sdk/lib/apikey.py +0 -5
  72. wandb/sdk/lib/config_util.py +0 -31
  73. wandb/sdk/lib/filesystem.py +11 -1
  74. wandb/sdk/lib/run_moment.py +72 -0
  75. wandb/sdk/service/service.py +7 -2
  76. wandb/sdk/service/streams.py +1 -6
  77. wandb/sdk/verify/verify.py +2 -1
  78. wandb/sdk/wandb_init.py +12 -1
  79. wandb/sdk/wandb_login.py +43 -26
  80. wandb/sdk/wandb_run.py +164 -110
  81. wandb/sdk/wandb_settings.py +58 -16
  82. wandb/testing/relay.py +5 -6
  83. wandb/util.py +50 -7
  84. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/METADATA +8 -1
  85. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/RECORD +89 -82
  86. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/WHEEL +1 -1
  87. wandb/apis/importers/base.py +0 -400
  88. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/LICENSE +0 -0
  89. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/entry_points.txt +0 -0
  90. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/top_level.txt +0 -0
wandb/apis/public/api.py CHANGED
@@ -88,10 +88,7 @@ class RetryingClient:
88
88
  return self._server_info
89
89
 
90
90
  def version_supported(self, min_version):
91
- try:
92
- from packaging.version import Version as parse_version # noqa: N813
93
- except ImportError:
94
- from pkg_resources import parse_version
91
+ from wandb.util import parse_version
95
92
 
96
93
  return parse_version(min_version) <= parse_version(
97
94
  self.server_info["cliVersionInfo"]["max_cli_version"]
@@ -316,15 +313,15 @@ class Api:
316
313
  entity: (str) Optional name of the entity to create the queue. If None, will use the configured or default entity.
317
314
  prioritization_mode: (str) Optional version of prioritization to use. Either "V0" or None
318
315
  config: (dict) Optional default resource configuration to be used for the queue. Use handlebars (eg. "{{var}}") to specify template variables.
319
- template_variables (dict): A dictionary of template variable schemas to be used with the config. Expected format of:
316
+ template_variables: (dict) A dictionary of template variable schemas to be used with the config. Expected format of:
320
317
  {
321
318
  "var-name": {
322
319
  "schema": {
323
- "type": "<string | number | integer>",
324
- "default": <optional value>,
325
- "minimum": <optional minimum>,
326
- "maximum": <optional maximum>,
327
- "enum": [..."<options>"]
320
+ "type": ("string", "number", or "integer"),
321
+ "default": (optional value),
322
+ "minimum": (optional minimum),
323
+ "maximum": (optional maximum),
324
+ "enum": [..."(options)"]
328
325
  }
329
326
  }
330
327
  }
@@ -8,6 +8,7 @@ import wandb
8
8
  from wandb.apis import public
9
9
  from wandb.apis.normalize import normalize_exceptions
10
10
  from wandb.apis.paginator import Paginator
11
+ from wandb.errors.term import termlog
11
12
 
12
13
  if TYPE_CHECKING:
13
14
  from wandb.apis.public import RetryingClient, Run
@@ -402,6 +403,43 @@ class ArtifactCollection:
402
403
  self._attrs = response["project"]["artifactType"]["artifactCollection"]
403
404
  return self._attrs
404
405
 
406
+ def change_type(self, new_type: str) -> None:
407
+ """Change the type of the artifact collection.
408
+
409
+ Arguments:
410
+ new_type: The new collection type to use, freeform string.
411
+ """
412
+ if not self.is_sequence():
413
+ raise ValueError("Artifact collection needs to be a sequence")
414
+ termlog(f"Changing artifact collection type of " f"{self.type} to {new_type}")
415
+ template = """
416
+ mutation MoveArtifactCollection(
417
+ $artifactSequenceID: ID!
418
+ $destinationArtifactTypeName: String!
419
+ ) {
420
+ moveArtifactSequence(
421
+ input: {
422
+ artifactSequenceID: $artifactSequenceID
423
+ destinationArtifactTypeName: $destinationArtifactTypeName
424
+ }
425
+ ) {
426
+ artifactCollection {
427
+ id
428
+ name
429
+ description
430
+ __typename
431
+ }
432
+ }
433
+ }
434
+ """
435
+ variable_values = {
436
+ "artifactSequenceID": self.id,
437
+ "destinationArtifactTypeName": new_type,
438
+ }
439
+ mutation = gql(template)
440
+ self.client.execute(mutation, variable_values=variable_values)
441
+ self.type = new_type
442
+
405
443
  @normalize_exceptions
406
444
  def is_sequence(self) -> bool:
407
445
  """Return True if this is a sequence."""
@@ -1,6 +1,7 @@
1
1
  """Public API: files."""
2
2
  import io
3
3
  import os
4
+ from typing import Optional
4
5
 
5
6
  import requests
6
7
  from wandb_gql import gql
@@ -11,6 +12,7 @@ from wandb import util
11
12
  from wandb.apis.attrs import Attrs
12
13
  from wandb.apis.normalize import normalize_exceptions
13
14
  from wandb.apis.paginator import Paginator
15
+ from wandb.apis.public.api import Api
14
16
  from wandb.apis.public.const import RETRY_TIMEDELTA
15
17
  from wandb.sdk.lib import retry
16
18
 
@@ -136,7 +138,11 @@ class File(Attrs):
136
138
  retryable_exceptions=(RetryError, requests.RequestException),
137
139
  )
138
140
  def download(
139
- self, root: str = ".", replace: bool = False, exist_ok: bool = False
141
+ self,
142
+ root: str = ".",
143
+ replace: bool = False,
144
+ exist_ok: bool = False,
145
+ api: Optional[Api] = None,
140
146
  ) -> io.TextIOWrapper:
141
147
  """Downloads a file previously saved by a run from the wandb server.
142
148
 
@@ -150,6 +156,9 @@ class File(Attrs):
150
156
  Raises:
151
157
  `ValueError` if file already exists, replace=False and exist_ok=False.
152
158
  """
159
+ if api is None:
160
+ api = wandb.Api()
161
+
153
162
  path = os.path.join(root, self.name)
154
163
  if os.path.exists(path) and not replace:
155
164
  if exist_ok:
@@ -159,7 +168,7 @@ class File(Attrs):
159
168
  "File already exists, pass replace=True to overwrite or exist_ok=True to leave it as is and don't error."
160
169
  )
161
170
 
162
- util.download_file_from_url(path, self.url, wandb.Api().api_key)
171
+ util.download_file_from_url(path, self.url, api.api_key)
163
172
  return open(path)
164
173
 
165
174
  @normalize_exceptions
@@ -18,22 +18,3 @@ from .interface import (
18
18
  )
19
19
  from .metrics import * # noqa
20
20
  from .panels import * # noqa
21
-
22
-
23
- def show_welcome_message():
24
- if os.getenv("WANDB_REPORT_API_DISABLE_MESSAGE"):
25
- return
26
-
27
- termlog(
28
- cleandoc(
29
- """
30
- Thanks for trying out Report API v2!
31
- See a tutorial and the changes here: http://wandb.me/report-api-quickstart
32
- For bugs/feature requests, please create an issue on github: https://github.com/wandb/wandb/issues
33
- You can disable this message by setting the env var WANDB_REPORT_API_DISABLE_MESSAGE=True
34
- """
35
- )
36
- )
37
-
38
-
39
- show_welcome_message()
@@ -37,7 +37,6 @@ fe_name_map_reversed = {v: k for k, v in fe_name_map.items()}
37
37
 
38
38
 
39
39
  def expr_to_filters(expr: str) -> Filters:
40
- print(f"expr: {expr=}")
41
40
  if not expr:
42
41
  filters = []
43
42
  else:
@@ -4,6 +4,8 @@ from datetime import datetime
4
4
  from typing import Dict, Iterable, Optional, Tuple, Union
5
5
  from typing import List as LList
6
6
 
7
+ from annotated_types import Annotated, Ge, Le
8
+
7
9
  try:
8
10
  from typing import Literal
9
11
  except ImportError:
@@ -758,14 +760,8 @@ block_mapping = {
758
760
 
759
761
  @dataclass(config=dataclass_config)
760
762
  class GradientPoint(Base):
761
- color: str
762
- offset: float = Field(0, ge=0, le=100)
763
-
764
- @validator("color")
765
- def validate_color(cls, v): # noqa: N805
766
- if not internal.is_valid_color(v):
767
- raise ValueError("invalid color, value should be hex, rgb, or rgba")
768
- return v
763
+ color: Annotated[str, internal.ColorStrConstraints]
764
+ offset: Annotated[float, Ge(0), Le(100)] = 0
769
765
 
770
766
  def to_model(self):
771
767
  return internal.GradientPoint(color=self.color, offset=self.offset)
@@ -1419,8 +1415,7 @@ class Report(Base):
1419
1415
  description: str = ""
1420
1416
  blocks: LList[BlockTypes] = Field(default_factory=list)
1421
1417
 
1422
- id: str = Field(default_factory=lambda: "", init=False, repr=False)
1423
-
1418
+ id: str = Field(default_factory=lambda: "", init=False, repr=False, kw_only=True)
1424
1419
  _discussion_threads: list = Field(default_factory=list, init=False, repr=False)
1425
1420
  _ref: dict = Field(default_factory=dict, init=False, repr=False)
1426
1421
  _panel_settings: dict = Field(default_factory=dict, init=False, repr=False)
@@ -1470,20 +1465,22 @@ class Report(Base):
1470
1465
  if blocks[-1] == internal.Paragraph():
1471
1466
  blocks = blocks[:-1]
1472
1467
 
1473
- return cls(
1468
+ obj = cls(
1474
1469
  title=model.display_name,
1475
1470
  description=model.description,
1476
1471
  entity=model.project.entity_name,
1477
1472
  project=model.project.name,
1478
- id=model.id,
1479
1473
  blocks=[_lookup(b) for b in blocks],
1480
- _discussion_threads=model.spec.discussion_threads,
1481
- _panel_settings=model.spec.panel_settings,
1482
- _ref=model.spec.ref,
1483
- _authors=model.spec.authors,
1484
- _created_at=model.created_at,
1485
- _updated_at=model.updated_at,
1486
1474
  )
1475
+ obj.id = model.id
1476
+ obj._discussion_threads = model.spec.discussion_threads
1477
+ obj._panel_settings = model.spec.panel_settings
1478
+ obj._ref = model.spec.ref
1479
+ obj._authors = model.spec.authors
1480
+ obj._created_at = model.created_at
1481
+ obj._updated_at = model.updated_at
1482
+
1483
+ return obj
1487
1484
 
1488
1485
  @property
1489
1486
  def url(self):
@@ -1,18 +1,19 @@
1
1
  """JSONSchema for internal types. Hopefully this is auto-generated one day!"""
2
2
  import json
3
3
  import random
4
- import re
5
4
  from copy import deepcopy
6
5
  from datetime import datetime
7
6
  from typing import Any, Dict, Optional, Tuple, Union
8
7
  from typing import List as LList
9
8
 
9
+ from annotated_types import Annotated, Ge, Le
10
+
10
11
  try:
11
12
  from typing import Literal
12
13
  except ImportError:
13
14
  from typing_extensions import Literal
14
15
 
15
- from pydantic import BaseModel, ConfigDict, Field, validator
16
+ from pydantic import BaseModel, ConfigDict, Field, StringConstraints, validator
16
17
  from pydantic.alias_generators import to_camel
17
18
 
18
19
 
@@ -48,6 +49,13 @@ def _generate_name(length: int = 12) -> str:
48
49
  return rand36.lower()[:length]
49
50
 
50
51
 
52
+ hex_pattern = r"^#(?:[0-9a-fA-F]{3}){1,2}$"
53
+ rgb_pattern = r"^rgb\(\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*\)$"
54
+ rgba_pattern = r"^rgba\(\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(1|0|0?\.\d+)\s*\)$"
55
+ ColorStrConstraints = StringConstraints(
56
+ pattern=f"{hex_pattern}|{rgb_pattern}|{rgba_pattern}"
57
+ )
58
+
51
59
  LinePlotStyle = Literal["line", "stacked-area", "pct-area"]
52
60
  BarPlotStyle = Literal["bar", "boxplot", "violin"]
53
61
  FontSize = Literal["small", "medium", "large", "auto"]
@@ -609,14 +617,8 @@ class LinePlot(Panel):
609
617
 
610
618
 
611
619
  class GradientPoint(ReportAPIBaseModel):
612
- color: str
613
- offset: float = Field(0, ge=0, le=100)
614
-
615
- @validator("color")
616
- def validate_color(cls, v): # noqa: N805
617
- if not is_valid_color(v):
618
- raise ValueError("invalid color, value should be hex, rgb, or rgba")
619
- return v
620
+ color: Annotated[str, ColorStrConstraints]
621
+ offset: Annotated[float, Ge(0), Le(100)] = 0
620
622
 
621
623
 
622
624
  class ScatterPlotConfig(ReportAPIBaseModel):
@@ -866,38 +868,3 @@ block_type_mapping = {
866
868
  "table-of-contents": TableOfContents,
867
869
  "block-quote": BlockQuote,
868
870
  }
869
-
870
-
871
- def is_valid_color(color_str: str) -> bool:
872
- # Regular expression for hex color validation
873
- hex_color_pattern = r"^#(?:[0-9a-fA-F]{3}){1,2}$"
874
-
875
- # Check if it's a valid hex color
876
- if re.match(hex_color_pattern, color_str):
877
- return True
878
-
879
- # Try parsing it as an RGB or RGBA tuple
880
- try:
881
- # Strip 'rgb(' or 'rgba(' and the closing ')'
882
- if color_str.startswith("rgb(") and color_str.endswith(")"):
883
- parts = color_str[4:-1].split(",")
884
- elif color_str.startswith("rgba(") and color_str.endswith(")"):
885
- parts = color_str[5:-1].split(",")
886
- else:
887
- return False
888
-
889
- # Convert parts to integers and validate ranges
890
- parts = [int(p.strip()) for p in parts]
891
- if len(parts) == 3 and all(0 <= p <= 255 for p in parts):
892
- return True # Valid RGB
893
- if (
894
- len(parts) == 4
895
- and all(0 <= p <= 255 for p in parts[:-1])
896
- and 0 <= parts[-1] <= 1
897
- ):
898
- return True # Valid RGBA
899
-
900
- except ValueError:
901
- pass
902
-
903
- return False
wandb/cli/cli.py CHANGED
@@ -433,10 +433,14 @@ def beta():
433
433
  from wandb.util import get_core_path
434
434
 
435
435
  if not get_core_path():
436
- click.echo(
437
- "wandb beta commands require wandb-core, please install with `pip install wandb-core`"
436
+ click.secho(
437
+ (
438
+ "wandb beta commands require wandb-core, please install with"
439
+ " `pip install wandb-core`"
440
+ ),
441
+ fg="red",
442
+ err=True,
438
443
  )
439
- sys.exit(1)
440
444
 
441
445
 
442
446
  @beta.command(
@@ -1596,29 +1600,28 @@ def launch(
1596
1600
  sys.exit(0)
1597
1601
  else:
1598
1602
  try:
1599
- asyncio.run(
1600
- _launch_add(
1601
- api,
1602
- uri,
1603
- job,
1604
- config,
1605
- template_variables,
1606
- project,
1607
- entity,
1608
- queue,
1609
- resource,
1610
- entry_point,
1611
- name,
1612
- git_version,
1613
- docker_image,
1614
- project_queue,
1615
- resource_args,
1616
- build=build,
1617
- run_id=run_id,
1618
- repository=repository,
1619
- priority=priority,
1620
- )
1603
+ _launch_add(
1604
+ api,
1605
+ uri,
1606
+ job,
1607
+ config,
1608
+ template_variables,
1609
+ project,
1610
+ entity,
1611
+ queue,
1612
+ resource,
1613
+ entry_point,
1614
+ name,
1615
+ git_version,
1616
+ docker_image,
1617
+ project_queue,
1618
+ resource_args,
1619
+ build=build,
1620
+ run_id=run_id,
1621
+ repository=repository,
1622
+ priority=priority,
1621
1623
  )
1624
+
1622
1625
  except Exception as e:
1623
1626
  wandb._sentry.exception(e)
1624
1627
  raise e
@@ -2342,8 +2345,29 @@ def artifact():
2342
2345
  default=None,
2343
2346
  help="Resume the last run from your current directory.",
2344
2347
  )
2348
+ @click.option(
2349
+ "--skip_cache",
2350
+ is_flag=True,
2351
+ default=False,
2352
+ help="Skip caching while uploading artifact files.",
2353
+ )
2354
+ @click.option(
2355
+ "--policy",
2356
+ default="mutable",
2357
+ help="Set the storage policy while uploading artifact files.",
2358
+ )
2345
2359
  @display_error
2346
- def put(path, name, description, type, alias, run_id, resume):
2360
+ def put(
2361
+ path,
2362
+ name,
2363
+ description,
2364
+ type,
2365
+ alias,
2366
+ run_id,
2367
+ resume,
2368
+ skip_cache,
2369
+ policy,
2370
+ ):
2347
2371
  if name is None:
2348
2372
  name = os.path.basename(path)
2349
2373
  public_api = PublicApi()
@@ -2358,10 +2382,10 @@ def put(path, name, description, type, alias, run_id, resume):
2358
2382
  artifact_path = f"{entity}/{project}/{artifact_name}:{alias[0]}"
2359
2383
  if os.path.isdir(path):
2360
2384
  wandb.termlog(f'Uploading directory {path} to: "{artifact_path}" ({type})')
2361
- artifact.add_dir(path)
2385
+ artifact.add_dir(path, skip_cache=skip_cache, policy=policy)
2362
2386
  elif os.path.isfile(path):
2363
2387
  wandb.termlog(f'Uploading file {path} to: "{artifact_path}" ({type})')
2364
- artifact.add_file(path)
2388
+ artifact.add_file(path, skip_cache=skip_cache, policy=policy)
2365
2389
  elif "://" in path:
2366
2390
  wandb.termlog(
2367
2391
  f'Logging reference artifact from {path} to: "{artifact_path}" ({type})'
@@ -2861,30 +2885,3 @@ def verify(host):
2861
2885
  and url_success
2862
2886
  ):
2863
2887
  sys.exit(1)
2864
-
2865
-
2866
- @cli.group("import", help="Commands for importing data from other systems")
2867
- def importer():
2868
- pass
2869
-
2870
-
2871
- @importer.command("mlflow", help="Import from MLFlow")
2872
- @click.option("--mlflow-tracking-uri", help="MLFlow Tracking URI")
2873
- @click.option(
2874
- "--target-entity", required=True, help="Override default entity to import data into"
2875
- )
2876
- @click.option(
2877
- "--target-project",
2878
- required=True,
2879
- help="Override default project to import data into",
2880
- )
2881
- def mlflow(mlflow_tracking_uri, target_entity, target_project):
2882
- from wandb.apis.importers import MlflowImporter
2883
-
2884
- importer = MlflowImporter(mlflow_tracking_uri=mlflow_tracking_uri)
2885
- overrides = {
2886
- "entity": target_entity,
2887
- "project": target_project,
2888
- }
2889
-
2890
- importer.import_all_parallel(overrides=overrides)
@@ -47,7 +47,8 @@ def monitor():
47
47
  import gym
48
48
  else:
49
49
  import gymnasium as gym # type: ignore
50
- from pkg_resources import parse_version
50
+
51
+ from wandb.util import parse_version
51
52
 
52
53
  if parse_version(gym.__version__) < parse_version("0.26.0"):
53
54
  _gym_version_lt_0_26 = True
@@ -187,7 +187,7 @@ class WandbModelCheckpoint(callbacks.ModelCheckpoint):
187
187
  @property
188
188
  def is_old_tf_keras_version(self) -> Optional[bool]:
189
189
  if self._is_old_tf_keras_version is None:
190
- from pkg_resources import parse_version
190
+ from wandb.util import parse_version
191
191
 
192
192
  try:
193
193
  if parse_version(tf.keras.__version__) < parse_version("2.6.0"):
@@ -19,7 +19,8 @@ from wandb.util import add_import_hook
19
19
 
20
20
  def _check_keras_version():
21
21
  from keras import __version__ as keras_version
22
- from pkg_resources import parse_version
22
+
23
+ from wandb.util import parse_version
23
24
 
24
25
  if parse_version(keras_version) < parse_version("2.4.0"):
25
26
  wandb.termwarn(
@@ -29,7 +30,7 @@ def _check_keras_version():
29
30
 
30
31
  def _can_compute_flops() -> bool:
31
32
  """FLOPS computation is restricted to TF 2.x as it requires tf.compat.v1."""
32
- from pkg_resources import parse_version
33
+ from wandb.util import parse_version
33
34
 
34
35
  if parse_version(tf.__version__) >= parse_version("2.0.0"):
35
36
  return True
@@ -73,9 +74,10 @@ def is_generator_like(data):
73
74
 
74
75
 
75
76
  def patch_tf_keras(): # noqa: C901
76
- from pkg_resources import parse_version
77
77
  from tensorflow.python.eager import context
78
78
 
79
+ from wandb.util import parse_version
80
+
79
81
  if (
80
82
  parse_version("2.6.0")
81
83
  <= parse_version(tf.__version__)
@@ -235,7 +237,7 @@ patch_tf_keras()
235
237
 
236
238
 
237
239
  def _get_custom_optimizer_parent_class():
238
- from pkg_resources import parse_version
240
+ from wandb.util import parse_version
239
241
 
240
242
  if parse_version(tf.__version__) >= parse_version("2.9.0"):
241
243
  custom_optimizer_parent_class = tf.keras.optimizers.legacy.Optimizer
@@ -3,8 +3,6 @@ import itertools
3
3
  import textwrap
4
4
  from typing import Callable, List, Mapping, Optional
5
5
 
6
- from pkg_resources import parse_version
7
-
8
6
  import wandb
9
7
 
10
8
  try:
@@ -13,6 +11,8 @@ try:
13
11
  from kfp.components._components import _create_task_factory_from_component_spec
14
12
  from kfp.components._python_op import _func_to_component_spec
15
13
 
14
+ from wandb.util import parse_version
15
+
16
16
  MIN_KFP_VERSION = "1.6.1"
17
17
 
18
18
  if parse_version(kfp_version) < parse_version(MIN_KFP_VERSION):
@@ -5,13 +5,12 @@ import re
5
5
  import time
6
6
  from typing import Any, Dict, Optional, Tuple
7
7
 
8
- from pkg_resources import parse_version
9
-
10
8
  import wandb
11
9
  from wandb import util
12
10
  from wandb.data_types import Table
13
11
  from wandb.sdk.lib import telemetry
14
12
  from wandb.sdk.wandb_run import Run
13
+ from wandb.util import parse_version
15
14
 
16
15
  openai = util.get_module(
17
16
  name="openai",
@@ -321,7 +321,6 @@ class WandBUltralyticsCallback:
321
321
  )
322
322
  if self.enable_model_checkpointing:
323
323
  self._save_model(trainer)
324
- self.model.to("cpu")
325
324
  trainer.model.to(self.device)
326
325
 
327
326
  def on_train_end(self, trainer: TRAINER_TYPE):