wandb 0.16.4__py3-none-any.whl → 0.16.6__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (55) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/agents/pyagent.py +1 -1
  3. wandb/apis/public/api.py +6 -6
  4. wandb/apis/reports/v2/interface.py +4 -8
  5. wandb/apis/reports/v2/internal.py +12 -45
  6. wandb/cli/cli.py +29 -5
  7. wandb/integration/openai/fine_tuning.py +74 -37
  8. wandb/integration/ultralytics/callback.py +0 -1
  9. wandb/proto/v3/wandb_internal_pb2.py +332 -312
  10. wandb/proto/v3/wandb_settings_pb2.py +13 -3
  11. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  12. wandb/proto/v4/wandb_internal_pb2.py +316 -312
  13. wandb/proto/v4/wandb_settings_pb2.py +5 -3
  14. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  15. wandb/sdk/artifacts/artifact.py +92 -26
  16. wandb/sdk/artifacts/artifact_manifest_entry.py +6 -1
  17. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
  18. wandb/sdk/artifacts/artifact_saver.py +16 -36
  19. wandb/sdk/artifacts/storage_handler.py +2 -1
  20. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +13 -5
  21. wandb/sdk/interface/interface.py +60 -15
  22. wandb/sdk/interface/interface_shared.py +13 -7
  23. wandb/sdk/internal/file_stream.py +19 -0
  24. wandb/sdk/internal/handler.py +1 -4
  25. wandb/sdk/internal/internal_api.py +2 -0
  26. wandb/sdk/internal/job_builder.py +45 -17
  27. wandb/sdk/internal/sender.py +53 -28
  28. wandb/sdk/internal/settings_static.py +9 -0
  29. wandb/sdk/internal/system/system_info.py +4 -1
  30. wandb/sdk/launch/_launch.py +5 -0
  31. wandb/sdk/launch/_project_spec.py +5 -20
  32. wandb/sdk/launch/agent/agent.py +80 -37
  33. wandb/sdk/launch/agent/config.py +8 -0
  34. wandb/sdk/launch/builder/kaniko_builder.py +149 -134
  35. wandb/sdk/launch/create_job.py +44 -48
  36. wandb/sdk/launch/runner/kubernetes_monitor.py +3 -1
  37. wandb/sdk/launch/runner/kubernetes_runner.py +20 -2
  38. wandb/sdk/launch/sweeps/scheduler.py +3 -1
  39. wandb/sdk/launch/utils.py +23 -5
  40. wandb/sdk/lib/__init__.py +2 -5
  41. wandb/sdk/lib/_settings_toposort_generated.py +2 -0
  42. wandb/sdk/lib/filesystem.py +11 -1
  43. wandb/sdk/lib/run_moment.py +78 -0
  44. wandb/sdk/service/streams.py +1 -6
  45. wandb/sdk/wandb_init.py +12 -7
  46. wandb/sdk/wandb_login.py +43 -26
  47. wandb/sdk/wandb_run.py +179 -94
  48. wandb/sdk/wandb_settings.py +55 -16
  49. wandb/testing/relay.py +5 -6
  50. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/METADATA +1 -1
  51. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/RECORD +55 -54
  52. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/WHEEL +1 -1
  53. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/LICENSE +0 -0
  54. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/entry_points.txt +0 -0
  55. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/top_level.txt +0 -0
wandb/__init__.py CHANGED
@@ -11,8 +11,8 @@ For scripts and interactive notebooks, see https://github.com/wandb/examples.
11
11
 
12
12
  For reference documentation, see https://docs.wandb.com/ref/python.
13
13
  """
14
- __version__ = "0.16.4"
15
- _minimum_core_version = "0.17.0b9"
14
+ __version__ = "0.16.6"
15
+ _minimum_core_version = "0.17.0b10"
16
16
 
17
17
  # Used with pypi checks and other messages related to pip
18
18
  _wandb_module = "wandb"
wandb/agents/pyagent.py CHANGED
@@ -347,7 +347,7 @@ def pyagent(sweep_id, function, entity=None, project=None, count=None):
347
347
  count (int, optional): the number of trials to run.
348
348
  """
349
349
  if not callable(function):
350
- raise Exception("function paramter must be callable!")
350
+ raise Exception("function parameter must be callable!")
351
351
  agent = Agent(
352
352
  sweep_id,
353
353
  function=function,
wandb/apis/public/api.py CHANGED
@@ -313,15 +313,15 @@ class Api:
313
313
  entity: (str) Optional name of the entity to create the queue. If None, will use the configured or default entity.
314
314
  prioritization_mode: (str) Optional version of prioritization to use. Either "V0" or None
315
315
  config: (dict) Optional default resource configuration to be used for the queue. Use handlebars (eg. "{{var}}") to specify template variables.
316
- template_variables (dict): A dictionary of template variable schemas to be used with the config. Expected format of:
316
+ template_variables: (dict) A dictionary of template variable schemas to be used with the config. Expected format of:
317
317
  {
318
318
  "var-name": {
319
319
  "schema": {
320
- "type": "<string | number | integer>",
321
- "default": <optional value>,
322
- "minimum": <optional minimum>,
323
- "maximum": <optional maximum>,
324
- "enum": [..."<options>"]
320
+ "type": ("string", "number", or "integer"),
321
+ "default": (optional value),
322
+ "minimum": (optional minimum),
323
+ "maximum": (optional maximum),
324
+ "enum": [..."(options)"]
325
325
  }
326
326
  }
327
327
  }
@@ -4,6 +4,8 @@ from datetime import datetime
4
4
  from typing import Dict, Iterable, Optional, Tuple, Union
5
5
  from typing import List as LList
6
6
 
7
+ from annotated_types import Annotated, Ge, Le
8
+
7
9
  try:
8
10
  from typing import Literal
9
11
  except ImportError:
@@ -758,14 +760,8 @@ block_mapping = {
758
760
 
759
761
  @dataclass(config=dataclass_config)
760
762
  class GradientPoint(Base):
761
- color: str
762
- offset: float = Field(0, ge=0, le=100)
763
-
764
- @validator("color")
765
- def validate_color(cls, v): # noqa: N805
766
- if not internal.is_valid_color(v):
767
- raise ValueError("invalid color, value should be hex, rgb, or rgba")
768
- return v
763
+ color: Annotated[str, internal.ColorStrConstraints]
764
+ offset: Annotated[float, Ge(0), Le(100)] = 0
769
765
 
770
766
  def to_model(self):
771
767
  return internal.GradientPoint(color=self.color, offset=self.offset)
@@ -1,18 +1,19 @@
1
1
  """JSONSchema for internal types. Hopefully this is auto-generated one day!"""
2
2
  import json
3
3
  import random
4
- import re
5
4
  from copy import deepcopy
6
5
  from datetime import datetime
7
6
  from typing import Any, Dict, Optional, Tuple, Union
8
7
  from typing import List as LList
9
8
 
9
+ from annotated_types import Annotated, Ge, Le
10
+
10
11
  try:
11
12
  from typing import Literal
12
13
  except ImportError:
13
14
  from typing_extensions import Literal
14
15
 
15
- from pydantic import BaseModel, ConfigDict, Field, validator
16
+ from pydantic import BaseModel, ConfigDict, Field, StringConstraints, validator
16
17
  from pydantic.alias_generators import to_camel
17
18
 
18
19
 
@@ -48,6 +49,13 @@ def _generate_name(length: int = 12) -> str:
48
49
  return rand36.lower()[:length]
49
50
 
50
51
 
52
+ hex_pattern = r"^#(?:[0-9a-fA-F]{3}){1,2}$"
53
+ rgb_pattern = r"^rgb\(\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*\)$"
54
+ rgba_pattern = r"^rgba\(\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\s*,\s*(1|0|0?\.\d+)\s*\)$"
55
+ ColorStrConstraints = StringConstraints(
56
+ pattern=f"{hex_pattern}|{rgb_pattern}|{rgba_pattern}"
57
+ )
58
+
51
59
  LinePlotStyle = Literal["line", "stacked-area", "pct-area"]
52
60
  BarPlotStyle = Literal["bar", "boxplot", "violin"]
53
61
  FontSize = Literal["small", "medium", "large", "auto"]
@@ -609,14 +617,8 @@ class LinePlot(Panel):
609
617
 
610
618
 
611
619
  class GradientPoint(ReportAPIBaseModel):
612
- color: str
613
- offset: float = Field(0, ge=0, le=100)
614
-
615
- @validator("color")
616
- def validate_color(cls, v): # noqa: N805
617
- if not is_valid_color(v):
618
- raise ValueError("invalid color, value should be hex, rgb, or rgba")
619
- return v
620
+ color: Annotated[str, ColorStrConstraints]
621
+ offset: Annotated[float, Ge(0), Le(100)] = 0
620
622
 
621
623
 
622
624
  class ScatterPlotConfig(ReportAPIBaseModel):
@@ -866,38 +868,3 @@ block_type_mapping = {
866
868
  "table-of-contents": TableOfContents,
867
869
  "block-quote": BlockQuote,
868
870
  }
869
-
870
-
871
- def is_valid_color(color_str: str) -> bool:
872
- # Regular expression for hex color validation
873
- hex_color_pattern = r"^#(?:[0-9a-fA-F]{3}){1,2}$"
874
-
875
- # Check if it's a valid hex color
876
- if re.match(hex_color_pattern, color_str):
877
- return True
878
-
879
- # Try parsing it as an RGB or RGBA tuple
880
- try:
881
- # Strip 'rgb(' or 'rgba(' and the closing ')'
882
- if color_str.startswith("rgb(") and color_str.endswith(")"):
883
- parts = color_str[4:-1].split(",")
884
- elif color_str.startswith("rgba(") and color_str.endswith(")"):
885
- parts = color_str[5:-1].split(",")
886
- else:
887
- return False
888
-
889
- # Convert parts to integers and validate ranges
890
- parts = [int(p.strip()) for p in parts]
891
- if len(parts) == 3 and all(0 <= p <= 255 for p in parts):
892
- return True # Valid RGB
893
- if (
894
- len(parts) == 4
895
- and all(0 <= p <= 255 for p in parts[:-1])
896
- and 0 <= parts[-1] <= 1
897
- ):
898
- return True # Valid RGBA
899
-
900
- except ValueError:
901
- pass
902
-
903
- return False
wandb/cli/cli.py CHANGED
@@ -1680,6 +1680,7 @@ def launch(
1680
1680
  hidden=True,
1681
1681
  help="a wandb client registration URL, this is generated in the UI",
1682
1682
  )
1683
+ @click.option("--verbose", "-v", count=True, help="Display verbose output")
1683
1684
  @display_error
1684
1685
  def launch_agent(
1685
1686
  ctx,
@@ -1690,6 +1691,7 @@ def launch_agent(
1690
1691
  config=None,
1691
1692
  url=None,
1692
1693
  log_file=None,
1694
+ verbose=0,
1693
1695
  ):
1694
1696
  logger.info(
1695
1697
  f"=== Launch-agent called with kwargs {locals()} CLI Version: {wandb.__version__} ==="
@@ -1707,7 +1709,7 @@ def launch_agent(
1707
1709
  api = _get_cling_api()
1708
1710
  wandb._sentry.configure_scope(process_context="launch_agent")
1709
1711
  agent_config, api = _launch.resolve_agent_config(
1710
- entity, project, max_jobs, queues, config
1712
+ entity, project, max_jobs, queues, config, verbose
1711
1713
  )
1712
1714
 
1713
1715
  if len(agent_config.get("queues")) == 0:
@@ -1905,7 +1907,7 @@ def describe(job):
1905
1907
  "--entry-point",
1906
1908
  "-E",
1907
1909
  "entrypoint",
1908
- help="Codepath to the main script, required for repo jobs",
1910
+ help="Entrypoint to the script, including an executable and an entrypoint file. Required for code or repo jobs",
1909
1911
  )
1910
1912
  @click.option(
1911
1913
  "--git-hash",
@@ -2345,8 +2347,30 @@ def artifact():
2345
2347
  default=None,
2346
2348
  help="Resume the last run from your current directory.",
2347
2349
  )
2350
+ @click.option(
2351
+ "--skip_cache",
2352
+ is_flag=True,
2353
+ default=False,
2354
+ help="Skip caching while uploading artifact files.",
2355
+ )
2356
+ @click.option(
2357
+ "--policy",
2358
+ default="mutable",
2359
+ type=click.Choice(["mutable", "immutable"]),
2360
+ help="Set the storage policy while uploading artifact files.",
2361
+ )
2348
2362
  @display_error
2349
- def put(path, name, description, type, alias, run_id, resume):
2363
+ def put(
2364
+ path,
2365
+ name,
2366
+ description,
2367
+ type,
2368
+ alias,
2369
+ run_id,
2370
+ resume,
2371
+ skip_cache,
2372
+ policy,
2373
+ ):
2350
2374
  if name is None:
2351
2375
  name = os.path.basename(path)
2352
2376
  public_api = PublicApi()
@@ -2361,10 +2385,10 @@ def put(path, name, description, type, alias, run_id, resume):
2361
2385
  artifact_path = f"{entity}/{project}/{artifact_name}:{alias[0]}"
2362
2386
  if os.path.isdir(path):
2363
2387
  wandb.termlog(f'Uploading directory {path} to: "{artifact_path}" ({type})')
2364
- artifact.add_dir(path)
2388
+ artifact.add_dir(path, skip_cache=skip_cache, policy=policy)
2365
2389
  elif os.path.isfile(path):
2366
2390
  wandb.termlog(f'Uploading file {path} to: "{artifact_path}" ({type})')
2367
- artifact.add_file(path)
2391
+ artifact.add_file(path, skip_cache=skip_cache, policy=policy)
2368
2392
  elif "://" in path:
2369
2393
  wandb.termlog(
2370
2394
  f'Logging reference artifact from {path} to: "{artifact_path}" ({type})'
@@ -1,9 +1,11 @@
1
1
  import datetime
2
2
  import io
3
3
  import json
4
+ import os
4
5
  import re
6
+ import tempfile
5
7
  import time
6
- from typing import Any, Dict, Optional, Tuple
8
+ from typing import Any, Dict, List, Optional, Tuple, Union
7
9
 
8
10
  import wandb
9
11
  from wandb import util
@@ -26,7 +28,10 @@ if parse_version(openai.__version__) < parse_version("1.0.1"):
26
28
 
27
29
  from openai import OpenAI # noqa: E402
28
30
  from openai.types.fine_tuning import FineTuningJob # noqa: E402
29
- from openai.types.fine_tuning.fine_tuning_job import Hyperparameters # noqa: E402
31
+ from openai.types.fine_tuning.fine_tuning_job import ( # noqa: E402
32
+ Error,
33
+ Hyperparameters,
34
+ )
30
35
 
31
36
  np = util.get_module(
32
37
  name="numpy",
@@ -59,6 +64,7 @@ class WandbLogger:
59
64
  entity: Optional[str] = None,
60
65
  overwrite: bool = False,
61
66
  wait_for_job_success: bool = True,
67
+ log_datasets: bool = True,
62
68
  **kwargs_wandb_init: Dict[str, Any],
63
69
  ) -> str:
64
70
  """Sync fine-tunes to Weights & Biases.
@@ -150,6 +156,7 @@ class WandbLogger:
150
156
  entity,
151
157
  overwrite,
152
158
  show_individual_warnings,
159
+ log_datasets,
153
160
  **kwargs_wandb_init,
154
161
  )
155
162
 
@@ -160,11 +167,14 @@ class WandbLogger:
160
167
 
161
168
  @classmethod
162
169
  def _wait_for_job_success(cls, fine_tune: FineTuningJob) -> FineTuningJob:
163
- wandb.termlog("Waiting for the OpenAI fine-tuning job to be finished...")
170
+ wandb.termlog("Waiting for the OpenAI fine-tuning job to finish training...")
171
+ wandb.termlog(
172
+ "To avoid blocking, you can call `WandbLogger.sync` with `wait_for_job_success=False` after OpenAI training completes."
173
+ )
164
174
  while True:
165
175
  if fine_tune.status == "succeeded":
166
176
  wandb.termlog(
167
- "Fine-tuning finished, logging metrics, model metadata, and more to W&B"
177
+ "Fine-tuning finished, logging metrics, model metadata, and run metadata to Weights & Biases"
168
178
  )
169
179
  return fine_tune
170
180
  if fine_tune.status == "failed":
@@ -190,6 +200,7 @@ class WandbLogger:
190
200
  entity: Optional[str],
191
201
  overwrite: bool,
192
202
  show_individual_warnings: bool,
203
+ log_datasets: bool,
193
204
  **kwargs_wandb_init: Dict[str, Any],
194
205
  ):
195
206
  fine_tune_id = fine_tune.id
@@ -209,7 +220,7 @@ class WandbLogger:
209
220
  # check results are present
210
221
  try:
211
222
  results_id = fine_tune.result_files[0]
212
- results = cls.openai_client.files.retrieve_content(file_id=results_id)
223
+ results = cls.openai_client.files.content(file_id=results_id).text
213
224
  except openai.NotFoundError:
214
225
  if show_individual_warnings:
215
226
  wandb.termwarn(
@@ -233,7 +244,7 @@ class WandbLogger:
233
244
  cls._run.summary["fine_tuned_model"] = fine_tuned_model
234
245
 
235
246
  # training/validation files and fine-tune details
236
- cls._log_artifacts(fine_tune, project, entity)
247
+ cls._log_artifacts(fine_tune, project, entity, log_datasets, overwrite)
237
248
 
238
249
  # mark run as complete
239
250
  cls._run.summary["status"] = "succeeded"
@@ -249,7 +260,7 @@ class WandbLogger:
249
260
  else:
250
261
  raise Exception(
251
262
  "It appears you are not currently logged in to Weights & Biases. "
252
- "Please run `wandb login` in your terminal. "
263
+ "Please run `wandb login` in your terminal or `wandb.login()` in a notebook."
253
264
  "When prompted, you can obtain your API key by visiting wandb.ai/authorize."
254
265
  )
255
266
 
@@ -286,15 +297,9 @@ class WandbLogger:
286
297
  config["finished_at"]
287
298
  ).strftime("%Y-%m-%d %H:%M:%S")
288
299
  if config.get("hyperparameters"):
289
- hyperparameters = config.pop("hyperparameters")
290
- hyperparams = cls._unpack_hyperparameters(hyperparameters)
291
- if hyperparams is None:
292
- # If unpacking fails, log the object which will render as string
293
- config["hyperparameters"] = hyperparameters
294
- else:
295
- # nested rendering on hyperparameters
296
- config["hyperparameters"] = hyperparams
297
-
300
+ config["hyperparameters"] = cls.sanitize(config["hyperparameters"])
301
+ if config.get("error"):
302
+ config["error"] = cls.sanitize(config["error"])
298
303
  return config
299
304
 
300
305
  @classmethod
@@ -314,21 +319,44 @@ class WandbLogger:
314
319
 
315
320
  return hyperparams
316
321
 
322
+ @staticmethod
323
+ def sanitize(input: Any) -> Union[Dict, List, str]:
324
+ valid_types = [bool, int, float, str]
325
+ if isinstance(input, (Hyperparameters, Error)):
326
+ return dict(input)
327
+ if isinstance(input, dict):
328
+ return {
329
+ k: v if type(v) in valid_types else str(v) for k, v in input.items()
330
+ }
331
+ elif isinstance(input, list):
332
+ return [v if type(v) in valid_types else str(v) for v in input]
333
+ else:
334
+ return str(input)
335
+
317
336
  @classmethod
318
337
  def _log_artifacts(
319
- cls, fine_tune: FineTuningJob, project: str, entity: Optional[str]
338
+ cls,
339
+ fine_tune: FineTuningJob,
340
+ project: str,
341
+ entity: Optional[str],
342
+ log_datasets: bool,
343
+ overwrite: bool,
320
344
  ) -> None:
321
- # training/validation files
322
- training_file = fine_tune.training_file if fine_tune.training_file else None
323
- validation_file = (
324
- fine_tune.validation_file if fine_tune.validation_file else None
325
- )
326
- for file, prefix, artifact_type in (
327
- (training_file, "train", "training_files"),
328
- (validation_file, "valid", "validation_files"),
329
- ):
330
- if file is not None:
331
- cls._log_artifact_inputs(file, prefix, artifact_type, project, entity)
345
+ if log_datasets:
346
+ wandb.termlog("Logging training/validation files...")
347
+ # training/validation files
348
+ training_file = fine_tune.training_file if fine_tune.training_file else None
349
+ validation_file = (
350
+ fine_tune.validation_file if fine_tune.validation_file else None
351
+ )
352
+ for file, prefix, artifact_type in (
353
+ (training_file, "train", "training_files"),
354
+ (validation_file, "valid", "validation_files"),
355
+ ):
356
+ if file is not None:
357
+ cls._log_artifact_inputs(
358
+ file, prefix, artifact_type, project, entity, overwrite
359
+ )
332
360
 
333
361
  # fine-tune details
334
362
  fine_tune_id = fine_tune.id
@@ -337,9 +365,14 @@ class WandbLogger:
337
365
  type="model",
338
366
  metadata=dict(fine_tune),
339
367
  )
368
+
340
369
  with artifact.new_file("model_metadata.json", mode="w", encoding="utf-8") as f:
341
370
  dict_fine_tune = dict(fine_tune)
342
- dict_fine_tune["hyperparameters"] = dict(dict_fine_tune["hyperparameters"])
371
+ dict_fine_tune["hyperparameters"] = cls.sanitize(
372
+ dict_fine_tune["hyperparameters"]
373
+ )
374
+ dict_fine_tune["error"] = cls.sanitize(dict_fine_tune["error"])
375
+ dict_fine_tune = cls.sanitize(dict_fine_tune)
343
376
  json.dump(dict_fine_tune, f, indent=2)
344
377
  cls._run.log_artifact(
345
378
  artifact,
@@ -354,6 +387,7 @@ class WandbLogger:
354
387
  artifact_type: str,
355
388
  project: str,
356
389
  entity: Optional[str],
390
+ overwrite: bool,
357
391
  ) -> None:
358
392
  # get input artifact
359
393
  artifact_name = f"{prefix}-{file_id}"
@@ -366,23 +400,26 @@ class WandbLogger:
366
400
  artifact = cls._get_wandb_artifact(artifact_path)
367
401
 
368
402
  # create artifact if file not already logged previously
369
- if artifact is None:
403
+ if artifact is None or overwrite:
370
404
  # get file content
371
405
  try:
372
- file_content = cls.openai_client.files.retrieve_content(file_id=file_id)
406
+ file_content = cls.openai_client.files.content(file_id=file_id)
373
407
  except openai.NotFoundError:
374
408
  wandb.termerror(
375
- f"File {file_id} could not be retrieved. Make sure you are allowed to download training/validation files"
409
+ f"File {file_id} could not be retrieved. Make sure you have OpenAI permissions to download training/validation files"
376
410
  )
377
411
  return
378
412
 
379
413
  artifact = wandb.Artifact(artifact_name, type=artifact_type)
380
- with artifact.new_file(file_id, mode="w", encoding="utf-8") as f:
381
- f.write(file_content)
414
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
415
+ tmp_file.write(file_content.content)
416
+ tmp_file_path = tmp_file.name
417
+ artifact.add_file(tmp_file_path, file_id)
418
+ os.unlink(tmp_file_path)
382
419
 
383
420
  # create a Table
384
421
  try:
385
- table, n_items = cls._make_table(file_content)
422
+ table, n_items = cls._make_table(file_content.text)
386
423
  # Add table to the artifact.
387
424
  artifact.add(table, file_id)
388
425
  # Add the same table to the workspace.
@@ -390,9 +427,9 @@ class WandbLogger:
390
427
  # Update the run config and artifact metadata
391
428
  cls._run.config.update({f"n_{prefix}": n_items})
392
429
  artifact.metadata["items"] = n_items
393
- except Exception:
430
+ except Exception as e:
394
431
  wandb.termerror(
395
- f"File {file_id} could not be read as a valid JSON file"
432
+ f"Issue saving {file_id} as a Table to Artifacts, exception:\n '{e}'"
396
433
  )
397
434
  else:
398
435
  # log number of items
@@ -321,7 +321,6 @@ class WandBUltralyticsCallback:
321
321
  )
322
322
  if self.enable_model_checkpointing:
323
323
  self._save_model(trainer)
324
- self.model.to("cpu")
325
324
  trainer.model.to(self.device)
326
325
 
327
326
  def on_train_end(self, trainer: TRAINER_TYPE):