wandb 0.16.5__py3-none-any.whl → 0.17.0rc1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (141) hide show
  1. package_readme.md +95 -0
  2. wandb/__init__.py +2 -2
  3. wandb/agents/pyagent.py +0 -1
  4. wandb/analytics/sentry.py +2 -1
  5. wandb/apis/importers/internals/protocols.py +30 -56
  6. wandb/apis/importers/mlflow.py +13 -26
  7. wandb/apis/importers/wandb.py +8 -14
  8. wandb/apis/public/api.py +1 -0
  9. wandb/apis/public/artifacts.py +1 -0
  10. wandb/apis/public/files.py +1 -0
  11. wandb/apis/public/history.py +1 -0
  12. wandb/apis/public/jobs.py +1 -0
  13. wandb/apis/public/projects.py +1 -0
  14. wandb/apis/public/reports.py +1 -0
  15. wandb/apis/public/runs.py +1 -0
  16. wandb/apis/public/sweeps.py +1 -0
  17. wandb/apis/public/teams.py +1 -0
  18. wandb/apis/public/users.py +1 -0
  19. wandb/apis/reports/v1/_blocks.py +2 -6
  20. wandb/apis/reports/v2/gql.py +1 -0
  21. wandb/apis/reports/v2/interface.py +3 -4
  22. wandb/apis/reports/v2/internal.py +5 -8
  23. wandb/cli/cli.py +7 -4
  24. wandb/data_types.py +3 -3
  25. wandb/env.py +35 -5
  26. wandb/errors/__init__.py +5 -0
  27. wandb/integration/catboost/catboost.py +1 -1
  28. wandb/integration/fastai/__init__.py +1 -0
  29. wandb/integration/keras/__init__.py +1 -0
  30. wandb/integration/keras/keras.py +6 -6
  31. wandb/integration/langchain/wandb_tracer.py +1 -0
  32. wandb/integration/lightning/fabric/logger.py +1 -3
  33. wandb/integration/metaflow/metaflow.py +41 -6
  34. wandb/integration/openai/fine_tuning.py +77 -40
  35. wandb/keras/__init__.py +1 -0
  36. wandb/proto/v3/wandb_internal_pb2.py +364 -332
  37. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  38. wandb/proto/v4/wandb_internal_pb2.py +322 -316
  39. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  40. wandb/proto/wandb_internal_codegen.py +0 -25
  41. wandb/sdk/artifacts/artifact.py +41 -13
  42. wandb/sdk/artifacts/artifact_download_logger.py +1 -0
  43. wandb/sdk/artifacts/artifact_file_cache.py +18 -4
  44. wandb/sdk/artifacts/artifact_instance_cache.py +1 -0
  45. wandb/sdk/artifacts/artifact_manifest.py +1 -0
  46. wandb/sdk/artifacts/artifact_manifest_entry.py +1 -0
  47. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
  48. wandb/sdk/artifacts/artifact_saver.py +21 -21
  49. wandb/sdk/artifacts/artifact_state.py +1 -0
  50. wandb/sdk/artifacts/artifact_ttl.py +1 -0
  51. wandb/sdk/artifacts/exceptions.py +1 -0
  52. wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
  53. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +13 -18
  54. wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -0
  55. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -0
  56. wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -0
  57. wandb/sdk/artifacts/storage_handlers/s3_handler.py +5 -3
  58. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +1 -0
  59. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -0
  60. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -0
  61. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +1 -0
  62. wandb/sdk/artifacts/storage_policy.py +1 -0
  63. wandb/sdk/data_types/base_types/media.py +3 -6
  64. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +3 -1
  65. wandb/sdk/integration_utils/auto_logging.py +5 -6
  66. wandb/sdk/integration_utils/data_logging.py +5 -1
  67. wandb/sdk/interface/interface.py +72 -37
  68. wandb/sdk/interface/interface_shared.py +7 -13
  69. wandb/sdk/internal/datastore.py +1 -1
  70. wandb/sdk/internal/handler.py +18 -2
  71. wandb/sdk/internal/internal.py +0 -1
  72. wandb/sdk/internal/internal_util.py +0 -1
  73. wandb/sdk/internal/job_builder.py +4 -3
  74. wandb/sdk/internal/profiler.py +1 -0
  75. wandb/sdk/internal/run.py +1 -0
  76. wandb/sdk/internal/sender.py +1 -1
  77. wandb/sdk/internal/system/assets/gpu_amd.py +44 -44
  78. wandb/sdk/internal/system/assets/gpu_apple.py +56 -11
  79. wandb/sdk/internal/system/assets/interfaces.py +6 -8
  80. wandb/sdk/internal/system/assets/open_metrics.py +2 -2
  81. wandb/sdk/internal/system/assets/trainium.py +1 -3
  82. wandb/sdk/launch/_launch.py +5 -0
  83. wandb/sdk/launch/_project_spec.py +10 -23
  84. wandb/sdk/launch/agent/agent.py +81 -37
  85. wandb/sdk/launch/agent/config.py +80 -11
  86. wandb/sdk/launch/builder/abstract.py +1 -0
  87. wandb/sdk/launch/builder/build.py +28 -1
  88. wandb/sdk/launch/builder/docker_builder.py +1 -0
  89. wandb/sdk/launch/builder/kaniko_builder.py +149 -134
  90. wandb/sdk/launch/builder/noop.py +1 -0
  91. wandb/sdk/launch/create_job.py +61 -48
  92. wandb/sdk/launch/environment/abstract.py +1 -0
  93. wandb/sdk/launch/environment/gcp_environment.py +1 -0
  94. wandb/sdk/launch/environment/local_environment.py +1 -0
  95. wandb/sdk/launch/loader.py +1 -0
  96. wandb/sdk/launch/registry/abstract.py +1 -0
  97. wandb/sdk/launch/registry/azure_container_registry.py +1 -0
  98. wandb/sdk/launch/registry/elastic_container_registry.py +1 -0
  99. wandb/sdk/launch/registry/google_artifact_registry.py +1 -0
  100. wandb/sdk/launch/registry/local_registry.py +1 -0
  101. wandb/sdk/launch/runner/abstract.py +1 -0
  102. wandb/sdk/launch/runner/kubernetes_monitor.py +4 -1
  103. wandb/sdk/launch/runner/kubernetes_runner.py +4 -3
  104. wandb/sdk/launch/runner/sagemaker_runner.py +11 -10
  105. wandb/sdk/launch/sweeps/scheduler.py +4 -1
  106. wandb/sdk/launch/sweeps/scheduler_sweep.py +1 -0
  107. wandb/sdk/launch/sweeps/utils.py +1 -1
  108. wandb/sdk/launch/utils.py +21 -3
  109. wandb/sdk/lib/_settings_toposort_generated.py +1 -0
  110. wandb/sdk/lib/fsm.py +8 -12
  111. wandb/sdk/lib/gitlib.py +4 -4
  112. wandb/sdk/lib/lazyloader.py +0 -1
  113. wandb/sdk/lib/proto_util.py +1 -1
  114. wandb/sdk/lib/retry.py +3 -2
  115. wandb/sdk/lib/run_moment.py +7 -1
  116. wandb/sdk/service/service.py +17 -15
  117. wandb/sdk/verify/verify.py +2 -1
  118. wandb/sdk/wandb_init.py +2 -8
  119. wandb/sdk/wandb_manager.py +2 -2
  120. wandb/sdk/wandb_require.py +5 -0
  121. wandb/sdk/wandb_run.py +64 -46
  122. wandb/sdk/wandb_settings.py +2 -1
  123. wandb/sklearn/__init__.py +1 -0
  124. wandb/sklearn/plot/__init__.py +1 -0
  125. wandb/sklearn/plot/classifier.py +1 -0
  126. wandb/sklearn/plot/clusterer.py +1 -0
  127. wandb/sklearn/plot/regressor.py +1 -0
  128. wandb/sklearn/plot/shared.py +1 -0
  129. wandb/sklearn/utils.py +1 -0
  130. wandb/testing/relay.py +4 -4
  131. wandb/trigger.py +1 -0
  132. wandb/util.py +40 -17
  133. wandb/wandb_controller.py +0 -1
  134. wandb/wandb_torch.py +1 -2
  135. {wandb-0.16.5.dist-info → wandb-0.17.0rc1.dist-info}/METADATA +68 -69
  136. {wandb-0.16.5.dist-info → wandb-0.17.0rc1.dist-info}/RECORD +139 -140
  137. {wandb-0.16.5.dist-info → wandb-0.17.0rc1.dist-info}/WHEEL +1 -2
  138. wandb/bin/apple_gpu_stats +0 -0
  139. wandb-0.16.5.dist-info/top_level.txt +0 -1
  140. {wandb-0.16.5.dist-info → wandb-0.17.0rc1.dist-info}/entry_points.txt +0 -0
  141. {wandb-0.16.5.dist-info → wandb-0.17.0rc1.dist-info/licenses}/LICENSE +0 -0
wandb/env.py CHANGED
@@ -13,7 +13,6 @@ these values in many cases.
13
13
  import json
14
14
  import os
15
15
  import sys
16
- from distutils.util import strtobool
17
16
  from pathlib import Path
18
17
  from typing import List, MutableMapping, Optional, Union
19
18
 
@@ -61,6 +60,8 @@ SAVE_CODE = "WANDB_SAVE_CODE"
61
60
  TAGS = "WANDB_TAGS"
62
61
  IGNORE = "WANDB_IGNORE_GLOBS"
63
62
  ERROR_REPORTING = "WANDB_ERROR_REPORTING"
63
+ CORE_ERROR_REPORTING = "WANDB_CORE_ERROR_REPORTING"
64
+ CORE_DEBUG = "WANDB_CORE_DEBUG"
64
65
  DOCKER = "WANDB_DOCKER"
65
66
  AGENT_REPORT_INTERVAL = "WANDB_AGENT_REPORT_INTERVAL"
66
67
  AGENT_KILL_DELAY = "WANDB_AGENT_KILL_DELAY"
@@ -87,6 +88,7 @@ _EXECUTABLE = "WANDB_EXECUTABLE"
87
88
  LAUNCH_QUEUE_NAME = "WANDB_LAUNCH_QUEUE_NAME"
88
89
  LAUNCH_QUEUE_ENTITY = "WANDB_LAUNCH_QUEUE_ENTITY"
89
90
  LAUNCH_TRACE_ID = "WANDB_LAUNCH_TRACE_ID"
91
+ _REQUIRE_CORE = "WANDB__REQUIRE_CORE"
90
92
 
91
93
  # For testing, to be removed in future version
92
94
  USE_V1_ARTIFACTS = "_WANDB_USE_V1_ARTIFACTS"
@@ -139,11 +141,16 @@ def _env_as_bool(
139
141
  if env is None:
140
142
  env = os.environ
141
143
  val = env.get(var, default)
144
+ if not isinstance(val, str):
145
+ return False
142
146
  try:
143
- val = bool(strtobool(val)) # type: ignore
144
- except (AttributeError, ValueError):
145
- pass
146
- return val if isinstance(val, bool) else False
147
+ return strtobool(val)
148
+ except ValueError:
149
+ return False
150
+
151
+
152
+ def is_require_core(env: Optional[Env] = None) -> bool:
153
+ return _env_as_bool(_REQUIRE_CORE, default="False", env=env)
147
154
 
148
155
 
149
156
  def is_debug(default: Optional[str] = None, env: Optional[Env] = None) -> bool:
@@ -154,6 +161,14 @@ def error_reporting_enabled() -> bool:
154
161
  return _env_as_bool(ERROR_REPORTING, default="True")
155
162
 
156
163
 
164
+ def core_error_reporting_enabled(default: Optional[str] = None) -> bool:
165
+ return _env_as_bool(CORE_ERROR_REPORTING, default=default)
166
+
167
+
168
+ def core_debug(default: Optional[str] = None) -> bool:
169
+ return _env_as_bool(CORE_DEBUG, default=default)
170
+
171
+
157
172
  def ssl_disabled() -> bool:
158
173
  return _env_as_bool(DISABLE_SSL, default="False")
159
174
 
@@ -464,3 +479,18 @@ def get_launch_trace_id(env: Optional[Env] = None) -> Optional[str]:
464
479
  env = os.environ
465
480
  val = env.get(LAUNCH_TRACE_ID, None)
466
481
  return val
482
+
483
+
484
+ def strtobool(val: str) -> bool:
485
+ """Convert a string representation of truth to true or false.
486
+
487
+ Copied from distutils. distutils was removed in Python 3.12.
488
+ """
489
+ val = val.lower()
490
+
491
+ if val in ("y", "yes", "t", "true", "on", "1"):
492
+ return True
493
+ elif val in ("n", "no", "f", "false", "off", "0"):
494
+ return False
495
+ else:
496
+ raise ValueError(f"invalid truth value {val!r}")
wandb/errors/__init__.py CHANGED
@@ -4,6 +4,7 @@ __all__ = [
4
4
  "AuthenticationError",
5
5
  "UsageError",
6
6
  "UnsupportedError",
7
+ "WandbCoreNotAvailableError",
7
8
  ]
8
9
 
9
10
  from typing import Optional
@@ -39,3 +40,7 @@ class UsageError(Error):
39
40
 
40
41
  class UnsupportedError(UsageError):
41
42
  """Raised when trying to use a feature that is not supported."""
43
+
44
+
45
+ class WandbCoreNotAvailableError(Error):
46
+ """Raised when wandb core is not available."""
@@ -81,7 +81,7 @@ def _checkpoint_artifact(
81
81
 
82
82
 
83
83
  def _log_feature_importance(
84
- model: Union[CatBoostClassifier, CatBoostRegressor]
84
+ model: Union[CatBoostClassifier, CatBoostRegressor],
85
85
  ) -> None:
86
86
  """Log feature importance with default settings."""
87
87
  if wandb.run is None:
@@ -35,6 +35,7 @@ Examples:
35
35
  learn.fit(..., callbacks=WandbCallback(learn, ...))
36
36
  ```
37
37
  """
38
+
38
39
  import random
39
40
  import sys
40
41
  from pathlib import Path
@@ -2,6 +2,7 @@
2
2
 
3
3
  Keras is a deep learning API for [`TensorFlow`](https://www.tensorflow.org/).
4
4
  """
5
+
5
6
  __all__ = (
6
7
  "WandbCallback",
7
8
  "WandbMetricsLogger",
@@ -616,9 +616,9 @@ class WandbCallback(tf.keras.callbacks.Callback):
616
616
  self.current = logs.get(self.monitor)
617
617
  if self.current and self.monitor_op(self.current, self.best):
618
618
  if self.log_best_prefix:
619
- wandb.run.summary[
620
- f"{self.log_best_prefix}{self.monitor}"
621
- ] = self.current
619
+ wandb.run.summary[f"{self.log_best_prefix}{self.monitor}"] = (
620
+ self.current
621
+ )
622
622
  wandb.run.summary["{}{}".format(self.log_best_prefix, "epoch")] = epoch
623
623
  if self.verbose and not self.save_model:
624
624
  print(
@@ -937,9 +937,9 @@ class WandbCallback(tf.keras.callbacks.Callback):
937
937
  grads = self._grad_accumulator_callback.grads
938
938
  metrics = {}
939
939
  for weight, grad in zip(weights, grads):
940
- metrics[
941
- "gradients/" + weight.name.split(":")[0] + ".gradient"
942
- ] = wandb.Histogram(grad)
940
+ metrics["gradients/" + weight.name.split(":")[0] + ".gradient"] = (
941
+ wandb.Histogram(grad)
942
+ )
943
943
  return metrics
944
944
 
945
945
  def _log_dataframe(self):
@@ -14,6 +14,7 @@ integration will not break user code. The one exception to the rule is at import
14
14
  LangChain is not installed, or the symbols are not in the same place, the appropriate error
15
15
  will be raised when importing this module.
16
16
  """
17
+
17
18
  from packaging import version
18
19
 
19
20
  import wandb.util
@@ -401,9 +401,7 @@ class WandbLogger(Logger):
401
401
  "*", step_metric="trainer/global_step", step_sync=True
402
402
  )
403
403
 
404
- self._experiment._label(
405
- repo="lightning_fabric_logger"
406
- ) # pylint: disable=protected-access
404
+ self._experiment._label(repo="lightning_fabric_logger") # pylint: disable=protected-access
407
405
  with telemetry.context(run=self._experiment) as tel:
408
406
  tel.feature.lightning_fabric_logger = True
409
407
  return self._experiment
@@ -36,7 +36,15 @@ try:
36
36
  import pandas as pd
37
37
 
38
38
  @typedispatch # noqa: F811
39
- def _wandb_use(name: str, data: pd.DataFrame, datasets=False, run=None, testing=False, *args, **kwargs): # type: ignore
39
+ def _wandb_use(
40
+ name: str,
41
+ data: pd.DataFrame,
42
+ datasets=False,
43
+ run=None,
44
+ testing=False,
45
+ *args,
46
+ **kwargs,
47
+ ): # type: ignore
40
48
  if testing:
41
49
  return "datasets" if datasets else None
42
50
 
@@ -74,7 +82,15 @@ try:
74
82
  import torch.nn as nn
75
83
 
76
84
  @typedispatch # noqa: F811
77
- def _wandb_use(name: str, data: nn.Module, models=False, run=None, testing=False, *args, **kwargs): # type: ignore
85
+ def _wandb_use(
86
+ name: str,
87
+ data: nn.Module,
88
+ models=False,
89
+ run=None,
90
+ testing=False,
91
+ *args,
92
+ **kwargs,
93
+ ): # type: ignore
78
94
  if testing:
79
95
  return "models" if models else None
80
96
 
@@ -111,7 +127,15 @@ try:
111
127
  from sklearn.base import BaseEstimator
112
128
 
113
129
  @typedispatch # noqa: F811
114
- def _wandb_use(name: str, data: BaseEstimator, models=False, run=None, testing=False, *args, **kwargs): # type: ignore
130
+ def _wandb_use(
131
+ name: str,
132
+ data: BaseEstimator,
133
+ models=False,
134
+ run=None,
135
+ testing=False,
136
+ *args,
137
+ **kwargs,
138
+ ): # type: ignore
115
139
  if testing:
116
140
  return "models" if models else None
117
141
 
@@ -169,7 +193,14 @@ class ArtifactProxy:
169
193
 
170
194
 
171
195
  @typedispatch # noqa: F811
172
- def wandb_track(name: str, data: (dict, list, set, str, int, float, bool), run=None, testing=False, *args, **kwargs): # type: ignore
196
+ def wandb_track(
197
+ name: str,
198
+ data: (dict, list, set, str, int, float, bool),
199
+ run=None,
200
+ testing=False,
201
+ *args,
202
+ **kwargs,
203
+ ): # type: ignore
173
204
  if testing:
174
205
  return "scalar"
175
206
 
@@ -222,12 +253,16 @@ def wandb_use(name: str, data, *args, **kwargs):
222
253
 
223
254
 
224
255
  @typedispatch # noqa: F811
225
- def wandb_use(name: str, data: (dict, list, set, str, int, float, bool), *args, **kwargs): # type: ignore
256
+ def wandb_use(
257
+ name: str, data: (dict, list, set, str, int, float, bool), *args, **kwargs
258
+ ): # type: ignore
226
259
  pass # do nothing for these types
227
260
 
228
261
 
229
262
  @typedispatch # noqa: F811
230
- def _wandb_use(name: str, data: Path, datasets=False, run=None, testing=False, *args, **kwargs): # type: ignore
263
+ def _wandb_use(
264
+ name: str, data: Path, datasets=False, run=None, testing=False, *args, **kwargs
265
+ ): # type: ignore
231
266
  if testing:
232
267
  return "datasets" if datasets else None
233
268
 
@@ -1,9 +1,11 @@
1
1
  import datetime
2
2
  import io
3
3
  import json
4
+ import os
4
5
  import re
6
+ import tempfile
5
7
  import time
6
- from typing import Any, Dict, Optional, Tuple
8
+ from typing import Any, Dict, List, Optional, Tuple, Union
7
9
 
8
10
  import wandb
9
11
  from wandb import util
@@ -26,7 +28,10 @@ if parse_version(openai.__version__) < parse_version("1.0.1"):
26
28
 
27
29
  from openai import OpenAI # noqa: E402
28
30
  from openai.types.fine_tuning import FineTuningJob # noqa: E402
29
- from openai.types.fine_tuning.fine_tuning_job import Hyperparameters # noqa: E402
31
+ from openai.types.fine_tuning.fine_tuning_job import ( # noqa: E402
32
+ Error,
33
+ Hyperparameters,
34
+ )
30
35
 
31
36
  np = util.get_module(
32
37
  name="numpy",
@@ -59,6 +64,7 @@ class WandbLogger:
59
64
  entity: Optional[str] = None,
60
65
  overwrite: bool = False,
61
66
  wait_for_job_success: bool = True,
67
+ log_datasets: bool = True,
62
68
  **kwargs_wandb_init: Dict[str, Any],
63
69
  ) -> str:
64
70
  """Sync fine-tunes to Weights & Biases.
@@ -150,6 +156,7 @@ class WandbLogger:
150
156
  entity,
151
157
  overwrite,
152
158
  show_individual_warnings,
159
+ log_datasets,
153
160
  **kwargs_wandb_init,
154
161
  )
155
162
 
@@ -160,11 +167,14 @@ class WandbLogger:
160
167
 
161
168
  @classmethod
162
169
  def _wait_for_job_success(cls, fine_tune: FineTuningJob) -> FineTuningJob:
163
- wandb.termlog("Waiting for the OpenAI fine-tuning job to be finished...")
170
+ wandb.termlog("Waiting for the OpenAI fine-tuning job to finish training...")
171
+ wandb.termlog(
172
+ "To avoid blocking, you can call `WandbLogger.sync` with `wait_for_job_success=False` after OpenAI training completes."
173
+ )
164
174
  while True:
165
175
  if fine_tune.status == "succeeded":
166
176
  wandb.termlog(
167
- "Fine-tuning finished, logging metrics, model metadata, and more to W&B"
177
+ "Fine-tuning finished, logging metrics, model metadata, and run metadata to Weights & Biases"
168
178
  )
169
179
  return fine_tune
170
180
  if fine_tune.status == "failed":
@@ -190,6 +200,7 @@ class WandbLogger:
190
200
  entity: Optional[str],
191
201
  overwrite: bool,
192
202
  show_individual_warnings: bool,
203
+ log_datasets: bool,
193
204
  **kwargs_wandb_init: Dict[str, Any],
194
205
  ):
195
206
  fine_tune_id = fine_tune.id
@@ -209,7 +220,7 @@ class WandbLogger:
209
220
  # check results are present
210
221
  try:
211
222
  results_id = fine_tune.result_files[0]
212
- results = cls.openai_client.files.retrieve_content(file_id=results_id)
223
+ results = cls.openai_client.files.content(file_id=results_id).text
213
224
  except openai.NotFoundError:
214
225
  if show_individual_warnings:
215
226
  wandb.termwarn(
@@ -233,7 +244,7 @@ class WandbLogger:
233
244
  cls._run.summary["fine_tuned_model"] = fine_tuned_model
234
245
 
235
246
  # training/validation files and fine-tune details
236
- cls._log_artifacts(fine_tune, project, entity)
247
+ cls._log_artifacts(fine_tune, project, entity, log_datasets, overwrite)
237
248
 
238
249
  # mark run as complete
239
250
  cls._run.summary["status"] = "succeeded"
@@ -249,7 +260,7 @@ class WandbLogger:
249
260
  else:
250
261
  raise Exception(
251
262
  "It appears you are not currently logged in to Weights & Biases. "
252
- "Please run `wandb login` in your terminal. "
263
+ "Please run `wandb login` in your terminal or `wandb.login()` in a notebook."
253
264
  "When prompted, you can obtain your API key by visiting wandb.ai/authorize."
254
265
  )
255
266
 
@@ -286,15 +297,9 @@ class WandbLogger:
286
297
  config["finished_at"]
287
298
  ).strftime("%Y-%m-%d %H:%M:%S")
288
299
  if config.get("hyperparameters"):
289
- hyperparameters = config.pop("hyperparameters")
290
- hyperparams = cls._unpack_hyperparameters(hyperparameters)
291
- if hyperparams is None:
292
- # If unpacking fails, log the object which will render as string
293
- config["hyperparameters"] = hyperparameters
294
- else:
295
- # nested rendering on hyperparameters
296
- config["hyperparameters"] = hyperparams
297
-
300
+ config["hyperparameters"] = cls.sanitize(config["hyperparameters"])
301
+ if config.get("error"):
302
+ config["error"] = cls.sanitize(config["error"])
298
303
  return config
299
304
 
300
305
  @classmethod
@@ -305,30 +310,53 @@ class WandbLogger:
305
310
  try:
306
311
  hyperparams["n_epochs"] = hyperparameters.n_epochs
307
312
  hyperparams["batch_size"] = hyperparameters.batch_size
308
- hyperparams[
309
- "learning_rate_multiplier"
310
- ] = hyperparameters.learning_rate_multiplier
313
+ hyperparams["learning_rate_multiplier"] = (
314
+ hyperparameters.learning_rate_multiplier
315
+ )
311
316
  except Exception:
312
317
  # If unpacking fails, return the object to be logged as config
313
318
  return None
314
319
 
315
320
  return hyperparams
316
321
 
322
+ @staticmethod
323
+ def sanitize(input: Any) -> Union[Dict, List, str]:
324
+ valid_types = [bool, int, float, str]
325
+ if isinstance(input, (Hyperparameters, Error)):
326
+ return dict(input)
327
+ if isinstance(input, dict):
328
+ return {
329
+ k: v if type(v) in valid_types else str(v) for k, v in input.items()
330
+ }
331
+ elif isinstance(input, list):
332
+ return [v if type(v) in valid_types else str(v) for v in input]
333
+ else:
334
+ return str(input)
335
+
317
336
  @classmethod
318
337
  def _log_artifacts(
319
- cls, fine_tune: FineTuningJob, project: str, entity: Optional[str]
338
+ cls,
339
+ fine_tune: FineTuningJob,
340
+ project: str,
341
+ entity: Optional[str],
342
+ log_datasets: bool,
343
+ overwrite: bool,
320
344
  ) -> None:
321
- # training/validation files
322
- training_file = fine_tune.training_file if fine_tune.training_file else None
323
- validation_file = (
324
- fine_tune.validation_file if fine_tune.validation_file else None
325
- )
326
- for file, prefix, artifact_type in (
327
- (training_file, "train", "training_files"),
328
- (validation_file, "valid", "validation_files"),
329
- ):
330
- if file is not None:
331
- cls._log_artifact_inputs(file, prefix, artifact_type, project, entity)
345
+ if log_datasets:
346
+ wandb.termlog("Logging training/validation files...")
347
+ # training/validation files
348
+ training_file = fine_tune.training_file if fine_tune.training_file else None
349
+ validation_file = (
350
+ fine_tune.validation_file if fine_tune.validation_file else None
351
+ )
352
+ for file, prefix, artifact_type in (
353
+ (training_file, "train", "training_files"),
354
+ (validation_file, "valid", "validation_files"),
355
+ ):
356
+ if file is not None:
357
+ cls._log_artifact_inputs(
358
+ file, prefix, artifact_type, project, entity, overwrite
359
+ )
332
360
 
333
361
  # fine-tune details
334
362
  fine_tune_id = fine_tune.id
@@ -337,9 +365,14 @@ class WandbLogger:
337
365
  type="model",
338
366
  metadata=dict(fine_tune),
339
367
  )
368
+
340
369
  with artifact.new_file("model_metadata.json", mode="w", encoding="utf-8") as f:
341
370
  dict_fine_tune = dict(fine_tune)
342
- dict_fine_tune["hyperparameters"] = dict(dict_fine_tune["hyperparameters"])
371
+ dict_fine_tune["hyperparameters"] = cls.sanitize(
372
+ dict_fine_tune["hyperparameters"]
373
+ )
374
+ dict_fine_tune["error"] = cls.sanitize(dict_fine_tune["error"])
375
+ dict_fine_tune = cls.sanitize(dict_fine_tune)
343
376
  json.dump(dict_fine_tune, f, indent=2)
344
377
  cls._run.log_artifact(
345
378
  artifact,
@@ -354,6 +387,7 @@ class WandbLogger:
354
387
  artifact_type: str,
355
388
  project: str,
356
389
  entity: Optional[str],
390
+ overwrite: bool,
357
391
  ) -> None:
358
392
  # get input artifact
359
393
  artifact_name = f"{prefix}-{file_id}"
@@ -366,23 +400,26 @@ class WandbLogger:
366
400
  artifact = cls._get_wandb_artifact(artifact_path)
367
401
 
368
402
  # create artifact if file not already logged previously
369
- if artifact is None:
403
+ if artifact is None or overwrite:
370
404
  # get file content
371
405
  try:
372
- file_content = cls.openai_client.files.retrieve_content(file_id=file_id)
406
+ file_content = cls.openai_client.files.content(file_id=file_id)
373
407
  except openai.NotFoundError:
374
408
  wandb.termerror(
375
- f"File {file_id} could not be retrieved. Make sure you are allowed to download training/validation files"
409
+ f"File {file_id} could not be retrieved. Make sure you have OpenAI permissions to download training/validation files"
376
410
  )
377
411
  return
378
412
 
379
413
  artifact = wandb.Artifact(artifact_name, type=artifact_type)
380
- with artifact.new_file(file_id, mode="w", encoding="utf-8") as f:
381
- f.write(file_content)
414
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
415
+ tmp_file.write(file_content.content)
416
+ tmp_file_path = tmp_file.name
417
+ artifact.add_file(tmp_file_path, file_id)
418
+ os.unlink(tmp_file_path)
382
419
 
383
420
  # create a Table
384
421
  try:
385
- table, n_items = cls._make_table(file_content)
422
+ table, n_items = cls._make_table(file_content.text)
386
423
  # Add table to the artifact.
387
424
  artifact.add(table, file_id)
388
425
  # Add the same table to the workspace.
@@ -390,9 +427,9 @@ class WandbLogger:
390
427
  # Update the run config and artifact metadata
391
428
  cls._run.config.update({f"n_{prefix}": n_items})
392
429
  artifact.metadata["items"] = n_items
393
- except Exception:
430
+ except Exception as e:
394
431
  wandb.termerror(
395
- f"File {file_id} could not be read as a valid JSON file"
432
+ f"Issue saving {file_id} as a Table to Artifacts, exception:\n '{e}'"
396
433
  )
397
434
  else:
398
435
  # log number of items
wandb/keras/__init__.py CHANGED
@@ -3,6 +3,7 @@
3
3
  In the future use e.g.:
4
4
  from wandb.integration.keras import WandbCallback
5
5
  """
6
+
6
7
  __all__ = (
7
8
  "WandbCallback",
8
9
  "WandbMetricsLogger",