wandb 0.16.5__py3-none-any.whl → 0.17.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- package_readme.md +95 -0
- wandb/__init__.py +2 -3
- wandb/agents/pyagent.py +0 -1
- wandb/analytics/sentry.py +2 -1
- wandb/apis/importers/internals/internal.py +0 -1
- wandb/apis/importers/internals/protocols.py +30 -56
- wandb/apis/importers/mlflow.py +13 -26
- wandb/apis/importers/wandb.py +8 -14
- wandb/apis/internal.py +0 -3
- wandb/apis/public/api.py +55 -3
- wandb/apis/public/artifacts.py +1 -0
- wandb/apis/public/files.py +1 -0
- wandb/apis/public/history.py +1 -0
- wandb/apis/public/jobs.py +17 -4
- wandb/apis/public/projects.py +1 -0
- wandb/apis/public/reports.py +1 -0
- wandb/apis/public/runs.py +15 -17
- wandb/apis/public/sweeps.py +1 -0
- wandb/apis/public/teams.py +1 -0
- wandb/apis/public/users.py +1 -0
- wandb/apis/reports/v1/_blocks.py +3 -7
- wandb/apis/reports/v2/gql.py +1 -0
- wandb/apis/reports/v2/interface.py +3 -4
- wandb/apis/reports/v2/internal.py +5 -8
- wandb/cli/cli.py +95 -22
- wandb/data_types.py +9 -6
- wandb/docker/__init__.py +1 -1
- wandb/env.py +38 -8
- wandb/errors/__init__.py +5 -0
- wandb/errors/term.py +10 -2
- wandb/filesync/step_checksum.py +1 -4
- wandb/filesync/step_prepare.py +4 -24
- wandb/filesync/step_upload.py +4 -106
- wandb/filesync/upload_job.py +0 -76
- wandb/integration/catboost/catboost.py +1 -1
- wandb/integration/fastai/__init__.py +1 -0
- wandb/integration/huggingface/resolver.py +2 -2
- wandb/integration/keras/__init__.py +1 -0
- wandb/integration/keras/callbacks/metrics_logger.py +1 -1
- wandb/integration/keras/keras.py +7 -7
- wandb/integration/langchain/wandb_tracer.py +1 -0
- wandb/integration/lightning/fabric/logger.py +1 -3
- wandb/integration/metaflow/metaflow.py +41 -6
- wandb/integration/openai/fine_tuning.py +77 -40
- wandb/integration/prodigy/prodigy.py +1 -1
- wandb/old/summary.py +1 -1
- wandb/plot/confusion_matrix.py +1 -1
- wandb/plot/pr_curve.py +2 -1
- wandb/plot/roc_curve.py +2 -1
- wandb/{plots → plot}/utils.py +13 -25
- wandb/proto/v3/wandb_internal_pb2.py +364 -332
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +322 -316
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/wandb_deprecated.py +7 -1
- wandb/proto/wandb_internal_codegen.py +3 -29
- wandb/sdk/artifacts/artifact.py +51 -20
- wandb/sdk/artifacts/artifact_download_logger.py +1 -0
- wandb/sdk/artifacts/artifact_file_cache.py +18 -4
- wandb/sdk/artifacts/artifact_instance_cache.py +1 -0
- wandb/sdk/artifacts/artifact_manifest.py +1 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +7 -3
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
- wandb/sdk/artifacts/artifact_saver.py +18 -27
- wandb/sdk/artifacts/artifact_state.py +1 -0
- wandb/sdk/artifacts/artifact_ttl.py +1 -0
- wandb/sdk/artifacts/exceptions.py +1 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +13 -18
- wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +5 -3
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +3 -42
- wandb/sdk/artifacts/storage_policy.py +2 -12
- wandb/sdk/data_types/_dtypes.py +8 -8
- wandb/sdk/data_types/base_types/media.py +3 -6
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +3 -1
- wandb/sdk/data_types/image.py +1 -1
- wandb/sdk/data_types/video.py +1 -1
- wandb/sdk/integration_utils/auto_logging.py +5 -6
- wandb/sdk/integration_utils/data_logging.py +10 -6
- wandb/sdk/interface/interface.py +86 -38
- wandb/sdk/interface/interface_shared.py +7 -13
- wandb/sdk/internal/datastore.py +1 -1
- wandb/sdk/internal/file_pusher.py +2 -5
- wandb/sdk/internal/file_stream.py +5 -18
- wandb/sdk/internal/handler.py +18 -2
- wandb/sdk/internal/internal.py +0 -1
- wandb/sdk/internal/internal_api.py +1 -129
- wandb/sdk/internal/internal_util.py +0 -1
- wandb/sdk/internal/job_builder.py +159 -45
- wandb/sdk/internal/profiler.py +1 -0
- wandb/sdk/internal/progress.py +0 -28
- wandb/sdk/internal/run.py +1 -0
- wandb/sdk/internal/sender.py +1 -2
- wandb/sdk/internal/system/assets/gpu_amd.py +44 -44
- wandb/sdk/internal/system/assets/gpu_apple.py +56 -11
- wandb/sdk/internal/system/assets/interfaces.py +6 -8
- wandb/sdk/internal/system/assets/open_metrics.py +2 -2
- wandb/sdk/internal/system/assets/trainium.py +1 -3
- wandb/sdk/launch/__init__.py +9 -1
- wandb/sdk/launch/_launch.py +9 -24
- wandb/sdk/launch/_launch_add.py +1 -3
- wandb/sdk/launch/_project_spec.py +188 -241
- wandb/sdk/launch/agent/agent.py +115 -48
- wandb/sdk/launch/agent/config.py +80 -14
- wandb/sdk/launch/builder/abstract.py +69 -1
- wandb/sdk/launch/builder/build.py +156 -555
- wandb/sdk/launch/builder/context_manager.py +235 -0
- wandb/sdk/launch/builder/docker_builder.py +8 -23
- wandb/sdk/launch/builder/kaniko_builder.py +161 -159
- wandb/sdk/launch/builder/noop.py +1 -0
- wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
- wandb/sdk/launch/create_job.py +68 -63
- wandb/sdk/launch/environment/abstract.py +1 -0
- wandb/sdk/launch/environment/gcp_environment.py +1 -0
- wandb/sdk/launch/environment/local_environment.py +1 -0
- wandb/sdk/launch/inputs/files.py +148 -0
- wandb/sdk/launch/inputs/internal.py +217 -0
- wandb/sdk/launch/inputs/manage.py +95 -0
- wandb/sdk/launch/loader.py +1 -0
- wandb/sdk/launch/registry/abstract.py +1 -0
- wandb/sdk/launch/registry/azure_container_registry.py +1 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +1 -0
- wandb/sdk/launch/registry/google_artifact_registry.py +2 -1
- wandb/sdk/launch/registry/local_registry.py +1 -0
- wandb/sdk/launch/runner/abstract.py +1 -0
- wandb/sdk/launch/runner/kubernetes_monitor.py +4 -1
- wandb/sdk/launch/runner/kubernetes_runner.py +9 -10
- wandb/sdk/launch/runner/local_container.py +2 -3
- wandb/sdk/launch/runner/local_process.py +8 -29
- wandb/sdk/launch/runner/sagemaker_runner.py +21 -20
- wandb/sdk/launch/runner/vertex_runner.py +8 -7
- wandb/sdk/launch/sweeps/scheduler.py +7 -4
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +3 -3
- wandb/sdk/launch/utils.py +33 -140
- wandb/sdk/lib/_settings_toposort_generated.py +1 -5
- wandb/sdk/lib/fsm.py +8 -12
- wandb/sdk/lib/gitlib.py +4 -4
- wandb/sdk/lib/import_hooks.py +1 -1
- wandb/sdk/lib/lazyloader.py +0 -1
- wandb/sdk/lib/proto_util.py +23 -2
- wandb/sdk/lib/redirect.py +19 -14
- wandb/sdk/lib/retry.py +3 -2
- wandb/sdk/lib/run_moment.py +7 -1
- wandb/sdk/lib/tracelog.py +1 -1
- wandb/sdk/service/service.py +19 -16
- wandb/sdk/verify/verify.py +2 -1
- wandb/sdk/wandb_init.py +16 -63
- wandb/sdk/wandb_manager.py +2 -2
- wandb/sdk/wandb_require.py +5 -0
- wandb/sdk/wandb_run.py +164 -90
- wandb/sdk/wandb_settings.py +2 -48
- wandb/sdk/wandb_setup.py +1 -1
- wandb/sklearn/__init__.py +1 -0
- wandb/sklearn/plot/__init__.py +1 -0
- wandb/sklearn/plot/classifier.py +11 -12
- wandb/sklearn/plot/clusterer.py +2 -1
- wandb/sklearn/plot/regressor.py +1 -0
- wandb/sklearn/plot/shared.py +1 -0
- wandb/sklearn/utils.py +1 -0
- wandb/testing/relay.py +4 -4
- wandb/trigger.py +1 -0
- wandb/util.py +67 -54
- wandb/wandb_controller.py +2 -3
- wandb/wandb_torch.py +1 -2
- {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/METADATA +67 -70
- {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/RECORD +178 -188
- {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/WHEEL +1 -2
- wandb/bin/apple_gpu_stats +0 -0
- wandb/catboost/__init__.py +0 -9
- wandb/fastai/__init__.py +0 -9
- wandb/keras/__init__.py +0 -18
- wandb/lightgbm/__init__.py +0 -9
- wandb/plots/__init__.py +0 -6
- wandb/plots/explain_text.py +0 -36
- wandb/plots/heatmap.py +0 -81
- wandb/plots/named_entity.py +0 -43
- wandb/plots/part_of_speech.py +0 -50
- wandb/plots/plot_definitions.py +0 -768
- wandb/plots/precision_recall.py +0 -121
- wandb/plots/roc.py +0 -103
- wandb/sacred/__init__.py +0 -3
- wandb/xgboost/__init__.py +0 -9
- wandb-0.16.5.dist-info/top_level.txt +0 -1
- {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.16.5.dist-info → wandb-0.17.0.dist-info/licenses}/LICENSE +0 -0
@@ -1,9 +1,11 @@
|
|
1
1
|
import datetime
|
2
2
|
import io
|
3
3
|
import json
|
4
|
+
import os
|
4
5
|
import re
|
6
|
+
import tempfile
|
5
7
|
import time
|
6
|
-
from typing import Any, Dict, Optional, Tuple
|
8
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
7
9
|
|
8
10
|
import wandb
|
9
11
|
from wandb import util
|
@@ -26,7 +28,10 @@ if parse_version(openai.__version__) < parse_version("1.0.1"):
|
|
26
28
|
|
27
29
|
from openai import OpenAI # noqa: E402
|
28
30
|
from openai.types.fine_tuning import FineTuningJob # noqa: E402
|
29
|
-
from openai.types.fine_tuning.fine_tuning_job import
|
31
|
+
from openai.types.fine_tuning.fine_tuning_job import ( # noqa: E402
|
32
|
+
Error,
|
33
|
+
Hyperparameters,
|
34
|
+
)
|
30
35
|
|
31
36
|
np = util.get_module(
|
32
37
|
name="numpy",
|
@@ -59,6 +64,7 @@ class WandbLogger:
|
|
59
64
|
entity: Optional[str] = None,
|
60
65
|
overwrite: bool = False,
|
61
66
|
wait_for_job_success: bool = True,
|
67
|
+
log_datasets: bool = True,
|
62
68
|
**kwargs_wandb_init: Dict[str, Any],
|
63
69
|
) -> str:
|
64
70
|
"""Sync fine-tunes to Weights & Biases.
|
@@ -150,6 +156,7 @@ class WandbLogger:
|
|
150
156
|
entity,
|
151
157
|
overwrite,
|
152
158
|
show_individual_warnings,
|
159
|
+
log_datasets,
|
153
160
|
**kwargs_wandb_init,
|
154
161
|
)
|
155
162
|
|
@@ -160,11 +167,14 @@ class WandbLogger:
|
|
160
167
|
|
161
168
|
@classmethod
|
162
169
|
def _wait_for_job_success(cls, fine_tune: FineTuningJob) -> FineTuningJob:
|
163
|
-
wandb.termlog("Waiting for the OpenAI fine-tuning job to
|
170
|
+
wandb.termlog("Waiting for the OpenAI fine-tuning job to finish training...")
|
171
|
+
wandb.termlog(
|
172
|
+
"To avoid blocking, you can call `WandbLogger.sync` with `wait_for_job_success=False` after OpenAI training completes."
|
173
|
+
)
|
164
174
|
while True:
|
165
175
|
if fine_tune.status == "succeeded":
|
166
176
|
wandb.termlog(
|
167
|
-
"Fine-tuning finished, logging metrics, model metadata, and
|
177
|
+
"Fine-tuning finished, logging metrics, model metadata, and run metadata to Weights & Biases"
|
168
178
|
)
|
169
179
|
return fine_tune
|
170
180
|
if fine_tune.status == "failed":
|
@@ -190,6 +200,7 @@ class WandbLogger:
|
|
190
200
|
entity: Optional[str],
|
191
201
|
overwrite: bool,
|
192
202
|
show_individual_warnings: bool,
|
203
|
+
log_datasets: bool,
|
193
204
|
**kwargs_wandb_init: Dict[str, Any],
|
194
205
|
):
|
195
206
|
fine_tune_id = fine_tune.id
|
@@ -209,7 +220,7 @@ class WandbLogger:
|
|
209
220
|
# check results are present
|
210
221
|
try:
|
211
222
|
results_id = fine_tune.result_files[0]
|
212
|
-
results = cls.openai_client.files.
|
223
|
+
results = cls.openai_client.files.content(file_id=results_id).text
|
213
224
|
except openai.NotFoundError:
|
214
225
|
if show_individual_warnings:
|
215
226
|
wandb.termwarn(
|
@@ -233,7 +244,7 @@ class WandbLogger:
|
|
233
244
|
cls._run.summary["fine_tuned_model"] = fine_tuned_model
|
234
245
|
|
235
246
|
# training/validation files and fine-tune details
|
236
|
-
cls._log_artifacts(fine_tune, project, entity)
|
247
|
+
cls._log_artifacts(fine_tune, project, entity, log_datasets, overwrite)
|
237
248
|
|
238
249
|
# mark run as complete
|
239
250
|
cls._run.summary["status"] = "succeeded"
|
@@ -249,7 +260,7 @@ class WandbLogger:
|
|
249
260
|
else:
|
250
261
|
raise Exception(
|
251
262
|
"It appears you are not currently logged in to Weights & Biases. "
|
252
|
-
"Please run `wandb login` in your terminal. "
|
263
|
+
"Please run `wandb login` in your terminal or `wandb.login()` in a notebook."
|
253
264
|
"When prompted, you can obtain your API key by visiting wandb.ai/authorize."
|
254
265
|
)
|
255
266
|
|
@@ -286,15 +297,9 @@ class WandbLogger:
|
|
286
297
|
config["finished_at"]
|
287
298
|
).strftime("%Y-%m-%d %H:%M:%S")
|
288
299
|
if config.get("hyperparameters"):
|
289
|
-
hyperparameters =
|
290
|
-
|
291
|
-
|
292
|
-
# If unpacking fails, log the object which will render as string
|
293
|
-
config["hyperparameters"] = hyperparameters
|
294
|
-
else:
|
295
|
-
# nested rendering on hyperparameters
|
296
|
-
config["hyperparameters"] = hyperparams
|
297
|
-
|
300
|
+
config["hyperparameters"] = cls.sanitize(config["hyperparameters"])
|
301
|
+
if config.get("error"):
|
302
|
+
config["error"] = cls.sanitize(config["error"])
|
298
303
|
return config
|
299
304
|
|
300
305
|
@classmethod
|
@@ -305,30 +310,53 @@ class WandbLogger:
|
|
305
310
|
try:
|
306
311
|
hyperparams["n_epochs"] = hyperparameters.n_epochs
|
307
312
|
hyperparams["batch_size"] = hyperparameters.batch_size
|
308
|
-
hyperparams[
|
309
|
-
|
310
|
-
|
313
|
+
hyperparams["learning_rate_multiplier"] = (
|
314
|
+
hyperparameters.learning_rate_multiplier
|
315
|
+
)
|
311
316
|
except Exception:
|
312
317
|
# If unpacking fails, return the object to be logged as config
|
313
318
|
return None
|
314
319
|
|
315
320
|
return hyperparams
|
316
321
|
|
322
|
+
@staticmethod
|
323
|
+
def sanitize(input: Any) -> Union[Dict, List, str]:
|
324
|
+
valid_types = [bool, int, float, str]
|
325
|
+
if isinstance(input, (Hyperparameters, Error)):
|
326
|
+
return dict(input)
|
327
|
+
if isinstance(input, dict):
|
328
|
+
return {
|
329
|
+
k: v if type(v) in valid_types else str(v) for k, v in input.items()
|
330
|
+
}
|
331
|
+
elif isinstance(input, list):
|
332
|
+
return [v if type(v) in valid_types else str(v) for v in input]
|
333
|
+
else:
|
334
|
+
return str(input)
|
335
|
+
|
317
336
|
@classmethod
|
318
337
|
def _log_artifacts(
|
319
|
-
cls,
|
338
|
+
cls,
|
339
|
+
fine_tune: FineTuningJob,
|
340
|
+
project: str,
|
341
|
+
entity: Optional[str],
|
342
|
+
log_datasets: bool,
|
343
|
+
overwrite: bool,
|
320
344
|
) -> None:
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
fine_tune.
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
345
|
+
if log_datasets:
|
346
|
+
wandb.termlog("Logging training/validation files...")
|
347
|
+
# training/validation files
|
348
|
+
training_file = fine_tune.training_file if fine_tune.training_file else None
|
349
|
+
validation_file = (
|
350
|
+
fine_tune.validation_file if fine_tune.validation_file else None
|
351
|
+
)
|
352
|
+
for file, prefix, artifact_type in (
|
353
|
+
(training_file, "train", "training_files"),
|
354
|
+
(validation_file, "valid", "validation_files"),
|
355
|
+
):
|
356
|
+
if file is not None:
|
357
|
+
cls._log_artifact_inputs(
|
358
|
+
file, prefix, artifact_type, project, entity, overwrite
|
359
|
+
)
|
332
360
|
|
333
361
|
# fine-tune details
|
334
362
|
fine_tune_id = fine_tune.id
|
@@ -337,9 +365,14 @@ class WandbLogger:
|
|
337
365
|
type="model",
|
338
366
|
metadata=dict(fine_tune),
|
339
367
|
)
|
368
|
+
|
340
369
|
with artifact.new_file("model_metadata.json", mode="w", encoding="utf-8") as f:
|
341
370
|
dict_fine_tune = dict(fine_tune)
|
342
|
-
dict_fine_tune["hyperparameters"] =
|
371
|
+
dict_fine_tune["hyperparameters"] = cls.sanitize(
|
372
|
+
dict_fine_tune["hyperparameters"]
|
373
|
+
)
|
374
|
+
dict_fine_tune["error"] = cls.sanitize(dict_fine_tune["error"])
|
375
|
+
dict_fine_tune = cls.sanitize(dict_fine_tune)
|
343
376
|
json.dump(dict_fine_tune, f, indent=2)
|
344
377
|
cls._run.log_artifact(
|
345
378
|
artifact,
|
@@ -354,6 +387,7 @@ class WandbLogger:
|
|
354
387
|
artifact_type: str,
|
355
388
|
project: str,
|
356
389
|
entity: Optional[str],
|
390
|
+
overwrite: bool,
|
357
391
|
) -> None:
|
358
392
|
# get input artifact
|
359
393
|
artifact_name = f"{prefix}-{file_id}"
|
@@ -366,23 +400,26 @@ class WandbLogger:
|
|
366
400
|
artifact = cls._get_wandb_artifact(artifact_path)
|
367
401
|
|
368
402
|
# create artifact if file not already logged previously
|
369
|
-
if artifact is None:
|
403
|
+
if artifact is None or overwrite:
|
370
404
|
# get file content
|
371
405
|
try:
|
372
|
-
file_content = cls.openai_client.files.
|
406
|
+
file_content = cls.openai_client.files.content(file_id=file_id)
|
373
407
|
except openai.NotFoundError:
|
374
408
|
wandb.termerror(
|
375
|
-
f"File {file_id} could not be retrieved. Make sure you
|
409
|
+
f"File {file_id} could not be retrieved. Make sure you have OpenAI permissions to download training/validation files"
|
376
410
|
)
|
377
411
|
return
|
378
412
|
|
379
413
|
artifact = wandb.Artifact(artifact_name, type=artifact_type)
|
380
|
-
with
|
381
|
-
|
414
|
+
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
415
|
+
tmp_file.write(file_content.content)
|
416
|
+
tmp_file_path = tmp_file.name
|
417
|
+
artifact.add_file(tmp_file_path, file_id)
|
418
|
+
os.unlink(tmp_file_path)
|
382
419
|
|
383
420
|
# create a Table
|
384
421
|
try:
|
385
|
-
table, n_items = cls._make_table(file_content)
|
422
|
+
table, n_items = cls._make_table(file_content.text)
|
386
423
|
# Add table to the artifact.
|
387
424
|
artifact.add(table, file_id)
|
388
425
|
# Add the same table to the workspace.
|
@@ -390,9 +427,9 @@ class WandbLogger:
|
|
390
427
|
# Update the run config and artifact metadata
|
391
428
|
cls._run.config.update({f"n_{prefix}": n_items})
|
392
429
|
artifact.metadata["items"] = n_items
|
393
|
-
except Exception:
|
430
|
+
except Exception as e:
|
394
431
|
wandb.termerror(
|
395
|
-
f"
|
432
|
+
f"Issue saving {file_id} as a Table to Artifacts, exception:\n '{e}'"
|
396
433
|
)
|
397
434
|
else:
|
398
435
|
# log number of items
|
wandb/old/summary.py
CHANGED
@@ -61,7 +61,7 @@ class SummarySubDict:
|
|
61
61
|
This should only be implemented by the "_root" child class.
|
62
62
|
|
63
63
|
We pass the child_dict so the item can be set on it or not as
|
64
|
-
appropriate. Returning None for a
|
64
|
+
appropriate. Returning None for a nonexistent path wouldn't be
|
65
65
|
distinguishable from that path being set to the value None.
|
66
66
|
"""
|
67
67
|
raise NotImplementedError
|
wandb/plot/confusion_matrix.py
CHANGED
@@ -48,7 +48,7 @@ def confusion_matrix(
|
|
48
48
|
|
49
49
|
assert (probs is None or preds is None) and not (
|
50
50
|
probs is None and preds is None
|
51
|
-
), "Must provide
|
51
|
+
), "Must provide probabilities or predictions but not both to confusion matrix"
|
52
52
|
|
53
53
|
if probs is not None:
|
54
54
|
preds = np.argmax(probs, axis=1).tolist()
|
wandb/plot/pr_curve.py
CHANGED
wandb/plot/roc_curve.py
CHANGED
wandb/{plots → plot}/utils.py
RENAMED
@@ -1,21 +1,9 @@
|
|
1
|
-
from
|
1
|
+
from typing import Iterable, Sequence
|
2
2
|
|
3
3
|
import wandb
|
4
4
|
from wandb import util
|
5
|
-
from wandb.sdk.lib import deprecate
|
6
5
|
|
7
6
|
|
8
|
-
def deprecation_notice() -> None:
|
9
|
-
deprecate.deprecate(
|
10
|
-
field_name=deprecate.Deprecated.plots,
|
11
|
-
warning_message=(
|
12
|
-
"wandb.plots.* functions are deprecated and will be removed in a future release. "
|
13
|
-
"Please use wandb.plot.* instead."
|
14
|
-
),
|
15
|
-
)
|
16
|
-
|
17
|
-
|
18
|
-
# Test assumptions for plotting parameters and datasets
|
19
7
|
def test_missing(**kwargs):
|
20
8
|
np = util.get_module("numpy", required="Logging plots requires numpy")
|
21
9
|
pd = util.get_module("pandas", required="Logging dataframes requires pandas")
|
@@ -25,7 +13,7 @@ def test_missing(**kwargs):
|
|
25
13
|
for k, v in kwargs.items():
|
26
14
|
# Missing/empty params/datapoint arrays
|
27
15
|
if v is None:
|
28
|
-
wandb.termerror("
|
16
|
+
wandb.termerror("{} is None. Please try again.".format(k))
|
29
17
|
test_passed = False
|
30
18
|
if (k == "X") or (k == "X_test"):
|
31
19
|
if isinstance(v, scipy.sparse.csr.csr_matrix):
|
@@ -64,8 +52,8 @@ def test_missing(**kwargs):
|
|
64
52
|
)
|
65
53
|
if non_nums > 0:
|
66
54
|
wandb.termerror(
|
67
|
-
"
|
68
|
-
|
55
|
+
f"{k} contains values that are not numbers. Please vectorize, "
|
56
|
+
f"label encode or one hot encode {k} and call the plotting function again."
|
69
57
|
)
|
70
58
|
test_passed = False
|
71
59
|
return test_passed
|
@@ -73,8 +61,8 @@ def test_missing(**kwargs):
|
|
73
61
|
|
74
62
|
def test_fitted(model):
|
75
63
|
np = util.get_module("numpy", required="Logging plots requires numpy")
|
76
|
-
|
77
|
-
|
64
|
+
_ = util.get_module("pandas", required="Logging dataframes requires pandas")
|
65
|
+
_ = util.get_module("scipy", required="Logging scipy matrices requires scipy")
|
78
66
|
scikit_utils = util.get_module(
|
79
67
|
"sklearn.utils",
|
80
68
|
required="roc requires the scikit utils submodule, install with `pip install scikit-learn`",
|
@@ -120,7 +108,7 @@ def test_fitted(model):
|
|
120
108
|
|
121
109
|
|
122
110
|
def encode_labels(df):
|
123
|
-
|
111
|
+
_ = util.get_module("pandas", required="Logging dataframes requires pandas")
|
124
112
|
preprocessing = util.get_module(
|
125
113
|
"sklearn.preprocessing",
|
126
114
|
"roc requires the scikit preprocessing submodule, install with `pip install scikit-learn`",
|
@@ -137,7 +125,7 @@ def encode_labels(df):
|
|
137
125
|
def test_types(**kwargs):
|
138
126
|
np = util.get_module("numpy", required="Logging plots requires numpy")
|
139
127
|
pd = util.get_module("pandas", required="Logging dataframes requires pandas")
|
140
|
-
|
128
|
+
_ = util.get_module("scipy", required="Logging scipy matrices requires scipy")
|
141
129
|
|
142
130
|
base = util.get_module(
|
143
131
|
"sklearn.base",
|
@@ -171,25 +159,25 @@ def test_types(**kwargs):
|
|
171
159
|
list,
|
172
160
|
),
|
173
161
|
):
|
174
|
-
wandb.termerror("
|
162
|
+
wandb.termerror("{} is not an array. Please try again.".format(k))
|
175
163
|
test_passed = False
|
176
164
|
# check for classifier types
|
177
165
|
if k == "model":
|
178
166
|
if (not base.is_classifier(v)) and (not base.is_regressor(v)):
|
179
167
|
wandb.termerror(
|
180
|
-
"
|
168
|
+
"{} is not a classifier or regressor. Please try again.".format(k)
|
181
169
|
)
|
182
170
|
test_passed = False
|
183
171
|
elif k == "clf" or k == "binary_clf":
|
184
172
|
if not (base.is_classifier(v)):
|
185
|
-
wandb.termerror("
|
173
|
+
wandb.termerror("{} is not a classifier. Please try again.".format(k))
|
186
174
|
test_passed = False
|
187
175
|
elif k == "regressor":
|
188
176
|
if not base.is_regressor(v):
|
189
|
-
wandb.termerror("
|
177
|
+
wandb.termerror("{} is not a regressor. Please try again.".format(k))
|
190
178
|
test_passed = False
|
191
179
|
elif k == "clusterer":
|
192
180
|
if not (getattr(v, "_estimator_type", None) == "clusterer"):
|
193
|
-
wandb.termerror("
|
181
|
+
wandb.termerror("{} is not a clusterer. Please try again.".format(k))
|
194
182
|
test_passed = False
|
195
183
|
return test_passed
|