wandb 0.16.5__py3-none-any.whl → 0.17.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (194) hide show
  1. package_readme.md +95 -0
  2. wandb/__init__.py +2 -3
  3. wandb/agents/pyagent.py +0 -1
  4. wandb/analytics/sentry.py +2 -1
  5. wandb/apis/importers/internals/internal.py +0 -1
  6. wandb/apis/importers/internals/protocols.py +30 -56
  7. wandb/apis/importers/mlflow.py +13 -26
  8. wandb/apis/importers/wandb.py +8 -14
  9. wandb/apis/internal.py +0 -3
  10. wandb/apis/public/api.py +55 -3
  11. wandb/apis/public/artifacts.py +1 -0
  12. wandb/apis/public/files.py +1 -0
  13. wandb/apis/public/history.py +1 -0
  14. wandb/apis/public/jobs.py +17 -4
  15. wandb/apis/public/projects.py +1 -0
  16. wandb/apis/public/reports.py +1 -0
  17. wandb/apis/public/runs.py +15 -17
  18. wandb/apis/public/sweeps.py +1 -0
  19. wandb/apis/public/teams.py +1 -0
  20. wandb/apis/public/users.py +1 -0
  21. wandb/apis/reports/v1/_blocks.py +3 -7
  22. wandb/apis/reports/v2/gql.py +1 -0
  23. wandb/apis/reports/v2/interface.py +3 -4
  24. wandb/apis/reports/v2/internal.py +5 -8
  25. wandb/cli/cli.py +95 -22
  26. wandb/data_types.py +9 -6
  27. wandb/docker/__init__.py +1 -1
  28. wandb/env.py +38 -8
  29. wandb/errors/__init__.py +5 -0
  30. wandb/errors/term.py +10 -2
  31. wandb/filesync/step_checksum.py +1 -4
  32. wandb/filesync/step_prepare.py +4 -24
  33. wandb/filesync/step_upload.py +4 -106
  34. wandb/filesync/upload_job.py +0 -76
  35. wandb/integration/catboost/catboost.py +1 -1
  36. wandb/integration/fastai/__init__.py +1 -0
  37. wandb/integration/huggingface/resolver.py +2 -2
  38. wandb/integration/keras/__init__.py +1 -0
  39. wandb/integration/keras/callbacks/metrics_logger.py +1 -1
  40. wandb/integration/keras/keras.py +7 -7
  41. wandb/integration/langchain/wandb_tracer.py +1 -0
  42. wandb/integration/lightning/fabric/logger.py +1 -3
  43. wandb/integration/metaflow/metaflow.py +41 -6
  44. wandb/integration/openai/fine_tuning.py +77 -40
  45. wandb/integration/prodigy/prodigy.py +1 -1
  46. wandb/old/summary.py +1 -1
  47. wandb/plot/confusion_matrix.py +1 -1
  48. wandb/plot/pr_curve.py +2 -1
  49. wandb/plot/roc_curve.py +2 -1
  50. wandb/{plots → plot}/utils.py +13 -25
  51. wandb/proto/v3/wandb_internal_pb2.py +364 -332
  52. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  53. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  54. wandb/proto/v4/wandb_internal_pb2.py +322 -316
  55. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  56. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  57. wandb/proto/wandb_deprecated.py +7 -1
  58. wandb/proto/wandb_internal_codegen.py +3 -29
  59. wandb/sdk/artifacts/artifact.py +51 -20
  60. wandb/sdk/artifacts/artifact_download_logger.py +1 -0
  61. wandb/sdk/artifacts/artifact_file_cache.py +18 -4
  62. wandb/sdk/artifacts/artifact_instance_cache.py +1 -0
  63. wandb/sdk/artifacts/artifact_manifest.py +1 -0
  64. wandb/sdk/artifacts/artifact_manifest_entry.py +7 -3
  65. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
  66. wandb/sdk/artifacts/artifact_saver.py +18 -27
  67. wandb/sdk/artifacts/artifact_state.py +1 -0
  68. wandb/sdk/artifacts/artifact_ttl.py +1 -0
  69. wandb/sdk/artifacts/exceptions.py +1 -0
  70. wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
  71. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +13 -18
  72. wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -0
  73. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -0
  74. wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -0
  75. wandb/sdk/artifacts/storage_handlers/s3_handler.py +5 -3
  76. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +1 -0
  77. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -0
  78. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -0
  79. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +3 -42
  80. wandb/sdk/artifacts/storage_policy.py +2 -12
  81. wandb/sdk/data_types/_dtypes.py +8 -8
  82. wandb/sdk/data_types/base_types/media.py +3 -6
  83. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +3 -1
  84. wandb/sdk/data_types/image.py +1 -1
  85. wandb/sdk/data_types/video.py +1 -1
  86. wandb/sdk/integration_utils/auto_logging.py +5 -6
  87. wandb/sdk/integration_utils/data_logging.py +10 -6
  88. wandb/sdk/interface/interface.py +86 -38
  89. wandb/sdk/interface/interface_shared.py +7 -13
  90. wandb/sdk/internal/datastore.py +1 -1
  91. wandb/sdk/internal/file_pusher.py +2 -5
  92. wandb/sdk/internal/file_stream.py +5 -18
  93. wandb/sdk/internal/handler.py +18 -2
  94. wandb/sdk/internal/internal.py +0 -1
  95. wandb/sdk/internal/internal_api.py +1 -129
  96. wandb/sdk/internal/internal_util.py +0 -1
  97. wandb/sdk/internal/job_builder.py +159 -45
  98. wandb/sdk/internal/profiler.py +1 -0
  99. wandb/sdk/internal/progress.py +0 -28
  100. wandb/sdk/internal/run.py +1 -0
  101. wandb/sdk/internal/sender.py +1 -2
  102. wandb/sdk/internal/system/assets/gpu_amd.py +44 -44
  103. wandb/sdk/internal/system/assets/gpu_apple.py +56 -11
  104. wandb/sdk/internal/system/assets/interfaces.py +6 -8
  105. wandb/sdk/internal/system/assets/open_metrics.py +2 -2
  106. wandb/sdk/internal/system/assets/trainium.py +1 -3
  107. wandb/sdk/launch/__init__.py +9 -1
  108. wandb/sdk/launch/_launch.py +9 -24
  109. wandb/sdk/launch/_launch_add.py +1 -3
  110. wandb/sdk/launch/_project_spec.py +188 -241
  111. wandb/sdk/launch/agent/agent.py +115 -48
  112. wandb/sdk/launch/agent/config.py +80 -14
  113. wandb/sdk/launch/builder/abstract.py +69 -1
  114. wandb/sdk/launch/builder/build.py +156 -555
  115. wandb/sdk/launch/builder/context_manager.py +235 -0
  116. wandb/sdk/launch/builder/docker_builder.py +8 -23
  117. wandb/sdk/launch/builder/kaniko_builder.py +161 -159
  118. wandb/sdk/launch/builder/noop.py +1 -0
  119. wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
  120. wandb/sdk/launch/create_job.py +68 -63
  121. wandb/sdk/launch/environment/abstract.py +1 -0
  122. wandb/sdk/launch/environment/gcp_environment.py +1 -0
  123. wandb/sdk/launch/environment/local_environment.py +1 -0
  124. wandb/sdk/launch/inputs/files.py +148 -0
  125. wandb/sdk/launch/inputs/internal.py +217 -0
  126. wandb/sdk/launch/inputs/manage.py +95 -0
  127. wandb/sdk/launch/loader.py +1 -0
  128. wandb/sdk/launch/registry/abstract.py +1 -0
  129. wandb/sdk/launch/registry/azure_container_registry.py +1 -0
  130. wandb/sdk/launch/registry/elastic_container_registry.py +1 -0
  131. wandb/sdk/launch/registry/google_artifact_registry.py +2 -1
  132. wandb/sdk/launch/registry/local_registry.py +1 -0
  133. wandb/sdk/launch/runner/abstract.py +1 -0
  134. wandb/sdk/launch/runner/kubernetes_monitor.py +4 -1
  135. wandb/sdk/launch/runner/kubernetes_runner.py +9 -10
  136. wandb/sdk/launch/runner/local_container.py +2 -3
  137. wandb/sdk/launch/runner/local_process.py +8 -29
  138. wandb/sdk/launch/runner/sagemaker_runner.py +21 -20
  139. wandb/sdk/launch/runner/vertex_runner.py +8 -7
  140. wandb/sdk/launch/sweeps/scheduler.py +7 -4
  141. wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
  142. wandb/sdk/launch/sweeps/utils.py +3 -3
  143. wandb/sdk/launch/utils.py +33 -140
  144. wandb/sdk/lib/_settings_toposort_generated.py +1 -5
  145. wandb/sdk/lib/fsm.py +8 -12
  146. wandb/sdk/lib/gitlib.py +4 -4
  147. wandb/sdk/lib/import_hooks.py +1 -1
  148. wandb/sdk/lib/lazyloader.py +0 -1
  149. wandb/sdk/lib/proto_util.py +23 -2
  150. wandb/sdk/lib/redirect.py +19 -14
  151. wandb/sdk/lib/retry.py +3 -2
  152. wandb/sdk/lib/run_moment.py +7 -1
  153. wandb/sdk/lib/tracelog.py +1 -1
  154. wandb/sdk/service/service.py +19 -16
  155. wandb/sdk/verify/verify.py +2 -1
  156. wandb/sdk/wandb_init.py +16 -63
  157. wandb/sdk/wandb_manager.py +2 -2
  158. wandb/sdk/wandb_require.py +5 -0
  159. wandb/sdk/wandb_run.py +164 -90
  160. wandb/sdk/wandb_settings.py +2 -48
  161. wandb/sdk/wandb_setup.py +1 -1
  162. wandb/sklearn/__init__.py +1 -0
  163. wandb/sklearn/plot/__init__.py +1 -0
  164. wandb/sklearn/plot/classifier.py +11 -12
  165. wandb/sklearn/plot/clusterer.py +2 -1
  166. wandb/sklearn/plot/regressor.py +1 -0
  167. wandb/sklearn/plot/shared.py +1 -0
  168. wandb/sklearn/utils.py +1 -0
  169. wandb/testing/relay.py +4 -4
  170. wandb/trigger.py +1 -0
  171. wandb/util.py +67 -54
  172. wandb/wandb_controller.py +2 -3
  173. wandb/wandb_torch.py +1 -2
  174. {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/METADATA +67 -70
  175. {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/RECORD +178 -188
  176. {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/WHEEL +1 -2
  177. wandb/bin/apple_gpu_stats +0 -0
  178. wandb/catboost/__init__.py +0 -9
  179. wandb/fastai/__init__.py +0 -9
  180. wandb/keras/__init__.py +0 -18
  181. wandb/lightgbm/__init__.py +0 -9
  182. wandb/plots/__init__.py +0 -6
  183. wandb/plots/explain_text.py +0 -36
  184. wandb/plots/heatmap.py +0 -81
  185. wandb/plots/named_entity.py +0 -43
  186. wandb/plots/part_of_speech.py +0 -50
  187. wandb/plots/plot_definitions.py +0 -768
  188. wandb/plots/precision_recall.py +0 -121
  189. wandb/plots/roc.py +0 -103
  190. wandb/sacred/__init__.py +0 -3
  191. wandb/xgboost/__init__.py +0 -9
  192. wandb-0.16.5.dist-info/top_level.txt +0 -1
  193. {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/entry_points.txt +0 -0
  194. {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info/licenses}/LICENSE +0 -0
@@ -1,9 +1,11 @@
1
1
  import datetime
2
2
  import io
3
3
  import json
4
+ import os
4
5
  import re
6
+ import tempfile
5
7
  import time
6
- from typing import Any, Dict, Optional, Tuple
8
+ from typing import Any, Dict, List, Optional, Tuple, Union
7
9
 
8
10
  import wandb
9
11
  from wandb import util
@@ -26,7 +28,10 @@ if parse_version(openai.__version__) < parse_version("1.0.1"):
26
28
 
27
29
  from openai import OpenAI # noqa: E402
28
30
  from openai.types.fine_tuning import FineTuningJob # noqa: E402
29
- from openai.types.fine_tuning.fine_tuning_job import Hyperparameters # noqa: E402
31
+ from openai.types.fine_tuning.fine_tuning_job import ( # noqa: E402
32
+ Error,
33
+ Hyperparameters,
34
+ )
30
35
 
31
36
  np = util.get_module(
32
37
  name="numpy",
@@ -59,6 +64,7 @@ class WandbLogger:
59
64
  entity: Optional[str] = None,
60
65
  overwrite: bool = False,
61
66
  wait_for_job_success: bool = True,
67
+ log_datasets: bool = True,
62
68
  **kwargs_wandb_init: Dict[str, Any],
63
69
  ) -> str:
64
70
  """Sync fine-tunes to Weights & Biases.
@@ -150,6 +156,7 @@ class WandbLogger:
150
156
  entity,
151
157
  overwrite,
152
158
  show_individual_warnings,
159
+ log_datasets,
153
160
  **kwargs_wandb_init,
154
161
  )
155
162
 
@@ -160,11 +167,14 @@ class WandbLogger:
160
167
 
161
168
  @classmethod
162
169
  def _wait_for_job_success(cls, fine_tune: FineTuningJob) -> FineTuningJob:
163
- wandb.termlog("Waiting for the OpenAI fine-tuning job to be finished...")
170
+ wandb.termlog("Waiting for the OpenAI fine-tuning job to finish training...")
171
+ wandb.termlog(
172
+ "To avoid blocking, you can call `WandbLogger.sync` with `wait_for_job_success=False` after OpenAI training completes."
173
+ )
164
174
  while True:
165
175
  if fine_tune.status == "succeeded":
166
176
  wandb.termlog(
167
- "Fine-tuning finished, logging metrics, model metadata, and more to W&B"
177
+ "Fine-tuning finished, logging metrics, model metadata, and run metadata to Weights & Biases"
168
178
  )
169
179
  return fine_tune
170
180
  if fine_tune.status == "failed":
@@ -190,6 +200,7 @@ class WandbLogger:
190
200
  entity: Optional[str],
191
201
  overwrite: bool,
192
202
  show_individual_warnings: bool,
203
+ log_datasets: bool,
193
204
  **kwargs_wandb_init: Dict[str, Any],
194
205
  ):
195
206
  fine_tune_id = fine_tune.id
@@ -209,7 +220,7 @@ class WandbLogger:
209
220
  # check results are present
210
221
  try:
211
222
  results_id = fine_tune.result_files[0]
212
- results = cls.openai_client.files.retrieve_content(file_id=results_id)
223
+ results = cls.openai_client.files.content(file_id=results_id).text
213
224
  except openai.NotFoundError:
214
225
  if show_individual_warnings:
215
226
  wandb.termwarn(
@@ -233,7 +244,7 @@ class WandbLogger:
233
244
  cls._run.summary["fine_tuned_model"] = fine_tuned_model
234
245
 
235
246
  # training/validation files and fine-tune details
236
- cls._log_artifacts(fine_tune, project, entity)
247
+ cls._log_artifacts(fine_tune, project, entity, log_datasets, overwrite)
237
248
 
238
249
  # mark run as complete
239
250
  cls._run.summary["status"] = "succeeded"
@@ -249,7 +260,7 @@ class WandbLogger:
249
260
  else:
250
261
  raise Exception(
251
262
  "It appears you are not currently logged in to Weights & Biases. "
252
- "Please run `wandb login` in your terminal. "
263
+ "Please run `wandb login` in your terminal or `wandb.login()` in a notebook."
253
264
  "When prompted, you can obtain your API key by visiting wandb.ai/authorize."
254
265
  )
255
266
 
@@ -286,15 +297,9 @@ class WandbLogger:
286
297
  config["finished_at"]
287
298
  ).strftime("%Y-%m-%d %H:%M:%S")
288
299
  if config.get("hyperparameters"):
289
- hyperparameters = config.pop("hyperparameters")
290
- hyperparams = cls._unpack_hyperparameters(hyperparameters)
291
- if hyperparams is None:
292
- # If unpacking fails, log the object which will render as string
293
- config["hyperparameters"] = hyperparameters
294
- else:
295
- # nested rendering on hyperparameters
296
- config["hyperparameters"] = hyperparams
297
-
300
+ config["hyperparameters"] = cls.sanitize(config["hyperparameters"])
301
+ if config.get("error"):
302
+ config["error"] = cls.sanitize(config["error"])
298
303
  return config
299
304
 
300
305
  @classmethod
@@ -305,30 +310,53 @@ class WandbLogger:
305
310
  try:
306
311
  hyperparams["n_epochs"] = hyperparameters.n_epochs
307
312
  hyperparams["batch_size"] = hyperparameters.batch_size
308
- hyperparams[
309
- "learning_rate_multiplier"
310
- ] = hyperparameters.learning_rate_multiplier
313
+ hyperparams["learning_rate_multiplier"] = (
314
+ hyperparameters.learning_rate_multiplier
315
+ )
311
316
  except Exception:
312
317
  # If unpacking fails, return the object to be logged as config
313
318
  return None
314
319
 
315
320
  return hyperparams
316
321
 
322
+ @staticmethod
323
+ def sanitize(input: Any) -> Union[Dict, List, str]:
324
+ valid_types = [bool, int, float, str]
325
+ if isinstance(input, (Hyperparameters, Error)):
326
+ return dict(input)
327
+ if isinstance(input, dict):
328
+ return {
329
+ k: v if type(v) in valid_types else str(v) for k, v in input.items()
330
+ }
331
+ elif isinstance(input, list):
332
+ return [v if type(v) in valid_types else str(v) for v in input]
333
+ else:
334
+ return str(input)
335
+
317
336
  @classmethod
318
337
  def _log_artifacts(
319
- cls, fine_tune: FineTuningJob, project: str, entity: Optional[str]
338
+ cls,
339
+ fine_tune: FineTuningJob,
340
+ project: str,
341
+ entity: Optional[str],
342
+ log_datasets: bool,
343
+ overwrite: bool,
320
344
  ) -> None:
321
- # training/validation files
322
- training_file = fine_tune.training_file if fine_tune.training_file else None
323
- validation_file = (
324
- fine_tune.validation_file if fine_tune.validation_file else None
325
- )
326
- for file, prefix, artifact_type in (
327
- (training_file, "train", "training_files"),
328
- (validation_file, "valid", "validation_files"),
329
- ):
330
- if file is not None:
331
- cls._log_artifact_inputs(file, prefix, artifact_type, project, entity)
345
+ if log_datasets:
346
+ wandb.termlog("Logging training/validation files...")
347
+ # training/validation files
348
+ training_file = fine_tune.training_file if fine_tune.training_file else None
349
+ validation_file = (
350
+ fine_tune.validation_file if fine_tune.validation_file else None
351
+ )
352
+ for file, prefix, artifact_type in (
353
+ (training_file, "train", "training_files"),
354
+ (validation_file, "valid", "validation_files"),
355
+ ):
356
+ if file is not None:
357
+ cls._log_artifact_inputs(
358
+ file, prefix, artifact_type, project, entity, overwrite
359
+ )
332
360
 
333
361
  # fine-tune details
334
362
  fine_tune_id = fine_tune.id
@@ -337,9 +365,14 @@ class WandbLogger:
337
365
  type="model",
338
366
  metadata=dict(fine_tune),
339
367
  )
368
+
340
369
  with artifact.new_file("model_metadata.json", mode="w", encoding="utf-8") as f:
341
370
  dict_fine_tune = dict(fine_tune)
342
- dict_fine_tune["hyperparameters"] = dict(dict_fine_tune["hyperparameters"])
371
+ dict_fine_tune["hyperparameters"] = cls.sanitize(
372
+ dict_fine_tune["hyperparameters"]
373
+ )
374
+ dict_fine_tune["error"] = cls.sanitize(dict_fine_tune["error"])
375
+ dict_fine_tune = cls.sanitize(dict_fine_tune)
343
376
  json.dump(dict_fine_tune, f, indent=2)
344
377
  cls._run.log_artifact(
345
378
  artifact,
@@ -354,6 +387,7 @@ class WandbLogger:
354
387
  artifact_type: str,
355
388
  project: str,
356
389
  entity: Optional[str],
390
+ overwrite: bool,
357
391
  ) -> None:
358
392
  # get input artifact
359
393
  artifact_name = f"{prefix}-{file_id}"
@@ -366,23 +400,26 @@ class WandbLogger:
366
400
  artifact = cls._get_wandb_artifact(artifact_path)
367
401
 
368
402
  # create artifact if file not already logged previously
369
- if artifact is None:
403
+ if artifact is None or overwrite:
370
404
  # get file content
371
405
  try:
372
- file_content = cls.openai_client.files.retrieve_content(file_id=file_id)
406
+ file_content = cls.openai_client.files.content(file_id=file_id)
373
407
  except openai.NotFoundError:
374
408
  wandb.termerror(
375
- f"File {file_id} could not be retrieved. Make sure you are allowed to download training/validation files"
409
+ f"File {file_id} could not be retrieved. Make sure you have OpenAI permissions to download training/validation files"
376
410
  )
377
411
  return
378
412
 
379
413
  artifact = wandb.Artifact(artifact_name, type=artifact_type)
380
- with artifact.new_file(file_id, mode="w", encoding="utf-8") as f:
381
- f.write(file_content)
414
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
415
+ tmp_file.write(file_content.content)
416
+ tmp_file_path = tmp_file.name
417
+ artifact.add_file(tmp_file_path, file_id)
418
+ os.unlink(tmp_file_path)
382
419
 
383
420
  # create a Table
384
421
  try:
385
- table, n_items = cls._make_table(file_content)
422
+ table, n_items = cls._make_table(file_content.text)
386
423
  # Add table to the artifact.
387
424
  artifact.add(table, file_id)
388
425
  # Add the same table to the workspace.
@@ -390,9 +427,9 @@ class WandbLogger:
390
427
  # Update the run config and artifact metadata
391
428
  cls._run.config.update({f"n_{prefix}": n_items})
392
429
  artifact.metadata["items"] = n_items
393
- except Exception:
430
+ except Exception as e:
394
431
  wandb.termerror(
395
- f"File {file_id} could not be read as a valid JSON file"
432
+ f"Issue saving {file_id} as a Table to Artifacts, exception:\n '{e}'"
396
433
  )
397
434
  else:
398
435
  # log number of items
@@ -26,7 +26,7 @@ from PIL import Image
26
26
 
27
27
  import wandb
28
28
  from wandb import util
29
- from wandb.plots.utils import test_missing
29
+ from wandb.plot.utils import test_missing
30
30
  from wandb.sdk.lib import telemetry as wb_telemetry
31
31
 
32
32
 
wandb/old/summary.py CHANGED
@@ -61,7 +61,7 @@ class SummarySubDict:
61
61
  This should only be implemented by the "_root" child class.
62
62
 
63
63
  We pass the child_dict so the item can be set on it or not as
64
- appropriate. Returning None for a nonexistant path wouldn't be
64
+ appropriate. Returning None for a nonexistent path wouldn't be
65
65
  distinguishable from that path being set to the value None.
66
66
  """
67
67
  raise NotImplementedError
@@ -48,7 +48,7 @@ def confusion_matrix(
48
48
 
49
49
  assert (probs is None or preds is None) and not (
50
50
  probs is None and preds is None
51
- ), "Must provide probabilties or predictions but not both to confusion matrix"
51
+ ), "Must provide probabilities or predictions but not both to confusion matrix"
52
52
 
53
53
  if probs is not None:
54
54
  preds = np.argmax(probs, axis=1).tolist()
wandb/plot/pr_curve.py CHANGED
@@ -2,7 +2,8 @@ from typing import Optional
2
2
 
3
3
  import wandb
4
4
  from wandb import util
5
- from wandb.plots.utils import test_missing, test_types
5
+
6
+ from .utils import test_missing, test_types
6
7
 
7
8
 
8
9
  def pr_curve(
wandb/plot/roc_curve.py CHANGED
@@ -2,7 +2,8 @@ from typing import Optional
2
2
 
3
3
  import wandb
4
4
  from wandb import util
5
- from wandb.plots.utils import test_missing, test_types
5
+
6
+ from .utils import test_missing, test_types
6
7
 
7
8
 
8
9
  def roc_curve(
@@ -1,21 +1,9 @@
1
- from collections.abc import Iterable, Sequence
1
+ from typing import Iterable, Sequence
2
2
 
3
3
  import wandb
4
4
  from wandb import util
5
- from wandb.sdk.lib import deprecate
6
5
 
7
6
 
8
- def deprecation_notice() -> None:
9
- deprecate.deprecate(
10
- field_name=deprecate.Deprecated.plots,
11
- warning_message=(
12
- "wandb.plots.* functions are deprecated and will be removed in a future release. "
13
- "Please use wandb.plot.* instead."
14
- ),
15
- )
16
-
17
-
18
- # Test assumptions for plotting parameters and datasets
19
7
  def test_missing(**kwargs):
20
8
  np = util.get_module("numpy", required="Logging plots requires numpy")
21
9
  pd = util.get_module("pandas", required="Logging dataframes requires pandas")
@@ -25,7 +13,7 @@ def test_missing(**kwargs):
25
13
  for k, v in kwargs.items():
26
14
  # Missing/empty params/datapoint arrays
27
15
  if v is None:
28
- wandb.termerror("%s is None. Please try again." % (k))
16
+ wandb.termerror("{} is None. Please try again.".format(k))
29
17
  test_passed = False
30
18
  if (k == "X") or (k == "X_test"):
31
19
  if isinstance(v, scipy.sparse.csr.csr_matrix):
@@ -64,8 +52,8 @@ def test_missing(**kwargs):
64
52
  )
65
53
  if non_nums > 0:
66
54
  wandb.termerror(
67
- "%s contains values that are not numbers. Please vectorize, label encode or one hot encode %s and call the plotting function again."
68
- % (k, k)
55
+ f"{k} contains values that are not numbers. Please vectorize, "
56
+ f"label encode or one hot encode {k} and call the plotting function again."
69
57
  )
70
58
  test_passed = False
71
59
  return test_passed
@@ -73,8 +61,8 @@ def test_missing(**kwargs):
73
61
 
74
62
  def test_fitted(model):
75
63
  np = util.get_module("numpy", required="Logging plots requires numpy")
76
- pd = util.get_module("pandas", required="Logging dataframes requires pandas")
77
- scipy = util.get_module("scipy", required="Logging scipy matrices requires scipy")
64
+ _ = util.get_module("pandas", required="Logging dataframes requires pandas")
65
+ _ = util.get_module("scipy", required="Logging scipy matrices requires scipy")
78
66
  scikit_utils = util.get_module(
79
67
  "sklearn.utils",
80
68
  required="roc requires the scikit utils submodule, install with `pip install scikit-learn`",
@@ -120,7 +108,7 @@ def test_fitted(model):
120
108
 
121
109
 
122
110
  def encode_labels(df):
123
- pd = util.get_module("pandas", required="Logging dataframes requires pandas")
111
+ _ = util.get_module("pandas", required="Logging dataframes requires pandas")
124
112
  preprocessing = util.get_module(
125
113
  "sklearn.preprocessing",
126
114
  "roc requires the scikit preprocessing submodule, install with `pip install scikit-learn`",
@@ -137,7 +125,7 @@ def encode_labels(df):
137
125
  def test_types(**kwargs):
138
126
  np = util.get_module("numpy", required="Logging plots requires numpy")
139
127
  pd = util.get_module("pandas", required="Logging dataframes requires pandas")
140
- scipy = util.get_module("scipy", required="Logging scipy matrices requires scipy")
128
+ _ = util.get_module("scipy", required="Logging scipy matrices requires scipy")
141
129
 
142
130
  base = util.get_module(
143
131
  "sklearn.base",
@@ -171,25 +159,25 @@ def test_types(**kwargs):
171
159
  list,
172
160
  ),
173
161
  ):
174
- wandb.termerror("%s is not an array. Please try again." % (k))
162
+ wandb.termerror("{} is not an array. Please try again.".format(k))
175
163
  test_passed = False
176
164
  # check for classifier types
177
165
  if k == "model":
178
166
  if (not base.is_classifier(v)) and (not base.is_regressor(v)):
179
167
  wandb.termerror(
180
- "%s is not a classifier or regressor. Please try again." % (k)
168
+ "{} is not a classifier or regressor. Please try again.".format(k)
181
169
  )
182
170
  test_passed = False
183
171
  elif k == "clf" or k == "binary_clf":
184
172
  if not (base.is_classifier(v)):
185
- wandb.termerror("%s is not a classifier. Please try again." % (k))
173
+ wandb.termerror("{} is not a classifier. Please try again.".format(k))
186
174
  test_passed = False
187
175
  elif k == "regressor":
188
176
  if not base.is_regressor(v):
189
- wandb.termerror("%s is not a regressor. Please try again." % (k))
177
+ wandb.termerror("{} is not a regressor. Please try again.".format(k))
190
178
  test_passed = False
191
179
  elif k == "clusterer":
192
180
  if not (getattr(v, "_estimator_type", None) == "clusterer"):
193
- wandb.termerror("%s is not a clusterer. Please try again." % (k))
181
+ wandb.termerror("{} is not a clusterer. Please try again.".format(k))
194
182
  test_passed = False
195
183
  return test_passed