wandb 0.16.3__py3-none-any.whl → 0.16.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/agents/pyagent.py +1 -1
  3. wandb/apis/importers/__init__.py +1 -4
  4. wandb/apis/importers/internals/internal.py +386 -0
  5. wandb/apis/importers/internals/protocols.py +125 -0
  6. wandb/apis/importers/internals/util.py +78 -0
  7. wandb/apis/importers/mlflow.py +125 -88
  8. wandb/apis/importers/validation.py +108 -0
  9. wandb/apis/importers/wandb.py +1604 -0
  10. wandb/apis/public/api.py +7 -10
  11. wandb/apis/public/artifacts.py +38 -0
  12. wandb/apis/public/files.py +11 -2
  13. wandb/apis/reports/v2/__init__.py +0 -19
  14. wandb/apis/reports/v2/expr_parsing.py +0 -1
  15. wandb/apis/reports/v2/interface.py +15 -18
  16. wandb/apis/reports/v2/internal.py +12 -45
  17. wandb/cli/cli.py +52 -55
  18. wandb/integration/gym/__init__.py +2 -1
  19. wandb/integration/keras/callbacks/model_checkpoint.py +1 -1
  20. wandb/integration/keras/keras.py +6 -4
  21. wandb/integration/kfp/kfp_patch.py +2 -2
  22. wandb/integration/openai/fine_tuning.py +1 -2
  23. wandb/integration/ultralytics/callback.py +0 -1
  24. wandb/proto/v3/wandb_internal_pb2.py +332 -312
  25. wandb/proto/v3/wandb_settings_pb2.py +13 -3
  26. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  27. wandb/proto/v4/wandb_internal_pb2.py +316 -312
  28. wandb/proto/v4/wandb_settings_pb2.py +5 -3
  29. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  30. wandb/sdk/artifacts/artifact.py +75 -31
  31. wandb/sdk/artifacts/artifact_manifest.py +5 -2
  32. wandb/sdk/artifacts/artifact_manifest_entry.py +6 -1
  33. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +8 -2
  34. wandb/sdk/artifacts/artifact_saver.py +19 -47
  35. wandb/sdk/artifacts/storage_handler.py +2 -1
  36. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +22 -9
  37. wandb/sdk/artifacts/storage_policy.py +4 -1
  38. wandb/sdk/data_types/base_types/wb_value.py +1 -1
  39. wandb/sdk/data_types/image.py +2 -2
  40. wandb/sdk/interface/interface.py +49 -13
  41. wandb/sdk/interface/interface_shared.py +17 -11
  42. wandb/sdk/internal/file_stream.py +20 -1
  43. wandb/sdk/internal/handler.py +1 -4
  44. wandb/sdk/internal/internal_api.py +3 -1
  45. wandb/sdk/internal/job_builder.py +49 -19
  46. wandb/sdk/internal/profiler.py +1 -1
  47. wandb/sdk/internal/sender.py +96 -124
  48. wandb/sdk/internal/sender_config.py +197 -0
  49. wandb/sdk/internal/settings_static.py +9 -0
  50. wandb/sdk/internal/system/system_info.py +5 -3
  51. wandb/sdk/internal/update.py +1 -1
  52. wandb/sdk/launch/_launch.py +3 -3
  53. wandb/sdk/launch/_launch_add.py +28 -29
  54. wandb/sdk/launch/_project_spec.py +148 -136
  55. wandb/sdk/launch/agent/agent.py +3 -7
  56. wandb/sdk/launch/agent/config.py +0 -27
  57. wandb/sdk/launch/builder/build.py +54 -28
  58. wandb/sdk/launch/builder/docker_builder.py +4 -15
  59. wandb/sdk/launch/builder/kaniko_builder.py +72 -45
  60. wandb/sdk/launch/create_job.py +6 -40
  61. wandb/sdk/launch/loader.py +10 -0
  62. wandb/sdk/launch/registry/anon.py +29 -0
  63. wandb/sdk/launch/registry/local_registry.py +4 -1
  64. wandb/sdk/launch/runner/kubernetes_runner.py +20 -2
  65. wandb/sdk/launch/runner/local_container.py +15 -10
  66. wandb/sdk/launch/runner/sagemaker_runner.py +1 -1
  67. wandb/sdk/launch/sweeps/scheduler.py +11 -3
  68. wandb/sdk/launch/utils.py +14 -0
  69. wandb/sdk/lib/__init__.py +2 -5
  70. wandb/sdk/lib/_settings_toposort_generated.py +4 -1
  71. wandb/sdk/lib/apikey.py +0 -5
  72. wandb/sdk/lib/config_util.py +0 -31
  73. wandb/sdk/lib/filesystem.py +11 -1
  74. wandb/sdk/lib/run_moment.py +72 -0
  75. wandb/sdk/service/service.py +7 -2
  76. wandb/sdk/service/streams.py +1 -6
  77. wandb/sdk/verify/verify.py +2 -1
  78. wandb/sdk/wandb_init.py +12 -1
  79. wandb/sdk/wandb_login.py +43 -26
  80. wandb/sdk/wandb_run.py +164 -110
  81. wandb/sdk/wandb_settings.py +58 -16
  82. wandb/testing/relay.py +5 -6
  83. wandb/util.py +50 -7
  84. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/METADATA +8 -1
  85. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/RECORD +89 -82
  86. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/WHEEL +1 -1
  87. wandb/apis/importers/base.py +0 -400
  88. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/LICENSE +0 -0
  89. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/entry_points.txt +0 -0
  90. {wandb-0.16.3.dist-info → wandb-0.16.5.dist-info}/top_level.txt +0 -0
wandb/apis/public/api.py CHANGED
@@ -88,10 +88,7 @@ class RetryingClient:
88
88
  return self._server_info
89
89
 
90
90
  def version_supported(self, min_version):
91
- try:
92
- from packaging.version import Version as parse_version # noqa: N813
93
- except ImportError:
94
- from pkg_resources import parse_version
91
+ from wandb.util import parse_version
95
92
 
96
93
  return parse_version(min_version) <= parse_version(
97
94
  self.server_info["cliVersionInfo"]["max_cli_version"]
@@ -316,15 +313,15 @@ class Api:
316
313
  entity: (str) Optional name of the entity to create the queue. If None, will use the configured or default entity.
317
314
  prioritization_mode: (str) Optional version of prioritization to use. Either "V0" or None
318
315
  config: (dict) Optional default resource configuration to be used for the queue. Use handlebars (eg. "{{var}}") to specify template variables.
319
- template_variables (dict): A dictionary of template variable schemas to be used with the config. Expected format of:
316
+ template_variables: (dict) A dictionary of template variable schemas to be used with the config. Expected format of:
320
317
  {
321
318
  "var-name": {
322
319
  "schema": {
323
- "type": "<string | number | integer>",
324
- "default": <optional value>,
325
- "minimum": <optional minimum>,
326
- "maximum": <optional maximum>,
327
- "enum": [..."<options>"]
320
+ "type": ("string", "number", or "integer"),
321
+ "default": (optional value),
322
+ "minimum": (optional minimum),
323
+ "maximum": (optional maximum),
324
+ "enum": [..."(options)"]
328
325
  }
329
326
  }
330
327
  }
@@ -8,6 +8,7 @@ import wandb
8
8
  from wandb.apis import public
9
9
  from wandb.apis.normalize import normalize_exceptions
10
10
  from wandb.apis.paginator import Paginator
11
+ from wandb.errors.term import termlog
11
12
 
12
13
  if TYPE_CHECKING:
13
14
  from wandb.apis.public import RetryingClient, Run
@@ -402,6 +403,43 @@ class ArtifactCollection:
402
403
  self._attrs = response["project"]["artifactType"]["artifactCollection"]
403
404
  return self._attrs
404
405
 
406
+ def change_type(self, new_type: str) -> None:
407
+ """Change the type of the artifact collection.
408
+
409
+ Arguments:
410
+ new_type: The new collection type to use, freeform string.
411
+ """
412
+ if not self.is_sequence():
413
+ raise ValueError("Artifact collection needs to be a sequence")
414
+ termlog(f"Changing artifact collection type of " f"{self.type} to {new_type}")
415
+ template = """
416
+ mutation MoveArtifactCollection(
417
+ $artifactSequenceID: ID!
418
+ $destinationArtifactTypeName: String!
419
+ ) {
420
+ moveArtifactSequence(
421
+ input: {
422
+ artifactSequenceID: $artifactSequenceID
423
+ destinationArtifactTypeName: $destinationArtifactTypeName
424
+ }
425
+ ) {
426
+ artifactCollection {
427
+ id
428
+ name
429
+ description
430
+ __typename
431
+ }
432
+ }
433
+ }
434
+ """
435
+ variable_values = {
436
+ "artifactSequenceID": self.id,
437
+ "destinationArtifactTypeName": new_type,
438
+ }
439
+ mutation = gql(template)
440
+ self.client.execute(mutation, variable_values=variable_values)
441
+ self.type = new_type
442
+
405
443
  @normalize_exceptions
406
444
  def is_sequence(self) -> bool:
407
445
  """Return True if this is a sequence."""
@@ -1,6 +1,7 @@
1
1
  """Public API: files."""
2
2
  import io
3
3
  import os
4
+ from typing import Optional
4
5
 
5
6
  import requests
6
7
  from wandb_gql import gql
@@ -11,6 +12,7 @@ from wandb import util
11
12
  from wandb.apis.attrs import Attrs
12
13
  from wandb.apis.normalize import normalize_exceptions
13
14
  from wandb.apis.paginator import Paginator
15
+ from wandb.apis.public.api import Api
14
16
  from wandb.apis.public.const import RETRY_TIMEDELTA
15
17
  from wandb.sdk.lib import retry
16
18
 
@@ -136,7 +138,11 @@ class File(Attrs):
136
138
  retryable_exceptions=(RetryError, requests.RequestException),
137
139
  )
138
140
  def download(
139
- self, root: str = ".", replace: bool = False, exist_ok: bool = False
141
+ self,
142
+ root: str = ".",
143
+ replace: bool = False,
144
+ exist_ok: bool = False,
145
+ api: Optional[Api] = None,
140
146
  ) -> io.TextIOWrapper:
141
147
  """Downloads a file previously saved by a run from the wandb server.
142
148
 
@@ -150,6 +156,9 @@ class File(Attrs):
150
156
  Raises:
151
157
  `ValueError` if file already exists, replace=False and exist_ok=False.
152
158
  """
159
+ if api is None:
160
+ api = wandb.Api()
161
+
153
162
  path = os.path.join(root, self.name)
154
163
  if os.path.exists(path) and not replace:
155
164
  if exist_ok:
@@ -159,7 +168,7 @@ class File(Attrs):
159
168
  "File already exists, pass replace=True to overwrite or exist_ok=True to leave it as is and don't error."
160
169
  )
161
170
 
162
- util.download_file_from_url(path, self.url, wandb.Api().api_key)
171
+ util.download_file_from_url(path, self.url, api.api_key)
163
172
  return open(path)
164
173
 
165
174
  @normalize_exceptions
@@ -18,22 +18,3 @@ from .interface import (
18
18
  )
19
19
  from .metrics import * # noqa
20
20
  from .panels import * # noqa
21
-
22
-
23
- def show_welcome_message():
24
- if os.getenv("WANDB_REPORT_API_DISABLE_MESSAGE"):
25
- return
26
-
27
- termlog(
28
- cleandoc(
29
- """
30
- Thanks for trying out Report API v2!
31
- See a tutorial and the changes here: http://wandb.me/report-api-quickstart
32
- For bugs/feature requests, please create an issue on github: https://github.com/wandb/wandb/issues
33
- You can disable this message by setting the env var WANDB_REPORT_API_DISABLE_MESSAGE=True
34
- """
35
- )
36
- )
37
-
38
-
39
- show_welcome_message()
@@ -37,7 +37,6 @@ fe_name_map_reversed = {v: k for k, v in fe_name_map.items()}
37
37
 
38
38
 
39
39
  def expr_to_filters(expr: str) -> Filters:
40
- print(f"expr: {expr=}")
41
40
  if not expr:
42
41
  filters = []
43
42
  else:
@@ -4,6 +4,8 @@ from datetime import datetime
4
4
  from typing import Dict, Iterable, Optional, Tuple, Union
5
5
  from typing import List as LList
6
6
 
7
+ from annotated_types import Annotated, Ge, Le
8
+
7
9
  try:
8
10
  from typing import Literal
9
11
  except ImportError:
@@ -758,14 +760,8 @@ block_mapping = {
758
760
 
759
761
  @dataclass(config=dataclass_config)
760
762
  class GradientPoint(Base):
761
- color: str
762
- offset: float = Field(0, ge=0, le=100)
763
-
764
- @validator("color")
765
- def validate_color(cls, v): # noqa: N805
766
- if not internal.is_valid_color(v):
767
- raise ValueError("invalid color, value should be hex, rgb, or rgba")
768
- return v
763
+ color: Annotated[str, internal.ColorStrConstraints]
764
+ offset: Annotated[float, Ge(0), Le(100)] = 0
769
765
 
770
766
  def to_model(self):
771
767
  return internal.GradientPoint(color=self.color, offset=self.offset)
@@ -1419,8 +1415,7 @@ class Report(Base):
1419
1415
  description: str = ""
1420
1416
  blocks: LList[BlockTypes] = Field(default_factory=list)
1421
1417
 
1422
- id: str = Field(default_factory=lambda: "", init=False, repr=False)
1423
-
1418
+ id: str = Field(default_factory=lambda: "", init=False, repr=False, kw_only=True)
1424
1419
  _discussion_threads: list = Field(default_factory=list, init=False, repr=False)
1425
1420
  _ref: dict = Field(default_factory=dict, init=False, repr=False)
1426
1421
  _panel_settings: dict = Field(default_factory=dict, init=False, repr=False)
@@ -1470,20 +1465,22 @@ class Report(Base):
1470
1465
  if blocks[-1] == internal.Paragraph():
1471
1466
  blocks = blocks[:-1]
1472
1467
 
1473
- return cls(
1468
+ obj = cls(
1474
1469
  title=model.display_name,
1475
1470
  description=model.description,
1476
1471
  entity=model.project.entity_name,
1477
1472
  project=model.project.name,
1478
- id=model.id,
1479
1473
  blocks=[_lookup(b) for b in blocks],
1480
- _discussion_threads=model.spec.discussion_threads,
1481
- _panel_settings=model.spec.panel_settings,
1482
- _ref=model.spec.ref,
1483
- _authors=model.spec.authors,
1484
- _created_at=model.created_at,
1485
- _updated_at=model.updated_at,
1486
1474
  )
1475
+ obj.id = model.id
1476
+ obj._discussion_threads = model.spec.discussion_threads
1477
+ obj._panel_settings = model.spec.panel_settings
1478
+ obj._ref = model.spec.ref
1479
+ obj._authors = model.spec.authors
1480
+ obj._created_at = model.created_at
1481
+ obj._updated_at = model.updated_at
1482
+
1483
+ return obj
1487
1484
 
1488
1485
  @property
1489
1486
  def url(self):
@@ -1,18 +1,19 @@
1
1
  """JSONSchema for internal types. Hopefully this is auto-generated one day!"""
2
2
  import json
3
3
  import random
4
- import re
5
4
  from copy import deepcopy
6
5
  from datetime import datetime
7
6
  from typing import Any, Dict, Optional, Tuple, Union
8
7
  from typing import List as LList
9
8
 
9
+ from annotated_types import Annotated, Ge, Le
10
+
10
11
  try:
11
12
  from typing import Literal
12
13
  except ImportError:
13
14
  from typing_extensions import Literal
14
15
 
15
- from pydantic import BaseModel, ConfigDict, Field, validator
16
+ from pydantic import BaseModel, ConfigDict, Field, StringConstraints, validator
16
17
  from pydantic.alias_generators import to_camel
17
18
 
18
19
 
@@ -48,6 +49,13 @@ def _generate_name(length: int = 12) -> str:
48
49
  return rand36.lower()[:length]
49
50
 
50
51
 
52
+ hex_pattern = r"^#(?:[0-9a-fA-F]{3}){1,2}$"
53
+ rgb_pattern = r"^rgb\(\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*\)$"
54
+ rgba_pattern = r"^rgba\(\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(1|0|0?\.\d+)\s*\)$"
55
+ ColorStrConstraints = StringConstraints(
56
+ pattern=f"{hex_pattern}|{rgb_pattern}|{rgba_pattern}"
57
+ )
58
+
51
59
  LinePlotStyle = Literal["line", "stacked-area", "pct-area"]
52
60
  BarPlotStyle = Literal["bar", "boxplot", "violin"]
53
61
  FontSize = Literal["small", "medium", "large", "auto"]
@@ -609,14 +617,8 @@ class LinePlot(Panel):
609
617
 
610
618
 
611
619
  class GradientPoint(ReportAPIBaseModel):
612
- color: str
613
- offset: float = Field(0, ge=0, le=100)
614
-
615
- @validator("color")
616
- def validate_color(cls, v): # noqa: N805
617
- if not is_valid_color(v):
618
- raise ValueError("invalid color, value should be hex, rgb, or rgba")
619
- return v
620
+ color: Annotated[str, ColorStrConstraints]
621
+ offset: Annotated[float, Ge(0), Le(100)] = 0
620
622
 
621
623
 
622
624
  class ScatterPlotConfig(ReportAPIBaseModel):
@@ -866,38 +868,3 @@ block_type_mapping = {
866
868
  "table-of-contents": TableOfContents,
867
869
  "block-quote": BlockQuote,
868
870
  }
869
-
870
-
871
- def is_valid_color(color_str: str) -> bool:
872
- # Regular expression for hex color validation
873
- hex_color_pattern = r"^#(?:[0-9a-fA-F]{3}){1,2}$"
874
-
875
- # Check if it's a valid hex color
876
- if re.match(hex_color_pattern, color_str):
877
- return True
878
-
879
- # Try parsing it as an RGB or RGBA tuple
880
- try:
881
- # Strip 'rgb(' or 'rgba(' and the closing ')'
882
- if color_str.startswith("rgb(") and color_str.endswith(")"):
883
- parts = color_str[4:-1].split(",")
884
- elif color_str.startswith("rgba(") and color_str.endswith(")"):
885
- parts = color_str[5:-1].split(",")
886
- else:
887
- return False
888
-
889
- # Convert parts to integers and validate ranges
890
- parts = [int(p.strip()) for p in parts]
891
- if len(parts) == 3 and all(0 <= p <= 255 for p in parts):
892
- return True # Valid RGB
893
- if (
894
- len(parts) == 4
895
- and all(0 <= p <= 255 for p in parts[:-1])
896
- and 0 <= parts[-1] <= 1
897
- ):
898
- return True # Valid RGBA
899
-
900
- except ValueError:
901
- pass
902
-
903
- return False
wandb/cli/cli.py CHANGED
@@ -433,10 +433,14 @@ def beta():
433
433
  from wandb.util import get_core_path
434
434
 
435
435
  if not get_core_path():
436
- click.echo(
437
- "wandb beta commands require wandb-core, please install with `pip install wandb-core`"
436
+ click.secho(
437
+ (
438
+ "wandb beta commands require wandb-core, please install with"
439
+ " `pip install wandb-core`"
440
+ ),
441
+ fg="red",
442
+ err=True,
438
443
  )
439
- sys.exit(1)
440
444
 
441
445
 
442
446
  @beta.command(
@@ -1596,29 +1600,28 @@ def launch(
1596
1600
  sys.exit(0)
1597
1601
  else:
1598
1602
  try:
1599
- asyncio.run(
1600
- _launch_add(
1601
- api,
1602
- uri,
1603
- job,
1604
- config,
1605
- template_variables,
1606
- project,
1607
- entity,
1608
- queue,
1609
- resource,
1610
- entry_point,
1611
- name,
1612
- git_version,
1613
- docker_image,
1614
- project_queue,
1615
- resource_args,
1616
- build=build,
1617
- run_id=run_id,
1618
- repository=repository,
1619
- priority=priority,
1620
- )
1603
+ _launch_add(
1604
+ api,
1605
+ uri,
1606
+ job,
1607
+ config,
1608
+ template_variables,
1609
+ project,
1610
+ entity,
1611
+ queue,
1612
+ resource,
1613
+ entry_point,
1614
+ name,
1615
+ git_version,
1616
+ docker_image,
1617
+ project_queue,
1618
+ resource_args,
1619
+ build=build,
1620
+ run_id=run_id,
1621
+ repository=repository,
1622
+ priority=priority,
1621
1623
  )
1624
+
1622
1625
  except Exception as e:
1623
1626
  wandb._sentry.exception(e)
1624
1627
  raise e
@@ -2342,8 +2345,29 @@ def artifact():
2342
2345
  default=None,
2343
2346
  help="Resume the last run from your current directory.",
2344
2347
  )
2348
+ @click.option(
2349
+ "--skip_cache",
2350
+ is_flag=True,
2351
+ default=False,
2352
+ help="Skip caching while uploading artifact files.",
2353
+ )
2354
+ @click.option(
2355
+ "--policy",
2356
+ default="mutable",
2357
+ help="Set the storage policy while uploading artifact files.",
2358
+ )
2345
2359
  @display_error
2346
- def put(path, name, description, type, alias, run_id, resume):
2360
+ def put(
2361
+ path,
2362
+ name,
2363
+ description,
2364
+ type,
2365
+ alias,
2366
+ run_id,
2367
+ resume,
2368
+ skip_cache,
2369
+ policy,
2370
+ ):
2347
2371
  if name is None:
2348
2372
  name = os.path.basename(path)
2349
2373
  public_api = PublicApi()
@@ -2358,10 +2382,10 @@ def put(path, name, description, type, alias, run_id, resume):
2358
2382
  artifact_path = f"{entity}/{project}/{artifact_name}:{alias[0]}"
2359
2383
  if os.path.isdir(path):
2360
2384
  wandb.termlog(f'Uploading directory {path} to: "{artifact_path}" ({type})')
2361
- artifact.add_dir(path)
2385
+ artifact.add_dir(path, skip_cache=skip_cache, policy=policy)
2362
2386
  elif os.path.isfile(path):
2363
2387
  wandb.termlog(f'Uploading file {path} to: "{artifact_path}" ({type})')
2364
- artifact.add_file(path)
2388
+ artifact.add_file(path, skip_cache=skip_cache, policy=policy)
2365
2389
  elif "://" in path:
2366
2390
  wandb.termlog(
2367
2391
  f'Logging reference artifact from {path} to: "{artifact_path}" ({type})'
@@ -2861,30 +2885,3 @@ def verify(host):
2861
2885
  and url_success
2862
2886
  ):
2863
2887
  sys.exit(1)
2864
-
2865
-
2866
- @cli.group("import", help="Commands for importing data from other systems")
2867
- def importer():
2868
- pass
2869
-
2870
-
2871
- @importer.command("mlflow", help="Import from MLFlow")
2872
- @click.option("--mlflow-tracking-uri", help="MLFlow Tracking URI")
2873
- @click.option(
2874
- "--target-entity", required=True, help="Override default entity to import data into"
2875
- )
2876
- @click.option(
2877
- "--target-project",
2878
- required=True,
2879
- help="Override default project to import data into",
2880
- )
2881
- def mlflow(mlflow_tracking_uri, target_entity, target_project):
2882
- from wandb.apis.importers import MlflowImporter
2883
-
2884
- importer = MlflowImporter(mlflow_tracking_uri=mlflow_tracking_uri)
2885
- overrides = {
2886
- "entity": target_entity,
2887
- "project": target_project,
2888
- }
2889
-
2890
- importer.import_all_parallel(overrides=overrides)
@@ -47,7 +47,8 @@ def monitor():
47
47
  import gym
48
48
  else:
49
49
  import gymnasium as gym # type: ignore
50
- from pkg_resources import parse_version
50
+
51
+ from wandb.util import parse_version
51
52
 
52
53
  if parse_version(gym.__version__) < parse_version("0.26.0"):
53
54
  _gym_version_lt_0_26 = True
@@ -187,7 +187,7 @@ class WandbModelCheckpoint(callbacks.ModelCheckpoint):
187
187
  @property
188
188
  def is_old_tf_keras_version(self) -> Optional[bool]:
189
189
  if self._is_old_tf_keras_version is None:
190
- from pkg_resources import parse_version
190
+ from wandb.util import parse_version
191
191
 
192
192
  try:
193
193
  if parse_version(tf.keras.__version__) < parse_version("2.6.0"):
@@ -19,7 +19,8 @@ from wandb.util import add_import_hook
19
19
 
20
20
  def _check_keras_version():
21
21
  from keras import __version__ as keras_version
22
- from pkg_resources import parse_version
22
+
23
+ from wandb.util import parse_version
23
24
 
24
25
  if parse_version(keras_version) < parse_version("2.4.0"):
25
26
  wandb.termwarn(
@@ -29,7 +30,7 @@ def _check_keras_version():
29
30
 
30
31
  def _can_compute_flops() -> bool:
31
32
  """FLOPS computation is restricted to TF 2.x as it requires tf.compat.v1."""
32
- from pkg_resources import parse_version
33
+ from wandb.util import parse_version
33
34
 
34
35
  if parse_version(tf.__version__) >= parse_version("2.0.0"):
35
36
  return True
@@ -73,9 +74,10 @@ def is_generator_like(data):
73
74
 
74
75
 
75
76
  def patch_tf_keras(): # noqa: C901
76
- from pkg_resources import parse_version
77
77
  from tensorflow.python.eager import context
78
78
 
79
+ from wandb.util import parse_version
80
+
79
81
  if (
80
82
  parse_version("2.6.0")
81
83
  <= parse_version(tf.__version__)
@@ -235,7 +237,7 @@ patch_tf_keras()
235
237
 
236
238
 
237
239
  def _get_custom_optimizer_parent_class():
238
- from pkg_resources import parse_version
240
+ from wandb.util import parse_version
239
241
 
240
242
  if parse_version(tf.__version__) >= parse_version("2.9.0"):
241
243
  custom_optimizer_parent_class = tf.keras.optimizers.legacy.Optimizer
@@ -3,8 +3,6 @@ import itertools
3
3
  import textwrap
4
4
  from typing import Callable, List, Mapping, Optional
5
5
 
6
- from pkg_resources import parse_version
7
-
8
6
  import wandb
9
7
 
10
8
  try:
@@ -13,6 +11,8 @@ try:
13
11
  from kfp.components._components import _create_task_factory_from_component_spec
14
12
  from kfp.components._python_op import _func_to_component_spec
15
13
 
14
+ from wandb.util import parse_version
15
+
16
16
  MIN_KFP_VERSION = "1.6.1"
17
17
 
18
18
  if parse_version(kfp_version) < parse_version(MIN_KFP_VERSION):
@@ -5,13 +5,12 @@ import re
5
5
  import time
6
6
  from typing import Any, Dict, Optional, Tuple
7
7
 
8
- from pkg_resources import parse_version
9
-
10
8
  import wandb
11
9
  from wandb import util
12
10
  from wandb.data_types import Table
13
11
  from wandb.sdk.lib import telemetry
14
12
  from wandb.sdk.wandb_run import Run
13
+ from wandb.util import parse_version
15
14
 
16
15
  openai = util.get_module(
17
16
  name="openai",
@@ -321,7 +321,6 @@ class WandBUltralyticsCallback:
321
321
  )
322
322
  if self.enable_model_checkpointing:
323
323
  self._save_model(trainer)
324
- self.model.to("cpu")
325
324
  trainer.model.to(self.device)
326
325
 
327
326
  def on_train_end(self, trainer: TRAINER_TYPE):