wandb 0.18.2__py3-none-musllinux_1_2_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (827) hide show
  1. package_readme.md +89 -0
  2. wandb/__init__.py +245 -0
  3. wandb/__init__.pyi +1139 -0
  4. wandb/__main__.py +3 -0
  5. wandb/_globals.py +19 -0
  6. wandb/agents/__init__.py +0 -0
  7. wandb/agents/pyagent.py +363 -0
  8. wandb/analytics/__init__.py +3 -0
  9. wandb/analytics/sentry.py +266 -0
  10. wandb/apis/__init__.py +48 -0
  11. wandb/apis/attrs.py +40 -0
  12. wandb/apis/importers/__init__.py +1 -0
  13. wandb/apis/importers/internals/internal.py +385 -0
  14. wandb/apis/importers/internals/protocols.py +99 -0
  15. wandb/apis/importers/internals/util.py +78 -0
  16. wandb/apis/importers/mlflow.py +254 -0
  17. wandb/apis/importers/validation.py +108 -0
  18. wandb/apis/importers/wandb.py +1603 -0
  19. wandb/apis/internal.py +232 -0
  20. wandb/apis/normalize.py +89 -0
  21. wandb/apis/paginator.py +81 -0
  22. wandb/apis/public/__init__.py +34 -0
  23. wandb/apis/public/api.py +1305 -0
  24. wandb/apis/public/artifacts.py +1090 -0
  25. wandb/apis/public/const.py +4 -0
  26. wandb/apis/public/files.py +195 -0
  27. wandb/apis/public/history.py +149 -0
  28. wandb/apis/public/jobs.py +659 -0
  29. wandb/apis/public/projects.py +154 -0
  30. wandb/apis/public/query_generator.py +166 -0
  31. wandb/apis/public/reports.py +469 -0
  32. wandb/apis/public/runs.py +914 -0
  33. wandb/apis/public/sweeps.py +240 -0
  34. wandb/apis/public/teams.py +198 -0
  35. wandb/apis/public/users.py +136 -0
  36. wandb/apis/reports/__init__.py +1 -0
  37. wandb/apis/reports/v1/__init__.py +8 -0
  38. wandb/apis/reports/v2/__init__.py +8 -0
  39. wandb/apis/workspaces/__init__.py +8 -0
  40. wandb/beta/workflows.py +288 -0
  41. wandb/bin/nvidia_gpu_stats +0 -0
  42. wandb/bin/wandb-core +0 -0
  43. wandb/cli/__init__.py +0 -0
  44. wandb/cli/cli.py +3004 -0
  45. wandb/data_types.py +63 -0
  46. wandb/docker/__init__.py +342 -0
  47. wandb/docker/auth.py +436 -0
  48. wandb/docker/wandb-entrypoint.sh +33 -0
  49. wandb/docker/www_authenticate.py +94 -0
  50. wandb/env.py +514 -0
  51. wandb/errors/__init__.py +17 -0
  52. wandb/errors/errors.py +37 -0
  53. wandb/errors/term.py +103 -0
  54. wandb/errors/util.py +57 -0
  55. wandb/errors/warnings.py +2 -0
  56. wandb/filesync/__init__.py +0 -0
  57. wandb/filesync/dir_watcher.py +403 -0
  58. wandb/filesync/stats.py +100 -0
  59. wandb/filesync/step_checksum.py +142 -0
  60. wandb/filesync/step_prepare.py +179 -0
  61. wandb/filesync/step_upload.py +290 -0
  62. wandb/filesync/upload_job.py +142 -0
  63. wandb/integration/__init__.py +0 -0
  64. wandb/integration/catboost/__init__.py +5 -0
  65. wandb/integration/catboost/catboost.py +178 -0
  66. wandb/integration/cohere/__init__.py +3 -0
  67. wandb/integration/cohere/cohere.py +21 -0
  68. wandb/integration/cohere/resolver.py +347 -0
  69. wandb/integration/diffusers/__init__.py +3 -0
  70. wandb/integration/diffusers/autologger.py +76 -0
  71. wandb/integration/diffusers/pipeline_resolver.py +50 -0
  72. wandb/integration/diffusers/resolvers/__init__.py +9 -0
  73. wandb/integration/diffusers/resolvers/multimodal.py +882 -0
  74. wandb/integration/diffusers/resolvers/utils.py +102 -0
  75. wandb/integration/fastai/__init__.py +249 -0
  76. wandb/integration/gym/__init__.py +105 -0
  77. wandb/integration/huggingface/__init__.py +3 -0
  78. wandb/integration/huggingface/huggingface.py +18 -0
  79. wandb/integration/huggingface/resolver.py +213 -0
  80. wandb/integration/keras/__init__.py +11 -0
  81. wandb/integration/keras/callbacks/__init__.py +5 -0
  82. wandb/integration/keras/callbacks/metrics_logger.py +136 -0
  83. wandb/integration/keras/callbacks/model_checkpoint.py +195 -0
  84. wandb/integration/keras/callbacks/tables_builder.py +226 -0
  85. wandb/integration/keras/keras.py +1091 -0
  86. wandb/integration/kfp/__init__.py +6 -0
  87. wandb/integration/kfp/helpers.py +28 -0
  88. wandb/integration/kfp/kfp_patch.py +324 -0
  89. wandb/integration/kfp/wandb_logging.py +182 -0
  90. wandb/integration/langchain/__init__.py +3 -0
  91. wandb/integration/langchain/wandb_tracer.py +48 -0
  92. wandb/integration/lightgbm/__init__.py +239 -0
  93. wandb/integration/lightning/__init__.py +0 -0
  94. wandb/integration/lightning/fabric/__init__.py +3 -0
  95. wandb/integration/lightning/fabric/logger.py +762 -0
  96. wandb/integration/magic.py +556 -0
  97. wandb/integration/metaflow/__init__.py +3 -0
  98. wandb/integration/metaflow/metaflow.py +383 -0
  99. wandb/integration/openai/__init__.py +3 -0
  100. wandb/integration/openai/fine_tuning.py +480 -0
  101. wandb/integration/openai/openai.py +22 -0
  102. wandb/integration/openai/resolver.py +240 -0
  103. wandb/integration/prodigy/__init__.py +3 -0
  104. wandb/integration/prodigy/prodigy.py +299 -0
  105. wandb/integration/sacred/__init__.py +117 -0
  106. wandb/integration/sagemaker/__init__.py +12 -0
  107. wandb/integration/sagemaker/auth.py +28 -0
  108. wandb/integration/sagemaker/config.py +49 -0
  109. wandb/integration/sagemaker/files.py +3 -0
  110. wandb/integration/sagemaker/resources.py +34 -0
  111. wandb/integration/sb3/__init__.py +3 -0
  112. wandb/integration/sb3/sb3.py +153 -0
  113. wandb/integration/sklearn/__init__.py +37 -0
  114. wandb/integration/sklearn/calculate/__init__.py +32 -0
  115. wandb/integration/sklearn/calculate/calibration_curves.py +125 -0
  116. wandb/integration/sklearn/calculate/class_proportions.py +68 -0
  117. wandb/integration/sklearn/calculate/confusion_matrix.py +93 -0
  118. wandb/integration/sklearn/calculate/decision_boundaries.py +40 -0
  119. wandb/integration/sklearn/calculate/elbow_curve.py +55 -0
  120. wandb/integration/sklearn/calculate/feature_importances.py +67 -0
  121. wandb/integration/sklearn/calculate/learning_curve.py +64 -0
  122. wandb/integration/sklearn/calculate/outlier_candidates.py +69 -0
  123. wandb/integration/sklearn/calculate/residuals.py +86 -0
  124. wandb/integration/sklearn/calculate/silhouette.py +118 -0
  125. wandb/integration/sklearn/calculate/summary_metrics.py +62 -0
  126. wandb/integration/sklearn/plot/__init__.py +35 -0
  127. wandb/integration/sklearn/plot/classifier.py +329 -0
  128. wandb/integration/sklearn/plot/clusterer.py +146 -0
  129. wandb/integration/sklearn/plot/regressor.py +121 -0
  130. wandb/integration/sklearn/plot/shared.py +91 -0
  131. wandb/integration/sklearn/utils.py +183 -0
  132. wandb/integration/tensorboard/__init__.py +10 -0
  133. wandb/integration/tensorboard/log.py +355 -0
  134. wandb/integration/tensorboard/monkeypatch.py +185 -0
  135. wandb/integration/tensorflow/__init__.py +5 -0
  136. wandb/integration/tensorflow/estimator_hook.py +54 -0
  137. wandb/integration/torch/__init__.py +0 -0
  138. wandb/integration/torch/wandb_torch.py +554 -0
  139. wandb/integration/ultralytics/__init__.py +11 -0
  140. wandb/integration/ultralytics/bbox_utils.py +208 -0
  141. wandb/integration/ultralytics/callback.py +524 -0
  142. wandb/integration/ultralytics/classification_utils.py +83 -0
  143. wandb/integration/ultralytics/mask_utils.py +202 -0
  144. wandb/integration/ultralytics/pose_utils.py +103 -0
  145. wandb/integration/xgboost/__init__.py +11 -0
  146. wandb/integration/xgboost/xgboost.py +189 -0
  147. wandb/integration/yolov8/__init__.py +0 -0
  148. wandb/integration/yolov8/yolov8.py +284 -0
  149. wandb/jupyter.py +515 -0
  150. wandb/magic.py +3 -0
  151. wandb/mpmain/__init__.py +0 -0
  152. wandb/mpmain/__main__.py +1 -0
  153. wandb/old/__init__.py +0 -0
  154. wandb/old/core.py +53 -0
  155. wandb/old/settings.py +173 -0
  156. wandb/old/summary.py +440 -0
  157. wandb/plot/__init__.py +19 -0
  158. wandb/plot/bar.py +45 -0
  159. wandb/plot/confusion_matrix.py +100 -0
  160. wandb/plot/histogram.py +39 -0
  161. wandb/plot/line.py +43 -0
  162. wandb/plot/line_series.py +88 -0
  163. wandb/plot/pr_curve.py +136 -0
  164. wandb/plot/roc_curve.py +118 -0
  165. wandb/plot/scatter.py +32 -0
  166. wandb/plot/utils.py +183 -0
  167. wandb/plot/viz.py +123 -0
  168. wandb/proto/__init__.py +0 -0
  169. wandb/proto/v3/__init__.py +0 -0
  170. wandb/proto/v3/wandb_base_pb2.py +55 -0
  171. wandb/proto/v3/wandb_internal_pb2.py +1608 -0
  172. wandb/proto/v3/wandb_server_pb2.py +208 -0
  173. wandb/proto/v3/wandb_settings_pb2.py +112 -0
  174. wandb/proto/v3/wandb_telemetry_pb2.py +106 -0
  175. wandb/proto/v4/__init__.py +0 -0
  176. wandb/proto/v4/wandb_base_pb2.py +30 -0
  177. wandb/proto/v4/wandb_internal_pb2.py +360 -0
  178. wandb/proto/v4/wandb_server_pb2.py +63 -0
  179. wandb/proto/v4/wandb_settings_pb2.py +45 -0
  180. wandb/proto/v4/wandb_telemetry_pb2.py +41 -0
  181. wandb/proto/v5/wandb_base_pb2.py +31 -0
  182. wandb/proto/v5/wandb_internal_pb2.py +361 -0
  183. wandb/proto/v5/wandb_server_pb2.py +64 -0
  184. wandb/proto/v5/wandb_settings_pb2.py +46 -0
  185. wandb/proto/v5/wandb_telemetry_pb2.py +42 -0
  186. wandb/proto/wandb_base_pb2.py +10 -0
  187. wandb/proto/wandb_deprecated.py +53 -0
  188. wandb/proto/wandb_generate_deprecated.py +34 -0
  189. wandb/proto/wandb_generate_proto.py +49 -0
  190. wandb/proto/wandb_internal_pb2.py +16 -0
  191. wandb/proto/wandb_server_pb2.py +10 -0
  192. wandb/proto/wandb_settings_pb2.py +10 -0
  193. wandb/proto/wandb_telemetry_pb2.py +10 -0
  194. wandb/py.typed +0 -0
  195. wandb/sdk/__init__.py +37 -0
  196. wandb/sdk/artifacts/__init__.py +0 -0
  197. wandb/sdk/artifacts/_validators.py +90 -0
  198. wandb/sdk/artifacts/artifact.py +2389 -0
  199. wandb/sdk/artifacts/artifact_download_logger.py +43 -0
  200. wandb/sdk/artifacts/artifact_file_cache.py +253 -0
  201. wandb/sdk/artifacts/artifact_instance_cache.py +17 -0
  202. wandb/sdk/artifacts/artifact_manifest.py +74 -0
  203. wandb/sdk/artifacts/artifact_manifest_entry.py +249 -0
  204. wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
  205. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +92 -0
  206. wandb/sdk/artifacts/artifact_saver.py +269 -0
  207. wandb/sdk/artifacts/artifact_state.py +11 -0
  208. wandb/sdk/artifacts/artifact_ttl.py +7 -0
  209. wandb/sdk/artifacts/exceptions.py +57 -0
  210. wandb/sdk/artifacts/staging.py +25 -0
  211. wandb/sdk/artifacts/storage_handler.py +62 -0
  212. wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
  213. wandb/sdk/artifacts/storage_handlers/azure_handler.py +208 -0
  214. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +228 -0
  215. wandb/sdk/artifacts/storage_handlers/http_handler.py +114 -0
  216. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +141 -0
  217. wandb/sdk/artifacts/storage_handlers/multi_handler.py +56 -0
  218. wandb/sdk/artifacts/storage_handlers/s3_handler.py +300 -0
  219. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +72 -0
  220. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +135 -0
  221. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +74 -0
  222. wandb/sdk/artifacts/storage_layout.py +6 -0
  223. wandb/sdk/artifacts/storage_policies/__init__.py +4 -0
  224. wandb/sdk/artifacts/storage_policies/register.py +1 -0
  225. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +378 -0
  226. wandb/sdk/artifacts/storage_policy.py +72 -0
  227. wandb/sdk/backend/__init__.py +0 -0
  228. wandb/sdk/backend/backend.py +222 -0
  229. wandb/sdk/data_types/__init__.py +0 -0
  230. wandb/sdk/data_types/_dtypes.py +914 -0
  231. wandb/sdk/data_types/_private.py +10 -0
  232. wandb/sdk/data_types/audio.py +165 -0
  233. wandb/sdk/data_types/base_types/__init__.py +0 -0
  234. wandb/sdk/data_types/base_types/json_metadata.py +55 -0
  235. wandb/sdk/data_types/base_types/media.py +315 -0
  236. wandb/sdk/data_types/base_types/wb_value.py +272 -0
  237. wandb/sdk/data_types/bokeh.py +70 -0
  238. wandb/sdk/data_types/graph.py +405 -0
  239. wandb/sdk/data_types/helper_types/__init__.py +0 -0
  240. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +295 -0
  241. wandb/sdk/data_types/helper_types/classes.py +159 -0
  242. wandb/sdk/data_types/helper_types/image_mask.py +235 -0
  243. wandb/sdk/data_types/histogram.py +96 -0
  244. wandb/sdk/data_types/html.py +115 -0
  245. wandb/sdk/data_types/image.py +845 -0
  246. wandb/sdk/data_types/molecule.py +241 -0
  247. wandb/sdk/data_types/object_3d.py +474 -0
  248. wandb/sdk/data_types/plotly.py +82 -0
  249. wandb/sdk/data_types/saved_model.py +446 -0
  250. wandb/sdk/data_types/table.py +1204 -0
  251. wandb/sdk/data_types/trace_tree.py +438 -0
  252. wandb/sdk/data_types/utils.py +229 -0
  253. wandb/sdk/data_types/video.py +247 -0
  254. wandb/sdk/integration_utils/__init__.py +0 -0
  255. wandb/sdk/integration_utils/auto_logging.py +239 -0
  256. wandb/sdk/integration_utils/data_logging.py +475 -0
  257. wandb/sdk/interface/__init__.py +0 -0
  258. wandb/sdk/interface/constants.py +4 -0
  259. wandb/sdk/interface/interface.py +972 -0
  260. wandb/sdk/interface/interface_queue.py +59 -0
  261. wandb/sdk/interface/interface_relay.py +53 -0
  262. wandb/sdk/interface/interface_shared.py +537 -0
  263. wandb/sdk/interface/interface_sock.py +61 -0
  264. wandb/sdk/interface/message_future.py +27 -0
  265. wandb/sdk/interface/message_future_poll.py +50 -0
  266. wandb/sdk/interface/router.py +118 -0
  267. wandb/sdk/interface/router_queue.py +44 -0
  268. wandb/sdk/interface/router_relay.py +39 -0
  269. wandb/sdk/interface/router_sock.py +36 -0
  270. wandb/sdk/interface/summary_record.py +67 -0
  271. wandb/sdk/internal/__init__.py +0 -0
  272. wandb/sdk/internal/context.py +89 -0
  273. wandb/sdk/internal/datastore.py +297 -0
  274. wandb/sdk/internal/file_pusher.py +181 -0
  275. wandb/sdk/internal/file_stream.py +695 -0
  276. wandb/sdk/internal/flow_control.py +263 -0
  277. wandb/sdk/internal/handler.py +901 -0
  278. wandb/sdk/internal/internal.py +417 -0
  279. wandb/sdk/internal/internal_api.py +4358 -0
  280. wandb/sdk/internal/internal_util.py +100 -0
  281. wandb/sdk/internal/job_builder.py +629 -0
  282. wandb/sdk/internal/profiler.py +78 -0
  283. wandb/sdk/internal/progress.py +83 -0
  284. wandb/sdk/internal/run.py +25 -0
  285. wandb/sdk/internal/sample.py +70 -0
  286. wandb/sdk/internal/sender.py +1686 -0
  287. wandb/sdk/internal/sender_config.py +197 -0
  288. wandb/sdk/internal/settings_static.py +90 -0
  289. wandb/sdk/internal/system/__init__.py +0 -0
  290. wandb/sdk/internal/system/assets/__init__.py +27 -0
  291. wandb/sdk/internal/system/assets/aggregators.py +37 -0
  292. wandb/sdk/internal/system/assets/asset_registry.py +20 -0
  293. wandb/sdk/internal/system/assets/cpu.py +163 -0
  294. wandb/sdk/internal/system/assets/disk.py +210 -0
  295. wandb/sdk/internal/system/assets/gpu.py +416 -0
  296. wandb/sdk/internal/system/assets/gpu_amd.py +239 -0
  297. wandb/sdk/internal/system/assets/gpu_apple.py +177 -0
  298. wandb/sdk/internal/system/assets/interfaces.py +207 -0
  299. wandb/sdk/internal/system/assets/ipu.py +177 -0
  300. wandb/sdk/internal/system/assets/memory.py +166 -0
  301. wandb/sdk/internal/system/assets/network.py +125 -0
  302. wandb/sdk/internal/system/assets/open_metrics.py +299 -0
  303. wandb/sdk/internal/system/assets/tpu.py +154 -0
  304. wandb/sdk/internal/system/assets/trainium.py +399 -0
  305. wandb/sdk/internal/system/env_probe_helpers.py +13 -0
  306. wandb/sdk/internal/system/system_info.py +249 -0
  307. wandb/sdk/internal/system/system_monitor.py +229 -0
  308. wandb/sdk/internal/tb_watcher.py +518 -0
  309. wandb/sdk/internal/thread_local_settings.py +18 -0
  310. wandb/sdk/internal/writer.py +206 -0
  311. wandb/sdk/launch/__init__.py +14 -0
  312. wandb/sdk/launch/_launch.py +330 -0
  313. wandb/sdk/launch/_launch_add.py +255 -0
  314. wandb/sdk/launch/_project_spec.py +566 -0
  315. wandb/sdk/launch/agent/__init__.py +5 -0
  316. wandb/sdk/launch/agent/agent.py +924 -0
  317. wandb/sdk/launch/agent/config.py +296 -0
  318. wandb/sdk/launch/agent/job_status_tracker.py +53 -0
  319. wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
  320. wandb/sdk/launch/builder/__init__.py +0 -0
  321. wandb/sdk/launch/builder/abstract.py +156 -0
  322. wandb/sdk/launch/builder/build.py +297 -0
  323. wandb/sdk/launch/builder/context_manager.py +235 -0
  324. wandb/sdk/launch/builder/docker_builder.py +177 -0
  325. wandb/sdk/launch/builder/kaniko_builder.py +595 -0
  326. wandb/sdk/launch/builder/noop.py +58 -0
  327. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +188 -0
  328. wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
  329. wandb/sdk/launch/create_job.py +528 -0
  330. wandb/sdk/launch/environment/abstract.py +29 -0
  331. wandb/sdk/launch/environment/aws_environment.py +322 -0
  332. wandb/sdk/launch/environment/azure_environment.py +105 -0
  333. wandb/sdk/launch/environment/gcp_environment.py +335 -0
  334. wandb/sdk/launch/environment/local_environment.py +66 -0
  335. wandb/sdk/launch/errors.py +19 -0
  336. wandb/sdk/launch/git_reference.py +109 -0
  337. wandb/sdk/launch/inputs/files.py +148 -0
  338. wandb/sdk/launch/inputs/internal.py +315 -0
  339. wandb/sdk/launch/inputs/manage.py +113 -0
  340. wandb/sdk/launch/inputs/schema.py +39 -0
  341. wandb/sdk/launch/loader.py +249 -0
  342. wandb/sdk/launch/registry/abstract.py +48 -0
  343. wandb/sdk/launch/registry/anon.py +29 -0
  344. wandb/sdk/launch/registry/azure_container_registry.py +124 -0
  345. wandb/sdk/launch/registry/elastic_container_registry.py +192 -0
  346. wandb/sdk/launch/registry/google_artifact_registry.py +219 -0
  347. wandb/sdk/launch/registry/local_registry.py +67 -0
  348. wandb/sdk/launch/runner/__init__.py +0 -0
  349. wandb/sdk/launch/runner/abstract.py +195 -0
  350. wandb/sdk/launch/runner/kubernetes_monitor.py +474 -0
  351. wandb/sdk/launch/runner/kubernetes_runner.py +963 -0
  352. wandb/sdk/launch/runner/local_container.py +301 -0
  353. wandb/sdk/launch/runner/local_process.py +78 -0
  354. wandb/sdk/launch/runner/sagemaker_runner.py +426 -0
  355. wandb/sdk/launch/runner/vertex_runner.py +230 -0
  356. wandb/sdk/launch/sweeps/__init__.py +39 -0
  357. wandb/sdk/launch/sweeps/scheduler.py +742 -0
  358. wandb/sdk/launch/sweeps/scheduler_sweep.py +91 -0
  359. wandb/sdk/launch/sweeps/utils.py +316 -0
  360. wandb/sdk/launch/utils.py +746 -0
  361. wandb/sdk/launch/wandb_reference.py +138 -0
  362. wandb/sdk/lib/__init__.py +5 -0
  363. wandb/sdk/lib/_settings_toposort_generate.py +159 -0
  364. wandb/sdk/lib/_settings_toposort_generated.py +250 -0
  365. wandb/sdk/lib/_wburls_generate.py +25 -0
  366. wandb/sdk/lib/_wburls_generated.py +22 -0
  367. wandb/sdk/lib/apikey.py +273 -0
  368. wandb/sdk/lib/capped_dict.py +26 -0
  369. wandb/sdk/lib/config_util.py +101 -0
  370. wandb/sdk/lib/credentials.py +141 -0
  371. wandb/sdk/lib/deprecate.py +42 -0
  372. wandb/sdk/lib/disabled.py +29 -0
  373. wandb/sdk/lib/exit_hooks.py +54 -0
  374. wandb/sdk/lib/file_stream_utils.py +118 -0
  375. wandb/sdk/lib/filenames.py +64 -0
  376. wandb/sdk/lib/filesystem.py +372 -0
  377. wandb/sdk/lib/fsm.py +174 -0
  378. wandb/sdk/lib/gitlib.py +239 -0
  379. wandb/sdk/lib/gql_request.py +65 -0
  380. wandb/sdk/lib/handler_util.py +21 -0
  381. wandb/sdk/lib/hashutil.py +84 -0
  382. wandb/sdk/lib/import_hooks.py +275 -0
  383. wandb/sdk/lib/ipython.py +146 -0
  384. wandb/sdk/lib/json_util.py +80 -0
  385. wandb/sdk/lib/lazyloader.py +63 -0
  386. wandb/sdk/lib/mailbox.py +460 -0
  387. wandb/sdk/lib/module.py +69 -0
  388. wandb/sdk/lib/paths.py +106 -0
  389. wandb/sdk/lib/preinit.py +42 -0
  390. wandb/sdk/lib/printer.py +313 -0
  391. wandb/sdk/lib/proto_util.py +90 -0
  392. wandb/sdk/lib/redirect.py +845 -0
  393. wandb/sdk/lib/reporting.py +99 -0
  394. wandb/sdk/lib/retry.py +289 -0
  395. wandb/sdk/lib/run_moment.py +78 -0
  396. wandb/sdk/lib/runid.py +12 -0
  397. wandb/sdk/lib/server.py +52 -0
  398. wandb/sdk/lib/service_connection.py +216 -0
  399. wandb/sdk/lib/service_token.py +94 -0
  400. wandb/sdk/lib/sock_client.py +295 -0
  401. wandb/sdk/lib/sparkline.py +45 -0
  402. wandb/sdk/lib/telemetry.py +100 -0
  403. wandb/sdk/lib/timed_input.py +133 -0
  404. wandb/sdk/lib/timer.py +19 -0
  405. wandb/sdk/lib/tracelog.py +255 -0
  406. wandb/sdk/lib/wburls.py +46 -0
  407. wandb/sdk/service/__init__.py +0 -0
  408. wandb/sdk/service/_startup_debug.py +22 -0
  409. wandb/sdk/service/port_file.py +53 -0
  410. wandb/sdk/service/server.py +116 -0
  411. wandb/sdk/service/server_sock.py +276 -0
  412. wandb/sdk/service/service.py +242 -0
  413. wandb/sdk/service/streams.py +417 -0
  414. wandb/sdk/verify/__init__.py +0 -0
  415. wandb/sdk/verify/verify.py +501 -0
  416. wandb/sdk/wandb_alerts.py +12 -0
  417. wandb/sdk/wandb_config.py +322 -0
  418. wandb/sdk/wandb_helper.py +54 -0
  419. wandb/sdk/wandb_init.py +1266 -0
  420. wandb/sdk/wandb_login.py +349 -0
  421. wandb/sdk/wandb_metric.py +110 -0
  422. wandb/sdk/wandb_require.py +97 -0
  423. wandb/sdk/wandb_require_helpers.py +44 -0
  424. wandb/sdk/wandb_run.py +4236 -0
  425. wandb/sdk/wandb_settings.py +2001 -0
  426. wandb/sdk/wandb_setup.py +409 -0
  427. wandb/sdk/wandb_summary.py +150 -0
  428. wandb/sdk/wandb_sweep.py +119 -0
  429. wandb/sdk/wandb_sync.py +81 -0
  430. wandb/sdk/wandb_watch.py +144 -0
  431. wandb/sklearn.py +35 -0
  432. wandb/sync/__init__.py +3 -0
  433. wandb/sync/sync.py +443 -0
  434. wandb/trigger.py +29 -0
  435. wandb/util.py +1956 -0
  436. wandb/vendor/__init__.py +0 -0
  437. wandb/vendor/gql-0.2.0/setup.py +40 -0
  438. wandb/vendor/gql-0.2.0/tests/__init__.py +0 -0
  439. wandb/vendor/gql-0.2.0/tests/starwars/__init__.py +0 -0
  440. wandb/vendor/gql-0.2.0/tests/starwars/fixtures.py +96 -0
  441. wandb/vendor/gql-0.2.0/tests/starwars/schema.py +146 -0
  442. wandb/vendor/gql-0.2.0/tests/starwars/test_dsl.py +293 -0
  443. wandb/vendor/gql-0.2.0/tests/starwars/test_query.py +355 -0
  444. wandb/vendor/gql-0.2.0/tests/starwars/test_validation.py +171 -0
  445. wandb/vendor/gql-0.2.0/tests/test_client.py +31 -0
  446. wandb/vendor/gql-0.2.0/tests/test_transport.py +89 -0
  447. wandb/vendor/gql-0.2.0/wandb_gql/__init__.py +4 -0
  448. wandb/vendor/gql-0.2.0/wandb_gql/client.py +75 -0
  449. wandb/vendor/gql-0.2.0/wandb_gql/dsl.py +152 -0
  450. wandb/vendor/gql-0.2.0/wandb_gql/gql.py +10 -0
  451. wandb/vendor/gql-0.2.0/wandb_gql/transport/__init__.py +0 -0
  452. wandb/vendor/gql-0.2.0/wandb_gql/transport/http.py +6 -0
  453. wandb/vendor/gql-0.2.0/wandb_gql/transport/local_schema.py +15 -0
  454. wandb/vendor/gql-0.2.0/wandb_gql/transport/requests.py +46 -0
  455. wandb/vendor/gql-0.2.0/wandb_gql/utils.py +21 -0
  456. wandb/vendor/graphql-core-1.1/setup.py +86 -0
  457. wandb/vendor/graphql-core-1.1/wandb_graphql/__init__.py +287 -0
  458. wandb/vendor/graphql-core-1.1/wandb_graphql/error/__init__.py +6 -0
  459. wandb/vendor/graphql-core-1.1/wandb_graphql/error/base.py +42 -0
  460. wandb/vendor/graphql-core-1.1/wandb_graphql/error/format_error.py +11 -0
  461. wandb/vendor/graphql-core-1.1/wandb_graphql/error/located_error.py +29 -0
  462. wandb/vendor/graphql-core-1.1/wandb_graphql/error/syntax_error.py +36 -0
  463. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__init__.py +26 -0
  464. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/base.py +311 -0
  465. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executor.py +398 -0
  466. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/__init__.py +0 -0
  467. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/asyncio.py +53 -0
  468. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/gevent.py +22 -0
  469. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/process.py +32 -0
  470. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/sync.py +7 -0
  471. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/thread.py +35 -0
  472. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/utils.py +6 -0
  473. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/__init__.py +0 -0
  474. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/executor.py +66 -0
  475. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/fragment.py +252 -0
  476. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/resolver.py +151 -0
  477. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/utils.py +7 -0
  478. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/middleware.py +57 -0
  479. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/values.py +145 -0
  480. wandb/vendor/graphql-core-1.1/wandb_graphql/graphql.py +60 -0
  481. wandb/vendor/graphql-core-1.1/wandb_graphql/language/__init__.py +0 -0
  482. wandb/vendor/graphql-core-1.1/wandb_graphql/language/ast.py +1349 -0
  483. wandb/vendor/graphql-core-1.1/wandb_graphql/language/base.py +19 -0
  484. wandb/vendor/graphql-core-1.1/wandb_graphql/language/lexer.py +435 -0
  485. wandb/vendor/graphql-core-1.1/wandb_graphql/language/location.py +30 -0
  486. wandb/vendor/graphql-core-1.1/wandb_graphql/language/parser.py +779 -0
  487. wandb/vendor/graphql-core-1.1/wandb_graphql/language/printer.py +193 -0
  488. wandb/vendor/graphql-core-1.1/wandb_graphql/language/source.py +18 -0
  489. wandb/vendor/graphql-core-1.1/wandb_graphql/language/visitor.py +222 -0
  490. wandb/vendor/graphql-core-1.1/wandb_graphql/language/visitor_meta.py +82 -0
  491. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__init__.py +0 -0
  492. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/cached_property.py +17 -0
  493. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/contain_subset.py +28 -0
  494. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/default_ordered_dict.py +40 -0
  495. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/ordereddict.py +8 -0
  496. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/pair_set.py +43 -0
  497. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/version.py +78 -0
  498. wandb/vendor/graphql-core-1.1/wandb_graphql/type/__init__.py +67 -0
  499. wandb/vendor/graphql-core-1.1/wandb_graphql/type/definition.py +619 -0
  500. wandb/vendor/graphql-core-1.1/wandb_graphql/type/directives.py +132 -0
  501. wandb/vendor/graphql-core-1.1/wandb_graphql/type/introspection.py +440 -0
  502. wandb/vendor/graphql-core-1.1/wandb_graphql/type/scalars.py +131 -0
  503. wandb/vendor/graphql-core-1.1/wandb_graphql/type/schema.py +100 -0
  504. wandb/vendor/graphql-core-1.1/wandb_graphql/type/typemap.py +145 -0
  505. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__init__.py +0 -0
  506. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/assert_valid_name.py +9 -0
  507. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_from_value.py +65 -0
  508. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_to_code.py +49 -0
  509. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_to_dict.py +24 -0
  510. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/base.py +75 -0
  511. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/build_ast_schema.py +291 -0
  512. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/build_client_schema.py +250 -0
  513. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/concat_ast.py +9 -0
  514. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/extend_schema.py +357 -0
  515. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/get_field_def.py +27 -0
  516. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/get_operation_ast.py +21 -0
  517. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/introspection_query.py +90 -0
  518. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/is_valid_literal_value.py +67 -0
  519. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/is_valid_value.py +66 -0
  520. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/quoted_or_list.py +21 -0
  521. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/schema_printer.py +168 -0
  522. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/suggestion_list.py +56 -0
  523. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_comparators.py +69 -0
  524. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_from_ast.py +21 -0
  525. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_info.py +149 -0
  526. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/value_from_ast.py +69 -0
  527. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/__init__.py +4 -0
  528. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__init__.py +79 -0
  529. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/arguments_of_correct_type.py +24 -0
  530. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/base.py +8 -0
  531. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/default_values_of_correct_type.py +44 -0
  532. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/fields_on_correct_type.py +113 -0
  533. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/fragments_on_composite_types.py +33 -0
  534. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_argument_names.py +70 -0
  535. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_directives.py +97 -0
  536. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_fragment_names.py +19 -0
  537. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_type_names.py +43 -0
  538. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/lone_anonymous_operation.py +23 -0
  539. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_fragment_cycles.py +59 -0
  540. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_undefined_variables.py +36 -0
  541. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_unused_fragments.py +38 -0
  542. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_unused_variables.py +37 -0
  543. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/overlapping_fields_can_be_merged.py +529 -0
  544. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/possible_fragment_spreads.py +44 -0
  545. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/provided_non_null_arguments.py +46 -0
  546. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/scalar_leafs.py +33 -0
  547. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_argument_names.py +32 -0
  548. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_fragment_names.py +28 -0
  549. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_input_field_names.py +33 -0
  550. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_operation_names.py +31 -0
  551. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_variable_names.py +27 -0
  552. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/variables_are_input_types.py +21 -0
  553. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/variables_in_allowed_position.py +53 -0
  554. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/validation.py +158 -0
  555. wandb/vendor/promise-2.3.0/conftest.py +30 -0
  556. wandb/vendor/promise-2.3.0/setup.py +64 -0
  557. wandb/vendor/promise-2.3.0/tests/__init__.py +0 -0
  558. wandb/vendor/promise-2.3.0/tests/conftest.py +8 -0
  559. wandb/vendor/promise-2.3.0/tests/test_awaitable.py +32 -0
  560. wandb/vendor/promise-2.3.0/tests/test_awaitable_35.py +47 -0
  561. wandb/vendor/promise-2.3.0/tests/test_benchmark.py +116 -0
  562. wandb/vendor/promise-2.3.0/tests/test_complex_threads.py +23 -0
  563. wandb/vendor/promise-2.3.0/tests/test_dataloader.py +452 -0
  564. wandb/vendor/promise-2.3.0/tests/test_dataloader_awaitable_35.py +99 -0
  565. wandb/vendor/promise-2.3.0/tests/test_dataloader_extra.py +65 -0
  566. wandb/vendor/promise-2.3.0/tests/test_extra.py +670 -0
  567. wandb/vendor/promise-2.3.0/tests/test_issues.py +132 -0
  568. wandb/vendor/promise-2.3.0/tests/test_promise_list.py +70 -0
  569. wandb/vendor/promise-2.3.0/tests/test_spec.py +584 -0
  570. wandb/vendor/promise-2.3.0/tests/test_thread_safety.py +115 -0
  571. wandb/vendor/promise-2.3.0/tests/utils.py +3 -0
  572. wandb/vendor/promise-2.3.0/wandb_promise/__init__.py +38 -0
  573. wandb/vendor/promise-2.3.0/wandb_promise/async_.py +135 -0
  574. wandb/vendor/promise-2.3.0/wandb_promise/compat.py +32 -0
  575. wandb/vendor/promise-2.3.0/wandb_promise/dataloader.py +326 -0
  576. wandb/vendor/promise-2.3.0/wandb_promise/iterate_promise.py +12 -0
  577. wandb/vendor/promise-2.3.0/wandb_promise/promise.py +848 -0
  578. wandb/vendor/promise-2.3.0/wandb_promise/promise_list.py +151 -0
  579. wandb/vendor/promise-2.3.0/wandb_promise/pyutils/__init__.py +0 -0
  580. wandb/vendor/promise-2.3.0/wandb_promise/pyutils/version.py +83 -0
  581. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/__init__.py +0 -0
  582. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/asyncio.py +22 -0
  583. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/gevent.py +21 -0
  584. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/immediate.py +27 -0
  585. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/thread.py +18 -0
  586. wandb/vendor/promise-2.3.0/wandb_promise/utils.py +56 -0
  587. wandb/vendor/pygments/__init__.py +90 -0
  588. wandb/vendor/pygments/cmdline.py +568 -0
  589. wandb/vendor/pygments/console.py +74 -0
  590. wandb/vendor/pygments/filter.py +74 -0
  591. wandb/vendor/pygments/filters/__init__.py +350 -0
  592. wandb/vendor/pygments/formatter.py +95 -0
  593. wandb/vendor/pygments/formatters/__init__.py +153 -0
  594. wandb/vendor/pygments/formatters/_mapping.py +85 -0
  595. wandb/vendor/pygments/formatters/bbcode.py +109 -0
  596. wandb/vendor/pygments/formatters/html.py +851 -0
  597. wandb/vendor/pygments/formatters/img.py +600 -0
  598. wandb/vendor/pygments/formatters/irc.py +182 -0
  599. wandb/vendor/pygments/formatters/latex.py +482 -0
  600. wandb/vendor/pygments/formatters/other.py +160 -0
  601. wandb/vendor/pygments/formatters/rtf.py +147 -0
  602. wandb/vendor/pygments/formatters/svg.py +153 -0
  603. wandb/vendor/pygments/formatters/terminal.py +136 -0
  604. wandb/vendor/pygments/formatters/terminal256.py +309 -0
  605. wandb/vendor/pygments/lexer.py +871 -0
  606. wandb/vendor/pygments/lexers/__init__.py +329 -0
  607. wandb/vendor/pygments/lexers/_asy_builtins.py +1645 -0
  608. wandb/vendor/pygments/lexers/_cl_builtins.py +232 -0
  609. wandb/vendor/pygments/lexers/_cocoa_builtins.py +72 -0
  610. wandb/vendor/pygments/lexers/_csound_builtins.py +1346 -0
  611. wandb/vendor/pygments/lexers/_lasso_builtins.py +5327 -0
  612. wandb/vendor/pygments/lexers/_lua_builtins.py +295 -0
  613. wandb/vendor/pygments/lexers/_mapping.py +500 -0
  614. wandb/vendor/pygments/lexers/_mql_builtins.py +1172 -0
  615. wandb/vendor/pygments/lexers/_openedge_builtins.py +2547 -0
  616. wandb/vendor/pygments/lexers/_php_builtins.py +4756 -0
  617. wandb/vendor/pygments/lexers/_postgres_builtins.py +621 -0
  618. wandb/vendor/pygments/lexers/_scilab_builtins.py +3094 -0
  619. wandb/vendor/pygments/lexers/_sourcemod_builtins.py +1163 -0
  620. wandb/vendor/pygments/lexers/_stan_builtins.py +532 -0
  621. wandb/vendor/pygments/lexers/_stata_builtins.py +419 -0
  622. wandb/vendor/pygments/lexers/_tsql_builtins.py +1004 -0
  623. wandb/vendor/pygments/lexers/_vim_builtins.py +1939 -0
  624. wandb/vendor/pygments/lexers/actionscript.py +240 -0
  625. wandb/vendor/pygments/lexers/agile.py +24 -0
  626. wandb/vendor/pygments/lexers/algebra.py +221 -0
  627. wandb/vendor/pygments/lexers/ambient.py +76 -0
  628. wandb/vendor/pygments/lexers/ampl.py +87 -0
  629. wandb/vendor/pygments/lexers/apl.py +101 -0
  630. wandb/vendor/pygments/lexers/archetype.py +318 -0
  631. wandb/vendor/pygments/lexers/asm.py +641 -0
  632. wandb/vendor/pygments/lexers/automation.py +374 -0
  633. wandb/vendor/pygments/lexers/basic.py +500 -0
  634. wandb/vendor/pygments/lexers/bibtex.py +160 -0
  635. wandb/vendor/pygments/lexers/business.py +612 -0
  636. wandb/vendor/pygments/lexers/c_cpp.py +252 -0
  637. wandb/vendor/pygments/lexers/c_like.py +541 -0
  638. wandb/vendor/pygments/lexers/capnproto.py +78 -0
  639. wandb/vendor/pygments/lexers/chapel.py +102 -0
  640. wandb/vendor/pygments/lexers/clean.py +288 -0
  641. wandb/vendor/pygments/lexers/compiled.py +34 -0
  642. wandb/vendor/pygments/lexers/configs.py +833 -0
  643. wandb/vendor/pygments/lexers/console.py +114 -0
  644. wandb/vendor/pygments/lexers/crystal.py +393 -0
  645. wandb/vendor/pygments/lexers/csound.py +366 -0
  646. wandb/vendor/pygments/lexers/css.py +689 -0
  647. wandb/vendor/pygments/lexers/d.py +251 -0
  648. wandb/vendor/pygments/lexers/dalvik.py +125 -0
  649. wandb/vendor/pygments/lexers/data.py +555 -0
  650. wandb/vendor/pygments/lexers/diff.py +165 -0
  651. wandb/vendor/pygments/lexers/dotnet.py +691 -0
  652. wandb/vendor/pygments/lexers/dsls.py +878 -0
  653. wandb/vendor/pygments/lexers/dylan.py +289 -0
  654. wandb/vendor/pygments/lexers/ecl.py +125 -0
  655. wandb/vendor/pygments/lexers/eiffel.py +65 -0
  656. wandb/vendor/pygments/lexers/elm.py +121 -0
  657. wandb/vendor/pygments/lexers/erlang.py +533 -0
  658. wandb/vendor/pygments/lexers/esoteric.py +277 -0
  659. wandb/vendor/pygments/lexers/ezhil.py +69 -0
  660. wandb/vendor/pygments/lexers/factor.py +344 -0
  661. wandb/vendor/pygments/lexers/fantom.py +250 -0
  662. wandb/vendor/pygments/lexers/felix.py +273 -0
  663. wandb/vendor/pygments/lexers/forth.py +177 -0
  664. wandb/vendor/pygments/lexers/fortran.py +205 -0
  665. wandb/vendor/pygments/lexers/foxpro.py +428 -0
  666. wandb/vendor/pygments/lexers/functional.py +21 -0
  667. wandb/vendor/pygments/lexers/go.py +101 -0
  668. wandb/vendor/pygments/lexers/grammar_notation.py +213 -0
  669. wandb/vendor/pygments/lexers/graph.py +80 -0
  670. wandb/vendor/pygments/lexers/graphics.py +553 -0
  671. wandb/vendor/pygments/lexers/haskell.py +843 -0
  672. wandb/vendor/pygments/lexers/haxe.py +936 -0
  673. wandb/vendor/pygments/lexers/hdl.py +382 -0
  674. wandb/vendor/pygments/lexers/hexdump.py +103 -0
  675. wandb/vendor/pygments/lexers/html.py +602 -0
  676. wandb/vendor/pygments/lexers/idl.py +270 -0
  677. wandb/vendor/pygments/lexers/igor.py +288 -0
  678. wandb/vendor/pygments/lexers/inferno.py +96 -0
  679. wandb/vendor/pygments/lexers/installers.py +322 -0
  680. wandb/vendor/pygments/lexers/int_fiction.py +1343 -0
  681. wandb/vendor/pygments/lexers/iolang.py +63 -0
  682. wandb/vendor/pygments/lexers/j.py +146 -0
  683. wandb/vendor/pygments/lexers/javascript.py +1525 -0
  684. wandb/vendor/pygments/lexers/julia.py +333 -0
  685. wandb/vendor/pygments/lexers/jvm.py +1573 -0
  686. wandb/vendor/pygments/lexers/lisp.py +2621 -0
  687. wandb/vendor/pygments/lexers/make.py +202 -0
  688. wandb/vendor/pygments/lexers/markup.py +595 -0
  689. wandb/vendor/pygments/lexers/math.py +21 -0
  690. wandb/vendor/pygments/lexers/matlab.py +663 -0
  691. wandb/vendor/pygments/lexers/ml.py +769 -0
  692. wandb/vendor/pygments/lexers/modeling.py +358 -0
  693. wandb/vendor/pygments/lexers/modula2.py +1561 -0
  694. wandb/vendor/pygments/lexers/monte.py +204 -0
  695. wandb/vendor/pygments/lexers/ncl.py +894 -0
  696. wandb/vendor/pygments/lexers/nimrod.py +159 -0
  697. wandb/vendor/pygments/lexers/nit.py +64 -0
  698. wandb/vendor/pygments/lexers/nix.py +136 -0
  699. wandb/vendor/pygments/lexers/oberon.py +105 -0
  700. wandb/vendor/pygments/lexers/objective.py +504 -0
  701. wandb/vendor/pygments/lexers/ooc.py +85 -0
  702. wandb/vendor/pygments/lexers/other.py +41 -0
  703. wandb/vendor/pygments/lexers/parasail.py +79 -0
  704. wandb/vendor/pygments/lexers/parsers.py +835 -0
  705. wandb/vendor/pygments/lexers/pascal.py +644 -0
  706. wandb/vendor/pygments/lexers/pawn.py +199 -0
  707. wandb/vendor/pygments/lexers/perl.py +620 -0
  708. wandb/vendor/pygments/lexers/php.py +267 -0
  709. wandb/vendor/pygments/lexers/praat.py +294 -0
  710. wandb/vendor/pygments/lexers/prolog.py +306 -0
  711. wandb/vendor/pygments/lexers/python.py +939 -0
  712. wandb/vendor/pygments/lexers/qvt.py +152 -0
  713. wandb/vendor/pygments/lexers/r.py +453 -0
  714. wandb/vendor/pygments/lexers/rdf.py +270 -0
  715. wandb/vendor/pygments/lexers/rebol.py +431 -0
  716. wandb/vendor/pygments/lexers/resource.py +85 -0
  717. wandb/vendor/pygments/lexers/rnc.py +67 -0
  718. wandb/vendor/pygments/lexers/roboconf.py +82 -0
  719. wandb/vendor/pygments/lexers/robotframework.py +560 -0
  720. wandb/vendor/pygments/lexers/ruby.py +519 -0
  721. wandb/vendor/pygments/lexers/rust.py +220 -0
  722. wandb/vendor/pygments/lexers/sas.py +228 -0
  723. wandb/vendor/pygments/lexers/scripting.py +1222 -0
  724. wandb/vendor/pygments/lexers/shell.py +794 -0
  725. wandb/vendor/pygments/lexers/smalltalk.py +195 -0
  726. wandb/vendor/pygments/lexers/smv.py +79 -0
  727. wandb/vendor/pygments/lexers/snobol.py +83 -0
  728. wandb/vendor/pygments/lexers/special.py +103 -0
  729. wandb/vendor/pygments/lexers/sql.py +681 -0
  730. wandb/vendor/pygments/lexers/stata.py +108 -0
  731. wandb/vendor/pygments/lexers/supercollider.py +90 -0
  732. wandb/vendor/pygments/lexers/tcl.py +145 -0
  733. wandb/vendor/pygments/lexers/templates.py +2283 -0
  734. wandb/vendor/pygments/lexers/testing.py +207 -0
  735. wandb/vendor/pygments/lexers/text.py +25 -0
  736. wandb/vendor/pygments/lexers/textedit.py +169 -0
  737. wandb/vendor/pygments/lexers/textfmts.py +297 -0
  738. wandb/vendor/pygments/lexers/theorem.py +458 -0
  739. wandb/vendor/pygments/lexers/trafficscript.py +54 -0
  740. wandb/vendor/pygments/lexers/typoscript.py +226 -0
  741. wandb/vendor/pygments/lexers/urbi.py +133 -0
  742. wandb/vendor/pygments/lexers/varnish.py +190 -0
  743. wandb/vendor/pygments/lexers/verification.py +111 -0
  744. wandb/vendor/pygments/lexers/web.py +24 -0
  745. wandb/vendor/pygments/lexers/webmisc.py +988 -0
  746. wandb/vendor/pygments/lexers/whiley.py +116 -0
  747. wandb/vendor/pygments/lexers/x10.py +69 -0
  748. wandb/vendor/pygments/modeline.py +44 -0
  749. wandb/vendor/pygments/plugin.py +68 -0
  750. wandb/vendor/pygments/regexopt.py +92 -0
  751. wandb/vendor/pygments/scanner.py +105 -0
  752. wandb/vendor/pygments/sphinxext.py +158 -0
  753. wandb/vendor/pygments/style.py +155 -0
  754. wandb/vendor/pygments/styles/__init__.py +80 -0
  755. wandb/vendor/pygments/styles/abap.py +29 -0
  756. wandb/vendor/pygments/styles/algol.py +63 -0
  757. wandb/vendor/pygments/styles/algol_nu.py +63 -0
  758. wandb/vendor/pygments/styles/arduino.py +98 -0
  759. wandb/vendor/pygments/styles/autumn.py +65 -0
  760. wandb/vendor/pygments/styles/borland.py +51 -0
  761. wandb/vendor/pygments/styles/bw.py +49 -0
  762. wandb/vendor/pygments/styles/colorful.py +81 -0
  763. wandb/vendor/pygments/styles/default.py +73 -0
  764. wandb/vendor/pygments/styles/emacs.py +72 -0
  765. wandb/vendor/pygments/styles/friendly.py +72 -0
  766. wandb/vendor/pygments/styles/fruity.py +42 -0
  767. wandb/vendor/pygments/styles/igor.py +29 -0
  768. wandb/vendor/pygments/styles/lovelace.py +97 -0
  769. wandb/vendor/pygments/styles/manni.py +75 -0
  770. wandb/vendor/pygments/styles/monokai.py +106 -0
  771. wandb/vendor/pygments/styles/murphy.py +80 -0
  772. wandb/vendor/pygments/styles/native.py +65 -0
  773. wandb/vendor/pygments/styles/paraiso_dark.py +125 -0
  774. wandb/vendor/pygments/styles/paraiso_light.py +125 -0
  775. wandb/vendor/pygments/styles/pastie.py +75 -0
  776. wandb/vendor/pygments/styles/perldoc.py +69 -0
  777. wandb/vendor/pygments/styles/rainbow_dash.py +89 -0
  778. wandb/vendor/pygments/styles/rrt.py +33 -0
  779. wandb/vendor/pygments/styles/sas.py +44 -0
  780. wandb/vendor/pygments/styles/stata.py +40 -0
  781. wandb/vendor/pygments/styles/tango.py +141 -0
  782. wandb/vendor/pygments/styles/trac.py +63 -0
  783. wandb/vendor/pygments/styles/vim.py +63 -0
  784. wandb/vendor/pygments/styles/vs.py +38 -0
  785. wandb/vendor/pygments/styles/xcode.py +51 -0
  786. wandb/vendor/pygments/token.py +213 -0
  787. wandb/vendor/pygments/unistring.py +217 -0
  788. wandb/vendor/pygments/util.py +388 -0
  789. wandb/vendor/pynvml/__init__.py +0 -0
  790. wandb/vendor/pynvml/pynvml.py +4779 -0
  791. wandb/vendor/watchdog_0_9_0/wandb_watchdog/__init__.py +17 -0
  792. wandb/vendor/watchdog_0_9_0/wandb_watchdog/events.py +615 -0
  793. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/__init__.py +98 -0
  794. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/api.py +369 -0
  795. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/fsevents.py +172 -0
  796. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/fsevents2.py +239 -0
  797. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/inotify.py +218 -0
  798. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/inotify_buffer.py +81 -0
  799. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/inotify_c.py +575 -0
  800. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/kqueue.py +730 -0
  801. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/polling.py +145 -0
  802. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/read_directory_changes.py +133 -0
  803. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/winapi.py +348 -0
  804. wandb/vendor/watchdog_0_9_0/wandb_watchdog/patterns.py +265 -0
  805. wandb/vendor/watchdog_0_9_0/wandb_watchdog/tricks/__init__.py +174 -0
  806. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/__init__.py +151 -0
  807. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/bricks.py +249 -0
  808. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/compat.py +29 -0
  809. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/decorators.py +198 -0
  810. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/delayed_queue.py +88 -0
  811. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/dirsnapshot.py +293 -0
  812. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/echo.py +157 -0
  813. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/event_backport.py +41 -0
  814. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/importlib2.py +40 -0
  815. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/platform.py +57 -0
  816. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/unicode_paths.py +64 -0
  817. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/win32stat.py +123 -0
  818. wandb/vendor/watchdog_0_9_0/wandb_watchdog/version.py +28 -0
  819. wandb/vendor/watchdog_0_9_0/wandb_watchdog/watchmedo.py +577 -0
  820. wandb/wandb_agent.py +588 -0
  821. wandb/wandb_controller.py +721 -0
  822. wandb/wandb_run.py +9 -0
  823. wandb-0.18.2.dist-info/METADATA +213 -0
  824. wandb-0.18.2.dist-info/RECORD +827 -0
  825. wandb-0.18.2.dist-info/WHEEL +5 -0
  826. wandb-0.18.2.dist-info/entry_points.txt +3 -0
  827. wandb-0.18.2.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1091 @@
1
+ """keras init."""
2
+
3
+ import logging
4
+ import operator
5
+ import os
6
+ import shutil
7
+ import sys
8
+ from itertools import chain
9
+
10
+ import numpy as np
11
+ import tensorflow as tf
12
+ import tensorflow.keras.backend as K # noqa: N812
13
+
14
+ import wandb
15
+ from wandb.sdk.integration_utils.data_logging import ValidationDataLogger
16
+ from wandb.sdk.lib.deprecate import Deprecated, deprecate
17
+ from wandb.util import add_import_hook
18
+
19
+
20
+ def _check_keras_version():
21
+ from keras import __version__ as keras_version
22
+
23
+ from wandb.util import parse_version
24
+
25
+ if parse_version(keras_version) < parse_version("2.4.0"):
26
+ wandb.termwarn(
27
+ f"Keras version {keras_version} is not fully supported. Required keras >= 2.4.0"
28
+ )
29
+
30
+
31
+ def _can_compute_flops() -> bool:
32
+ """FLOPS computation is restricted to TF 2.x as it requires tf.compat.v1."""
33
+ from wandb.util import parse_version
34
+
35
+ if parse_version(tf.__version__) >= parse_version("2.0.0"):
36
+ return True
37
+
38
+ return False
39
+
40
+
41
+ if "keras" in sys.modules:
42
+ _check_keras_version()
43
+ else:
44
+ add_import_hook("keras", _check_keras_version)
45
+
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+
50
+ def is_dataset(data):
51
+ dataset_ops = wandb.util.get_module("tensorflow.python.data.ops.dataset_ops")
52
+ if dataset_ops and hasattr(dataset_ops, "DatasetV2"):
53
+ dataset_types = (dataset_ops.DatasetV2,)
54
+ if hasattr(dataset_ops, "DatasetV1"):
55
+ dataset_types = dataset_types + (dataset_ops.DatasetV1,)
56
+ return isinstance(data, dataset_types)
57
+ else:
58
+ return False
59
+
60
+
61
+ def is_generator_like(data):
62
+ # Checks if data is a generator, Sequence, or Iterator.
63
+
64
+ types = (tf.keras.utils.Sequence,)
65
+ iterator_ops = wandb.util.get_module("tensorflow.python.data.ops.iterator_ops")
66
+ if iterator_ops:
67
+ types = types + (iterator_ops.Iterator,)
68
+ # EagerIterator was in tensorflow < 2
69
+ if hasattr(iterator_ops, "EagerIterator"):
70
+ types = types + (iterator_ops.EagerIterator,)
71
+ elif hasattr(iterator_ops, "IteratorV2"):
72
+ types = types + (iterator_ops.IteratorV2,)
73
+ return hasattr(data, "next") or hasattr(data, "__next__") or isinstance(data, types)
74
+
75
+
76
+ def patch_tf_keras(): # noqa: C901
77
+ from tensorflow.python.eager import context
78
+
79
+ from wandb.util import parse_version
80
+
81
+ if (
82
+ parse_version("2.6.0")
83
+ <= parse_version(tf.__version__)
84
+ < parse_version("2.13.0")
85
+ ):
86
+ keras_engine = "keras.engine"
87
+ try:
88
+ from keras.engine import training
89
+ from keras.engine import training_arrays_v1 as training_arrays
90
+ from keras.engine import training_generator_v1 as training_generator
91
+ except (ImportError, AttributeError):
92
+ wandb.termerror("Unable to patch Tensorflow/Keras")
93
+ logger.exception("exception while trying to patch_tf_keras")
94
+ return
95
+ else:
96
+ keras_engine = "tensorflow.python.keras.engine"
97
+
98
+ from tensorflow.python.keras.engine import training
99
+
100
+ try:
101
+ from tensorflow.python.keras.engine import (
102
+ training_arrays_v1 as training_arrays,
103
+ )
104
+ from tensorflow.python.keras.engine import (
105
+ training_generator_v1 as training_generator,
106
+ )
107
+ except (ImportError, AttributeError):
108
+ try:
109
+ from tensorflow.python.keras.engine import (
110
+ training_arrays,
111
+ training_generator,
112
+ )
113
+ except (ImportError, AttributeError):
114
+ wandb.termerror("Unable to patch Tensorflow/Keras")
115
+ logger.exception("exception while trying to patch_tf_keras")
116
+ return
117
+
118
+ # Tensorflow 2.1
119
+ training_v2_1 = wandb.util.get_module("tensorflow.python.keras.engine.training_v2")
120
+ # Tensorflow 2.2
121
+ training_v2_2 = wandb.util.get_module(f"{keras_engine}.training_v1")
122
+
123
+ if training_v2_1:
124
+ old_v2 = training_v2_1.Loop.fit
125
+ elif training_v2_2:
126
+ old_v2 = training.Model.fit
127
+
128
+ old_arrays = training_arrays.fit_loop
129
+ old_generator = training_generator.fit_generator
130
+
131
+ def set_wandb_attrs(cbk, val_data):
132
+ if isinstance(cbk, WandbCallback):
133
+ if is_generator_like(val_data):
134
+ cbk.generator = val_data
135
+ elif is_dataset(val_data):
136
+ if context.executing_eagerly():
137
+ cbk.generator = iter(val_data)
138
+ else:
139
+ wandb.termwarn(
140
+ "Found a validation dataset in graph mode, can't patch Keras."
141
+ )
142
+ elif isinstance(val_data, tuple) and isinstance(val_data[0], tf.Tensor):
143
+ # Graph mode dataset generator
144
+ def gen():
145
+ while True:
146
+ yield K.get_session().run(val_data)
147
+
148
+ cbk.generator = gen()
149
+ else:
150
+ cbk.validation_data = val_data
151
+
152
+ def new_arrays(*args, **kwargs):
153
+ cbks = kwargs.get("callbacks", [])
154
+ val_inputs = kwargs.get("val_inputs")
155
+ val_targets = kwargs.get("val_targets")
156
+ # TODO: these could be generators, why index 0?
157
+ if val_inputs and val_targets:
158
+ for cbk in cbks:
159
+ set_wandb_attrs(cbk, (val_inputs[0], val_targets[0]))
160
+ return old_arrays(*args, **kwargs)
161
+
162
+ def new_generator(*args, **kwargs):
163
+ cbks = kwargs.get("callbacks", [])
164
+ val_data = kwargs.get("validation_data")
165
+ if val_data:
166
+ for cbk in cbks:
167
+ set_wandb_attrs(cbk, val_data)
168
+ return old_generator(*args, **kwargs)
169
+
170
+ def new_v2(*args, **kwargs):
171
+ cbks = kwargs.get("callbacks", [])
172
+ val_data = kwargs.get("validation_data")
173
+ if val_data:
174
+ for cbk in cbks:
175
+ set_wandb_attrs(cbk, val_data)
176
+ return old_v2(*args, **kwargs)
177
+
178
+ training_arrays.orig_fit_loop = old_arrays
179
+ training_arrays.fit_loop = new_arrays
180
+ training_generator.orig_fit_generator = old_generator
181
+ training_generator.fit_generator = new_generator
182
+ wandb.patched["keras"].append([f"{keras_engine}.training_arrays", "fit_loop"])
183
+ wandb.patched["keras"].append(
184
+ [f"{keras_engine}.training_generator", "fit_generator"]
185
+ )
186
+
187
+ if training_v2_1:
188
+ training_v2_1.Loop.fit = new_v2
189
+ wandb.patched["keras"].append(
190
+ ["tensorflow.python.keras.engine.training_v2.Loop", "fit"]
191
+ )
192
+ elif training_v2_2:
193
+ training.Model.fit = new_v2
194
+ wandb.patched["keras"].append([f"{keras_engine}.training.Model", "fit"])
195
+
196
+
197
+ def _array_has_dtype(array):
198
+ return hasattr(array, "dtype")
199
+
200
+
201
+ def _update_if_numeric(metrics, key, values):
202
+ if not _array_has_dtype(values):
203
+ _warn_not_logging(key)
204
+ return
205
+
206
+ if not is_numeric_array(values):
207
+ _warn_not_logging_non_numeric(key)
208
+ return
209
+
210
+ metrics[key] = wandb.Histogram(values)
211
+
212
+
213
+ def is_numeric_array(array):
214
+ return np.issubdtype(array.dtype, np.number)
215
+
216
+
217
+ def _warn_not_logging_non_numeric(name):
218
+ wandb.termwarn(
219
+ f"Non-numeric values found in layer: {name}, not logging this layer",
220
+ repeat=False,
221
+ )
222
+
223
+
224
+ def _warn_not_logging(name):
225
+ wandb.termwarn(
226
+ f"Layer {name} has undetermined datatype not logging this layer",
227
+ repeat=False,
228
+ )
229
+
230
+
231
+ tf_logger = tf.get_logger()
232
+
233
+ patch_tf_keras()
234
+
235
+
236
+ ### For gradient logging ###
237
+
238
+
239
+ def _get_custom_optimizer_parent_class():
240
+ from wandb.util import parse_version
241
+
242
+ if parse_version(tf.__version__) >= parse_version("2.9.0"):
243
+ custom_optimizer_parent_class = tf.keras.optimizers.legacy.Optimizer
244
+ else:
245
+ custom_optimizer_parent_class = tf.keras.optimizers.Optimizer
246
+
247
+ return custom_optimizer_parent_class
248
+
249
+
250
+ _custom_optimizer_parent_class = _get_custom_optimizer_parent_class()
251
+
252
+
253
+ class _CustomOptimizer(_custom_optimizer_parent_class):
254
+ def __init__(self):
255
+ super().__init__(name="CustomOptimizer")
256
+ self._resource_apply_dense = tf.function(self._resource_apply_dense)
257
+ self._resource_apply_sparse = tf.function(self._resource_apply_sparse)
258
+
259
+ def _resource_apply_dense(self, grad, var):
260
+ var.assign(grad)
261
+
262
+ # this needs to be implemented to prevent a NotImplementedError when
263
+ # using Lookup layers.
264
+ def _resource_apply_sparse(self, grad, var, indices):
265
+ pass
266
+
267
+ def get_config(self):
268
+ return super().get_config()
269
+
270
+
271
+ class _GradAccumulatorCallback(tf.keras.callbacks.Callback):
272
+ """Accumulates gradients during a fit() call when used in conjunction with the CustomOptimizer above."""
273
+
274
+ def set_model(self, model):
275
+ super().set_model(model)
276
+ self.og_weights = model.get_weights()
277
+ self.grads = [np.zeros(tuple(w.shape)) for w in model.trainable_weights]
278
+
279
+ def on_batch_end(self, batch, logs=None):
280
+ for g, w in zip(self.grads, self.model.trainable_weights):
281
+ g += w.numpy()
282
+ self.model.set_weights(self.og_weights)
283
+
284
+ def get_grads(self):
285
+ return [g.copy() for g in self.grads]
286
+
287
+
288
+ ###
289
+
290
+
291
+ class WandbCallback(tf.keras.callbacks.Callback):
292
+ """`WandbCallback` automatically integrates keras with wandb.
293
+
294
+ Example:
295
+ ```python
296
+ model.fit(
297
+ X_train,
298
+ y_train,
299
+ validation_data=(X_test, y_test),
300
+ callbacks=[WandbCallback()],
301
+ )
302
+ ```
303
+
304
+ `WandbCallback` will automatically log history data from any
305
+ metrics collected by keras: loss and anything passed into `keras_model.compile()`.
306
+
307
+ `WandbCallback` will set summary metrics for the run associated with the "best" training
308
+ step, where "best" is defined by the `monitor` and `mode` attributes. This defaults
309
+ to the epoch with the minimum `val_loss`. `WandbCallback` will by default save the model
310
+ associated with the best `epoch`.
311
+
312
+ `WandbCallback` can optionally log gradient and parameter histograms.
313
+
314
+ `WandbCallback` can optionally save training and validation data for wandb to visualize.
315
+
316
+ Arguments:
317
+ monitor: (str) name of metric to monitor. Defaults to `val_loss`.
318
+ mode: (str) one of {`auto`, `min`, `max`}.
319
+ `min` - save model when monitor is minimized
320
+ `max` - save model when monitor is maximized
321
+ `auto` - try to guess when to save the model (default).
322
+ save_model:
323
+ True - save a model when monitor beats all previous epochs
324
+ False - don't save models
325
+ save_graph: (boolean) if True save model graph to wandb (default to True).
326
+ save_weights_only: (boolean) if True, then only the model's weights will be
327
+ saved (`model.save_weights(filepath)`), else the full model
328
+ is saved (`model.save(filepath)`).
329
+ log_weights: (boolean) if True save histograms of the model's layer's weights.
330
+ log_gradients: (boolean) if True log histograms of the training gradients
331
+ training_data: (tuple) Same format `(X,y)` as passed to `model.fit`. This is needed
332
+ for calculating gradients - this is mandatory if `log_gradients` is `True`.
333
+ validation_data: (tuple) Same format `(X,y)` as passed to `model.fit`. A set of data
334
+ for wandb to visualize. If this is set, every epoch, wandb will
335
+ make a small number of predictions and save the results for later visualization. In case
336
+ you are working with image data, please also set `input_type` and `output_type` in order
337
+ to log correctly.
338
+ generator: (generator) a generator that returns validation data for wandb to visualize. This
339
+ generator should return tuples `(X,y)`. Either `validate_data` or generator should
340
+ be set for wandb to visualize specific data examples. In case you are working with image data,
341
+ please also set `input_type` and `output_type` in order to log correctly.
342
+ validation_steps: (int) if `validation_data` is a generator, how many
343
+ steps to run the generator for the full validation set.
344
+ labels: (list) If you are visualizing your data with wandb this list of labels
345
+ will convert numeric output to understandable string if you are building a
346
+ multiclass classifier. If you are making a binary classifier you can pass in
347
+ a list of two labels ["label for false", "label for true"]. If `validate_data`
348
+ and generator are both false, this won't do anything.
349
+ predictions: (int) the number of predictions to make for visualization each epoch, max
350
+ is 100.
351
+ input_type: (string) type of the model input to help visualization. can be one of:
352
+ (`image`, `images`, `segmentation_mask`, `auto`).
353
+ output_type: (string) type of the model output to help visualization. can be one of:
354
+ (`image`, `images`, `segmentation_mask`, `label`).
355
+ log_evaluation: (boolean) if True, save a Table containing validation data and the
356
+ model's predictions at each epoch. See `validation_indexes`,
357
+ `validation_row_processor`, and `output_row_processor` for additional details.
358
+ class_colors: ([float, float, float]) if the input or output is a segmentation mask,
359
+ an array containing an rgb tuple (range 0-1) for each class.
360
+ log_batch_frequency: (integer) if None, callback will log every epoch.
361
+ If set to integer, callback will log training metrics every `log_batch_frequency`
362
+ batches.
363
+ log_best_prefix: (string) if None, no extra summary metrics will be saved.
364
+ If set to a string, the monitored metric and epoch will be prepended with this value
365
+ and stored as summary metrics.
366
+ validation_indexes: ([wandb.data_types._TableLinkMixin]) an ordered list of index keys to associate
367
+ with each validation example. If log_evaluation is True and `validation_indexes` is provided,
368
+ then a Table of validation data will not be created and instead each prediction will
369
+ be associated with the row represented by the `TableLinkMixin`. The most common way to obtain
370
+ such keys are is use `Table.get_index()` which will return a list of row keys.
371
+ validation_row_processor: (Callable) a function to apply to the validation data, commonly used to visualize the data.
372
+ The function will receive an `ndx` (int) and a `row` (dict). If your model has a single input,
373
+ then `row["input"]` will be the input data for the row. Else, it will be keyed based on the name of the
374
+ input slot. If your fit function takes a single target, then `row["target"]` will be the target data for the row. Else,
375
+ it will be keyed based on the name of the output slots. For example, if your input data is a single ndarray,
376
+ but you wish to visualize the data as an Image, then you can provide `lambda ndx, row: {"img": wandb.Image(row["input"])}`
377
+ as the processor. Ignored if log_evaluation is False or `validation_indexes` are present.
378
+ output_row_processor: (Callable) same as `validation_row_processor`, but applied to the model's output. `row["output"]` will contain
379
+ the results of the model output.
380
+ infer_missing_processors: (bool) Determines if `validation_row_processor` and `output_row_processor`
381
+ should be inferred if missing. Defaults to True. If `labels` are provided, we will attempt to infer classification-type
382
+ processors where appropriate.
383
+ log_evaluation_frequency: (int) Determines the frequency which evaluation results will be logged. Default 0 (only at the end of training).
384
+ Set to 1 to log every epoch, 2 to log every other epoch, and so on. Has no effect when log_evaluation is False.
385
+ compute_flops: (bool) Compute the FLOPs of your Keras Sequential or Functional model in GigaFLOPs unit.
386
+ """
387
+
388
+ def __init__(
389
+ self,
390
+ monitor="val_loss",
391
+ verbose=0,
392
+ mode="auto",
393
+ save_weights_only=False,
394
+ log_weights=False,
395
+ log_gradients=False,
396
+ save_model=True,
397
+ training_data=None,
398
+ validation_data=None,
399
+ labels=None,
400
+ predictions=36,
401
+ generator=None,
402
+ input_type=None,
403
+ output_type=None,
404
+ log_evaluation=False,
405
+ validation_steps=None,
406
+ class_colors=None,
407
+ log_batch_frequency=None,
408
+ log_best_prefix="best_",
409
+ save_graph=True,
410
+ validation_indexes=None,
411
+ validation_row_processor=None,
412
+ prediction_row_processor=None,
413
+ infer_missing_processors=True,
414
+ log_evaluation_frequency=0,
415
+ compute_flops=False,
416
+ **kwargs,
417
+ ):
418
+ if wandb.run is None:
419
+ raise wandb.Error("You must call wandb.init() before WandbCallback()")
420
+
421
+ deprecate(
422
+ field_name=Deprecated.keras_callback,
423
+ warning_message=(
424
+ "WandbCallback is deprecated and will be removed in a future release. "
425
+ "Please use the WandbMetricsLogger, WandbModelCheckpoint, and WandbEvalCallback "
426
+ "callbacks instead. "
427
+ "See https://docs.wandb.ai/guides/integrations/keras for more information."
428
+ ),
429
+ )
430
+
431
+ with wandb.wandb_lib.telemetry.context(run=wandb.run) as tel:
432
+ tel.feature.keras = True
433
+ self.validation_data = None
434
+ # This is kept around for legacy reasons
435
+ if validation_data is not None:
436
+ if is_generator_like(validation_data):
437
+ generator = validation_data
438
+ else:
439
+ self.validation_data = validation_data
440
+ if labels is None:
441
+ labels = []
442
+ self.labels = labels
443
+ self.predictions = min(predictions, 100)
444
+
445
+ self.monitor = monitor
446
+ self.verbose = verbose
447
+ self.save_weights_only = save_weights_only
448
+ self.save_graph = save_graph
449
+
450
+ wandb.save("model-best.h5")
451
+ self.filepath = os.path.join(wandb.run.dir, "model-best.h5")
452
+ self.save_model = save_model
453
+ if save_model:
454
+ deprecate(
455
+ field_name=Deprecated.keras_callback__save_model,
456
+ warning_message=(
457
+ "The save_model argument by default saves the model in the HDF5 format that cannot save "
458
+ "custom objects like subclassed models and custom layers. This behavior will be deprecated "
459
+ "in a future release in favor of the SavedModel format. Meanwhile, the HDF5 model is saved "
460
+ "as W&B files and the SavedModel as W&B Artifacts."
461
+ ),
462
+ )
463
+
464
+ self.save_model_as_artifact = True
465
+ self.log_weights = log_weights
466
+ self.log_gradients = log_gradients
467
+ self.training_data = training_data
468
+ self.generator = generator
469
+ self._graph_rendered = False
470
+
471
+ data_type = kwargs.get("data_type", None)
472
+ if data_type is not None:
473
+ deprecate(
474
+ field_name=Deprecated.keras_callback__data_type,
475
+ warning_message=(
476
+ "The data_type argument of wandb.keras.WandbCallback is deprecated "
477
+ "and will be removed in a future release. Please use input_type instead.\n"
478
+ "Setting input_type = data_type."
479
+ ),
480
+ )
481
+ input_type = data_type
482
+ self.input_type = input_type
483
+ self.output_type = output_type
484
+ self.log_evaluation = log_evaluation
485
+ self.validation_steps = validation_steps
486
+ self.class_colors = np.array(class_colors) if class_colors is not None else None
487
+ self.log_batch_frequency = log_batch_frequency
488
+ self.log_best_prefix = log_best_prefix
489
+ self.compute_flops = compute_flops
490
+
491
+ self._prediction_batch_size = None
492
+
493
+ if self.log_gradients:
494
+ if int(tf.__version__.split(".")[0]) < 2:
495
+ raise Exception("Gradient logging requires tensorflow 2.0 or higher.")
496
+ if self.training_data is None:
497
+ raise ValueError(
498
+ "training_data argument is required for gradient logging."
499
+ )
500
+ if isinstance(self.training_data, (list, tuple)):
501
+ if len(self.training_data) != 2:
502
+ raise ValueError("training data must be a tuple of length two")
503
+ self._training_data_x, self._training_data_y = self.training_data
504
+ else:
505
+ self._training_data_x = (
506
+ self.training_data
507
+ ) # generator, tf.data.Dataset etc
508
+ self._training_data_y = None
509
+
510
+ # From Keras
511
+ if mode not in ["auto", "min", "max"]:
512
+ print(f"WandbCallback mode {mode} is unknown, fallback to auto mode.")
513
+ mode = "auto"
514
+
515
+ if mode == "min":
516
+ self.monitor_op = operator.lt
517
+ self.best = float("inf")
518
+ elif mode == "max":
519
+ self.monitor_op = operator.gt
520
+ self.best = float("-inf")
521
+ else:
522
+ if "acc" in self.monitor or self.monitor.startswith("fmeasure"):
523
+ self.monitor_op = operator.gt
524
+ self.best = float("-inf")
525
+ else:
526
+ self.monitor_op = operator.lt
527
+ self.best = float("inf")
528
+ # Get the previous best metric for resumed runs
529
+ previous_best = wandb.run.summary.get(f"{self.log_best_prefix}{self.monitor}")
530
+ if previous_best is not None:
531
+ self.best = previous_best
532
+
533
+ self._validation_data_logger = None
534
+ self._validation_indexes = validation_indexes
535
+ self._validation_row_processor = validation_row_processor
536
+ self._prediction_row_processor = prediction_row_processor
537
+ self._infer_missing_processors = infer_missing_processors
538
+ self._log_evaluation_frequency = log_evaluation_frequency
539
+ self._model_trained_since_last_eval = False
540
+
541
+ def _build_grad_accumulator_model(self):
542
+ inputs = self.model.inputs
543
+ outputs = self.model(inputs)
544
+ grad_acc_model = tf.keras.models.Model(inputs, outputs)
545
+ grad_acc_model.compile(loss=self.model.loss, optimizer=_CustomOptimizer())
546
+
547
+ # make sure magic doesn't think this is a user model
548
+ grad_acc_model._wandb_internal_model = True
549
+
550
+ self._grad_accumulator_model = grad_acc_model
551
+ self._grad_accumulator_callback = _GradAccumulatorCallback()
552
+
553
+ def _implements_train_batch_hooks(self):
554
+ return self.log_batch_frequency is not None
555
+
556
+ def _implements_test_batch_hooks(self):
557
+ return self.log_batch_frequency is not None
558
+
559
+ def _implements_predict_batch_hooks(self):
560
+ return self.log_batch_frequency is not None
561
+
562
+ def set_params(self, params):
563
+ self.params = params
564
+
565
+ def set_model(self, model):
566
+ super().set_model(model)
567
+ if self.input_type == "auto" and len(model.inputs) == 1:
568
+ self.input_type = wandb.util.guess_data_type(
569
+ model.inputs[0].shape, risky=True
570
+ )
571
+ if self.input_type and self.output_type is None and len(model.outputs) == 1:
572
+ self.output_type = wandb.util.guess_data_type(model.outputs[0].shape)
573
+ if self.log_gradients:
574
+ self._build_grad_accumulator_model()
575
+
576
+ def _attempt_evaluation_log(self, commit=True):
577
+ if self.log_evaluation and self._validation_data_logger:
578
+ try:
579
+ if not self.model:
580
+ wandb.termwarn("WandbCallback unable to read model from trainer")
581
+ else:
582
+ self._validation_data_logger.log_predictions(
583
+ predictions=self._validation_data_logger.make_predictions(
584
+ self.model.predict
585
+ ),
586
+ commit=commit,
587
+ )
588
+ self._model_trained_since_last_eval = False
589
+ except Exception as e:
590
+ wandb.termwarn("Error during prediction logging for epoch: " + str(e))
591
+
592
+ def on_epoch_end(self, epoch, logs=None):
593
+ if logs is None:
594
+ logs = {}
595
+ if self.log_weights:
596
+ wandb.log(self._log_weights(), commit=False)
597
+
598
+ if self.log_gradients:
599
+ wandb.log(self._log_gradients(), commit=False)
600
+
601
+ if self.input_type in (
602
+ "image",
603
+ "images",
604
+ "segmentation_mask",
605
+ ) or self.output_type in ("image", "images", "segmentation_mask"):
606
+ if self.generator:
607
+ self.validation_data = next(self.generator)
608
+ if self.validation_data is None:
609
+ wandb.termwarn(
610
+ "No validation_data set, pass a generator to the callback."
611
+ )
612
+ elif self.validation_data and len(self.validation_data) > 0:
613
+ wandb.log(
614
+ {"examples": self._log_images(num_images=self.predictions)},
615
+ commit=False,
616
+ )
617
+
618
+ if (
619
+ self._log_evaluation_frequency > 0
620
+ and epoch % self._log_evaluation_frequency == 0
621
+ ):
622
+ self._attempt_evaluation_log(commit=False)
623
+
624
+ wandb.log({"epoch": epoch}, commit=False)
625
+ wandb.log(logs, commit=True)
626
+
627
+ self.current = logs.get(self.monitor)
628
+ if self.current and self.monitor_op(self.current, self.best):
629
+ if self.log_best_prefix:
630
+ wandb.run.summary[f"{self.log_best_prefix}{self.monitor}"] = (
631
+ self.current
632
+ )
633
+ wandb.run.summary["{}{}".format(self.log_best_prefix, "epoch")] = epoch
634
+ if self.verbose and not self.save_model:
635
+ print(
636
+ "Epoch %05d: %s improved from %0.5f to %0.5f"
637
+ % (epoch, self.monitor, self.best, self.current)
638
+ )
639
+ if self.save_model:
640
+ self._save_model(epoch)
641
+
642
+ if self.save_model and self.save_model_as_artifact:
643
+ self._save_model_as_artifact(epoch)
644
+
645
+ self.best = self.current
646
+
647
+ # This is what keras used pre tensorflow.keras
648
+ def on_batch_begin(self, batch, logs=None):
649
+ pass
650
+
651
+ # This is what keras used pre tensorflow.keras
652
+ def on_batch_end(self, batch, logs=None):
653
+ if self.save_graph and not self._graph_rendered:
654
+ # Couldn't do this in train_begin because keras may still not be built
655
+ wandb.run.summary["graph"] = wandb.Graph.from_keras(self.model)
656
+ self._graph_rendered = True
657
+
658
+ if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
659
+ wandb.log(logs, commit=True)
660
+
661
+ def on_train_batch_begin(self, batch, logs=None):
662
+ self._model_trained_since_last_eval = True
663
+
664
+ def on_train_batch_end(self, batch, logs=None):
665
+ if self.save_graph and not self._graph_rendered:
666
+ # Couldn't do this in train_begin because keras may still not be built
667
+ wandb.run.summary["graph"] = wandb.Graph.from_keras(self.model)
668
+ self._graph_rendered = True
669
+
670
+ if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
671
+ wandb.log(logs, commit=True)
672
+
673
+ def on_test_begin(self, logs=None):
674
+ pass
675
+
676
+ def on_test_end(self, logs=None):
677
+ pass
678
+
679
+ def on_test_batch_begin(self, batch, logs=None):
680
+ pass
681
+
682
+ def on_test_batch_end(self, batch, logs=None):
683
+ pass
684
+
685
+ def on_train_begin(self, logs=None):
686
+ if self.log_evaluation:
687
+ try:
688
+ validation_data = None
689
+ if self.validation_data:
690
+ validation_data = self.validation_data
691
+ elif self.generator:
692
+ if not self.validation_steps:
693
+ wandb.termwarn(
694
+ "WandbCallback is unable to log validation data. "
695
+ "When using a generator for validation_data, you must pass validation_steps"
696
+ )
697
+ else:
698
+ x = None
699
+ y_true = None
700
+ for _ in range(self.validation_steps):
701
+ bx, by_true = next(self.generator)
702
+ if x is None:
703
+ x, y_true = bx, by_true
704
+ else:
705
+ x, y_true = (
706
+ np.append(x, bx, axis=0),
707
+ np.append(y_true, by_true, axis=0),
708
+ )
709
+ validation_data = (x, y_true)
710
+ else:
711
+ wandb.termwarn(
712
+ "WandbCallback is unable to read validation_data from trainer "
713
+ "and therefore cannot log validation data. Ensure Keras is properly "
714
+ "patched by calling `from wandb.keras import WandbCallback` at the top of your script."
715
+ )
716
+ if validation_data:
717
+ self._validation_data_logger = ValidationDataLogger(
718
+ inputs=validation_data[0],
719
+ targets=validation_data[1],
720
+ indexes=self._validation_indexes,
721
+ validation_row_processor=self._validation_row_processor,
722
+ prediction_row_processor=self._prediction_row_processor,
723
+ class_labels=self.labels,
724
+ infer_missing_processors=self._infer_missing_processors,
725
+ )
726
+ except Exception as e:
727
+ wandb.termwarn(
728
+ "Error initializing ValidationDataLogger in WandbCallback. "
729
+ f"Skipping logging validation data. Error: {str(e)}"
730
+ )
731
+
732
+ if self.compute_flops and _can_compute_flops():
733
+ try:
734
+ wandb.summary["GFLOPs"] = self.get_flops()
735
+ except Exception as e:
736
+ wandb.termwarn("Unable to compute FLOPs for this model.")
737
+ logger.exception(e)
738
+
739
+ def on_train_end(self, logs=None):
740
+ if self._model_trained_since_last_eval:
741
+ self._attempt_evaluation_log()
742
+
743
+ def on_predict_begin(self, logs=None):
744
+ pass
745
+
746
+ def on_predict_end(self, logs=None):
747
+ pass
748
+
749
+ def on_predict_batch_begin(self, batch, logs=None):
750
+ pass
751
+
752
+ def on_predict_batch_end(self, batch, logs=None):
753
+ pass
754
+
755
+ def _logits_to_captions(self, logits):
756
+ if logits[0].shape[-1] == 1:
757
+ # Scalar output from the model
758
+ # TODO: handle validation_y
759
+ if len(self.labels) == 2:
760
+ # User has named true and false
761
+ captions = [
762
+ self.labels[1] if logits[0] > 0.5 else self.labels[0]
763
+ for logit in logits
764
+ ]
765
+ else:
766
+ if len(self.labels) != 0:
767
+ wandb.termwarn(
768
+ "keras model is producing a single output, "
769
+ 'so labels should be a length two array: ["False label", "True label"].'
770
+ )
771
+ captions = [logit[0] for logit in logits]
772
+ else:
773
+ # Vector output from the model
774
+ # TODO: handle validation_y
775
+ labels = np.argmax(np.stack(logits), axis=1)
776
+
777
+ if len(self.labels) > 0:
778
+ # User has named the categories in self.labels
779
+ captions = []
780
+ for label in labels:
781
+ try:
782
+ captions.append(self.labels[label])
783
+ except IndexError:
784
+ captions.append(label)
785
+ else:
786
+ captions = labels
787
+ return captions
788
+
789
+ def _masks_to_pixels(self, masks):
790
+ # if its a binary mask, just return it as grayscale instead of picking the argmax
791
+ if len(masks[0].shape) == 2 or masks[0].shape[-1] == 1:
792
+ return masks
793
+ class_colors = (
794
+ self.class_colors
795
+ if self.class_colors is not None
796
+ else np.array(wandb.util.class_colors(masks[0].shape[2]))
797
+ )
798
+ imgs = class_colors[np.argmax(masks, axis=-1)]
799
+ return imgs
800
+
801
+ def _log_images(self, num_images=36):
802
+ validation_X = self.validation_data[0] # noqa: N806
803
+ validation_y = self.validation_data[1]
804
+
805
+ validation_length = len(validation_X)
806
+
807
+ if validation_length > num_images:
808
+ # pick some data at random
809
+ indices = np.random.choice(validation_length, num_images, replace=False)
810
+ else:
811
+ indices = range(validation_length)
812
+
813
+ test_data = []
814
+ test_output = []
815
+ for i in indices:
816
+ test_example = validation_X[i]
817
+ test_data.append(test_example)
818
+ test_output.append(validation_y[i])
819
+
820
+ if self.model.stateful:
821
+ predictions = self.model.predict(np.stack(test_data), batch_size=1)
822
+ self.model.reset_states()
823
+ else:
824
+ predictions = self.model.predict(
825
+ np.stack(test_data), batch_size=self._prediction_batch_size
826
+ )
827
+ if len(predictions) != len(test_data):
828
+ self._prediction_batch_size = 1
829
+ predictions = self.model.predict(
830
+ np.stack(test_data), batch_size=self._prediction_batch_size
831
+ )
832
+
833
+ if self.input_type == "label":
834
+ if self.output_type in ("image", "images", "segmentation_mask"):
835
+ captions = self._logits_to_captions(test_data)
836
+ output_image_data = (
837
+ self._masks_to_pixels(predictions)
838
+ if self.output_type == "segmentation_mask"
839
+ else predictions
840
+ )
841
+ reference_image_data = (
842
+ self._masks_to_pixels(test_output)
843
+ if self.output_type == "segmentation_mask"
844
+ else test_output
845
+ )
846
+ output_images = [
847
+ wandb.Image(data, caption=captions[i], grouping=2)
848
+ for i, data in enumerate(output_image_data)
849
+ ]
850
+ reference_images = [
851
+ wandb.Image(data, caption=captions[i])
852
+ for i, data in enumerate(reference_image_data)
853
+ ]
854
+ return list(chain.from_iterable(zip(output_images, reference_images)))
855
+ elif self.input_type in ("image", "images", "segmentation_mask"):
856
+ input_image_data = (
857
+ self._masks_to_pixels(test_data)
858
+ if self.input_type == "segmentation_mask"
859
+ else test_data
860
+ )
861
+ if self.output_type == "label":
862
+ # we just use the predicted label as the caption for now
863
+ captions = self._logits_to_captions(predictions)
864
+ return [
865
+ wandb.Image(data, caption=captions[i])
866
+ for i, data in enumerate(test_data)
867
+ ]
868
+ elif self.output_type in ("image", "images", "segmentation_mask"):
869
+ output_image_data = (
870
+ self._masks_to_pixels(predictions)
871
+ if self.output_type == "segmentation_mask"
872
+ else predictions
873
+ )
874
+ reference_image_data = (
875
+ self._masks_to_pixels(test_output)
876
+ if self.output_type == "segmentation_mask"
877
+ else test_output
878
+ )
879
+ input_images = [
880
+ wandb.Image(data, grouping=3)
881
+ for i, data in enumerate(input_image_data)
882
+ ]
883
+ output_images = [
884
+ wandb.Image(data) for i, data in enumerate(output_image_data)
885
+ ]
886
+ reference_images = [
887
+ wandb.Image(data) for i, data in enumerate(reference_image_data)
888
+ ]
889
+ return list(
890
+ chain.from_iterable(
891
+ zip(input_images, output_images, reference_images)
892
+ )
893
+ )
894
+ else:
895
+ # unknown output, just log the input images
896
+ return [wandb.Image(img) for img in test_data]
897
+ elif self.output_type in ("image", "images", "segmentation_mask"):
898
+ # unknown input, just log the predicted and reference outputs without captions
899
+ output_image_data = (
900
+ self._masks_to_pixels(predictions)
901
+ if self.output_type == "segmentation_mask"
902
+ else predictions
903
+ )
904
+ reference_image_data = (
905
+ self._masks_to_pixels(test_output)
906
+ if self.output_type == "segmentation_mask"
907
+ else test_output
908
+ )
909
+ output_images = [
910
+ wandb.Image(data, grouping=2)
911
+ for i, data in enumerate(output_image_data)
912
+ ]
913
+ reference_images = [
914
+ wandb.Image(data) for i, data in enumerate(reference_image_data)
915
+ ]
916
+ return list(chain.from_iterable(zip(output_images, reference_images)))
917
+
918
+ def _log_weights(self):
919
+ metrics = {}
920
+ for layer in self.model.layers:
921
+ weights = layer.get_weights()
922
+ if len(weights) == 1:
923
+ _update_if_numeric(
924
+ metrics, "parameters/" + layer.name + ".weights", weights[0]
925
+ )
926
+ elif len(weights) == 2:
927
+ _update_if_numeric(
928
+ metrics, "parameters/" + layer.name + ".weights", weights[0]
929
+ )
930
+ _update_if_numeric(
931
+ metrics, "parameters/" + layer.name + ".bias", weights[1]
932
+ )
933
+ return metrics
934
+
935
+ def _log_gradients(self):
936
+ # Suppress callback warnings grad accumulator
937
+ og_level = tf_logger.level
938
+ tf_logger.setLevel("ERROR")
939
+
940
+ self._grad_accumulator_model.fit(
941
+ self._training_data_x,
942
+ self._training_data_y,
943
+ verbose=0,
944
+ callbacks=[self._grad_accumulator_callback],
945
+ )
946
+ tf_logger.setLevel(og_level)
947
+ weights = self.model.trainable_weights
948
+ grads = self._grad_accumulator_callback.grads
949
+ metrics = {}
950
+ for weight, grad in zip(weights, grads):
951
+ metrics["gradients/" + weight.name.split(":")[0] + ".gradient"] = (
952
+ wandb.Histogram(grad)
953
+ )
954
+ return metrics
955
+
956
+ def _log_dataframe(self):
957
+ x, y_true, y_pred = None, None, None
958
+
959
+ if self.validation_data:
960
+ x, y_true = self.validation_data[0], self.validation_data[1]
961
+ y_pred = self.model.predict(x)
962
+ elif self.generator:
963
+ if not self.validation_steps:
964
+ wandb.termwarn(
965
+ "when using a generator for validation data with dataframes, "
966
+ "you must pass validation_steps. skipping"
967
+ )
968
+ return None
969
+
970
+ for _ in range(self.validation_steps):
971
+ bx, by_true = next(self.generator)
972
+ by_pred = self.model.predict(bx)
973
+ if x is None:
974
+ x, y_true, y_pred = bx, by_true, by_pred
975
+ else:
976
+ x, y_true, y_pred = (
977
+ np.append(x, bx, axis=0),
978
+ np.append(y_true, by_true, axis=0),
979
+ np.append(y_pred, by_pred, axis=0),
980
+ )
981
+
982
+ if self.input_type in ("image", "images") and self.output_type == "label":
983
+ return wandb.image_categorizer_dataframe(
984
+ x=x, y_true=y_true, y_pred=y_pred, labels=self.labels
985
+ )
986
+ elif (
987
+ self.input_type in ("image", "images")
988
+ and self.output_type == "segmentation_mask"
989
+ ):
990
+ return wandb.image_segmentation_dataframe(
991
+ x=x,
992
+ y_true=y_true,
993
+ y_pred=y_pred,
994
+ labels=self.labels,
995
+ class_colors=self.class_colors,
996
+ )
997
+ else:
998
+ wandb.termwarn(
999
+ f"unknown dataframe type for input_type={self.input_type} and output_type={self.output_type}"
1000
+ )
1001
+ return None
1002
+
1003
+ def _save_model(self, epoch):
1004
+ if wandb.run.disabled:
1005
+ return
1006
+ if self.verbose > 0:
1007
+ print(
1008
+ "Epoch %05d: %s improved from %0.5f to %0.5f,"
1009
+ " saving model to %s"
1010
+ % (epoch, self.monitor, self.best, self.current, self.filepath)
1011
+ )
1012
+
1013
+ try:
1014
+ if self.save_weights_only:
1015
+ self.model.save_weights(self.filepath, overwrite=True)
1016
+ else:
1017
+ self.model.save(self.filepath, overwrite=True)
1018
+ # Was getting `RuntimeError: Unable to create link` in TF 1.13.1
1019
+ # also saw `TypeError: can't pickle _thread.RLock objects`
1020
+ except (ImportError, RuntimeError, TypeError, AttributeError) as e:
1021
+ wandb.termerror(
1022
+ "Can't save model in the h5py format. The model will be saved as "
1023
+ "as an W&B Artifact in the 'tf' format."
1024
+ )
1025
+ logger.exception(e)
1026
+
1027
+ def _save_model_as_artifact(self, epoch):
1028
+ if wandb.run.disabled:
1029
+ return
1030
+
1031
+ # Save the model in the SavedModel format.
1032
+ # TODO: Replace this manual artifact creation with the `log_model` method
1033
+ # after `log_model` is released from beta.
1034
+ self.model.save(self.filepath[:-3], overwrite=True, save_format="tf")
1035
+
1036
+ # Log the model as artifact.
1037
+ name = wandb.util.make_artifact_name_safe(f"model-{wandb.run.name}")
1038
+ model_artifact = wandb.Artifact(name, type="model")
1039
+ model_artifact.add_dir(self.filepath[:-3])
1040
+ wandb.run.log_artifact(model_artifact, aliases=["latest", f"epoch_{epoch}"])
1041
+
1042
+ # Remove the SavedModel from wandb dir as we don't want to log it to save memory.
1043
+ shutil.rmtree(self.filepath[:-3])
1044
+
1045
+ def get_flops(self) -> float:
1046
+ """Calculate FLOPS [GFLOPs] for a tf.keras.Model or tf.keras.Sequential model in inference mode.
1047
+
1048
+ It uses tf.compat.v1.profiler under the hood.
1049
+ """
1050
+ if not hasattr(self, "model"):
1051
+ raise wandb.Error("self.model must be set before using this method.")
1052
+
1053
+ if not isinstance(
1054
+ self.model, (tf.keras.models.Sequential, tf.keras.models.Model)
1055
+ ):
1056
+ raise ValueError(
1057
+ "Calculating FLOPS is only supported for "
1058
+ "`tf.keras.Model` and `tf.keras.Sequential` instances."
1059
+ )
1060
+
1061
+ from tensorflow.python.framework.convert_to_constants import (
1062
+ convert_variables_to_constants_v2_as_graph,
1063
+ )
1064
+
1065
+ # Compute FLOPs for one sample
1066
+ batch_size = 1
1067
+ inputs = [
1068
+ tf.TensorSpec([batch_size] + inp.shape[1:], inp.dtype)
1069
+ for inp in self.model.inputs
1070
+ ]
1071
+
1072
+ # convert tf.keras model into frozen graph to count FLOPs about operations used at inference
1073
+ real_model = tf.function(self.model).get_concrete_function(inputs)
1074
+ frozen_func, _ = convert_variables_to_constants_v2_as_graph(real_model)
1075
+
1076
+ # Calculate FLOPs with tf.profiler
1077
+ run_meta = tf.compat.v1.RunMetadata()
1078
+ opts = (
1079
+ tf.compat.v1.profiler.ProfileOptionBuilder(
1080
+ tf.compat.v1.profiler.ProfileOptionBuilder().float_operation()
1081
+ )
1082
+ .with_empty_output()
1083
+ .build()
1084
+ )
1085
+
1086
+ flops = tf.compat.v1.profiler.profile(
1087
+ graph=frozen_func.graph, run_meta=run_meta, cmd="scope", options=opts
1088
+ )
1089
+
1090
+ # convert to GFLOPs
1091
+ return (flops.total_float_ops / 1e9) / 2