wandb 0.18.2__py3-none-musllinux_1_2_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package_readme.md +89 -0
- wandb/__init__.py +245 -0
- wandb/__init__.pyi +1139 -0
- wandb/__main__.py +3 -0
- wandb/_globals.py +19 -0
- wandb/agents/__init__.py +0 -0
- wandb/agents/pyagent.py +363 -0
- wandb/analytics/__init__.py +3 -0
- wandb/analytics/sentry.py +266 -0
- wandb/apis/__init__.py +48 -0
- wandb/apis/attrs.py +40 -0
- wandb/apis/importers/__init__.py +1 -0
- wandb/apis/importers/internals/internal.py +385 -0
- wandb/apis/importers/internals/protocols.py +99 -0
- wandb/apis/importers/internals/util.py +78 -0
- wandb/apis/importers/mlflow.py +254 -0
- wandb/apis/importers/validation.py +108 -0
- wandb/apis/importers/wandb.py +1603 -0
- wandb/apis/internal.py +232 -0
- wandb/apis/normalize.py +89 -0
- wandb/apis/paginator.py +81 -0
- wandb/apis/public/__init__.py +34 -0
- wandb/apis/public/api.py +1305 -0
- wandb/apis/public/artifacts.py +1090 -0
- wandb/apis/public/const.py +4 -0
- wandb/apis/public/files.py +195 -0
- wandb/apis/public/history.py +149 -0
- wandb/apis/public/jobs.py +659 -0
- wandb/apis/public/projects.py +154 -0
- wandb/apis/public/query_generator.py +166 -0
- wandb/apis/public/reports.py +469 -0
- wandb/apis/public/runs.py +914 -0
- wandb/apis/public/sweeps.py +240 -0
- wandb/apis/public/teams.py +198 -0
- wandb/apis/public/users.py +136 -0
- wandb/apis/reports/__init__.py +1 -0
- wandb/apis/reports/v1/__init__.py +8 -0
- wandb/apis/reports/v2/__init__.py +8 -0
- wandb/apis/workspaces/__init__.py +8 -0
- wandb/beta/workflows.py +288 -0
- wandb/bin/nvidia_gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/__init__.py +0 -0
- wandb/cli/cli.py +3004 -0
- wandb/data_types.py +63 -0
- wandb/docker/__init__.py +342 -0
- wandb/docker/auth.py +436 -0
- wandb/docker/wandb-entrypoint.sh +33 -0
- wandb/docker/www_authenticate.py +94 -0
- wandb/env.py +514 -0
- wandb/errors/__init__.py +17 -0
- wandb/errors/errors.py +37 -0
- wandb/errors/term.py +103 -0
- wandb/errors/util.py +57 -0
- wandb/errors/warnings.py +2 -0
- wandb/filesync/__init__.py +0 -0
- wandb/filesync/dir_watcher.py +403 -0
- wandb/filesync/stats.py +100 -0
- wandb/filesync/step_checksum.py +142 -0
- wandb/filesync/step_prepare.py +179 -0
- wandb/filesync/step_upload.py +290 -0
- wandb/filesync/upload_job.py +142 -0
- wandb/integration/__init__.py +0 -0
- wandb/integration/catboost/__init__.py +5 -0
- wandb/integration/catboost/catboost.py +178 -0
- wandb/integration/cohere/__init__.py +3 -0
- wandb/integration/cohere/cohere.py +21 -0
- wandb/integration/cohere/resolver.py +347 -0
- wandb/integration/diffusers/__init__.py +3 -0
- wandb/integration/diffusers/autologger.py +76 -0
- wandb/integration/diffusers/pipeline_resolver.py +50 -0
- wandb/integration/diffusers/resolvers/__init__.py +9 -0
- wandb/integration/diffusers/resolvers/multimodal.py +882 -0
- wandb/integration/diffusers/resolvers/utils.py +102 -0
- wandb/integration/fastai/__init__.py +249 -0
- wandb/integration/gym/__init__.py +105 -0
- wandb/integration/huggingface/__init__.py +3 -0
- wandb/integration/huggingface/huggingface.py +18 -0
- wandb/integration/huggingface/resolver.py +213 -0
- wandb/integration/keras/__init__.py +11 -0
- wandb/integration/keras/callbacks/__init__.py +5 -0
- wandb/integration/keras/callbacks/metrics_logger.py +136 -0
- wandb/integration/keras/callbacks/model_checkpoint.py +195 -0
- wandb/integration/keras/callbacks/tables_builder.py +226 -0
- wandb/integration/keras/keras.py +1091 -0
- wandb/integration/kfp/__init__.py +6 -0
- wandb/integration/kfp/helpers.py +28 -0
- wandb/integration/kfp/kfp_patch.py +324 -0
- wandb/integration/kfp/wandb_logging.py +182 -0
- wandb/integration/langchain/__init__.py +3 -0
- wandb/integration/langchain/wandb_tracer.py +48 -0
- wandb/integration/lightgbm/__init__.py +239 -0
- wandb/integration/lightning/__init__.py +0 -0
- wandb/integration/lightning/fabric/__init__.py +3 -0
- wandb/integration/lightning/fabric/logger.py +762 -0
- wandb/integration/magic.py +556 -0
- wandb/integration/metaflow/__init__.py +3 -0
- wandb/integration/metaflow/metaflow.py +383 -0
- wandb/integration/openai/__init__.py +3 -0
- wandb/integration/openai/fine_tuning.py +480 -0
- wandb/integration/openai/openai.py +22 -0
- wandb/integration/openai/resolver.py +240 -0
- wandb/integration/prodigy/__init__.py +3 -0
- wandb/integration/prodigy/prodigy.py +299 -0
- wandb/integration/sacred/__init__.py +117 -0
- wandb/integration/sagemaker/__init__.py +12 -0
- wandb/integration/sagemaker/auth.py +28 -0
- wandb/integration/sagemaker/config.py +49 -0
- wandb/integration/sagemaker/files.py +3 -0
- wandb/integration/sagemaker/resources.py +34 -0
- wandb/integration/sb3/__init__.py +3 -0
- wandb/integration/sb3/sb3.py +153 -0
- wandb/integration/sklearn/__init__.py +37 -0
- wandb/integration/sklearn/calculate/__init__.py +32 -0
- wandb/integration/sklearn/calculate/calibration_curves.py +125 -0
- wandb/integration/sklearn/calculate/class_proportions.py +68 -0
- wandb/integration/sklearn/calculate/confusion_matrix.py +93 -0
- wandb/integration/sklearn/calculate/decision_boundaries.py +40 -0
- wandb/integration/sklearn/calculate/elbow_curve.py +55 -0
- wandb/integration/sklearn/calculate/feature_importances.py +67 -0
- wandb/integration/sklearn/calculate/learning_curve.py +64 -0
- wandb/integration/sklearn/calculate/outlier_candidates.py +69 -0
- wandb/integration/sklearn/calculate/residuals.py +86 -0
- wandb/integration/sklearn/calculate/silhouette.py +118 -0
- wandb/integration/sklearn/calculate/summary_metrics.py +62 -0
- wandb/integration/sklearn/plot/__init__.py +35 -0
- wandb/integration/sklearn/plot/classifier.py +329 -0
- wandb/integration/sklearn/plot/clusterer.py +146 -0
- wandb/integration/sklearn/plot/regressor.py +121 -0
- wandb/integration/sklearn/plot/shared.py +91 -0
- wandb/integration/sklearn/utils.py +183 -0
- wandb/integration/tensorboard/__init__.py +10 -0
- wandb/integration/tensorboard/log.py +355 -0
- wandb/integration/tensorboard/monkeypatch.py +185 -0
- wandb/integration/tensorflow/__init__.py +5 -0
- wandb/integration/tensorflow/estimator_hook.py +54 -0
- wandb/integration/torch/__init__.py +0 -0
- wandb/integration/torch/wandb_torch.py +554 -0
- wandb/integration/ultralytics/__init__.py +11 -0
- wandb/integration/ultralytics/bbox_utils.py +208 -0
- wandb/integration/ultralytics/callback.py +524 -0
- wandb/integration/ultralytics/classification_utils.py +83 -0
- wandb/integration/ultralytics/mask_utils.py +202 -0
- wandb/integration/ultralytics/pose_utils.py +103 -0
- wandb/integration/xgboost/__init__.py +11 -0
- wandb/integration/xgboost/xgboost.py +189 -0
- wandb/integration/yolov8/__init__.py +0 -0
- wandb/integration/yolov8/yolov8.py +284 -0
- wandb/jupyter.py +515 -0
- wandb/magic.py +3 -0
- wandb/mpmain/__init__.py +0 -0
- wandb/mpmain/__main__.py +1 -0
- wandb/old/__init__.py +0 -0
- wandb/old/core.py +53 -0
- wandb/old/settings.py +173 -0
- wandb/old/summary.py +440 -0
- wandb/plot/__init__.py +19 -0
- wandb/plot/bar.py +45 -0
- wandb/plot/confusion_matrix.py +100 -0
- wandb/plot/histogram.py +39 -0
- wandb/plot/line.py +43 -0
- wandb/plot/line_series.py +88 -0
- wandb/plot/pr_curve.py +136 -0
- wandb/plot/roc_curve.py +118 -0
- wandb/plot/scatter.py +32 -0
- wandb/plot/utils.py +183 -0
- wandb/plot/viz.py +123 -0
- wandb/proto/__init__.py +0 -0
- wandb/proto/v3/__init__.py +0 -0
- wandb/proto/v3/wandb_base_pb2.py +55 -0
- wandb/proto/v3/wandb_internal_pb2.py +1608 -0
- wandb/proto/v3/wandb_server_pb2.py +208 -0
- wandb/proto/v3/wandb_settings_pb2.py +112 -0
- wandb/proto/v3/wandb_telemetry_pb2.py +106 -0
- wandb/proto/v4/__init__.py +0 -0
- wandb/proto/v4/wandb_base_pb2.py +30 -0
- wandb/proto/v4/wandb_internal_pb2.py +360 -0
- wandb/proto/v4/wandb_server_pb2.py +63 -0
- wandb/proto/v4/wandb_settings_pb2.py +45 -0
- wandb/proto/v4/wandb_telemetry_pb2.py +41 -0
- wandb/proto/v5/wandb_base_pb2.py +31 -0
- wandb/proto/v5/wandb_internal_pb2.py +361 -0
- wandb/proto/v5/wandb_server_pb2.py +64 -0
- wandb/proto/v5/wandb_settings_pb2.py +46 -0
- wandb/proto/v5/wandb_telemetry_pb2.py +42 -0
- wandb/proto/wandb_base_pb2.py +10 -0
- wandb/proto/wandb_deprecated.py +53 -0
- wandb/proto/wandb_generate_deprecated.py +34 -0
- wandb/proto/wandb_generate_proto.py +49 -0
- wandb/proto/wandb_internal_pb2.py +16 -0
- wandb/proto/wandb_server_pb2.py +10 -0
- wandb/proto/wandb_settings_pb2.py +10 -0
- wandb/proto/wandb_telemetry_pb2.py +10 -0
- wandb/py.typed +0 -0
- wandb/sdk/__init__.py +37 -0
- wandb/sdk/artifacts/__init__.py +0 -0
- wandb/sdk/artifacts/_validators.py +90 -0
- wandb/sdk/artifacts/artifact.py +2389 -0
- wandb/sdk/artifacts/artifact_download_logger.py +43 -0
- wandb/sdk/artifacts/artifact_file_cache.py +253 -0
- wandb/sdk/artifacts/artifact_instance_cache.py +17 -0
- wandb/sdk/artifacts/artifact_manifest.py +74 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +249 -0
- wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +92 -0
- wandb/sdk/artifacts/artifact_saver.py +269 -0
- wandb/sdk/artifacts/artifact_state.py +11 -0
- wandb/sdk/artifacts/artifact_ttl.py +7 -0
- wandb/sdk/artifacts/exceptions.py +57 -0
- wandb/sdk/artifacts/staging.py +25 -0
- wandb/sdk/artifacts/storage_handler.py +62 -0
- wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +208 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +228 -0
- wandb/sdk/artifacts/storage_handlers/http_handler.py +114 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +141 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +56 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +300 -0
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +72 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +135 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +74 -0
- wandb/sdk/artifacts/storage_layout.py +6 -0
- wandb/sdk/artifacts/storage_policies/__init__.py +4 -0
- wandb/sdk/artifacts/storage_policies/register.py +1 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +378 -0
- wandb/sdk/artifacts/storage_policy.py +72 -0
- wandb/sdk/backend/__init__.py +0 -0
- wandb/sdk/backend/backend.py +222 -0
- wandb/sdk/data_types/__init__.py +0 -0
- wandb/sdk/data_types/_dtypes.py +914 -0
- wandb/sdk/data_types/_private.py +10 -0
- wandb/sdk/data_types/audio.py +165 -0
- wandb/sdk/data_types/base_types/__init__.py +0 -0
- wandb/sdk/data_types/base_types/json_metadata.py +55 -0
- wandb/sdk/data_types/base_types/media.py +315 -0
- wandb/sdk/data_types/base_types/wb_value.py +272 -0
- wandb/sdk/data_types/bokeh.py +70 -0
- wandb/sdk/data_types/graph.py +405 -0
- wandb/sdk/data_types/helper_types/__init__.py +0 -0
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +295 -0
- wandb/sdk/data_types/helper_types/classes.py +159 -0
- wandb/sdk/data_types/helper_types/image_mask.py +235 -0
- wandb/sdk/data_types/histogram.py +96 -0
- wandb/sdk/data_types/html.py +115 -0
- wandb/sdk/data_types/image.py +845 -0
- wandb/sdk/data_types/molecule.py +241 -0
- wandb/sdk/data_types/object_3d.py +474 -0
- wandb/sdk/data_types/plotly.py +82 -0
- wandb/sdk/data_types/saved_model.py +446 -0
- wandb/sdk/data_types/table.py +1204 -0
- wandb/sdk/data_types/trace_tree.py +438 -0
- wandb/sdk/data_types/utils.py +229 -0
- wandb/sdk/data_types/video.py +247 -0
- wandb/sdk/integration_utils/__init__.py +0 -0
- wandb/sdk/integration_utils/auto_logging.py +239 -0
- wandb/sdk/integration_utils/data_logging.py +475 -0
- wandb/sdk/interface/__init__.py +0 -0
- wandb/sdk/interface/constants.py +4 -0
- wandb/sdk/interface/interface.py +972 -0
- wandb/sdk/interface/interface_queue.py +59 -0
- wandb/sdk/interface/interface_relay.py +53 -0
- wandb/sdk/interface/interface_shared.py +537 -0
- wandb/sdk/interface/interface_sock.py +61 -0
- wandb/sdk/interface/message_future.py +27 -0
- wandb/sdk/interface/message_future_poll.py +50 -0
- wandb/sdk/interface/router.py +118 -0
- wandb/sdk/interface/router_queue.py +44 -0
- wandb/sdk/interface/router_relay.py +39 -0
- wandb/sdk/interface/router_sock.py +36 -0
- wandb/sdk/interface/summary_record.py +67 -0
- wandb/sdk/internal/__init__.py +0 -0
- wandb/sdk/internal/context.py +89 -0
- wandb/sdk/internal/datastore.py +297 -0
- wandb/sdk/internal/file_pusher.py +181 -0
- wandb/sdk/internal/file_stream.py +695 -0
- wandb/sdk/internal/flow_control.py +263 -0
- wandb/sdk/internal/handler.py +901 -0
- wandb/sdk/internal/internal.py +417 -0
- wandb/sdk/internal/internal_api.py +4358 -0
- wandb/sdk/internal/internal_util.py +100 -0
- wandb/sdk/internal/job_builder.py +629 -0
- wandb/sdk/internal/profiler.py +78 -0
- wandb/sdk/internal/progress.py +83 -0
- wandb/sdk/internal/run.py +25 -0
- wandb/sdk/internal/sample.py +70 -0
- wandb/sdk/internal/sender.py +1686 -0
- wandb/sdk/internal/sender_config.py +197 -0
- wandb/sdk/internal/settings_static.py +90 -0
- wandb/sdk/internal/system/__init__.py +0 -0
- wandb/sdk/internal/system/assets/__init__.py +27 -0
- wandb/sdk/internal/system/assets/aggregators.py +37 -0
- wandb/sdk/internal/system/assets/asset_registry.py +20 -0
- wandb/sdk/internal/system/assets/cpu.py +163 -0
- wandb/sdk/internal/system/assets/disk.py +210 -0
- wandb/sdk/internal/system/assets/gpu.py +416 -0
- wandb/sdk/internal/system/assets/gpu_amd.py +239 -0
- wandb/sdk/internal/system/assets/gpu_apple.py +177 -0
- wandb/sdk/internal/system/assets/interfaces.py +207 -0
- wandb/sdk/internal/system/assets/ipu.py +177 -0
- wandb/sdk/internal/system/assets/memory.py +166 -0
- wandb/sdk/internal/system/assets/network.py +125 -0
- wandb/sdk/internal/system/assets/open_metrics.py +299 -0
- wandb/sdk/internal/system/assets/tpu.py +154 -0
- wandb/sdk/internal/system/assets/trainium.py +399 -0
- wandb/sdk/internal/system/env_probe_helpers.py +13 -0
- wandb/sdk/internal/system/system_info.py +249 -0
- wandb/sdk/internal/system/system_monitor.py +229 -0
- wandb/sdk/internal/tb_watcher.py +518 -0
- wandb/sdk/internal/thread_local_settings.py +18 -0
- wandb/sdk/internal/writer.py +206 -0
- wandb/sdk/launch/__init__.py +14 -0
- wandb/sdk/launch/_launch.py +330 -0
- wandb/sdk/launch/_launch_add.py +255 -0
- wandb/sdk/launch/_project_spec.py +566 -0
- wandb/sdk/launch/agent/__init__.py +5 -0
- wandb/sdk/launch/agent/agent.py +924 -0
- wandb/sdk/launch/agent/config.py +296 -0
- wandb/sdk/launch/agent/job_status_tracker.py +53 -0
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
- wandb/sdk/launch/builder/__init__.py +0 -0
- wandb/sdk/launch/builder/abstract.py +156 -0
- wandb/sdk/launch/builder/build.py +297 -0
- wandb/sdk/launch/builder/context_manager.py +235 -0
- wandb/sdk/launch/builder/docker_builder.py +177 -0
- wandb/sdk/launch/builder/kaniko_builder.py +595 -0
- wandb/sdk/launch/builder/noop.py +58 -0
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +188 -0
- wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
- wandb/sdk/launch/create_job.py +528 -0
- wandb/sdk/launch/environment/abstract.py +29 -0
- wandb/sdk/launch/environment/aws_environment.py +322 -0
- wandb/sdk/launch/environment/azure_environment.py +105 -0
- wandb/sdk/launch/environment/gcp_environment.py +335 -0
- wandb/sdk/launch/environment/local_environment.py +66 -0
- wandb/sdk/launch/errors.py +19 -0
- wandb/sdk/launch/git_reference.py +109 -0
- wandb/sdk/launch/inputs/files.py +148 -0
- wandb/sdk/launch/inputs/internal.py +315 -0
- wandb/sdk/launch/inputs/manage.py +113 -0
- wandb/sdk/launch/inputs/schema.py +39 -0
- wandb/sdk/launch/loader.py +249 -0
- wandb/sdk/launch/registry/abstract.py +48 -0
- wandb/sdk/launch/registry/anon.py +29 -0
- wandb/sdk/launch/registry/azure_container_registry.py +124 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +192 -0
- wandb/sdk/launch/registry/google_artifact_registry.py +219 -0
- wandb/sdk/launch/registry/local_registry.py +67 -0
- wandb/sdk/launch/runner/__init__.py +0 -0
- wandb/sdk/launch/runner/abstract.py +195 -0
- wandb/sdk/launch/runner/kubernetes_monitor.py +474 -0
- wandb/sdk/launch/runner/kubernetes_runner.py +963 -0
- wandb/sdk/launch/runner/local_container.py +301 -0
- wandb/sdk/launch/runner/local_process.py +78 -0
- wandb/sdk/launch/runner/sagemaker_runner.py +426 -0
- wandb/sdk/launch/runner/vertex_runner.py +230 -0
- wandb/sdk/launch/sweeps/__init__.py +39 -0
- wandb/sdk/launch/sweeps/scheduler.py +742 -0
- wandb/sdk/launch/sweeps/scheduler_sweep.py +91 -0
- wandb/sdk/launch/sweeps/utils.py +316 -0
- wandb/sdk/launch/utils.py +746 -0
- wandb/sdk/launch/wandb_reference.py +138 -0
- wandb/sdk/lib/__init__.py +5 -0
- wandb/sdk/lib/_settings_toposort_generate.py +159 -0
- wandb/sdk/lib/_settings_toposort_generated.py +250 -0
- wandb/sdk/lib/_wburls_generate.py +25 -0
- wandb/sdk/lib/_wburls_generated.py +22 -0
- wandb/sdk/lib/apikey.py +273 -0
- wandb/sdk/lib/capped_dict.py +26 -0
- wandb/sdk/lib/config_util.py +101 -0
- wandb/sdk/lib/credentials.py +141 -0
- wandb/sdk/lib/deprecate.py +42 -0
- wandb/sdk/lib/disabled.py +29 -0
- wandb/sdk/lib/exit_hooks.py +54 -0
- wandb/sdk/lib/file_stream_utils.py +118 -0
- wandb/sdk/lib/filenames.py +64 -0
- wandb/sdk/lib/filesystem.py +372 -0
- wandb/sdk/lib/fsm.py +174 -0
- wandb/sdk/lib/gitlib.py +239 -0
- wandb/sdk/lib/gql_request.py +65 -0
- wandb/sdk/lib/handler_util.py +21 -0
- wandb/sdk/lib/hashutil.py +84 -0
- wandb/sdk/lib/import_hooks.py +275 -0
- wandb/sdk/lib/ipython.py +146 -0
- wandb/sdk/lib/json_util.py +80 -0
- wandb/sdk/lib/lazyloader.py +63 -0
- wandb/sdk/lib/mailbox.py +460 -0
- wandb/sdk/lib/module.py +69 -0
- wandb/sdk/lib/paths.py +106 -0
- wandb/sdk/lib/preinit.py +42 -0
- wandb/sdk/lib/printer.py +313 -0
- wandb/sdk/lib/proto_util.py +90 -0
- wandb/sdk/lib/redirect.py +845 -0
- wandb/sdk/lib/reporting.py +99 -0
- wandb/sdk/lib/retry.py +289 -0
- wandb/sdk/lib/run_moment.py +78 -0
- wandb/sdk/lib/runid.py +12 -0
- wandb/sdk/lib/server.py +52 -0
- wandb/sdk/lib/service_connection.py +216 -0
- wandb/sdk/lib/service_token.py +94 -0
- wandb/sdk/lib/sock_client.py +295 -0
- wandb/sdk/lib/sparkline.py +45 -0
- wandb/sdk/lib/telemetry.py +100 -0
- wandb/sdk/lib/timed_input.py +133 -0
- wandb/sdk/lib/timer.py +19 -0
- wandb/sdk/lib/tracelog.py +255 -0
- wandb/sdk/lib/wburls.py +46 -0
- wandb/sdk/service/__init__.py +0 -0
- wandb/sdk/service/_startup_debug.py +22 -0
- wandb/sdk/service/port_file.py +53 -0
- wandb/sdk/service/server.py +116 -0
- wandb/sdk/service/server_sock.py +276 -0
- wandb/sdk/service/service.py +242 -0
- wandb/sdk/service/streams.py +417 -0
- wandb/sdk/verify/__init__.py +0 -0
- wandb/sdk/verify/verify.py +501 -0
- wandb/sdk/wandb_alerts.py +12 -0
- wandb/sdk/wandb_config.py +322 -0
- wandb/sdk/wandb_helper.py +54 -0
- wandb/sdk/wandb_init.py +1266 -0
- wandb/sdk/wandb_login.py +349 -0
- wandb/sdk/wandb_metric.py +110 -0
- wandb/sdk/wandb_require.py +97 -0
- wandb/sdk/wandb_require_helpers.py +44 -0
- wandb/sdk/wandb_run.py +4236 -0
- wandb/sdk/wandb_settings.py +2001 -0
- wandb/sdk/wandb_setup.py +409 -0
- wandb/sdk/wandb_summary.py +150 -0
- wandb/sdk/wandb_sweep.py +119 -0
- wandb/sdk/wandb_sync.py +81 -0
- wandb/sdk/wandb_watch.py +144 -0
- wandb/sklearn.py +35 -0
- wandb/sync/__init__.py +3 -0
- wandb/sync/sync.py +443 -0
- wandb/trigger.py +29 -0
- wandb/util.py +1956 -0
- wandb/vendor/__init__.py +0 -0
- wandb/vendor/gql-0.2.0/setup.py +40 -0
- wandb/vendor/gql-0.2.0/tests/__init__.py +0 -0
- wandb/vendor/gql-0.2.0/tests/starwars/__init__.py +0 -0
- wandb/vendor/gql-0.2.0/tests/starwars/fixtures.py +96 -0
- wandb/vendor/gql-0.2.0/tests/starwars/schema.py +146 -0
- wandb/vendor/gql-0.2.0/tests/starwars/test_dsl.py +293 -0
- wandb/vendor/gql-0.2.0/tests/starwars/test_query.py +355 -0
- wandb/vendor/gql-0.2.0/tests/starwars/test_validation.py +171 -0
- wandb/vendor/gql-0.2.0/tests/test_client.py +31 -0
- wandb/vendor/gql-0.2.0/tests/test_transport.py +89 -0
- wandb/vendor/gql-0.2.0/wandb_gql/__init__.py +4 -0
- wandb/vendor/gql-0.2.0/wandb_gql/client.py +75 -0
- wandb/vendor/gql-0.2.0/wandb_gql/dsl.py +152 -0
- wandb/vendor/gql-0.2.0/wandb_gql/gql.py +10 -0
- wandb/vendor/gql-0.2.0/wandb_gql/transport/__init__.py +0 -0
- wandb/vendor/gql-0.2.0/wandb_gql/transport/http.py +6 -0
- wandb/vendor/gql-0.2.0/wandb_gql/transport/local_schema.py +15 -0
- wandb/vendor/gql-0.2.0/wandb_gql/transport/requests.py +46 -0
- wandb/vendor/gql-0.2.0/wandb_gql/utils.py +21 -0
- wandb/vendor/graphql-core-1.1/setup.py +86 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/__init__.py +287 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/error/__init__.py +6 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/error/base.py +42 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/error/format_error.py +11 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/error/located_error.py +29 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/error/syntax_error.py +36 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__init__.py +26 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/execution/base.py +311 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executor.py +398 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/__init__.py +0 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/asyncio.py +53 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/gevent.py +22 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/process.py +32 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/sync.py +7 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/thread.py +35 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/utils.py +6 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/__init__.py +0 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/executor.py +66 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/fragment.py +252 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/resolver.py +151 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/utils.py +7 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/execution/middleware.py +57 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/execution/values.py +145 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/graphql.py +60 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/language/__init__.py +0 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/language/ast.py +1349 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/language/base.py +19 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/language/lexer.py +435 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/language/location.py +30 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/language/parser.py +779 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/language/printer.py +193 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/language/source.py +18 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/language/visitor.py +222 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/language/visitor_meta.py +82 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__init__.py +0 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/cached_property.py +17 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/contain_subset.py +28 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/default_ordered_dict.py +40 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/ordereddict.py +8 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/pair_set.py +43 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/version.py +78 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/type/__init__.py +67 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/type/definition.py +619 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/type/directives.py +132 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/type/introspection.py +440 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/type/scalars.py +131 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/type/schema.py +100 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/type/typemap.py +145 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__init__.py +0 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/assert_valid_name.py +9 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_from_value.py +65 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_to_code.py +49 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_to_dict.py +24 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/base.py +75 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/build_ast_schema.py +291 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/build_client_schema.py +250 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/concat_ast.py +9 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/extend_schema.py +357 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/get_field_def.py +27 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/get_operation_ast.py +21 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/introspection_query.py +90 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/is_valid_literal_value.py +67 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/is_valid_value.py +66 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/quoted_or_list.py +21 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/schema_printer.py +168 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/suggestion_list.py +56 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_comparators.py +69 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_from_ast.py +21 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_info.py +149 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/utils/value_from_ast.py +69 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/__init__.py +4 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__init__.py +79 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/arguments_of_correct_type.py +24 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/base.py +8 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/default_values_of_correct_type.py +44 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/fields_on_correct_type.py +113 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/fragments_on_composite_types.py +33 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_argument_names.py +70 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_directives.py +97 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_fragment_names.py +19 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_type_names.py +43 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/lone_anonymous_operation.py +23 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_fragment_cycles.py +59 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_undefined_variables.py +36 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_unused_fragments.py +38 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_unused_variables.py +37 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/overlapping_fields_can_be_merged.py +529 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/possible_fragment_spreads.py +44 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/provided_non_null_arguments.py +46 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/scalar_leafs.py +33 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_argument_names.py +32 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_fragment_names.py +28 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_input_field_names.py +33 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_operation_names.py +31 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_variable_names.py +27 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/variables_are_input_types.py +21 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/variables_in_allowed_position.py +53 -0
- wandb/vendor/graphql-core-1.1/wandb_graphql/validation/validation.py +158 -0
- wandb/vendor/promise-2.3.0/conftest.py +30 -0
- wandb/vendor/promise-2.3.0/setup.py +64 -0
- wandb/vendor/promise-2.3.0/tests/__init__.py +0 -0
- wandb/vendor/promise-2.3.0/tests/conftest.py +8 -0
- wandb/vendor/promise-2.3.0/tests/test_awaitable.py +32 -0
- wandb/vendor/promise-2.3.0/tests/test_awaitable_35.py +47 -0
- wandb/vendor/promise-2.3.0/tests/test_benchmark.py +116 -0
- wandb/vendor/promise-2.3.0/tests/test_complex_threads.py +23 -0
- wandb/vendor/promise-2.3.0/tests/test_dataloader.py +452 -0
- wandb/vendor/promise-2.3.0/tests/test_dataloader_awaitable_35.py +99 -0
- wandb/vendor/promise-2.3.0/tests/test_dataloader_extra.py +65 -0
- wandb/vendor/promise-2.3.0/tests/test_extra.py +670 -0
- wandb/vendor/promise-2.3.0/tests/test_issues.py +132 -0
- wandb/vendor/promise-2.3.0/tests/test_promise_list.py +70 -0
- wandb/vendor/promise-2.3.0/tests/test_spec.py +584 -0
- wandb/vendor/promise-2.3.0/tests/test_thread_safety.py +115 -0
- wandb/vendor/promise-2.3.0/tests/utils.py +3 -0
- wandb/vendor/promise-2.3.0/wandb_promise/__init__.py +38 -0
- wandb/vendor/promise-2.3.0/wandb_promise/async_.py +135 -0
- wandb/vendor/promise-2.3.0/wandb_promise/compat.py +32 -0
- wandb/vendor/promise-2.3.0/wandb_promise/dataloader.py +326 -0
- wandb/vendor/promise-2.3.0/wandb_promise/iterate_promise.py +12 -0
- wandb/vendor/promise-2.3.0/wandb_promise/promise.py +848 -0
- wandb/vendor/promise-2.3.0/wandb_promise/promise_list.py +151 -0
- wandb/vendor/promise-2.3.0/wandb_promise/pyutils/__init__.py +0 -0
- wandb/vendor/promise-2.3.0/wandb_promise/pyutils/version.py +83 -0
- wandb/vendor/promise-2.3.0/wandb_promise/schedulers/__init__.py +0 -0
- wandb/vendor/promise-2.3.0/wandb_promise/schedulers/asyncio.py +22 -0
- wandb/vendor/promise-2.3.0/wandb_promise/schedulers/gevent.py +21 -0
- wandb/vendor/promise-2.3.0/wandb_promise/schedulers/immediate.py +27 -0
- wandb/vendor/promise-2.3.0/wandb_promise/schedulers/thread.py +18 -0
- wandb/vendor/promise-2.3.0/wandb_promise/utils.py +56 -0
- wandb/vendor/pygments/__init__.py +90 -0
- wandb/vendor/pygments/cmdline.py +568 -0
- wandb/vendor/pygments/console.py +74 -0
- wandb/vendor/pygments/filter.py +74 -0
- wandb/vendor/pygments/filters/__init__.py +350 -0
- wandb/vendor/pygments/formatter.py +95 -0
- wandb/vendor/pygments/formatters/__init__.py +153 -0
- wandb/vendor/pygments/formatters/_mapping.py +85 -0
- wandb/vendor/pygments/formatters/bbcode.py +109 -0
- wandb/vendor/pygments/formatters/html.py +851 -0
- wandb/vendor/pygments/formatters/img.py +600 -0
- wandb/vendor/pygments/formatters/irc.py +182 -0
- wandb/vendor/pygments/formatters/latex.py +482 -0
- wandb/vendor/pygments/formatters/other.py +160 -0
- wandb/vendor/pygments/formatters/rtf.py +147 -0
- wandb/vendor/pygments/formatters/svg.py +153 -0
- wandb/vendor/pygments/formatters/terminal.py +136 -0
- wandb/vendor/pygments/formatters/terminal256.py +309 -0
- wandb/vendor/pygments/lexer.py +871 -0
- wandb/vendor/pygments/lexers/__init__.py +329 -0
- wandb/vendor/pygments/lexers/_asy_builtins.py +1645 -0
- wandb/vendor/pygments/lexers/_cl_builtins.py +232 -0
- wandb/vendor/pygments/lexers/_cocoa_builtins.py +72 -0
- wandb/vendor/pygments/lexers/_csound_builtins.py +1346 -0
- wandb/vendor/pygments/lexers/_lasso_builtins.py +5327 -0
- wandb/vendor/pygments/lexers/_lua_builtins.py +295 -0
- wandb/vendor/pygments/lexers/_mapping.py +500 -0
- wandb/vendor/pygments/lexers/_mql_builtins.py +1172 -0
- wandb/vendor/pygments/lexers/_openedge_builtins.py +2547 -0
- wandb/vendor/pygments/lexers/_php_builtins.py +4756 -0
- wandb/vendor/pygments/lexers/_postgres_builtins.py +621 -0
- wandb/vendor/pygments/lexers/_scilab_builtins.py +3094 -0
- wandb/vendor/pygments/lexers/_sourcemod_builtins.py +1163 -0
- wandb/vendor/pygments/lexers/_stan_builtins.py +532 -0
- wandb/vendor/pygments/lexers/_stata_builtins.py +419 -0
- wandb/vendor/pygments/lexers/_tsql_builtins.py +1004 -0
- wandb/vendor/pygments/lexers/_vim_builtins.py +1939 -0
- wandb/vendor/pygments/lexers/actionscript.py +240 -0
- wandb/vendor/pygments/lexers/agile.py +24 -0
- wandb/vendor/pygments/lexers/algebra.py +221 -0
- wandb/vendor/pygments/lexers/ambient.py +76 -0
- wandb/vendor/pygments/lexers/ampl.py +87 -0
- wandb/vendor/pygments/lexers/apl.py +101 -0
- wandb/vendor/pygments/lexers/archetype.py +318 -0
- wandb/vendor/pygments/lexers/asm.py +641 -0
- wandb/vendor/pygments/lexers/automation.py +374 -0
- wandb/vendor/pygments/lexers/basic.py +500 -0
- wandb/vendor/pygments/lexers/bibtex.py +160 -0
- wandb/vendor/pygments/lexers/business.py +612 -0
- wandb/vendor/pygments/lexers/c_cpp.py +252 -0
- wandb/vendor/pygments/lexers/c_like.py +541 -0
- wandb/vendor/pygments/lexers/capnproto.py +78 -0
- wandb/vendor/pygments/lexers/chapel.py +102 -0
- wandb/vendor/pygments/lexers/clean.py +288 -0
- wandb/vendor/pygments/lexers/compiled.py +34 -0
- wandb/vendor/pygments/lexers/configs.py +833 -0
- wandb/vendor/pygments/lexers/console.py +114 -0
- wandb/vendor/pygments/lexers/crystal.py +393 -0
- wandb/vendor/pygments/lexers/csound.py +366 -0
- wandb/vendor/pygments/lexers/css.py +689 -0
- wandb/vendor/pygments/lexers/d.py +251 -0
- wandb/vendor/pygments/lexers/dalvik.py +125 -0
- wandb/vendor/pygments/lexers/data.py +555 -0
- wandb/vendor/pygments/lexers/diff.py +165 -0
- wandb/vendor/pygments/lexers/dotnet.py +691 -0
- wandb/vendor/pygments/lexers/dsls.py +878 -0
- wandb/vendor/pygments/lexers/dylan.py +289 -0
- wandb/vendor/pygments/lexers/ecl.py +125 -0
- wandb/vendor/pygments/lexers/eiffel.py +65 -0
- wandb/vendor/pygments/lexers/elm.py +121 -0
- wandb/vendor/pygments/lexers/erlang.py +533 -0
- wandb/vendor/pygments/lexers/esoteric.py +277 -0
- wandb/vendor/pygments/lexers/ezhil.py +69 -0
- wandb/vendor/pygments/lexers/factor.py +344 -0
- wandb/vendor/pygments/lexers/fantom.py +250 -0
- wandb/vendor/pygments/lexers/felix.py +273 -0
- wandb/vendor/pygments/lexers/forth.py +177 -0
- wandb/vendor/pygments/lexers/fortran.py +205 -0
- wandb/vendor/pygments/lexers/foxpro.py +428 -0
- wandb/vendor/pygments/lexers/functional.py +21 -0
- wandb/vendor/pygments/lexers/go.py +101 -0
- wandb/vendor/pygments/lexers/grammar_notation.py +213 -0
- wandb/vendor/pygments/lexers/graph.py +80 -0
- wandb/vendor/pygments/lexers/graphics.py +553 -0
- wandb/vendor/pygments/lexers/haskell.py +843 -0
- wandb/vendor/pygments/lexers/haxe.py +936 -0
- wandb/vendor/pygments/lexers/hdl.py +382 -0
- wandb/vendor/pygments/lexers/hexdump.py +103 -0
- wandb/vendor/pygments/lexers/html.py +602 -0
- wandb/vendor/pygments/lexers/idl.py +270 -0
- wandb/vendor/pygments/lexers/igor.py +288 -0
- wandb/vendor/pygments/lexers/inferno.py +96 -0
- wandb/vendor/pygments/lexers/installers.py +322 -0
- wandb/vendor/pygments/lexers/int_fiction.py +1343 -0
- wandb/vendor/pygments/lexers/iolang.py +63 -0
- wandb/vendor/pygments/lexers/j.py +146 -0
- wandb/vendor/pygments/lexers/javascript.py +1525 -0
- wandb/vendor/pygments/lexers/julia.py +333 -0
- wandb/vendor/pygments/lexers/jvm.py +1573 -0
- wandb/vendor/pygments/lexers/lisp.py +2621 -0
- wandb/vendor/pygments/lexers/make.py +202 -0
- wandb/vendor/pygments/lexers/markup.py +595 -0
- wandb/vendor/pygments/lexers/math.py +21 -0
- wandb/vendor/pygments/lexers/matlab.py +663 -0
- wandb/vendor/pygments/lexers/ml.py +769 -0
- wandb/vendor/pygments/lexers/modeling.py +358 -0
- wandb/vendor/pygments/lexers/modula2.py +1561 -0
- wandb/vendor/pygments/lexers/monte.py +204 -0
- wandb/vendor/pygments/lexers/ncl.py +894 -0
- wandb/vendor/pygments/lexers/nimrod.py +159 -0
- wandb/vendor/pygments/lexers/nit.py +64 -0
- wandb/vendor/pygments/lexers/nix.py +136 -0
- wandb/vendor/pygments/lexers/oberon.py +105 -0
- wandb/vendor/pygments/lexers/objective.py +504 -0
- wandb/vendor/pygments/lexers/ooc.py +85 -0
- wandb/vendor/pygments/lexers/other.py +41 -0
- wandb/vendor/pygments/lexers/parasail.py +79 -0
- wandb/vendor/pygments/lexers/parsers.py +835 -0
- wandb/vendor/pygments/lexers/pascal.py +644 -0
- wandb/vendor/pygments/lexers/pawn.py +199 -0
- wandb/vendor/pygments/lexers/perl.py +620 -0
- wandb/vendor/pygments/lexers/php.py +267 -0
- wandb/vendor/pygments/lexers/praat.py +294 -0
- wandb/vendor/pygments/lexers/prolog.py +306 -0
- wandb/vendor/pygments/lexers/python.py +939 -0
- wandb/vendor/pygments/lexers/qvt.py +152 -0
- wandb/vendor/pygments/lexers/r.py +453 -0
- wandb/vendor/pygments/lexers/rdf.py +270 -0
- wandb/vendor/pygments/lexers/rebol.py +431 -0
- wandb/vendor/pygments/lexers/resource.py +85 -0
- wandb/vendor/pygments/lexers/rnc.py +67 -0
- wandb/vendor/pygments/lexers/roboconf.py +82 -0
- wandb/vendor/pygments/lexers/robotframework.py +560 -0
- wandb/vendor/pygments/lexers/ruby.py +519 -0
- wandb/vendor/pygments/lexers/rust.py +220 -0
- wandb/vendor/pygments/lexers/sas.py +228 -0
- wandb/vendor/pygments/lexers/scripting.py +1222 -0
- wandb/vendor/pygments/lexers/shell.py +794 -0
- wandb/vendor/pygments/lexers/smalltalk.py +195 -0
- wandb/vendor/pygments/lexers/smv.py +79 -0
- wandb/vendor/pygments/lexers/snobol.py +83 -0
- wandb/vendor/pygments/lexers/special.py +103 -0
- wandb/vendor/pygments/lexers/sql.py +681 -0
- wandb/vendor/pygments/lexers/stata.py +108 -0
- wandb/vendor/pygments/lexers/supercollider.py +90 -0
- wandb/vendor/pygments/lexers/tcl.py +145 -0
- wandb/vendor/pygments/lexers/templates.py +2283 -0
- wandb/vendor/pygments/lexers/testing.py +207 -0
- wandb/vendor/pygments/lexers/text.py +25 -0
- wandb/vendor/pygments/lexers/textedit.py +169 -0
- wandb/vendor/pygments/lexers/textfmts.py +297 -0
- wandb/vendor/pygments/lexers/theorem.py +458 -0
- wandb/vendor/pygments/lexers/trafficscript.py +54 -0
- wandb/vendor/pygments/lexers/typoscript.py +226 -0
- wandb/vendor/pygments/lexers/urbi.py +133 -0
- wandb/vendor/pygments/lexers/varnish.py +190 -0
- wandb/vendor/pygments/lexers/verification.py +111 -0
- wandb/vendor/pygments/lexers/web.py +24 -0
- wandb/vendor/pygments/lexers/webmisc.py +988 -0
- wandb/vendor/pygments/lexers/whiley.py +116 -0
- wandb/vendor/pygments/lexers/x10.py +69 -0
- wandb/vendor/pygments/modeline.py +44 -0
- wandb/vendor/pygments/plugin.py +68 -0
- wandb/vendor/pygments/regexopt.py +92 -0
- wandb/vendor/pygments/scanner.py +105 -0
- wandb/vendor/pygments/sphinxext.py +158 -0
- wandb/vendor/pygments/style.py +155 -0
- wandb/vendor/pygments/styles/__init__.py +80 -0
- wandb/vendor/pygments/styles/abap.py +29 -0
- wandb/vendor/pygments/styles/algol.py +63 -0
- wandb/vendor/pygments/styles/algol_nu.py +63 -0
- wandb/vendor/pygments/styles/arduino.py +98 -0
- wandb/vendor/pygments/styles/autumn.py +65 -0
- wandb/vendor/pygments/styles/borland.py +51 -0
- wandb/vendor/pygments/styles/bw.py +49 -0
- wandb/vendor/pygments/styles/colorful.py +81 -0
- wandb/vendor/pygments/styles/default.py +73 -0
- wandb/vendor/pygments/styles/emacs.py +72 -0
- wandb/vendor/pygments/styles/friendly.py +72 -0
- wandb/vendor/pygments/styles/fruity.py +42 -0
- wandb/vendor/pygments/styles/igor.py +29 -0
- wandb/vendor/pygments/styles/lovelace.py +97 -0
- wandb/vendor/pygments/styles/manni.py +75 -0
- wandb/vendor/pygments/styles/monokai.py +106 -0
- wandb/vendor/pygments/styles/murphy.py +80 -0
- wandb/vendor/pygments/styles/native.py +65 -0
- wandb/vendor/pygments/styles/paraiso_dark.py +125 -0
- wandb/vendor/pygments/styles/paraiso_light.py +125 -0
- wandb/vendor/pygments/styles/pastie.py +75 -0
- wandb/vendor/pygments/styles/perldoc.py +69 -0
- wandb/vendor/pygments/styles/rainbow_dash.py +89 -0
- wandb/vendor/pygments/styles/rrt.py +33 -0
- wandb/vendor/pygments/styles/sas.py +44 -0
- wandb/vendor/pygments/styles/stata.py +40 -0
- wandb/vendor/pygments/styles/tango.py +141 -0
- wandb/vendor/pygments/styles/trac.py +63 -0
- wandb/vendor/pygments/styles/vim.py +63 -0
- wandb/vendor/pygments/styles/vs.py +38 -0
- wandb/vendor/pygments/styles/xcode.py +51 -0
- wandb/vendor/pygments/token.py +213 -0
- wandb/vendor/pygments/unistring.py +217 -0
- wandb/vendor/pygments/util.py +388 -0
- wandb/vendor/pynvml/__init__.py +0 -0
- wandb/vendor/pynvml/pynvml.py +4779 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/__init__.py +17 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/events.py +615 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/__init__.py +98 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/api.py +369 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/fsevents.py +172 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/fsevents2.py +239 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/inotify.py +218 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/inotify_buffer.py +81 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/inotify_c.py +575 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/kqueue.py +730 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/polling.py +145 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/read_directory_changes.py +133 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/winapi.py +348 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/patterns.py +265 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/tricks/__init__.py +174 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/__init__.py +151 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/bricks.py +249 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/compat.py +29 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/decorators.py +198 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/delayed_queue.py +88 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/dirsnapshot.py +293 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/echo.py +157 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/event_backport.py +41 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/importlib2.py +40 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/platform.py +57 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/unicode_paths.py +64 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/win32stat.py +123 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/version.py +28 -0
- wandb/vendor/watchdog_0_9_0/wandb_watchdog/watchmedo.py +577 -0
- wandb/wandb_agent.py +588 -0
- wandb/wandb_controller.py +721 -0
- wandb/wandb_run.py +9 -0
- wandb-0.18.2.dist-info/METADATA +213 -0
- wandb-0.18.2.dist-info/RECORD +827 -0
- wandb-0.18.2.dist-info/WHEEL +5 -0
- wandb-0.18.2.dist-info/entry_points.txt +3 -0
- wandb-0.18.2.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1091 @@
|
|
1
|
+
"""keras init."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import operator
|
5
|
+
import os
|
6
|
+
import shutil
|
7
|
+
import sys
|
8
|
+
from itertools import chain
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
import tensorflow as tf
|
12
|
+
import tensorflow.keras.backend as K # noqa: N812
|
13
|
+
|
14
|
+
import wandb
|
15
|
+
from wandb.sdk.integration_utils.data_logging import ValidationDataLogger
|
16
|
+
from wandb.sdk.lib.deprecate import Deprecated, deprecate
|
17
|
+
from wandb.util import add_import_hook
|
18
|
+
|
19
|
+
|
20
|
+
def _check_keras_version():
|
21
|
+
from keras import __version__ as keras_version
|
22
|
+
|
23
|
+
from wandb.util import parse_version
|
24
|
+
|
25
|
+
if parse_version(keras_version) < parse_version("2.4.0"):
|
26
|
+
wandb.termwarn(
|
27
|
+
f"Keras version {keras_version} is not fully supported. Required keras >= 2.4.0"
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
def _can_compute_flops() -> bool:
|
32
|
+
"""FLOPS computation is restricted to TF 2.x as it requires tf.compat.v1."""
|
33
|
+
from wandb.util import parse_version
|
34
|
+
|
35
|
+
if parse_version(tf.__version__) >= parse_version("2.0.0"):
|
36
|
+
return True
|
37
|
+
|
38
|
+
return False
|
39
|
+
|
40
|
+
|
41
|
+
if "keras" in sys.modules:
|
42
|
+
_check_keras_version()
|
43
|
+
else:
|
44
|
+
add_import_hook("keras", _check_keras_version)
|
45
|
+
|
46
|
+
|
47
|
+
logger = logging.getLogger(__name__)
|
48
|
+
|
49
|
+
|
50
|
+
def is_dataset(data):
|
51
|
+
dataset_ops = wandb.util.get_module("tensorflow.python.data.ops.dataset_ops")
|
52
|
+
if dataset_ops and hasattr(dataset_ops, "DatasetV2"):
|
53
|
+
dataset_types = (dataset_ops.DatasetV2,)
|
54
|
+
if hasattr(dataset_ops, "DatasetV1"):
|
55
|
+
dataset_types = dataset_types + (dataset_ops.DatasetV1,)
|
56
|
+
return isinstance(data, dataset_types)
|
57
|
+
else:
|
58
|
+
return False
|
59
|
+
|
60
|
+
|
61
|
+
def is_generator_like(data):
|
62
|
+
# Checks if data is a generator, Sequence, or Iterator.
|
63
|
+
|
64
|
+
types = (tf.keras.utils.Sequence,)
|
65
|
+
iterator_ops = wandb.util.get_module("tensorflow.python.data.ops.iterator_ops")
|
66
|
+
if iterator_ops:
|
67
|
+
types = types + (iterator_ops.Iterator,)
|
68
|
+
# EagerIterator was in tensorflow < 2
|
69
|
+
if hasattr(iterator_ops, "EagerIterator"):
|
70
|
+
types = types + (iterator_ops.EagerIterator,)
|
71
|
+
elif hasattr(iterator_ops, "IteratorV2"):
|
72
|
+
types = types + (iterator_ops.IteratorV2,)
|
73
|
+
return hasattr(data, "next") or hasattr(data, "__next__") or isinstance(data, types)
|
74
|
+
|
75
|
+
|
76
|
+
def patch_tf_keras(): # noqa: C901
|
77
|
+
from tensorflow.python.eager import context
|
78
|
+
|
79
|
+
from wandb.util import parse_version
|
80
|
+
|
81
|
+
if (
|
82
|
+
parse_version("2.6.0")
|
83
|
+
<= parse_version(tf.__version__)
|
84
|
+
< parse_version("2.13.0")
|
85
|
+
):
|
86
|
+
keras_engine = "keras.engine"
|
87
|
+
try:
|
88
|
+
from keras.engine import training
|
89
|
+
from keras.engine import training_arrays_v1 as training_arrays
|
90
|
+
from keras.engine import training_generator_v1 as training_generator
|
91
|
+
except (ImportError, AttributeError):
|
92
|
+
wandb.termerror("Unable to patch Tensorflow/Keras")
|
93
|
+
logger.exception("exception while trying to patch_tf_keras")
|
94
|
+
return
|
95
|
+
else:
|
96
|
+
keras_engine = "tensorflow.python.keras.engine"
|
97
|
+
|
98
|
+
from tensorflow.python.keras.engine import training
|
99
|
+
|
100
|
+
try:
|
101
|
+
from tensorflow.python.keras.engine import (
|
102
|
+
training_arrays_v1 as training_arrays,
|
103
|
+
)
|
104
|
+
from tensorflow.python.keras.engine import (
|
105
|
+
training_generator_v1 as training_generator,
|
106
|
+
)
|
107
|
+
except (ImportError, AttributeError):
|
108
|
+
try:
|
109
|
+
from tensorflow.python.keras.engine import (
|
110
|
+
training_arrays,
|
111
|
+
training_generator,
|
112
|
+
)
|
113
|
+
except (ImportError, AttributeError):
|
114
|
+
wandb.termerror("Unable to patch Tensorflow/Keras")
|
115
|
+
logger.exception("exception while trying to patch_tf_keras")
|
116
|
+
return
|
117
|
+
|
118
|
+
# Tensorflow 2.1
|
119
|
+
training_v2_1 = wandb.util.get_module("tensorflow.python.keras.engine.training_v2")
|
120
|
+
# Tensorflow 2.2
|
121
|
+
training_v2_2 = wandb.util.get_module(f"{keras_engine}.training_v1")
|
122
|
+
|
123
|
+
if training_v2_1:
|
124
|
+
old_v2 = training_v2_1.Loop.fit
|
125
|
+
elif training_v2_2:
|
126
|
+
old_v2 = training.Model.fit
|
127
|
+
|
128
|
+
old_arrays = training_arrays.fit_loop
|
129
|
+
old_generator = training_generator.fit_generator
|
130
|
+
|
131
|
+
def set_wandb_attrs(cbk, val_data):
|
132
|
+
if isinstance(cbk, WandbCallback):
|
133
|
+
if is_generator_like(val_data):
|
134
|
+
cbk.generator = val_data
|
135
|
+
elif is_dataset(val_data):
|
136
|
+
if context.executing_eagerly():
|
137
|
+
cbk.generator = iter(val_data)
|
138
|
+
else:
|
139
|
+
wandb.termwarn(
|
140
|
+
"Found a validation dataset in graph mode, can't patch Keras."
|
141
|
+
)
|
142
|
+
elif isinstance(val_data, tuple) and isinstance(val_data[0], tf.Tensor):
|
143
|
+
# Graph mode dataset generator
|
144
|
+
def gen():
|
145
|
+
while True:
|
146
|
+
yield K.get_session().run(val_data)
|
147
|
+
|
148
|
+
cbk.generator = gen()
|
149
|
+
else:
|
150
|
+
cbk.validation_data = val_data
|
151
|
+
|
152
|
+
def new_arrays(*args, **kwargs):
|
153
|
+
cbks = kwargs.get("callbacks", [])
|
154
|
+
val_inputs = kwargs.get("val_inputs")
|
155
|
+
val_targets = kwargs.get("val_targets")
|
156
|
+
# TODO: these could be generators, why index 0?
|
157
|
+
if val_inputs and val_targets:
|
158
|
+
for cbk in cbks:
|
159
|
+
set_wandb_attrs(cbk, (val_inputs[0], val_targets[0]))
|
160
|
+
return old_arrays(*args, **kwargs)
|
161
|
+
|
162
|
+
def new_generator(*args, **kwargs):
|
163
|
+
cbks = kwargs.get("callbacks", [])
|
164
|
+
val_data = kwargs.get("validation_data")
|
165
|
+
if val_data:
|
166
|
+
for cbk in cbks:
|
167
|
+
set_wandb_attrs(cbk, val_data)
|
168
|
+
return old_generator(*args, **kwargs)
|
169
|
+
|
170
|
+
def new_v2(*args, **kwargs):
|
171
|
+
cbks = kwargs.get("callbacks", [])
|
172
|
+
val_data = kwargs.get("validation_data")
|
173
|
+
if val_data:
|
174
|
+
for cbk in cbks:
|
175
|
+
set_wandb_attrs(cbk, val_data)
|
176
|
+
return old_v2(*args, **kwargs)
|
177
|
+
|
178
|
+
training_arrays.orig_fit_loop = old_arrays
|
179
|
+
training_arrays.fit_loop = new_arrays
|
180
|
+
training_generator.orig_fit_generator = old_generator
|
181
|
+
training_generator.fit_generator = new_generator
|
182
|
+
wandb.patched["keras"].append([f"{keras_engine}.training_arrays", "fit_loop"])
|
183
|
+
wandb.patched["keras"].append(
|
184
|
+
[f"{keras_engine}.training_generator", "fit_generator"]
|
185
|
+
)
|
186
|
+
|
187
|
+
if training_v2_1:
|
188
|
+
training_v2_1.Loop.fit = new_v2
|
189
|
+
wandb.patched["keras"].append(
|
190
|
+
["tensorflow.python.keras.engine.training_v2.Loop", "fit"]
|
191
|
+
)
|
192
|
+
elif training_v2_2:
|
193
|
+
training.Model.fit = new_v2
|
194
|
+
wandb.patched["keras"].append([f"{keras_engine}.training.Model", "fit"])
|
195
|
+
|
196
|
+
|
197
|
+
def _array_has_dtype(array):
|
198
|
+
return hasattr(array, "dtype")
|
199
|
+
|
200
|
+
|
201
|
+
def _update_if_numeric(metrics, key, values):
|
202
|
+
if not _array_has_dtype(values):
|
203
|
+
_warn_not_logging(key)
|
204
|
+
return
|
205
|
+
|
206
|
+
if not is_numeric_array(values):
|
207
|
+
_warn_not_logging_non_numeric(key)
|
208
|
+
return
|
209
|
+
|
210
|
+
metrics[key] = wandb.Histogram(values)
|
211
|
+
|
212
|
+
|
213
|
+
def is_numeric_array(array):
|
214
|
+
return np.issubdtype(array.dtype, np.number)
|
215
|
+
|
216
|
+
|
217
|
+
def _warn_not_logging_non_numeric(name):
|
218
|
+
wandb.termwarn(
|
219
|
+
f"Non-numeric values found in layer: {name}, not logging this layer",
|
220
|
+
repeat=False,
|
221
|
+
)
|
222
|
+
|
223
|
+
|
224
|
+
def _warn_not_logging(name):
|
225
|
+
wandb.termwarn(
|
226
|
+
f"Layer {name} has undetermined datatype not logging this layer",
|
227
|
+
repeat=False,
|
228
|
+
)
|
229
|
+
|
230
|
+
|
231
|
+
tf_logger = tf.get_logger()
|
232
|
+
|
233
|
+
patch_tf_keras()
|
234
|
+
|
235
|
+
|
236
|
+
### For gradient logging ###
|
237
|
+
|
238
|
+
|
239
|
+
def _get_custom_optimizer_parent_class():
|
240
|
+
from wandb.util import parse_version
|
241
|
+
|
242
|
+
if parse_version(tf.__version__) >= parse_version("2.9.0"):
|
243
|
+
custom_optimizer_parent_class = tf.keras.optimizers.legacy.Optimizer
|
244
|
+
else:
|
245
|
+
custom_optimizer_parent_class = tf.keras.optimizers.Optimizer
|
246
|
+
|
247
|
+
return custom_optimizer_parent_class
|
248
|
+
|
249
|
+
|
250
|
+
_custom_optimizer_parent_class = _get_custom_optimizer_parent_class()
|
251
|
+
|
252
|
+
|
253
|
+
class _CustomOptimizer(_custom_optimizer_parent_class):
|
254
|
+
def __init__(self):
|
255
|
+
super().__init__(name="CustomOptimizer")
|
256
|
+
self._resource_apply_dense = tf.function(self._resource_apply_dense)
|
257
|
+
self._resource_apply_sparse = tf.function(self._resource_apply_sparse)
|
258
|
+
|
259
|
+
def _resource_apply_dense(self, grad, var):
|
260
|
+
var.assign(grad)
|
261
|
+
|
262
|
+
# this needs to be implemented to prevent a NotImplementedError when
|
263
|
+
# using Lookup layers.
|
264
|
+
def _resource_apply_sparse(self, grad, var, indices):
|
265
|
+
pass
|
266
|
+
|
267
|
+
def get_config(self):
|
268
|
+
return super().get_config()
|
269
|
+
|
270
|
+
|
271
|
+
class _GradAccumulatorCallback(tf.keras.callbacks.Callback):
|
272
|
+
"""Accumulates gradients during a fit() call when used in conjunction with the CustomOptimizer above."""
|
273
|
+
|
274
|
+
def set_model(self, model):
|
275
|
+
super().set_model(model)
|
276
|
+
self.og_weights = model.get_weights()
|
277
|
+
self.grads = [np.zeros(tuple(w.shape)) for w in model.trainable_weights]
|
278
|
+
|
279
|
+
def on_batch_end(self, batch, logs=None):
|
280
|
+
for g, w in zip(self.grads, self.model.trainable_weights):
|
281
|
+
g += w.numpy()
|
282
|
+
self.model.set_weights(self.og_weights)
|
283
|
+
|
284
|
+
def get_grads(self):
|
285
|
+
return [g.copy() for g in self.grads]
|
286
|
+
|
287
|
+
|
288
|
+
###
|
289
|
+
|
290
|
+
|
291
|
+
class WandbCallback(tf.keras.callbacks.Callback):
|
292
|
+
"""`WandbCallback` automatically integrates keras with wandb.
|
293
|
+
|
294
|
+
Example:
|
295
|
+
```python
|
296
|
+
model.fit(
|
297
|
+
X_train,
|
298
|
+
y_train,
|
299
|
+
validation_data=(X_test, y_test),
|
300
|
+
callbacks=[WandbCallback()],
|
301
|
+
)
|
302
|
+
```
|
303
|
+
|
304
|
+
`WandbCallback` will automatically log history data from any
|
305
|
+
metrics collected by keras: loss and anything passed into `keras_model.compile()`.
|
306
|
+
|
307
|
+
`WandbCallback` will set summary metrics for the run associated with the "best" training
|
308
|
+
step, where "best" is defined by the `monitor` and `mode` attributes. This defaults
|
309
|
+
to the epoch with the minimum `val_loss`. `WandbCallback` will by default save the model
|
310
|
+
associated with the best `epoch`.
|
311
|
+
|
312
|
+
`WandbCallback` can optionally log gradient and parameter histograms.
|
313
|
+
|
314
|
+
`WandbCallback` can optionally save training and validation data for wandb to visualize.
|
315
|
+
|
316
|
+
Arguments:
|
317
|
+
monitor: (str) name of metric to monitor. Defaults to `val_loss`.
|
318
|
+
mode: (str) one of {`auto`, `min`, `max`}.
|
319
|
+
`min` - save model when monitor is minimized
|
320
|
+
`max` - save model when monitor is maximized
|
321
|
+
`auto` - try to guess when to save the model (default).
|
322
|
+
save_model:
|
323
|
+
True - save a model when monitor beats all previous epochs
|
324
|
+
False - don't save models
|
325
|
+
save_graph: (boolean) if True save model graph to wandb (default to True).
|
326
|
+
save_weights_only: (boolean) if True, then only the model's weights will be
|
327
|
+
saved (`model.save_weights(filepath)`), else the full model
|
328
|
+
is saved (`model.save(filepath)`).
|
329
|
+
log_weights: (boolean) if True save histograms of the model's layer's weights.
|
330
|
+
log_gradients: (boolean) if True log histograms of the training gradients
|
331
|
+
training_data: (tuple) Same format `(X,y)` as passed to `model.fit`. This is needed
|
332
|
+
for calculating gradients - this is mandatory if `log_gradients` is `True`.
|
333
|
+
validation_data: (tuple) Same format `(X,y)` as passed to `model.fit`. A set of data
|
334
|
+
for wandb to visualize. If this is set, every epoch, wandb will
|
335
|
+
make a small number of predictions and save the results for later visualization. In case
|
336
|
+
you are working with image data, please also set `input_type` and `output_type` in order
|
337
|
+
to log correctly.
|
338
|
+
generator: (generator) a generator that returns validation data for wandb to visualize. This
|
339
|
+
generator should return tuples `(X,y)`. Either `validate_data` or generator should
|
340
|
+
be set for wandb to visualize specific data examples. In case you are working with image data,
|
341
|
+
please also set `input_type` and `output_type` in order to log correctly.
|
342
|
+
validation_steps: (int) if `validation_data` is a generator, how many
|
343
|
+
steps to run the generator for the full validation set.
|
344
|
+
labels: (list) If you are visualizing your data with wandb this list of labels
|
345
|
+
will convert numeric output to understandable string if you are building a
|
346
|
+
multiclass classifier. If you are making a binary classifier you can pass in
|
347
|
+
a list of two labels ["label for false", "label for true"]. If `validate_data`
|
348
|
+
and generator are both false, this won't do anything.
|
349
|
+
predictions: (int) the number of predictions to make for visualization each epoch, max
|
350
|
+
is 100.
|
351
|
+
input_type: (string) type of the model input to help visualization. can be one of:
|
352
|
+
(`image`, `images`, `segmentation_mask`, `auto`).
|
353
|
+
output_type: (string) type of the model output to help visualization. can be one of:
|
354
|
+
(`image`, `images`, `segmentation_mask`, `label`).
|
355
|
+
log_evaluation: (boolean) if True, save a Table containing validation data and the
|
356
|
+
model's predictions at each epoch. See `validation_indexes`,
|
357
|
+
`validation_row_processor`, and `output_row_processor` for additional details.
|
358
|
+
class_colors: ([float, float, float]) if the input or output is a segmentation mask,
|
359
|
+
an array containing an rgb tuple (range 0-1) for each class.
|
360
|
+
log_batch_frequency: (integer) if None, callback will log every epoch.
|
361
|
+
If set to integer, callback will log training metrics every `log_batch_frequency`
|
362
|
+
batches.
|
363
|
+
log_best_prefix: (string) if None, no extra summary metrics will be saved.
|
364
|
+
If set to a string, the monitored metric and epoch will be prepended with this value
|
365
|
+
and stored as summary metrics.
|
366
|
+
validation_indexes: ([wandb.data_types._TableLinkMixin]) an ordered list of index keys to associate
|
367
|
+
with each validation example. If log_evaluation is True and `validation_indexes` is provided,
|
368
|
+
then a Table of validation data will not be created and instead each prediction will
|
369
|
+
be associated with the row represented by the `TableLinkMixin`. The most common way to obtain
|
370
|
+
such keys are is use `Table.get_index()` which will return a list of row keys.
|
371
|
+
validation_row_processor: (Callable) a function to apply to the validation data, commonly used to visualize the data.
|
372
|
+
The function will receive an `ndx` (int) and a `row` (dict). If your model has a single input,
|
373
|
+
then `row["input"]` will be the input data for the row. Else, it will be keyed based on the name of the
|
374
|
+
input slot. If your fit function takes a single target, then `row["target"]` will be the target data for the row. Else,
|
375
|
+
it will be keyed based on the name of the output slots. For example, if your input data is a single ndarray,
|
376
|
+
but you wish to visualize the data as an Image, then you can provide `lambda ndx, row: {"img": wandb.Image(row["input"])}`
|
377
|
+
as the processor. Ignored if log_evaluation is False or `validation_indexes` are present.
|
378
|
+
output_row_processor: (Callable) same as `validation_row_processor`, but applied to the model's output. `row["output"]` will contain
|
379
|
+
the results of the model output.
|
380
|
+
infer_missing_processors: (bool) Determines if `validation_row_processor` and `output_row_processor`
|
381
|
+
should be inferred if missing. Defaults to True. If `labels` are provided, we will attempt to infer classification-type
|
382
|
+
processors where appropriate.
|
383
|
+
log_evaluation_frequency: (int) Determines the frequency which evaluation results will be logged. Default 0 (only at the end of training).
|
384
|
+
Set to 1 to log every epoch, 2 to log every other epoch, and so on. Has no effect when log_evaluation is False.
|
385
|
+
compute_flops: (bool) Compute the FLOPs of your Keras Sequential or Functional model in GigaFLOPs unit.
|
386
|
+
"""
|
387
|
+
|
388
|
+
def __init__(
|
389
|
+
self,
|
390
|
+
monitor="val_loss",
|
391
|
+
verbose=0,
|
392
|
+
mode="auto",
|
393
|
+
save_weights_only=False,
|
394
|
+
log_weights=False,
|
395
|
+
log_gradients=False,
|
396
|
+
save_model=True,
|
397
|
+
training_data=None,
|
398
|
+
validation_data=None,
|
399
|
+
labels=None,
|
400
|
+
predictions=36,
|
401
|
+
generator=None,
|
402
|
+
input_type=None,
|
403
|
+
output_type=None,
|
404
|
+
log_evaluation=False,
|
405
|
+
validation_steps=None,
|
406
|
+
class_colors=None,
|
407
|
+
log_batch_frequency=None,
|
408
|
+
log_best_prefix="best_",
|
409
|
+
save_graph=True,
|
410
|
+
validation_indexes=None,
|
411
|
+
validation_row_processor=None,
|
412
|
+
prediction_row_processor=None,
|
413
|
+
infer_missing_processors=True,
|
414
|
+
log_evaluation_frequency=0,
|
415
|
+
compute_flops=False,
|
416
|
+
**kwargs,
|
417
|
+
):
|
418
|
+
if wandb.run is None:
|
419
|
+
raise wandb.Error("You must call wandb.init() before WandbCallback()")
|
420
|
+
|
421
|
+
deprecate(
|
422
|
+
field_name=Deprecated.keras_callback,
|
423
|
+
warning_message=(
|
424
|
+
"WandbCallback is deprecated and will be removed in a future release. "
|
425
|
+
"Please use the WandbMetricsLogger, WandbModelCheckpoint, and WandbEvalCallback "
|
426
|
+
"callbacks instead. "
|
427
|
+
"See https://docs.wandb.ai/guides/integrations/keras for more information."
|
428
|
+
),
|
429
|
+
)
|
430
|
+
|
431
|
+
with wandb.wandb_lib.telemetry.context(run=wandb.run) as tel:
|
432
|
+
tel.feature.keras = True
|
433
|
+
self.validation_data = None
|
434
|
+
# This is kept around for legacy reasons
|
435
|
+
if validation_data is not None:
|
436
|
+
if is_generator_like(validation_data):
|
437
|
+
generator = validation_data
|
438
|
+
else:
|
439
|
+
self.validation_data = validation_data
|
440
|
+
if labels is None:
|
441
|
+
labels = []
|
442
|
+
self.labels = labels
|
443
|
+
self.predictions = min(predictions, 100)
|
444
|
+
|
445
|
+
self.monitor = monitor
|
446
|
+
self.verbose = verbose
|
447
|
+
self.save_weights_only = save_weights_only
|
448
|
+
self.save_graph = save_graph
|
449
|
+
|
450
|
+
wandb.save("model-best.h5")
|
451
|
+
self.filepath = os.path.join(wandb.run.dir, "model-best.h5")
|
452
|
+
self.save_model = save_model
|
453
|
+
if save_model:
|
454
|
+
deprecate(
|
455
|
+
field_name=Deprecated.keras_callback__save_model,
|
456
|
+
warning_message=(
|
457
|
+
"The save_model argument by default saves the model in the HDF5 format that cannot save "
|
458
|
+
"custom objects like subclassed models and custom layers. This behavior will be deprecated "
|
459
|
+
"in a future release in favor of the SavedModel format. Meanwhile, the HDF5 model is saved "
|
460
|
+
"as W&B files and the SavedModel as W&B Artifacts."
|
461
|
+
),
|
462
|
+
)
|
463
|
+
|
464
|
+
self.save_model_as_artifact = True
|
465
|
+
self.log_weights = log_weights
|
466
|
+
self.log_gradients = log_gradients
|
467
|
+
self.training_data = training_data
|
468
|
+
self.generator = generator
|
469
|
+
self._graph_rendered = False
|
470
|
+
|
471
|
+
data_type = kwargs.get("data_type", None)
|
472
|
+
if data_type is not None:
|
473
|
+
deprecate(
|
474
|
+
field_name=Deprecated.keras_callback__data_type,
|
475
|
+
warning_message=(
|
476
|
+
"The data_type argument of wandb.keras.WandbCallback is deprecated "
|
477
|
+
"and will be removed in a future release. Please use input_type instead.\n"
|
478
|
+
"Setting input_type = data_type."
|
479
|
+
),
|
480
|
+
)
|
481
|
+
input_type = data_type
|
482
|
+
self.input_type = input_type
|
483
|
+
self.output_type = output_type
|
484
|
+
self.log_evaluation = log_evaluation
|
485
|
+
self.validation_steps = validation_steps
|
486
|
+
self.class_colors = np.array(class_colors) if class_colors is not None else None
|
487
|
+
self.log_batch_frequency = log_batch_frequency
|
488
|
+
self.log_best_prefix = log_best_prefix
|
489
|
+
self.compute_flops = compute_flops
|
490
|
+
|
491
|
+
self._prediction_batch_size = None
|
492
|
+
|
493
|
+
if self.log_gradients:
|
494
|
+
if int(tf.__version__.split(".")[0]) < 2:
|
495
|
+
raise Exception("Gradient logging requires tensorflow 2.0 or higher.")
|
496
|
+
if self.training_data is None:
|
497
|
+
raise ValueError(
|
498
|
+
"training_data argument is required for gradient logging."
|
499
|
+
)
|
500
|
+
if isinstance(self.training_data, (list, tuple)):
|
501
|
+
if len(self.training_data) != 2:
|
502
|
+
raise ValueError("training data must be a tuple of length two")
|
503
|
+
self._training_data_x, self._training_data_y = self.training_data
|
504
|
+
else:
|
505
|
+
self._training_data_x = (
|
506
|
+
self.training_data
|
507
|
+
) # generator, tf.data.Dataset etc
|
508
|
+
self._training_data_y = None
|
509
|
+
|
510
|
+
# From Keras
|
511
|
+
if mode not in ["auto", "min", "max"]:
|
512
|
+
print(f"WandbCallback mode {mode} is unknown, fallback to auto mode.")
|
513
|
+
mode = "auto"
|
514
|
+
|
515
|
+
if mode == "min":
|
516
|
+
self.monitor_op = operator.lt
|
517
|
+
self.best = float("inf")
|
518
|
+
elif mode == "max":
|
519
|
+
self.monitor_op = operator.gt
|
520
|
+
self.best = float("-inf")
|
521
|
+
else:
|
522
|
+
if "acc" in self.monitor or self.monitor.startswith("fmeasure"):
|
523
|
+
self.monitor_op = operator.gt
|
524
|
+
self.best = float("-inf")
|
525
|
+
else:
|
526
|
+
self.monitor_op = operator.lt
|
527
|
+
self.best = float("inf")
|
528
|
+
# Get the previous best metric for resumed runs
|
529
|
+
previous_best = wandb.run.summary.get(f"{self.log_best_prefix}{self.monitor}")
|
530
|
+
if previous_best is not None:
|
531
|
+
self.best = previous_best
|
532
|
+
|
533
|
+
self._validation_data_logger = None
|
534
|
+
self._validation_indexes = validation_indexes
|
535
|
+
self._validation_row_processor = validation_row_processor
|
536
|
+
self._prediction_row_processor = prediction_row_processor
|
537
|
+
self._infer_missing_processors = infer_missing_processors
|
538
|
+
self._log_evaluation_frequency = log_evaluation_frequency
|
539
|
+
self._model_trained_since_last_eval = False
|
540
|
+
|
541
|
+
def _build_grad_accumulator_model(self):
|
542
|
+
inputs = self.model.inputs
|
543
|
+
outputs = self.model(inputs)
|
544
|
+
grad_acc_model = tf.keras.models.Model(inputs, outputs)
|
545
|
+
grad_acc_model.compile(loss=self.model.loss, optimizer=_CustomOptimizer())
|
546
|
+
|
547
|
+
# make sure magic doesn't think this is a user model
|
548
|
+
grad_acc_model._wandb_internal_model = True
|
549
|
+
|
550
|
+
self._grad_accumulator_model = grad_acc_model
|
551
|
+
self._grad_accumulator_callback = _GradAccumulatorCallback()
|
552
|
+
|
553
|
+
def _implements_train_batch_hooks(self):
|
554
|
+
return self.log_batch_frequency is not None
|
555
|
+
|
556
|
+
def _implements_test_batch_hooks(self):
|
557
|
+
return self.log_batch_frequency is not None
|
558
|
+
|
559
|
+
def _implements_predict_batch_hooks(self):
|
560
|
+
return self.log_batch_frequency is not None
|
561
|
+
|
562
|
+
def set_params(self, params):
|
563
|
+
self.params = params
|
564
|
+
|
565
|
+
def set_model(self, model):
|
566
|
+
super().set_model(model)
|
567
|
+
if self.input_type == "auto" and len(model.inputs) == 1:
|
568
|
+
self.input_type = wandb.util.guess_data_type(
|
569
|
+
model.inputs[0].shape, risky=True
|
570
|
+
)
|
571
|
+
if self.input_type and self.output_type is None and len(model.outputs) == 1:
|
572
|
+
self.output_type = wandb.util.guess_data_type(model.outputs[0].shape)
|
573
|
+
if self.log_gradients:
|
574
|
+
self._build_grad_accumulator_model()
|
575
|
+
|
576
|
+
def _attempt_evaluation_log(self, commit=True):
|
577
|
+
if self.log_evaluation and self._validation_data_logger:
|
578
|
+
try:
|
579
|
+
if not self.model:
|
580
|
+
wandb.termwarn("WandbCallback unable to read model from trainer")
|
581
|
+
else:
|
582
|
+
self._validation_data_logger.log_predictions(
|
583
|
+
predictions=self._validation_data_logger.make_predictions(
|
584
|
+
self.model.predict
|
585
|
+
),
|
586
|
+
commit=commit,
|
587
|
+
)
|
588
|
+
self._model_trained_since_last_eval = False
|
589
|
+
except Exception as e:
|
590
|
+
wandb.termwarn("Error during prediction logging for epoch: " + str(e))
|
591
|
+
|
592
|
+
def on_epoch_end(self, epoch, logs=None):
|
593
|
+
if logs is None:
|
594
|
+
logs = {}
|
595
|
+
if self.log_weights:
|
596
|
+
wandb.log(self._log_weights(), commit=False)
|
597
|
+
|
598
|
+
if self.log_gradients:
|
599
|
+
wandb.log(self._log_gradients(), commit=False)
|
600
|
+
|
601
|
+
if self.input_type in (
|
602
|
+
"image",
|
603
|
+
"images",
|
604
|
+
"segmentation_mask",
|
605
|
+
) or self.output_type in ("image", "images", "segmentation_mask"):
|
606
|
+
if self.generator:
|
607
|
+
self.validation_data = next(self.generator)
|
608
|
+
if self.validation_data is None:
|
609
|
+
wandb.termwarn(
|
610
|
+
"No validation_data set, pass a generator to the callback."
|
611
|
+
)
|
612
|
+
elif self.validation_data and len(self.validation_data) > 0:
|
613
|
+
wandb.log(
|
614
|
+
{"examples": self._log_images(num_images=self.predictions)},
|
615
|
+
commit=False,
|
616
|
+
)
|
617
|
+
|
618
|
+
if (
|
619
|
+
self._log_evaluation_frequency > 0
|
620
|
+
and epoch % self._log_evaluation_frequency == 0
|
621
|
+
):
|
622
|
+
self._attempt_evaluation_log(commit=False)
|
623
|
+
|
624
|
+
wandb.log({"epoch": epoch}, commit=False)
|
625
|
+
wandb.log(logs, commit=True)
|
626
|
+
|
627
|
+
self.current = logs.get(self.monitor)
|
628
|
+
if self.current and self.monitor_op(self.current, self.best):
|
629
|
+
if self.log_best_prefix:
|
630
|
+
wandb.run.summary[f"{self.log_best_prefix}{self.monitor}"] = (
|
631
|
+
self.current
|
632
|
+
)
|
633
|
+
wandb.run.summary["{}{}".format(self.log_best_prefix, "epoch")] = epoch
|
634
|
+
if self.verbose and not self.save_model:
|
635
|
+
print(
|
636
|
+
"Epoch %05d: %s improved from %0.5f to %0.5f"
|
637
|
+
% (epoch, self.monitor, self.best, self.current)
|
638
|
+
)
|
639
|
+
if self.save_model:
|
640
|
+
self._save_model(epoch)
|
641
|
+
|
642
|
+
if self.save_model and self.save_model_as_artifact:
|
643
|
+
self._save_model_as_artifact(epoch)
|
644
|
+
|
645
|
+
self.best = self.current
|
646
|
+
|
647
|
+
# This is what keras used pre tensorflow.keras
|
648
|
+
def on_batch_begin(self, batch, logs=None):
|
649
|
+
pass
|
650
|
+
|
651
|
+
# This is what keras used pre tensorflow.keras
|
652
|
+
def on_batch_end(self, batch, logs=None):
|
653
|
+
if self.save_graph and not self._graph_rendered:
|
654
|
+
# Couldn't do this in train_begin because keras may still not be built
|
655
|
+
wandb.run.summary["graph"] = wandb.Graph.from_keras(self.model)
|
656
|
+
self._graph_rendered = True
|
657
|
+
|
658
|
+
if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
|
659
|
+
wandb.log(logs, commit=True)
|
660
|
+
|
661
|
+
def on_train_batch_begin(self, batch, logs=None):
|
662
|
+
self._model_trained_since_last_eval = True
|
663
|
+
|
664
|
+
def on_train_batch_end(self, batch, logs=None):
|
665
|
+
if self.save_graph and not self._graph_rendered:
|
666
|
+
# Couldn't do this in train_begin because keras may still not be built
|
667
|
+
wandb.run.summary["graph"] = wandb.Graph.from_keras(self.model)
|
668
|
+
self._graph_rendered = True
|
669
|
+
|
670
|
+
if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
|
671
|
+
wandb.log(logs, commit=True)
|
672
|
+
|
673
|
+
def on_test_begin(self, logs=None):
|
674
|
+
pass
|
675
|
+
|
676
|
+
def on_test_end(self, logs=None):
|
677
|
+
pass
|
678
|
+
|
679
|
+
def on_test_batch_begin(self, batch, logs=None):
|
680
|
+
pass
|
681
|
+
|
682
|
+
def on_test_batch_end(self, batch, logs=None):
|
683
|
+
pass
|
684
|
+
|
685
|
+
def on_train_begin(self, logs=None):
|
686
|
+
if self.log_evaluation:
|
687
|
+
try:
|
688
|
+
validation_data = None
|
689
|
+
if self.validation_data:
|
690
|
+
validation_data = self.validation_data
|
691
|
+
elif self.generator:
|
692
|
+
if not self.validation_steps:
|
693
|
+
wandb.termwarn(
|
694
|
+
"WandbCallback is unable to log validation data. "
|
695
|
+
"When using a generator for validation_data, you must pass validation_steps"
|
696
|
+
)
|
697
|
+
else:
|
698
|
+
x = None
|
699
|
+
y_true = None
|
700
|
+
for _ in range(self.validation_steps):
|
701
|
+
bx, by_true = next(self.generator)
|
702
|
+
if x is None:
|
703
|
+
x, y_true = bx, by_true
|
704
|
+
else:
|
705
|
+
x, y_true = (
|
706
|
+
np.append(x, bx, axis=0),
|
707
|
+
np.append(y_true, by_true, axis=0),
|
708
|
+
)
|
709
|
+
validation_data = (x, y_true)
|
710
|
+
else:
|
711
|
+
wandb.termwarn(
|
712
|
+
"WandbCallback is unable to read validation_data from trainer "
|
713
|
+
"and therefore cannot log validation data. Ensure Keras is properly "
|
714
|
+
"patched by calling `from wandb.keras import WandbCallback` at the top of your script."
|
715
|
+
)
|
716
|
+
if validation_data:
|
717
|
+
self._validation_data_logger = ValidationDataLogger(
|
718
|
+
inputs=validation_data[0],
|
719
|
+
targets=validation_data[1],
|
720
|
+
indexes=self._validation_indexes,
|
721
|
+
validation_row_processor=self._validation_row_processor,
|
722
|
+
prediction_row_processor=self._prediction_row_processor,
|
723
|
+
class_labels=self.labels,
|
724
|
+
infer_missing_processors=self._infer_missing_processors,
|
725
|
+
)
|
726
|
+
except Exception as e:
|
727
|
+
wandb.termwarn(
|
728
|
+
"Error initializing ValidationDataLogger in WandbCallback. "
|
729
|
+
f"Skipping logging validation data. Error: {str(e)}"
|
730
|
+
)
|
731
|
+
|
732
|
+
if self.compute_flops and _can_compute_flops():
|
733
|
+
try:
|
734
|
+
wandb.summary["GFLOPs"] = self.get_flops()
|
735
|
+
except Exception as e:
|
736
|
+
wandb.termwarn("Unable to compute FLOPs for this model.")
|
737
|
+
logger.exception(e)
|
738
|
+
|
739
|
+
def on_train_end(self, logs=None):
|
740
|
+
if self._model_trained_since_last_eval:
|
741
|
+
self._attempt_evaluation_log()
|
742
|
+
|
743
|
+
def on_predict_begin(self, logs=None):
|
744
|
+
pass
|
745
|
+
|
746
|
+
def on_predict_end(self, logs=None):
|
747
|
+
pass
|
748
|
+
|
749
|
+
def on_predict_batch_begin(self, batch, logs=None):
|
750
|
+
pass
|
751
|
+
|
752
|
+
def on_predict_batch_end(self, batch, logs=None):
|
753
|
+
pass
|
754
|
+
|
755
|
+
def _logits_to_captions(self, logits):
|
756
|
+
if logits[0].shape[-1] == 1:
|
757
|
+
# Scalar output from the model
|
758
|
+
# TODO: handle validation_y
|
759
|
+
if len(self.labels) == 2:
|
760
|
+
# User has named true and false
|
761
|
+
captions = [
|
762
|
+
self.labels[1] if logits[0] > 0.5 else self.labels[0]
|
763
|
+
for logit in logits
|
764
|
+
]
|
765
|
+
else:
|
766
|
+
if len(self.labels) != 0:
|
767
|
+
wandb.termwarn(
|
768
|
+
"keras model is producing a single output, "
|
769
|
+
'so labels should be a length two array: ["False label", "True label"].'
|
770
|
+
)
|
771
|
+
captions = [logit[0] for logit in logits]
|
772
|
+
else:
|
773
|
+
# Vector output from the model
|
774
|
+
# TODO: handle validation_y
|
775
|
+
labels = np.argmax(np.stack(logits), axis=1)
|
776
|
+
|
777
|
+
if len(self.labels) > 0:
|
778
|
+
# User has named the categories in self.labels
|
779
|
+
captions = []
|
780
|
+
for label in labels:
|
781
|
+
try:
|
782
|
+
captions.append(self.labels[label])
|
783
|
+
except IndexError:
|
784
|
+
captions.append(label)
|
785
|
+
else:
|
786
|
+
captions = labels
|
787
|
+
return captions
|
788
|
+
|
789
|
+
def _masks_to_pixels(self, masks):
|
790
|
+
# if its a binary mask, just return it as grayscale instead of picking the argmax
|
791
|
+
if len(masks[0].shape) == 2 or masks[0].shape[-1] == 1:
|
792
|
+
return masks
|
793
|
+
class_colors = (
|
794
|
+
self.class_colors
|
795
|
+
if self.class_colors is not None
|
796
|
+
else np.array(wandb.util.class_colors(masks[0].shape[2]))
|
797
|
+
)
|
798
|
+
imgs = class_colors[np.argmax(masks, axis=-1)]
|
799
|
+
return imgs
|
800
|
+
|
801
|
+
def _log_images(self, num_images=36):
|
802
|
+
validation_X = self.validation_data[0] # noqa: N806
|
803
|
+
validation_y = self.validation_data[1]
|
804
|
+
|
805
|
+
validation_length = len(validation_X)
|
806
|
+
|
807
|
+
if validation_length > num_images:
|
808
|
+
# pick some data at random
|
809
|
+
indices = np.random.choice(validation_length, num_images, replace=False)
|
810
|
+
else:
|
811
|
+
indices = range(validation_length)
|
812
|
+
|
813
|
+
test_data = []
|
814
|
+
test_output = []
|
815
|
+
for i in indices:
|
816
|
+
test_example = validation_X[i]
|
817
|
+
test_data.append(test_example)
|
818
|
+
test_output.append(validation_y[i])
|
819
|
+
|
820
|
+
if self.model.stateful:
|
821
|
+
predictions = self.model.predict(np.stack(test_data), batch_size=1)
|
822
|
+
self.model.reset_states()
|
823
|
+
else:
|
824
|
+
predictions = self.model.predict(
|
825
|
+
np.stack(test_data), batch_size=self._prediction_batch_size
|
826
|
+
)
|
827
|
+
if len(predictions) != len(test_data):
|
828
|
+
self._prediction_batch_size = 1
|
829
|
+
predictions = self.model.predict(
|
830
|
+
np.stack(test_data), batch_size=self._prediction_batch_size
|
831
|
+
)
|
832
|
+
|
833
|
+
if self.input_type == "label":
|
834
|
+
if self.output_type in ("image", "images", "segmentation_mask"):
|
835
|
+
captions = self._logits_to_captions(test_data)
|
836
|
+
output_image_data = (
|
837
|
+
self._masks_to_pixels(predictions)
|
838
|
+
if self.output_type == "segmentation_mask"
|
839
|
+
else predictions
|
840
|
+
)
|
841
|
+
reference_image_data = (
|
842
|
+
self._masks_to_pixels(test_output)
|
843
|
+
if self.output_type == "segmentation_mask"
|
844
|
+
else test_output
|
845
|
+
)
|
846
|
+
output_images = [
|
847
|
+
wandb.Image(data, caption=captions[i], grouping=2)
|
848
|
+
for i, data in enumerate(output_image_data)
|
849
|
+
]
|
850
|
+
reference_images = [
|
851
|
+
wandb.Image(data, caption=captions[i])
|
852
|
+
for i, data in enumerate(reference_image_data)
|
853
|
+
]
|
854
|
+
return list(chain.from_iterable(zip(output_images, reference_images)))
|
855
|
+
elif self.input_type in ("image", "images", "segmentation_mask"):
|
856
|
+
input_image_data = (
|
857
|
+
self._masks_to_pixels(test_data)
|
858
|
+
if self.input_type == "segmentation_mask"
|
859
|
+
else test_data
|
860
|
+
)
|
861
|
+
if self.output_type == "label":
|
862
|
+
# we just use the predicted label as the caption for now
|
863
|
+
captions = self._logits_to_captions(predictions)
|
864
|
+
return [
|
865
|
+
wandb.Image(data, caption=captions[i])
|
866
|
+
for i, data in enumerate(test_data)
|
867
|
+
]
|
868
|
+
elif self.output_type in ("image", "images", "segmentation_mask"):
|
869
|
+
output_image_data = (
|
870
|
+
self._masks_to_pixels(predictions)
|
871
|
+
if self.output_type == "segmentation_mask"
|
872
|
+
else predictions
|
873
|
+
)
|
874
|
+
reference_image_data = (
|
875
|
+
self._masks_to_pixels(test_output)
|
876
|
+
if self.output_type == "segmentation_mask"
|
877
|
+
else test_output
|
878
|
+
)
|
879
|
+
input_images = [
|
880
|
+
wandb.Image(data, grouping=3)
|
881
|
+
for i, data in enumerate(input_image_data)
|
882
|
+
]
|
883
|
+
output_images = [
|
884
|
+
wandb.Image(data) for i, data in enumerate(output_image_data)
|
885
|
+
]
|
886
|
+
reference_images = [
|
887
|
+
wandb.Image(data) for i, data in enumerate(reference_image_data)
|
888
|
+
]
|
889
|
+
return list(
|
890
|
+
chain.from_iterable(
|
891
|
+
zip(input_images, output_images, reference_images)
|
892
|
+
)
|
893
|
+
)
|
894
|
+
else:
|
895
|
+
# unknown output, just log the input images
|
896
|
+
return [wandb.Image(img) for img in test_data]
|
897
|
+
elif self.output_type in ("image", "images", "segmentation_mask"):
|
898
|
+
# unknown input, just log the predicted and reference outputs without captions
|
899
|
+
output_image_data = (
|
900
|
+
self._masks_to_pixels(predictions)
|
901
|
+
if self.output_type == "segmentation_mask"
|
902
|
+
else predictions
|
903
|
+
)
|
904
|
+
reference_image_data = (
|
905
|
+
self._masks_to_pixels(test_output)
|
906
|
+
if self.output_type == "segmentation_mask"
|
907
|
+
else test_output
|
908
|
+
)
|
909
|
+
output_images = [
|
910
|
+
wandb.Image(data, grouping=2)
|
911
|
+
for i, data in enumerate(output_image_data)
|
912
|
+
]
|
913
|
+
reference_images = [
|
914
|
+
wandb.Image(data) for i, data in enumerate(reference_image_data)
|
915
|
+
]
|
916
|
+
return list(chain.from_iterable(zip(output_images, reference_images)))
|
917
|
+
|
918
|
+
def _log_weights(self):
|
919
|
+
metrics = {}
|
920
|
+
for layer in self.model.layers:
|
921
|
+
weights = layer.get_weights()
|
922
|
+
if len(weights) == 1:
|
923
|
+
_update_if_numeric(
|
924
|
+
metrics, "parameters/" + layer.name + ".weights", weights[0]
|
925
|
+
)
|
926
|
+
elif len(weights) == 2:
|
927
|
+
_update_if_numeric(
|
928
|
+
metrics, "parameters/" + layer.name + ".weights", weights[0]
|
929
|
+
)
|
930
|
+
_update_if_numeric(
|
931
|
+
metrics, "parameters/" + layer.name + ".bias", weights[1]
|
932
|
+
)
|
933
|
+
return metrics
|
934
|
+
|
935
|
+
def _log_gradients(self):
|
936
|
+
# Suppress callback warnings grad accumulator
|
937
|
+
og_level = tf_logger.level
|
938
|
+
tf_logger.setLevel("ERROR")
|
939
|
+
|
940
|
+
self._grad_accumulator_model.fit(
|
941
|
+
self._training_data_x,
|
942
|
+
self._training_data_y,
|
943
|
+
verbose=0,
|
944
|
+
callbacks=[self._grad_accumulator_callback],
|
945
|
+
)
|
946
|
+
tf_logger.setLevel(og_level)
|
947
|
+
weights = self.model.trainable_weights
|
948
|
+
grads = self._grad_accumulator_callback.grads
|
949
|
+
metrics = {}
|
950
|
+
for weight, grad in zip(weights, grads):
|
951
|
+
metrics["gradients/" + weight.name.split(":")[0] + ".gradient"] = (
|
952
|
+
wandb.Histogram(grad)
|
953
|
+
)
|
954
|
+
return metrics
|
955
|
+
|
956
|
+
def _log_dataframe(self):
|
957
|
+
x, y_true, y_pred = None, None, None
|
958
|
+
|
959
|
+
if self.validation_data:
|
960
|
+
x, y_true = self.validation_data[0], self.validation_data[1]
|
961
|
+
y_pred = self.model.predict(x)
|
962
|
+
elif self.generator:
|
963
|
+
if not self.validation_steps:
|
964
|
+
wandb.termwarn(
|
965
|
+
"when using a generator for validation data with dataframes, "
|
966
|
+
"you must pass validation_steps. skipping"
|
967
|
+
)
|
968
|
+
return None
|
969
|
+
|
970
|
+
for _ in range(self.validation_steps):
|
971
|
+
bx, by_true = next(self.generator)
|
972
|
+
by_pred = self.model.predict(bx)
|
973
|
+
if x is None:
|
974
|
+
x, y_true, y_pred = bx, by_true, by_pred
|
975
|
+
else:
|
976
|
+
x, y_true, y_pred = (
|
977
|
+
np.append(x, bx, axis=0),
|
978
|
+
np.append(y_true, by_true, axis=0),
|
979
|
+
np.append(y_pred, by_pred, axis=0),
|
980
|
+
)
|
981
|
+
|
982
|
+
if self.input_type in ("image", "images") and self.output_type == "label":
|
983
|
+
return wandb.image_categorizer_dataframe(
|
984
|
+
x=x, y_true=y_true, y_pred=y_pred, labels=self.labels
|
985
|
+
)
|
986
|
+
elif (
|
987
|
+
self.input_type in ("image", "images")
|
988
|
+
and self.output_type == "segmentation_mask"
|
989
|
+
):
|
990
|
+
return wandb.image_segmentation_dataframe(
|
991
|
+
x=x,
|
992
|
+
y_true=y_true,
|
993
|
+
y_pred=y_pred,
|
994
|
+
labels=self.labels,
|
995
|
+
class_colors=self.class_colors,
|
996
|
+
)
|
997
|
+
else:
|
998
|
+
wandb.termwarn(
|
999
|
+
f"unknown dataframe type for input_type={self.input_type} and output_type={self.output_type}"
|
1000
|
+
)
|
1001
|
+
return None
|
1002
|
+
|
1003
|
+
def _save_model(self, epoch):
|
1004
|
+
if wandb.run.disabled:
|
1005
|
+
return
|
1006
|
+
if self.verbose > 0:
|
1007
|
+
print(
|
1008
|
+
"Epoch %05d: %s improved from %0.5f to %0.5f,"
|
1009
|
+
" saving model to %s"
|
1010
|
+
% (epoch, self.monitor, self.best, self.current, self.filepath)
|
1011
|
+
)
|
1012
|
+
|
1013
|
+
try:
|
1014
|
+
if self.save_weights_only:
|
1015
|
+
self.model.save_weights(self.filepath, overwrite=True)
|
1016
|
+
else:
|
1017
|
+
self.model.save(self.filepath, overwrite=True)
|
1018
|
+
# Was getting `RuntimeError: Unable to create link` in TF 1.13.1
|
1019
|
+
# also saw `TypeError: can't pickle _thread.RLock objects`
|
1020
|
+
except (ImportError, RuntimeError, TypeError, AttributeError) as e:
|
1021
|
+
wandb.termerror(
|
1022
|
+
"Can't save model in the h5py format. The model will be saved as "
|
1023
|
+
"as an W&B Artifact in the 'tf' format."
|
1024
|
+
)
|
1025
|
+
logger.exception(e)
|
1026
|
+
|
1027
|
+
def _save_model_as_artifact(self, epoch):
|
1028
|
+
if wandb.run.disabled:
|
1029
|
+
return
|
1030
|
+
|
1031
|
+
# Save the model in the SavedModel format.
|
1032
|
+
# TODO: Replace this manual artifact creation with the `log_model` method
|
1033
|
+
# after `log_model` is released from beta.
|
1034
|
+
self.model.save(self.filepath[:-3], overwrite=True, save_format="tf")
|
1035
|
+
|
1036
|
+
# Log the model as artifact.
|
1037
|
+
name = wandb.util.make_artifact_name_safe(f"model-{wandb.run.name}")
|
1038
|
+
model_artifact = wandb.Artifact(name, type="model")
|
1039
|
+
model_artifact.add_dir(self.filepath[:-3])
|
1040
|
+
wandb.run.log_artifact(model_artifact, aliases=["latest", f"epoch_{epoch}"])
|
1041
|
+
|
1042
|
+
# Remove the SavedModel from wandb dir as we don't want to log it to save memory.
|
1043
|
+
shutil.rmtree(self.filepath[:-3])
|
1044
|
+
|
1045
|
+
def get_flops(self) -> float:
|
1046
|
+
"""Calculate FLOPS [GFLOPs] for a tf.keras.Model or tf.keras.Sequential model in inference mode.
|
1047
|
+
|
1048
|
+
It uses tf.compat.v1.profiler under the hood.
|
1049
|
+
"""
|
1050
|
+
if not hasattr(self, "model"):
|
1051
|
+
raise wandb.Error("self.model must be set before using this method.")
|
1052
|
+
|
1053
|
+
if not isinstance(
|
1054
|
+
self.model, (tf.keras.models.Sequential, tf.keras.models.Model)
|
1055
|
+
):
|
1056
|
+
raise ValueError(
|
1057
|
+
"Calculating FLOPS is only supported for "
|
1058
|
+
"`tf.keras.Model` and `tf.keras.Sequential` instances."
|
1059
|
+
)
|
1060
|
+
|
1061
|
+
from tensorflow.python.framework.convert_to_constants import (
|
1062
|
+
convert_variables_to_constants_v2_as_graph,
|
1063
|
+
)
|
1064
|
+
|
1065
|
+
# Compute FLOPs for one sample
|
1066
|
+
batch_size = 1
|
1067
|
+
inputs = [
|
1068
|
+
tf.TensorSpec([batch_size] + inp.shape[1:], inp.dtype)
|
1069
|
+
for inp in self.model.inputs
|
1070
|
+
]
|
1071
|
+
|
1072
|
+
# convert tf.keras model into frozen graph to count FLOPs about operations used at inference
|
1073
|
+
real_model = tf.function(self.model).get_concrete_function(inputs)
|
1074
|
+
frozen_func, _ = convert_variables_to_constants_v2_as_graph(real_model)
|
1075
|
+
|
1076
|
+
# Calculate FLOPs with tf.profiler
|
1077
|
+
run_meta = tf.compat.v1.RunMetadata()
|
1078
|
+
opts = (
|
1079
|
+
tf.compat.v1.profiler.ProfileOptionBuilder(
|
1080
|
+
tf.compat.v1.profiler.ProfileOptionBuilder().float_operation()
|
1081
|
+
)
|
1082
|
+
.with_empty_output()
|
1083
|
+
.build()
|
1084
|
+
)
|
1085
|
+
|
1086
|
+
flops = tf.compat.v1.profiler.profile(
|
1087
|
+
graph=frozen_func.graph, run_meta=run_meta, cmd="scope", options=opts
|
1088
|
+
)
|
1089
|
+
|
1090
|
+
# convert to GFLOPs
|
1091
|
+
return (flops.total_float_ops / 1e9) / 2
|