wandb 0.16.4__py3-none-any.whl → 0.16.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/agents/pyagent.py +1 -1
  3. wandb/apis/public/api.py +6 -6
  4. wandb/apis/reports/v2/interface.py +4 -8
  5. wandb/apis/reports/v2/internal.py +12 -45
  6. wandb/cli/cli.py +29 -5
  7. wandb/integration/openai/fine_tuning.py +74 -37
  8. wandb/integration/ultralytics/callback.py +0 -1
  9. wandb/proto/v3/wandb_internal_pb2.py +332 -312
  10. wandb/proto/v3/wandb_settings_pb2.py +13 -3
  11. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  12. wandb/proto/v4/wandb_internal_pb2.py +316 -312
  13. wandb/proto/v4/wandb_settings_pb2.py +5 -3
  14. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  15. wandb/sdk/artifacts/artifact.py +92 -26
  16. wandb/sdk/artifacts/artifact_manifest_entry.py +6 -1
  17. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
  18. wandb/sdk/artifacts/artifact_saver.py +16 -36
  19. wandb/sdk/artifacts/storage_handler.py +2 -1
  20. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +13 -5
  21. wandb/sdk/interface/interface.py +60 -15
  22. wandb/sdk/interface/interface_shared.py +13 -7
  23. wandb/sdk/internal/file_stream.py +19 -0
  24. wandb/sdk/internal/handler.py +1 -4
  25. wandb/sdk/internal/internal_api.py +2 -0
  26. wandb/sdk/internal/job_builder.py +45 -17
  27. wandb/sdk/internal/sender.py +53 -28
  28. wandb/sdk/internal/settings_static.py +9 -0
  29. wandb/sdk/internal/system/system_info.py +4 -1
  30. wandb/sdk/launch/_launch.py +5 -0
  31. wandb/sdk/launch/_project_spec.py +5 -20
  32. wandb/sdk/launch/agent/agent.py +80 -37
  33. wandb/sdk/launch/agent/config.py +8 -0
  34. wandb/sdk/launch/builder/kaniko_builder.py +149 -134
  35. wandb/sdk/launch/create_job.py +44 -48
  36. wandb/sdk/launch/runner/kubernetes_monitor.py +3 -1
  37. wandb/sdk/launch/runner/kubernetes_runner.py +20 -2
  38. wandb/sdk/launch/sweeps/scheduler.py +3 -1
  39. wandb/sdk/launch/utils.py +23 -5
  40. wandb/sdk/lib/__init__.py +2 -5
  41. wandb/sdk/lib/_settings_toposort_generated.py +2 -0
  42. wandb/sdk/lib/filesystem.py +11 -1
  43. wandb/sdk/lib/run_moment.py +78 -0
  44. wandb/sdk/service/streams.py +1 -6
  45. wandb/sdk/wandb_init.py +12 -7
  46. wandb/sdk/wandb_login.py +43 -26
  47. wandb/sdk/wandb_run.py +179 -94
  48. wandb/sdk/wandb_settings.py +55 -16
  49. wandb/testing/relay.py +5 -6
  50. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/METADATA +1 -1
  51. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/RECORD +55 -54
  52. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/WHEEL +1 -1
  53. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/LICENSE +0 -0
  54. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/entry_points.txt +0 -0
  55. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/top_level.txt +0 -0
wandb/__init__.py CHANGED
@@ -11,8 +11,8 @@ For scripts and interactive notebooks, see https://github.com/wandb/examples.
11
11
 
12
12
  For reference documentation, see https://docs.wandb.com/ref/python.
13
13
  """
14
- __version__ = "0.16.4"
15
- _minimum_core_version = "0.17.0b9"
14
+ __version__ = "0.16.6"
15
+ _minimum_core_version = "0.17.0b10"
16
16
 
17
17
  # Used with pypi checks and other messages related to pip
18
18
  _wandb_module = "wandb"
wandb/agents/pyagent.py CHANGED
@@ -347,7 +347,7 @@ def pyagent(sweep_id, function, entity=None, project=None, count=None):
347
347
  count (int, optional): the number of trials to run.
348
348
  """
349
349
  if not callable(function):
350
- raise Exception("function paramter must be callable!")
350
+ raise Exception("function parameter must be callable!")
351
351
  agent = Agent(
352
352
  sweep_id,
353
353
  function=function,
wandb/apis/public/api.py CHANGED
@@ -313,15 +313,15 @@ class Api:
313
313
  entity: (str) Optional name of the entity to create the queue. If None, will use the configured or default entity.
314
314
  prioritization_mode: (str) Optional version of prioritization to use. Either "V0" or None
315
315
  config: (dict) Optional default resource configuration to be used for the queue. Use handlebars (eg. "{{var}}") to specify template variables.
316
- template_variables (dict): A dictionary of template variable schemas to be used with the config. Expected format of:
316
+ template_variables: (dict) A dictionary of template variable schemas to be used with the config. Expected format of:
317
317
  {
318
318
  "var-name": {
319
319
  "schema": {
320
- "type": "<string | number | integer>",
321
- "default": <optional value>,
322
- "minimum": <optional minimum>,
323
- "maximum": <optional maximum>,
324
- "enum": [..."<options>"]
320
+ "type": ("string", "number", or "integer"),
321
+ "default": (optional value),
322
+ "minimum": (optional minimum),
323
+ "maximum": (optional maximum),
324
+ "enum": [..."(options)"]
325
325
  }
326
326
  }
327
327
  }
@@ -4,6 +4,8 @@ from datetime import datetime
4
4
  from typing import Dict, Iterable, Optional, Tuple, Union
5
5
  from typing import List as LList
6
6
 
7
+ from annotated_types import Annotated, Ge, Le
8
+
7
9
  try:
8
10
  from typing import Literal
9
11
  except ImportError:
@@ -758,14 +760,8 @@ block_mapping = {
758
760
 
759
761
  @dataclass(config=dataclass_config)
760
762
  class GradientPoint(Base):
761
- color: str
762
- offset: float = Field(0, ge=0, le=100)
763
-
764
- @validator("color")
765
- def validate_color(cls, v): # noqa: N805
766
- if not internal.is_valid_color(v):
767
- raise ValueError("invalid color, value should be hex, rgb, or rgba")
768
- return v
763
+ color: Annotated[str, internal.ColorStrConstraints]
764
+ offset: Annotated[float, Ge(0), Le(100)] = 0
769
765
 
770
766
  def to_model(self):
771
767
  return internal.GradientPoint(color=self.color, offset=self.offset)
@@ -1,18 +1,19 @@
1
1
  """JSONSchema for internal types. Hopefully this is auto-generated one day!"""
2
2
  import json
3
3
  import random
4
- import re
5
4
  from copy import deepcopy
6
5
  from datetime import datetime
7
6
  from typing import Any, Dict, Optional, Tuple, Union
8
7
  from typing import List as LList
9
8
 
9
+ from annotated_types import Annotated, Ge, Le
10
+
10
11
  try:
11
12
  from typing import Literal
12
13
  except ImportError:
13
14
  from typing_extensions import Literal
14
15
 
15
- from pydantic import BaseModel, ConfigDict, Field, validator
16
+ from pydantic import BaseModel, ConfigDict, Field, StringConstraints, validator
16
17
  from pydantic.alias_generators import to_camel
17
18
 
18
19
 
@@ -48,6 +49,13 @@ def _generate_name(length: int = 12) -> str:
48
49
  return rand36.lower()[:length]
49
50
 
50
51
 
52
+ hex_pattern = r"^#(?:[0-9a-fA-F]{3}){1,2}$"
53
+ rgb_pattern = r"^rgb\(\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*\)$"
54
+ rgba_pattern = r"^rgba\(\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(1|0|0?\.\d+)\s*\)$"
55
+ ColorStrConstraints = StringConstraints(
56
+ pattern=f"{hex_pattern}|{rgb_pattern}|{rgba_pattern}"
57
+ )
58
+
51
59
  LinePlotStyle = Literal["line", "stacked-area", "pct-area"]
52
60
  BarPlotStyle = Literal["bar", "boxplot", "violin"]
53
61
  FontSize = Literal["small", "medium", "large", "auto"]
@@ -609,14 +617,8 @@ class LinePlot(Panel):
609
617
 
610
618
 
611
619
  class GradientPoint(ReportAPIBaseModel):
612
- color: str
613
- offset: float = Field(0, ge=0, le=100)
614
-
615
- @validator("color")
616
- def validate_color(cls, v): # noqa: N805
617
- if not is_valid_color(v):
618
- raise ValueError("invalid color, value should be hex, rgb, or rgba")
619
- return v
620
+ color: Annotated[str, ColorStrConstraints]
621
+ offset: Annotated[float, Ge(0), Le(100)] = 0
620
622
 
621
623
 
622
624
  class ScatterPlotConfig(ReportAPIBaseModel):
@@ -866,38 +868,3 @@ block_type_mapping = {
866
868
  "table-of-contents": TableOfContents,
867
869
  "block-quote": BlockQuote,
868
870
  }
869
-
870
-
871
- def is_valid_color(color_str: str) -> bool:
872
- # Regular expression for hex color validation
873
- hex_color_pattern = r"^#(?:[0-9a-fA-F]{3}){1,2}$"
874
-
875
- # Check if it's a valid hex color
876
- if re.match(hex_color_pattern, color_str):
877
- return True
878
-
879
- # Try parsing it as an RGB or RGBA tuple
880
- try:
881
- # Strip 'rgb(' or 'rgba(' and the closing ')'
882
- if color_str.startswith("rgb(") and color_str.endswith(")"):
883
- parts = color_str[4:-1].split(",")
884
- elif color_str.startswith("rgba(") and color_str.endswith(")"):
885
- parts = color_str[5:-1].split(",")
886
- else:
887
- return False
888
-
889
- # Convert parts to integers and validate ranges
890
- parts = [int(p.strip()) for p in parts]
891
- if len(parts) == 3 and all(0 <= p <= 255 for p in parts):
892
- return True # Valid RGB
893
- if (
894
- len(parts) == 4
895
- and all(0 <= p <= 255 for p in parts[:-1])
896
- and 0 <= parts[-1] <= 1
897
- ):
898
- return True # Valid RGBA
899
-
900
- except ValueError:
901
- pass
902
-
903
- return False
wandb/cli/cli.py CHANGED
@@ -1680,6 +1680,7 @@ def launch(
1680
1680
  hidden=True,
1681
1681
  help="a wandb client registration URL, this is generated in the UI",
1682
1682
  )
1683
+ @click.option("--verbose", "-v", count=True, help="Display verbose output")
1683
1684
  @display_error
1684
1685
  def launch_agent(
1685
1686
  ctx,
@@ -1690,6 +1691,7 @@ def launch_agent(
1690
1691
  config=None,
1691
1692
  url=None,
1692
1693
  log_file=None,
1694
+ verbose=0,
1693
1695
  ):
1694
1696
  logger.info(
1695
1697
  f"=== Launch-agent called with kwargs {locals()} CLI Version: {wandb.__version__} ==="
@@ -1707,7 +1709,7 @@ def launch_agent(
1707
1709
  api = _get_cling_api()
1708
1710
  wandb._sentry.configure_scope(process_context="launch_agent")
1709
1711
  agent_config, api = _launch.resolve_agent_config(
1710
- entity, project, max_jobs, queues, config
1712
+ entity, project, max_jobs, queues, config, verbose
1711
1713
  )
1712
1714
 
1713
1715
  if len(agent_config.get("queues")) == 0:
@@ -1905,7 +1907,7 @@ def describe(job):
1905
1907
  "--entry-point",
1906
1908
  "-E",
1907
1909
  "entrypoint",
1908
- help="Codepath to the main script, required for repo jobs",
1910
+ help="Entrypoint to the script, including an executable and an entrypoint file. Required for code or repo jobs",
1909
1911
  )
1910
1912
  @click.option(
1911
1913
  "--git-hash",
@@ -2345,8 +2347,30 @@ def artifact():
2345
2347
  default=None,
2346
2348
  help="Resume the last run from your current directory.",
2347
2349
  )
2350
+ @click.option(
2351
+ "--skip_cache",
2352
+ is_flag=True,
2353
+ default=False,
2354
+ help="Skip caching while uploading artifact files.",
2355
+ )
2356
+ @click.option(
2357
+ "--policy",
2358
+ default="mutable",
2359
+ type=click.Choice(["mutable", "immutable"]),
2360
+ help="Set the storage policy while uploading artifact files.",
2361
+ )
2348
2362
  @display_error
2349
- def put(path, name, description, type, alias, run_id, resume):
2363
+ def put(
2364
+ path,
2365
+ name,
2366
+ description,
2367
+ type,
2368
+ alias,
2369
+ run_id,
2370
+ resume,
2371
+ skip_cache,
2372
+ policy,
2373
+ ):
2350
2374
  if name is None:
2351
2375
  name = os.path.basename(path)
2352
2376
  public_api = PublicApi()
@@ -2361,10 +2385,10 @@ def put(path, name, description, type, alias, run_id, resume):
2361
2385
  artifact_path = f"{entity}/{project}/{artifact_name}:{alias[0]}"
2362
2386
  if os.path.isdir(path):
2363
2387
  wandb.termlog(f'Uploading directory {path} to: "{artifact_path}" ({type})')
2364
- artifact.add_dir(path)
2388
+ artifact.add_dir(path, skip_cache=skip_cache, policy=policy)
2365
2389
  elif os.path.isfile(path):
2366
2390
  wandb.termlog(f'Uploading file {path} to: "{artifact_path}" ({type})')
2367
- artifact.add_file(path)
2391
+ artifact.add_file(path, skip_cache=skip_cache, policy=policy)
2368
2392
  elif "://" in path:
2369
2393
  wandb.termlog(
2370
2394
  f'Logging reference artifact from {path} to: "{artifact_path}" ({type})'
@@ -1,9 +1,11 @@
1
1
  import datetime
2
2
  import io
3
3
  import json
4
+ import os
4
5
  import re
6
+ import tempfile
5
7
  import time
6
- from typing import Any, Dict, Optional, Tuple
8
+ from typing import Any, Dict, List, Optional, Tuple, Union
7
9
 
8
10
  import wandb
9
11
  from wandb import util
@@ -26,7 +28,10 @@ if parse_version(openai.__version__) < parse_version("1.0.1"):
26
28
 
27
29
  from openai import OpenAI # noqa: E402
28
30
  from openai.types.fine_tuning import FineTuningJob # noqa: E402
29
- from openai.types.fine_tuning.fine_tuning_job import Hyperparameters # noqa: E402
31
+ from openai.types.fine_tuning.fine_tuning_job import ( # noqa: E402
32
+ Error,
33
+ Hyperparameters,
34
+ )
30
35
 
31
36
  np = util.get_module(
32
37
  name="numpy",
@@ -59,6 +64,7 @@ class WandbLogger:
59
64
  entity: Optional[str] = None,
60
65
  overwrite: bool = False,
61
66
  wait_for_job_success: bool = True,
67
+ log_datasets: bool = True,
62
68
  **kwargs_wandb_init: Dict[str, Any],
63
69
  ) -> str:
64
70
  """Sync fine-tunes to Weights & Biases.
@@ -150,6 +156,7 @@ class WandbLogger:
150
156
  entity,
151
157
  overwrite,
152
158
  show_individual_warnings,
159
+ log_datasets,
153
160
  **kwargs_wandb_init,
154
161
  )
155
162
 
@@ -160,11 +167,14 @@ class WandbLogger:
160
167
 
161
168
  @classmethod
162
169
  def _wait_for_job_success(cls, fine_tune: FineTuningJob) -> FineTuningJob:
163
- wandb.termlog("Waiting for the OpenAI fine-tuning job to be finished...")
170
+ wandb.termlog("Waiting for the OpenAI fine-tuning job to finish training...")
171
+ wandb.termlog(
172
+ "To avoid blocking, you can call `WandbLogger.sync` with `wait_for_job_success=False` after OpenAI training completes."
173
+ )
164
174
  while True:
165
175
  if fine_tune.status == "succeeded":
166
176
  wandb.termlog(
167
- "Fine-tuning finished, logging metrics, model metadata, and more to W&B"
177
+ "Fine-tuning finished, logging metrics, model metadata, and run metadata to Weights & Biases"
168
178
  )
169
179
  return fine_tune
170
180
  if fine_tune.status == "failed":
@@ -190,6 +200,7 @@ class WandbLogger:
190
200
  entity: Optional[str],
191
201
  overwrite: bool,
192
202
  show_individual_warnings: bool,
203
+ log_datasets: bool,
193
204
  **kwargs_wandb_init: Dict[str, Any],
194
205
  ):
195
206
  fine_tune_id = fine_tune.id
@@ -209,7 +220,7 @@ class WandbLogger:
209
220
  # check results are present
210
221
  try:
211
222
  results_id = fine_tune.result_files[0]
212
- results = cls.openai_client.files.retrieve_content(file_id=results_id)
223
+ results = cls.openai_client.files.content(file_id=results_id).text
213
224
  except openai.NotFoundError:
214
225
  if show_individual_warnings:
215
226
  wandb.termwarn(
@@ -233,7 +244,7 @@ class WandbLogger:
233
244
  cls._run.summary["fine_tuned_model"] = fine_tuned_model
234
245
 
235
246
  # training/validation files and fine-tune details
236
- cls._log_artifacts(fine_tune, project, entity)
247
+ cls._log_artifacts(fine_tune, project, entity, log_datasets, overwrite)
237
248
 
238
249
  # mark run as complete
239
250
  cls._run.summary["status"] = "succeeded"
@@ -249,7 +260,7 @@ class WandbLogger:
249
260
  else:
250
261
  raise Exception(
251
262
  "It appears you are not currently logged in to Weights & Biases. "
252
- "Please run `wandb login` in your terminal. "
263
+ "Please run `wandb login` in your terminal or `wandb.login()` in a notebook."
253
264
  "When prompted, you can obtain your API key by visiting wandb.ai/authorize."
254
265
  )
255
266
 
@@ -286,15 +297,9 @@ class WandbLogger:
286
297
  config["finished_at"]
287
298
  ).strftime("%Y-%m-%d %H:%M:%S")
288
299
  if config.get("hyperparameters"):
289
- hyperparameters = config.pop("hyperparameters")
290
- hyperparams = cls._unpack_hyperparameters(hyperparameters)
291
- if hyperparams is None:
292
- # If unpacking fails, log the object which will render as string
293
- config["hyperparameters"] = hyperparameters
294
- else:
295
- # nested rendering on hyperparameters
296
- config["hyperparameters"] = hyperparams
297
-
300
+ config["hyperparameters"] = cls.sanitize(config["hyperparameters"])
301
+ if config.get("error"):
302
+ config["error"] = cls.sanitize(config["error"])
298
303
  return config
299
304
 
300
305
  @classmethod
@@ -314,21 +319,44 @@ class WandbLogger:
314
319
 
315
320
  return hyperparams
316
321
 
322
+ @staticmethod
323
+ def sanitize(input: Any) -> Union[Dict, List, str]:
324
+ valid_types = [bool, int, float, str]
325
+ if isinstance(input, (Hyperparameters, Error)):
326
+ return dict(input)
327
+ if isinstance(input, dict):
328
+ return {
329
+ k: v if type(v) in valid_types else str(v) for k, v in input.items()
330
+ }
331
+ elif isinstance(input, list):
332
+ return [v if type(v) in valid_types else str(v) for v in input]
333
+ else:
334
+ return str(input)
335
+
317
336
  @classmethod
318
337
  def _log_artifacts(
319
- cls, fine_tune: FineTuningJob, project: str, entity: Optional[str]
338
+ cls,
339
+ fine_tune: FineTuningJob,
340
+ project: str,
341
+ entity: Optional[str],
342
+ log_datasets: bool,
343
+ overwrite: bool,
320
344
  ) -> None:
321
- # training/validation files
322
- training_file = fine_tune.training_file if fine_tune.training_file else None
323
- validation_file = (
324
- fine_tune.validation_file if fine_tune.validation_file else None
325
- )
326
- for file, prefix, artifact_type in (
327
- (training_file, "train", "training_files"),
328
- (validation_file, "valid", "validation_files"),
329
- ):
330
- if file is not None:
331
- cls._log_artifact_inputs(file, prefix, artifact_type, project, entity)
345
+ if log_datasets:
346
+ wandb.termlog("Logging training/validation files...")
347
+ # training/validation files
348
+ training_file = fine_tune.training_file if fine_tune.training_file else None
349
+ validation_file = (
350
+ fine_tune.validation_file if fine_tune.validation_file else None
351
+ )
352
+ for file, prefix, artifact_type in (
353
+ (training_file, "train", "training_files"),
354
+ (validation_file, "valid", "validation_files"),
355
+ ):
356
+ if file is not None:
357
+ cls._log_artifact_inputs(
358
+ file, prefix, artifact_type, project, entity, overwrite
359
+ )
332
360
 
333
361
  # fine-tune details
334
362
  fine_tune_id = fine_tune.id
@@ -337,9 +365,14 @@ class WandbLogger:
337
365
  type="model",
338
366
  metadata=dict(fine_tune),
339
367
  )
368
+
340
369
  with artifact.new_file("model_metadata.json", mode="w", encoding="utf-8") as f:
341
370
  dict_fine_tune = dict(fine_tune)
342
- dict_fine_tune["hyperparameters"] = dict(dict_fine_tune["hyperparameters"])
371
+ dict_fine_tune["hyperparameters"] = cls.sanitize(
372
+ dict_fine_tune["hyperparameters"]
373
+ )
374
+ dict_fine_tune["error"] = cls.sanitize(dict_fine_tune["error"])
375
+ dict_fine_tune = cls.sanitize(dict_fine_tune)
343
376
  json.dump(dict_fine_tune, f, indent=2)
344
377
  cls._run.log_artifact(
345
378
  artifact,
@@ -354,6 +387,7 @@ class WandbLogger:
354
387
  artifact_type: str,
355
388
  project: str,
356
389
  entity: Optional[str],
390
+ overwrite: bool,
357
391
  ) -> None:
358
392
  # get input artifact
359
393
  artifact_name = f"{prefix}-{file_id}"
@@ -366,23 +400,26 @@ class WandbLogger:
366
400
  artifact = cls._get_wandb_artifact(artifact_path)
367
401
 
368
402
  # create artifact if file not already logged previously
369
- if artifact is None:
403
+ if artifact is None or overwrite:
370
404
  # get file content
371
405
  try:
372
- file_content = cls.openai_client.files.retrieve_content(file_id=file_id)
406
+ file_content = cls.openai_client.files.content(file_id=file_id)
373
407
  except openai.NotFoundError:
374
408
  wandb.termerror(
375
- f"File {file_id} could not be retrieved. Make sure you are allowed to download training/validation files"
409
+ f"File {file_id} could not be retrieved. Make sure you have OpenAI permissions to download training/validation files"
376
410
  )
377
411
  return
378
412
 
379
413
  artifact = wandb.Artifact(artifact_name, type=artifact_type)
380
- with artifact.new_file(file_id, mode="w", encoding="utf-8") as f:
381
- f.write(file_content)
414
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
415
+ tmp_file.write(file_content.content)
416
+ tmp_file_path = tmp_file.name
417
+ artifact.add_file(tmp_file_path, file_id)
418
+ os.unlink(tmp_file_path)
382
419
 
383
420
  # create a Table
384
421
  try:
385
- table, n_items = cls._make_table(file_content)
422
+ table, n_items = cls._make_table(file_content.text)
386
423
  # Add table to the artifact.
387
424
  artifact.add(table, file_id)
388
425
  # Add the same table to the workspace.
@@ -390,9 +427,9 @@ class WandbLogger:
390
427
  # Update the run config and artifact metadata
391
428
  cls._run.config.update({f"n_{prefix}": n_items})
392
429
  artifact.metadata["items"] = n_items
393
- except Exception:
430
+ except Exception as e:
394
431
  wandb.termerror(
395
- f"File {file_id} could not be read as a valid JSON file"
432
+ f"Issue saving {file_id} as a Table to Artifacts, exception:\n '{e}'"
396
433
  )
397
434
  else:
398
435
  # log number of items
@@ -321,7 +321,6 @@ class WandBUltralyticsCallback:
321
321
  )
322
322
  if self.enable_model_checkpointing:
323
323
  self._save_model(trainer)
324
- self.model.to("cpu")
325
324
  trainer.model.to(self.device)
326
325
 
327
326
  def on_train_end(self, trainer: TRAINER_TYPE):