wandb 0.21.2__py3-none-macosx_12_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (904) hide show
  1. package_readme.md +97 -0
  2. wandb/__init__.py +248 -0
  3. wandb/__init__.pyi +1230 -0
  4. wandb/__main__.py +3 -0
  5. wandb/_iterutils.py +65 -0
  6. wandb/_pydantic/__init__.py +30 -0
  7. wandb/_pydantic/base.py +128 -0
  8. wandb/_pydantic/utils.py +80 -0
  9. wandb/_pydantic/v1_compat.py +284 -0
  10. wandb/agents/__init__.py +0 -0
  11. wandb/agents/pyagent.py +386 -0
  12. wandb/analytics/__init__.py +3 -0
  13. wandb/analytics/sentry.py +267 -0
  14. wandb/apis/__init__.py +48 -0
  15. wandb/apis/attrs.py +50 -0
  16. wandb/apis/importers/__init__.py +1 -0
  17. wandb/apis/importers/internals/internal.py +382 -0
  18. wandb/apis/importers/internals/protocols.py +103 -0
  19. wandb/apis/importers/internals/util.py +78 -0
  20. wandb/apis/importers/mlflow.py +254 -0
  21. wandb/apis/importers/validation.py +108 -0
  22. wandb/apis/importers/wandb.py +1608 -0
  23. wandb/apis/internal.py +239 -0
  24. wandb/apis/normalize.py +81 -0
  25. wandb/apis/paginator.py +138 -0
  26. wandb/apis/public/__init__.py +35 -0
  27. wandb/apis/public/api.py +2449 -0
  28. wandb/apis/public/artifacts.py +1046 -0
  29. wandb/apis/public/automations.py +85 -0
  30. wandb/apis/public/const.py +4 -0
  31. wandb/apis/public/files.py +402 -0
  32. wandb/apis/public/history.py +201 -0
  33. wandb/apis/public/integrations.py +203 -0
  34. wandb/apis/public/jobs.py +742 -0
  35. wandb/apis/public/projects.py +276 -0
  36. wandb/apis/public/query_generator.py +176 -0
  37. wandb/apis/public/registries/__init__.py +0 -0
  38. wandb/apis/public/registries/_freezable_list.py +179 -0
  39. wandb/apis/public/registries/_utils.py +138 -0
  40. wandb/apis/public/registries/registries_search.py +347 -0
  41. wandb/apis/public/registries/registry.py +358 -0
  42. wandb/apis/public/reports.py +595 -0
  43. wandb/apis/public/runs.py +1216 -0
  44. wandb/apis/public/sweeps.py +440 -0
  45. wandb/apis/public/teams.py +235 -0
  46. wandb/apis/public/users.py +177 -0
  47. wandb/apis/public/utils.py +210 -0
  48. wandb/apis/reports/__init__.py +1 -0
  49. wandb/apis/reports/v1/__init__.py +8 -0
  50. wandb/apis/reports/v2/__init__.py +8 -0
  51. wandb/apis/workspaces/__init__.py +8 -0
  52. wandb/automations/__init__.py +73 -0
  53. wandb/automations/_filters/__init__.py +40 -0
  54. wandb/automations/_filters/expressions.py +181 -0
  55. wandb/automations/_filters/operators.py +258 -0
  56. wandb/automations/_filters/run_metrics.py +330 -0
  57. wandb/automations/_generated/__init__.py +177 -0
  58. wandb/automations/_generated/create_automation.py +17 -0
  59. wandb/automations/_generated/create_generic_webhook_integration.py +43 -0
  60. wandb/automations/_generated/delete_automation.py +15 -0
  61. wandb/automations/_generated/enums.py +35 -0
  62. wandb/automations/_generated/fragments.py +358 -0
  63. wandb/automations/_generated/generic_webhook_integrations_by_entity.py +22 -0
  64. wandb/automations/_generated/get_automations.py +24 -0
  65. wandb/automations/_generated/get_automations_by_entity.py +26 -0
  66. wandb/automations/_generated/input_types.py +104 -0
  67. wandb/automations/_generated/integrations_by_entity.py +22 -0
  68. wandb/automations/_generated/operations.py +647 -0
  69. wandb/automations/_generated/slack_integrations_by_entity.py +22 -0
  70. wandb/automations/_generated/update_automation.py +17 -0
  71. wandb/automations/_utils.py +235 -0
  72. wandb/automations/_validators.py +165 -0
  73. wandb/automations/actions.py +218 -0
  74. wandb/automations/automations.py +85 -0
  75. wandb/automations/events.py +285 -0
  76. wandb/automations/integrations.py +45 -0
  77. wandb/automations/scopes.py +78 -0
  78. wandb/beta/workflows.py +324 -0
  79. wandb/bin/gpu_stats +0 -0
  80. wandb/bin/wandb-core +0 -0
  81. wandb/cli/__init__.py +0 -0
  82. wandb/cli/beta.py +175 -0
  83. wandb/cli/cli.py +2883 -0
  84. wandb/data_types.py +66 -0
  85. wandb/docker/__init__.py +290 -0
  86. wandb/docker/names.py +40 -0
  87. wandb/docker/wandb-entrypoint.sh +33 -0
  88. wandb/env.py +535 -0
  89. wandb/errors/__init__.py +17 -0
  90. wandb/errors/errors.py +40 -0
  91. wandb/errors/links.py +73 -0
  92. wandb/errors/term.py +415 -0
  93. wandb/errors/util.py +57 -0
  94. wandb/errors/warnings.py +2 -0
  95. wandb/filesync/__init__.py +0 -0
  96. wandb/filesync/dir_watcher.py +404 -0
  97. wandb/filesync/stats.py +100 -0
  98. wandb/filesync/step_checksum.py +142 -0
  99. wandb/filesync/step_prepare.py +179 -0
  100. wandb/filesync/step_upload.py +287 -0
  101. wandb/filesync/upload_job.py +142 -0
  102. wandb/integration/__init__.py +0 -0
  103. wandb/integration/catboost/__init__.py +5 -0
  104. wandb/integration/catboost/catboost.py +182 -0
  105. wandb/integration/cohere/__init__.py +3 -0
  106. wandb/integration/cohere/cohere.py +21 -0
  107. wandb/integration/cohere/resolver.py +347 -0
  108. wandb/integration/diffusers/__init__.py +3 -0
  109. wandb/integration/diffusers/autologger.py +76 -0
  110. wandb/integration/diffusers/pipeline_resolver.py +50 -0
  111. wandb/integration/diffusers/resolvers/__init__.py +9 -0
  112. wandb/integration/diffusers/resolvers/multimodal.py +881 -0
  113. wandb/integration/diffusers/resolvers/utils.py +102 -0
  114. wandb/integration/fastai/__init__.py +243 -0
  115. wandb/integration/gym/__init__.py +98 -0
  116. wandb/integration/huggingface/__init__.py +3 -0
  117. wandb/integration/huggingface/huggingface.py +18 -0
  118. wandb/integration/huggingface/resolver.py +213 -0
  119. wandb/integration/keras/__init__.py +11 -0
  120. wandb/integration/keras/callbacks/__init__.py +5 -0
  121. wandb/integration/keras/callbacks/metrics_logger.py +129 -0
  122. wandb/integration/keras/callbacks/model_checkpoint.py +188 -0
  123. wandb/integration/keras/callbacks/tables_builder.py +228 -0
  124. wandb/integration/keras/keras.py +1086 -0
  125. wandb/integration/kfp/__init__.py +6 -0
  126. wandb/integration/kfp/helpers.py +28 -0
  127. wandb/integration/kfp/kfp_patch.py +335 -0
  128. wandb/integration/kfp/wandb_logging.py +182 -0
  129. wandb/integration/langchain/__init__.py +3 -0
  130. wandb/integration/langchain/wandb_tracer.py +49 -0
  131. wandb/integration/lightgbm/__init__.py +239 -0
  132. wandb/integration/lightning/__init__.py +0 -0
  133. wandb/integration/lightning/fabric/__init__.py +3 -0
  134. wandb/integration/lightning/fabric/logger.py +763 -0
  135. wandb/integration/metaflow/__init__.py +9 -0
  136. wandb/integration/metaflow/data_pandas.py +74 -0
  137. wandb/integration/metaflow/data_pytorch.py +75 -0
  138. wandb/integration/metaflow/data_sklearn.py +76 -0
  139. wandb/integration/metaflow/errors.py +13 -0
  140. wandb/integration/metaflow/metaflow.py +327 -0
  141. wandb/integration/openai/__init__.py +3 -0
  142. wandb/integration/openai/fine_tuning.py +480 -0
  143. wandb/integration/openai/openai.py +22 -0
  144. wandb/integration/openai/resolver.py +240 -0
  145. wandb/integration/prodigy/__init__.py +3 -0
  146. wandb/integration/prodigy/prodigy.py +291 -0
  147. wandb/integration/sacred/__init__.py +117 -0
  148. wandb/integration/sagemaker/__init__.py +14 -0
  149. wandb/integration/sagemaker/auth.py +29 -0
  150. wandb/integration/sagemaker/config.py +58 -0
  151. wandb/integration/sagemaker/files.py +2 -0
  152. wandb/integration/sagemaker/resources.py +63 -0
  153. wandb/integration/sb3/__init__.py +3 -0
  154. wandb/integration/sb3/sb3.py +147 -0
  155. wandb/integration/sklearn/__init__.py +37 -0
  156. wandb/integration/sklearn/calculate/__init__.py +32 -0
  157. wandb/integration/sklearn/calculate/calibration_curves.py +125 -0
  158. wandb/integration/sklearn/calculate/class_proportions.py +68 -0
  159. wandb/integration/sklearn/calculate/confusion_matrix.py +93 -0
  160. wandb/integration/sklearn/calculate/decision_boundaries.py +40 -0
  161. wandb/integration/sklearn/calculate/elbow_curve.py +55 -0
  162. wandb/integration/sklearn/calculate/feature_importances.py +67 -0
  163. wandb/integration/sklearn/calculate/learning_curve.py +64 -0
  164. wandb/integration/sklearn/calculate/outlier_candidates.py +69 -0
  165. wandb/integration/sklearn/calculate/residuals.py +86 -0
  166. wandb/integration/sklearn/calculate/silhouette.py +118 -0
  167. wandb/integration/sklearn/calculate/summary_metrics.py +62 -0
  168. wandb/integration/sklearn/plot/__init__.py +35 -0
  169. wandb/integration/sklearn/plot/classifier.py +329 -0
  170. wandb/integration/sklearn/plot/clusterer.py +146 -0
  171. wandb/integration/sklearn/plot/regressor.py +121 -0
  172. wandb/integration/sklearn/plot/shared.py +91 -0
  173. wandb/integration/sklearn/utils.py +184 -0
  174. wandb/integration/tensorboard/__init__.py +10 -0
  175. wandb/integration/tensorboard/log.py +351 -0
  176. wandb/integration/tensorboard/monkeypatch.py +186 -0
  177. wandb/integration/tensorflow/__init__.py +5 -0
  178. wandb/integration/tensorflow/estimator_hook.py +54 -0
  179. wandb/integration/torch/__init__.py +0 -0
  180. wandb/integration/torch/wandb_torch.py +554 -0
  181. wandb/integration/ultralytics/__init__.py +11 -0
  182. wandb/integration/ultralytics/bbox_utils.py +215 -0
  183. wandb/integration/ultralytics/callback.py +528 -0
  184. wandb/integration/ultralytics/classification_utils.py +83 -0
  185. wandb/integration/ultralytics/mask_utils.py +202 -0
  186. wandb/integration/ultralytics/pose_utils.py +103 -0
  187. wandb/integration/weave/__init__.py +6 -0
  188. wandb/integration/weave/interface.py +49 -0
  189. wandb/integration/weave/weave.py +63 -0
  190. wandb/integration/xgboost/__init__.py +11 -0
  191. wandb/integration/xgboost/xgboost.py +189 -0
  192. wandb/integration/yolov8/__init__.py +0 -0
  193. wandb/integration/yolov8/yolov8.py +284 -0
  194. wandb/jupyter.py +538 -0
  195. wandb/mpmain/__init__.py +0 -0
  196. wandb/mpmain/__main__.py +1 -0
  197. wandb/old/__init__.py +0 -0
  198. wandb/old/core.py +53 -0
  199. wandb/old/settings.py +176 -0
  200. wandb/old/summary.py +438 -0
  201. wandb/plot/__init__.py +30 -0
  202. wandb/plot/bar.py +71 -0
  203. wandb/plot/confusion_matrix.py +185 -0
  204. wandb/plot/custom_chart.py +147 -0
  205. wandb/plot/histogram.py +66 -0
  206. wandb/plot/line.py +75 -0
  207. wandb/plot/line_series.py +173 -0
  208. wandb/plot/pr_curve.py +186 -0
  209. wandb/plot/roc_curve.py +163 -0
  210. wandb/plot/scatter.py +66 -0
  211. wandb/plot/utils.py +184 -0
  212. wandb/plot/viz.py +41 -0
  213. wandb/proto/__init__.py +0 -0
  214. wandb/proto/v3/__init__.py +0 -0
  215. wandb/proto/v3/wandb_base_pb2.py +55 -0
  216. wandb/proto/v3/wandb_internal_pb2.py +1728 -0
  217. wandb/proto/v3/wandb_server_pb2.py +228 -0
  218. wandb/proto/v3/wandb_settings_pb2.py +122 -0
  219. wandb/proto/v3/wandb_telemetry_pb2.py +106 -0
  220. wandb/proto/v4/__init__.py +0 -0
  221. wandb/proto/v4/wandb_base_pb2.py +30 -0
  222. wandb/proto/v4/wandb_internal_pb2.py +382 -0
  223. wandb/proto/v4/wandb_server_pb2.py +67 -0
  224. wandb/proto/v4/wandb_settings_pb2.py +47 -0
  225. wandb/proto/v4/wandb_telemetry_pb2.py +41 -0
  226. wandb/proto/v5/wandb_base_pb2.py +31 -0
  227. wandb/proto/v5/wandb_internal_pb2.py +383 -0
  228. wandb/proto/v5/wandb_server_pb2.py +68 -0
  229. wandb/proto/v5/wandb_settings_pb2.py +48 -0
  230. wandb/proto/v5/wandb_telemetry_pb2.py +42 -0
  231. wandb/proto/v6/wandb_base_pb2.py +41 -0
  232. wandb/proto/v6/wandb_internal_pb2.py +393 -0
  233. wandb/proto/v6/wandb_server_pb2.py +78 -0
  234. wandb/proto/v6/wandb_settings_pb2.py +58 -0
  235. wandb/proto/v6/wandb_telemetry_pb2.py +52 -0
  236. wandb/proto/wandb_base_pb2.py +12 -0
  237. wandb/proto/wandb_deprecated.py +59 -0
  238. wandb/proto/wandb_generate_deprecated.py +30 -0
  239. wandb/proto/wandb_generate_proto.py +49 -0
  240. wandb/proto/wandb_internal_pb2.py +18 -0
  241. wandb/proto/wandb_server_pb2.py +12 -0
  242. wandb/proto/wandb_settings_pb2.py +12 -0
  243. wandb/proto/wandb_telemetry_pb2.py +12 -0
  244. wandb/py.typed +0 -0
  245. wandb/sdk/__init__.py +37 -0
  246. wandb/sdk/artifacts/__init__.py +0 -0
  247. wandb/sdk/artifacts/_factories.py +17 -0
  248. wandb/sdk/artifacts/_generated/__init__.py +508 -0
  249. wandb/sdk/artifacts/_generated/add_aliases.py +21 -0
  250. wandb/sdk/artifacts/_generated/artifact_by_id.py +17 -0
  251. wandb/sdk/artifacts/_generated/artifact_by_name.py +22 -0
  252. wandb/sdk/artifacts/_generated/artifact_collection_membership_file_urls.py +43 -0
  253. wandb/sdk/artifacts/_generated/artifact_collection_membership_files.py +43 -0
  254. wandb/sdk/artifacts/_generated/artifact_created_by.py +47 -0
  255. wandb/sdk/artifacts/_generated/artifact_file_urls.py +22 -0
  256. wandb/sdk/artifacts/_generated/artifact_type.py +31 -0
  257. wandb/sdk/artifacts/_generated/artifact_used_by.py +43 -0
  258. wandb/sdk/artifacts/_generated/artifact_version_files.py +36 -0
  259. wandb/sdk/artifacts/_generated/artifact_via_membership_by_name.py +26 -0
  260. wandb/sdk/artifacts/_generated/create_artifact_collection_tag_assignments.py +36 -0
  261. wandb/sdk/artifacts/_generated/delete_aliases.py +21 -0
  262. wandb/sdk/artifacts/_generated/delete_artifact.py +28 -0
  263. wandb/sdk/artifacts/_generated/delete_artifact_collection_tag_assignments.py +25 -0
  264. wandb/sdk/artifacts/_generated/delete_artifact_portfolio.py +35 -0
  265. wandb/sdk/artifacts/_generated/delete_artifact_sequence.py +35 -0
  266. wandb/sdk/artifacts/_generated/enums.py +22 -0
  267. wandb/sdk/artifacts/_generated/fetch_artifact_manifest.py +38 -0
  268. wandb/sdk/artifacts/_generated/fetch_linked_artifacts.py +67 -0
  269. wandb/sdk/artifacts/_generated/fetch_registries.py +32 -0
  270. wandb/sdk/artifacts/_generated/fragments.py +459 -0
  271. wandb/sdk/artifacts/_generated/input_types.py +46 -0
  272. wandb/sdk/artifacts/_generated/link_artifact.py +27 -0
  273. wandb/sdk/artifacts/_generated/move_artifact_collection.py +35 -0
  274. wandb/sdk/artifacts/_generated/operations.py +1223 -0
  275. wandb/sdk/artifacts/_generated/project_artifact_collection.py +101 -0
  276. wandb/sdk/artifacts/_generated/project_artifact_collections.py +33 -0
  277. wandb/sdk/artifacts/_generated/project_artifact_type.py +24 -0
  278. wandb/sdk/artifacts/_generated/project_artifact_types.py +24 -0
  279. wandb/sdk/artifacts/_generated/project_artifacts.py +42 -0
  280. wandb/sdk/artifacts/_generated/registry_collections.py +34 -0
  281. wandb/sdk/artifacts/_generated/registry_versions.py +34 -0
  282. wandb/sdk/artifacts/_generated/run_input_artifacts.py +51 -0
  283. wandb/sdk/artifacts/_generated/run_output_artifacts.py +51 -0
  284. wandb/sdk/artifacts/_generated/unlink_artifact.py +25 -0
  285. wandb/sdk/artifacts/_generated/update_artifact.py +26 -0
  286. wandb/sdk/artifacts/_generated/update_artifact_portfolio.py +35 -0
  287. wandb/sdk/artifacts/_generated/update_artifact_sequence.py +35 -0
  288. wandb/sdk/artifacts/_graphql_fragments.py +19 -0
  289. wandb/sdk/artifacts/_internal_artifact.py +54 -0
  290. wandb/sdk/artifacts/_validators.py +309 -0
  291. wandb/sdk/artifacts/artifact.py +2702 -0
  292. wandb/sdk/artifacts/artifact_download_logger.py +45 -0
  293. wandb/sdk/artifacts/artifact_file_cache.py +251 -0
  294. wandb/sdk/artifacts/artifact_instance_cache.py +17 -0
  295. wandb/sdk/artifacts/artifact_manifest.py +76 -0
  296. wandb/sdk/artifacts/artifact_manifest_entry.py +258 -0
  297. wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
  298. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +94 -0
  299. wandb/sdk/artifacts/artifact_saver.py +277 -0
  300. wandb/sdk/artifacts/artifact_state.py +13 -0
  301. wandb/sdk/artifacts/artifact_ttl.py +9 -0
  302. wandb/sdk/artifacts/exceptions.py +71 -0
  303. wandb/sdk/artifacts/staging.py +27 -0
  304. wandb/sdk/artifacts/storage_handler.py +62 -0
  305. wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
  306. wandb/sdk/artifacts/storage_handlers/azure_handler.py +214 -0
  307. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
  308. wandb/sdk/artifacts/storage_handlers/http_handler.py +114 -0
  309. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +142 -0
  310. wandb/sdk/artifacts/storage_handlers/multi_handler.py +56 -0
  311. wandb/sdk/artifacts/storage_handlers/s3_handler.py +339 -0
  312. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +68 -0
  313. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +131 -0
  314. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +74 -0
  315. wandb/sdk/artifacts/storage_layout.py +8 -0
  316. wandb/sdk/artifacts/storage_policies/__init__.py +4 -0
  317. wandb/sdk/artifacts/storage_policies/register.py +1 -0
  318. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +580 -0
  319. wandb/sdk/artifacts/storage_policy.py +75 -0
  320. wandb/sdk/backend/__init__.py +0 -0
  321. wandb/sdk/backend/backend.py +57 -0
  322. wandb/sdk/data_types/__init__.py +0 -0
  323. wandb/sdk/data_types/_dtypes.py +914 -0
  324. wandb/sdk/data_types/_private.py +10 -0
  325. wandb/sdk/data_types/audio.py +208 -0
  326. wandb/sdk/data_types/base_types/__init__.py +0 -0
  327. wandb/sdk/data_types/base_types/json_metadata.py +55 -0
  328. wandb/sdk/data_types/base_types/media.py +339 -0
  329. wandb/sdk/data_types/base_types/wb_value.py +295 -0
  330. wandb/sdk/data_types/bokeh.py +87 -0
  331. wandb/sdk/data_types/graph.py +439 -0
  332. wandb/sdk/data_types/helper_types/__init__.py +0 -0
  333. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +327 -0
  334. wandb/sdk/data_types/helper_types/classes.py +159 -0
  335. wandb/sdk/data_types/helper_types/image_mask.py +251 -0
  336. wandb/sdk/data_types/histogram.py +107 -0
  337. wandb/sdk/data_types/html.py +165 -0
  338. wandb/sdk/data_types/image.py +974 -0
  339. wandb/sdk/data_types/molecule.py +250 -0
  340. wandb/sdk/data_types/object_3d.py +495 -0
  341. wandb/sdk/data_types/plotly.py +95 -0
  342. wandb/sdk/data_types/saved_model.py +435 -0
  343. wandb/sdk/data_types/table.py +1468 -0
  344. wandb/sdk/data_types/table_decorators.py +108 -0
  345. wandb/sdk/data_types/trace_tree.py +440 -0
  346. wandb/sdk/data_types/utils.py +260 -0
  347. wandb/sdk/data_types/video.py +303 -0
  348. wandb/sdk/integration_utils/__init__.py +0 -0
  349. wandb/sdk/integration_utils/auto_logging.py +232 -0
  350. wandb/sdk/integration_utils/data_logging.py +475 -0
  351. wandb/sdk/interface/__init__.py +0 -0
  352. wandb/sdk/interface/constants.py +4 -0
  353. wandb/sdk/interface/interface.py +1056 -0
  354. wandb/sdk/interface/interface_queue.py +40 -0
  355. wandb/sdk/interface/interface_shared.py +471 -0
  356. wandb/sdk/interface/interface_sock.py +49 -0
  357. wandb/sdk/interface/summary_record.py +67 -0
  358. wandb/sdk/internal/__init__.py +0 -0
  359. wandb/sdk/internal/_generated/__init__.py +15 -0
  360. wandb/sdk/internal/_generated/enums.py +4 -0
  361. wandb/sdk/internal/_generated/input_types.py +4 -0
  362. wandb/sdk/internal/_generated/operations.py +15 -0
  363. wandb/sdk/internal/_generated/server_features_query.py +27 -0
  364. wandb/sdk/internal/context.py +89 -0
  365. wandb/sdk/internal/datastore.py +293 -0
  366. wandb/sdk/internal/file_pusher.py +177 -0
  367. wandb/sdk/internal/file_stream.py +686 -0
  368. wandb/sdk/internal/handler.py +854 -0
  369. wandb/sdk/internal/incremental_table_util.py +53 -0
  370. wandb/sdk/internal/internal_api.py +4723 -0
  371. wandb/sdk/internal/job_builder.py +639 -0
  372. wandb/sdk/internal/profiler.py +79 -0
  373. wandb/sdk/internal/progress.py +77 -0
  374. wandb/sdk/internal/run.py +27 -0
  375. wandb/sdk/internal/sample.py +70 -0
  376. wandb/sdk/internal/sender.py +1692 -0
  377. wandb/sdk/internal/sender_config.py +203 -0
  378. wandb/sdk/internal/settings_static.py +120 -0
  379. wandb/sdk/internal/tb_watcher.py +519 -0
  380. wandb/sdk/internal/thread_local_settings.py +18 -0
  381. wandb/sdk/launch/__init__.py +15 -0
  382. wandb/sdk/launch/_launch.py +331 -0
  383. wandb/sdk/launch/_launch_add.py +255 -0
  384. wandb/sdk/launch/_project_spec.py +565 -0
  385. wandb/sdk/launch/agent/__init__.py +5 -0
  386. wandb/sdk/launch/agent/agent.py +931 -0
  387. wandb/sdk/launch/agent/config.py +296 -0
  388. wandb/sdk/launch/agent/job_status_tracker.py +55 -0
  389. wandb/sdk/launch/agent/run_queue_item_file_saver.py +39 -0
  390. wandb/sdk/launch/builder/__init__.py +0 -0
  391. wandb/sdk/launch/builder/abstract.py +156 -0
  392. wandb/sdk/launch/builder/build.py +296 -0
  393. wandb/sdk/launch/builder/context_manager.py +235 -0
  394. wandb/sdk/launch/builder/docker_builder.py +177 -0
  395. wandb/sdk/launch/builder/kaniko_builder.py +595 -0
  396. wandb/sdk/launch/builder/noop.py +58 -0
  397. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +188 -0
  398. wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
  399. wandb/sdk/launch/create_job.py +541 -0
  400. wandb/sdk/launch/environment/abstract.py +29 -0
  401. wandb/sdk/launch/environment/aws_environment.py +322 -0
  402. wandb/sdk/launch/environment/azure_environment.py +105 -0
  403. wandb/sdk/launch/environment/gcp_environment.py +334 -0
  404. wandb/sdk/launch/environment/local_environment.py +65 -0
  405. wandb/sdk/launch/errors.py +13 -0
  406. wandb/sdk/launch/git_reference.py +109 -0
  407. wandb/sdk/launch/inputs/files.py +148 -0
  408. wandb/sdk/launch/inputs/internal.py +314 -0
  409. wandb/sdk/launch/inputs/manage.py +113 -0
  410. wandb/sdk/launch/inputs/schema.py +40 -0
  411. wandb/sdk/launch/loader.py +249 -0
  412. wandb/sdk/launch/registry/abstract.py +48 -0
  413. wandb/sdk/launch/registry/anon.py +29 -0
  414. wandb/sdk/launch/registry/azure_container_registry.py +124 -0
  415. wandb/sdk/launch/registry/elastic_container_registry.py +192 -0
  416. wandb/sdk/launch/registry/google_artifact_registry.py +219 -0
  417. wandb/sdk/launch/registry/local_registry.py +65 -0
  418. wandb/sdk/launch/runner/__init__.py +0 -0
  419. wandb/sdk/launch/runner/abstract.py +185 -0
  420. wandb/sdk/launch/runner/kubernetes_monitor.py +473 -0
  421. wandb/sdk/launch/runner/kubernetes_runner.py +1285 -0
  422. wandb/sdk/launch/runner/local_container.py +301 -0
  423. wandb/sdk/launch/runner/local_process.py +78 -0
  424. wandb/sdk/launch/runner/sagemaker_runner.py +424 -0
  425. wandb/sdk/launch/runner/vertex_runner.py +225 -0
  426. wandb/sdk/launch/sweeps/__init__.py +37 -0
  427. wandb/sdk/launch/sweeps/scheduler.py +739 -0
  428. wandb/sdk/launch/sweeps/scheduler_sweep.py +90 -0
  429. wandb/sdk/launch/sweeps/utils.py +324 -0
  430. wandb/sdk/launch/utils.py +746 -0
  431. wandb/sdk/launch/wandb_reference.py +138 -0
  432. wandb/sdk/lib/__init__.py +5 -0
  433. wandb/sdk/lib/apikey.py +334 -0
  434. wandb/sdk/lib/asyncio_compat.py +213 -0
  435. wandb/sdk/lib/asyncio_manager.py +252 -0
  436. wandb/sdk/lib/capped_dict.py +26 -0
  437. wandb/sdk/lib/config_util.py +101 -0
  438. wandb/sdk/lib/console_capture.py +219 -0
  439. wandb/sdk/lib/credentials.py +141 -0
  440. wandb/sdk/lib/deprecate.py +27 -0
  441. wandb/sdk/lib/disabled.py +30 -0
  442. wandb/sdk/lib/exit_hooks.py +54 -0
  443. wandb/sdk/lib/file_stream_utils.py +118 -0
  444. wandb/sdk/lib/filenames.py +64 -0
  445. wandb/sdk/lib/filesystem.py +372 -0
  446. wandb/sdk/lib/fsm.py +165 -0
  447. wandb/sdk/lib/gitlib.py +240 -0
  448. wandb/sdk/lib/gql_request.py +65 -0
  449. wandb/sdk/lib/handler_util.py +21 -0
  450. wandb/sdk/lib/hashutil.py +106 -0
  451. wandb/sdk/lib/import_hooks.py +275 -0
  452. wandb/sdk/lib/interrupt.py +37 -0
  453. wandb/sdk/lib/ipython.py +126 -0
  454. wandb/sdk/lib/json_util.py +75 -0
  455. wandb/sdk/lib/lazyloader.py +63 -0
  456. wandb/sdk/lib/module.py +72 -0
  457. wandb/sdk/lib/paths.py +106 -0
  458. wandb/sdk/lib/preinit.py +42 -0
  459. wandb/sdk/lib/printer.py +571 -0
  460. wandb/sdk/lib/printer_asyncio.py +48 -0
  461. wandb/sdk/lib/progress.py +320 -0
  462. wandb/sdk/lib/proto_util.py +90 -0
  463. wandb/sdk/lib/redirect.py +876 -0
  464. wandb/sdk/lib/retry.py +395 -0
  465. wandb/sdk/lib/run_moment.py +82 -0
  466. wandb/sdk/lib/runid.py +12 -0
  467. wandb/sdk/lib/server.py +58 -0
  468. wandb/sdk/lib/service/ipc_support.py +13 -0
  469. wandb/sdk/lib/service/service_client.py +106 -0
  470. wandb/sdk/lib/service/service_connection.py +192 -0
  471. wandb/sdk/lib/service/service_port_file.py +105 -0
  472. wandb/sdk/lib/service/service_process.py +111 -0
  473. wandb/sdk/lib/service/service_token.py +181 -0
  474. wandb/sdk/lib/sparkline.py +44 -0
  475. wandb/sdk/lib/telemetry.py +100 -0
  476. wandb/sdk/lib/timed_input.py +133 -0
  477. wandb/sdk/lib/timer.py +19 -0
  478. wandb/sdk/lib/wb_logging.py +161 -0
  479. wandb/sdk/mailbox/__init__.py +23 -0
  480. wandb/sdk/mailbox/mailbox.py +143 -0
  481. wandb/sdk/mailbox/mailbox_handle.py +132 -0
  482. wandb/sdk/mailbox/response_handle.py +99 -0
  483. wandb/sdk/mailbox/wait_with_progress.py +100 -0
  484. wandb/sdk/projects/_generated/__init__.py +47 -0
  485. wandb/sdk/projects/_generated/delete_project.py +22 -0
  486. wandb/sdk/projects/_generated/enums.py +4 -0
  487. wandb/sdk/projects/_generated/fetch_registry.py +22 -0
  488. wandb/sdk/projects/_generated/fragments.py +41 -0
  489. wandb/sdk/projects/_generated/input_types.py +13 -0
  490. wandb/sdk/projects/_generated/operations.py +88 -0
  491. wandb/sdk/projects/_generated/rename_project.py +27 -0
  492. wandb/sdk/projects/_generated/upsert_registry_project.py +27 -0
  493. wandb/sdk/verify/__init__.py +0 -0
  494. wandb/sdk/verify/verify.py +555 -0
  495. wandb/sdk/wandb_alerts.py +12 -0
  496. wandb/sdk/wandb_config.py +323 -0
  497. wandb/sdk/wandb_helper.py +54 -0
  498. wandb/sdk/wandb_init.py +1581 -0
  499. wandb/sdk/wandb_login.py +332 -0
  500. wandb/sdk/wandb_metric.py +112 -0
  501. wandb/sdk/wandb_require.py +88 -0
  502. wandb/sdk/wandb_require_helpers.py +44 -0
  503. wandb/sdk/wandb_run.py +4088 -0
  504. wandb/sdk/wandb_settings.py +2105 -0
  505. wandb/sdk/wandb_setup.py +560 -0
  506. wandb/sdk/wandb_summary.py +150 -0
  507. wandb/sdk/wandb_sweep.py +120 -0
  508. wandb/sdk/wandb_sync.py +71 -0
  509. wandb/sdk/wandb_watch.py +146 -0
  510. wandb/sklearn.py +35 -0
  511. wandb/sync/__init__.py +3 -0
  512. wandb/sync/sync.py +452 -0
  513. wandb/trigger.py +29 -0
  514. wandb/util.py +2040 -0
  515. wandb/vendor/__init__.py +0 -0
  516. wandb/vendor/gql-0.2.0/setup.py +40 -0
  517. wandb/vendor/gql-0.2.0/tests/__init__.py +0 -0
  518. wandb/vendor/gql-0.2.0/tests/starwars/__init__.py +0 -0
  519. wandb/vendor/gql-0.2.0/tests/starwars/fixtures.py +96 -0
  520. wandb/vendor/gql-0.2.0/tests/starwars/schema.py +146 -0
  521. wandb/vendor/gql-0.2.0/tests/starwars/test_dsl.py +293 -0
  522. wandb/vendor/gql-0.2.0/tests/starwars/test_query.py +355 -0
  523. wandb/vendor/gql-0.2.0/tests/starwars/test_validation.py +171 -0
  524. wandb/vendor/gql-0.2.0/tests/test_client.py +31 -0
  525. wandb/vendor/gql-0.2.0/tests/test_transport.py +89 -0
  526. wandb/vendor/gql-0.2.0/wandb_gql/__init__.py +4 -0
  527. wandb/vendor/gql-0.2.0/wandb_gql/client.py +75 -0
  528. wandb/vendor/gql-0.2.0/wandb_gql/dsl.py +152 -0
  529. wandb/vendor/gql-0.2.0/wandb_gql/gql.py +10 -0
  530. wandb/vendor/gql-0.2.0/wandb_gql/transport/__init__.py +0 -0
  531. wandb/vendor/gql-0.2.0/wandb_gql/transport/http.py +6 -0
  532. wandb/vendor/gql-0.2.0/wandb_gql/transport/local_schema.py +15 -0
  533. wandb/vendor/gql-0.2.0/wandb_gql/transport/requests.py +46 -0
  534. wandb/vendor/gql-0.2.0/wandb_gql/utils.py +21 -0
  535. wandb/vendor/graphql-core-1.1/setup.py +86 -0
  536. wandb/vendor/graphql-core-1.1/wandb_graphql/__init__.py +287 -0
  537. wandb/vendor/graphql-core-1.1/wandb_graphql/error/__init__.py +6 -0
  538. wandb/vendor/graphql-core-1.1/wandb_graphql/error/base.py +42 -0
  539. wandb/vendor/graphql-core-1.1/wandb_graphql/error/format_error.py +11 -0
  540. wandb/vendor/graphql-core-1.1/wandb_graphql/error/located_error.py +29 -0
  541. wandb/vendor/graphql-core-1.1/wandb_graphql/error/syntax_error.py +36 -0
  542. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__init__.py +26 -0
  543. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/base.py +311 -0
  544. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executor.py +398 -0
  545. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/__init__.py +0 -0
  546. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/asyncio.py +53 -0
  547. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/gevent.py +22 -0
  548. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/process.py +32 -0
  549. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/sync.py +7 -0
  550. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/thread.py +35 -0
  551. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/utils.py +6 -0
  552. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/__init__.py +0 -0
  553. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/executor.py +66 -0
  554. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/fragment.py +252 -0
  555. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/resolver.py +151 -0
  556. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/utils.py +7 -0
  557. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/middleware.py +57 -0
  558. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/values.py +145 -0
  559. wandb/vendor/graphql-core-1.1/wandb_graphql/graphql.py +60 -0
  560. wandb/vendor/graphql-core-1.1/wandb_graphql/language/__init__.py +0 -0
  561. wandb/vendor/graphql-core-1.1/wandb_graphql/language/ast.py +1349 -0
  562. wandb/vendor/graphql-core-1.1/wandb_graphql/language/base.py +19 -0
  563. wandb/vendor/graphql-core-1.1/wandb_graphql/language/lexer.py +435 -0
  564. wandb/vendor/graphql-core-1.1/wandb_graphql/language/location.py +30 -0
  565. wandb/vendor/graphql-core-1.1/wandb_graphql/language/parser.py +779 -0
  566. wandb/vendor/graphql-core-1.1/wandb_graphql/language/printer.py +193 -0
  567. wandb/vendor/graphql-core-1.1/wandb_graphql/language/source.py +18 -0
  568. wandb/vendor/graphql-core-1.1/wandb_graphql/language/visitor.py +222 -0
  569. wandb/vendor/graphql-core-1.1/wandb_graphql/language/visitor_meta.py +82 -0
  570. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__init__.py +0 -0
  571. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/cached_property.py +17 -0
  572. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/contain_subset.py +28 -0
  573. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/default_ordered_dict.py +40 -0
  574. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/ordereddict.py +8 -0
  575. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/pair_set.py +43 -0
  576. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/version.py +78 -0
  577. wandb/vendor/graphql-core-1.1/wandb_graphql/type/__init__.py +67 -0
  578. wandb/vendor/graphql-core-1.1/wandb_graphql/type/definition.py +619 -0
  579. wandb/vendor/graphql-core-1.1/wandb_graphql/type/directives.py +132 -0
  580. wandb/vendor/graphql-core-1.1/wandb_graphql/type/introspection.py +440 -0
  581. wandb/vendor/graphql-core-1.1/wandb_graphql/type/scalars.py +131 -0
  582. wandb/vendor/graphql-core-1.1/wandb_graphql/type/schema.py +100 -0
  583. wandb/vendor/graphql-core-1.1/wandb_graphql/type/typemap.py +145 -0
  584. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__init__.py +0 -0
  585. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/assert_valid_name.py +9 -0
  586. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_from_value.py +65 -0
  587. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_to_code.py +49 -0
  588. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_to_dict.py +24 -0
  589. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/base.py +75 -0
  590. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/build_ast_schema.py +291 -0
  591. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/build_client_schema.py +250 -0
  592. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/concat_ast.py +9 -0
  593. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/extend_schema.py +357 -0
  594. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/get_field_def.py +27 -0
  595. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/get_operation_ast.py +21 -0
  596. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/introspection_query.py +90 -0
  597. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/is_valid_literal_value.py +67 -0
  598. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/is_valid_value.py +66 -0
  599. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/quoted_or_list.py +21 -0
  600. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/schema_printer.py +168 -0
  601. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/suggestion_list.py +56 -0
  602. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_comparators.py +69 -0
  603. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_from_ast.py +21 -0
  604. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_info.py +149 -0
  605. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/value_from_ast.py +69 -0
  606. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/__init__.py +4 -0
  607. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__init__.py +79 -0
  608. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/arguments_of_correct_type.py +24 -0
  609. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/base.py +8 -0
  610. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/default_values_of_correct_type.py +44 -0
  611. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/fields_on_correct_type.py +113 -0
  612. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/fragments_on_composite_types.py +33 -0
  613. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_argument_names.py +70 -0
  614. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_directives.py +97 -0
  615. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_fragment_names.py +19 -0
  616. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_type_names.py +43 -0
  617. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/lone_anonymous_operation.py +23 -0
  618. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_fragment_cycles.py +59 -0
  619. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_undefined_variables.py +36 -0
  620. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_unused_fragments.py +38 -0
  621. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_unused_variables.py +37 -0
  622. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/overlapping_fields_can_be_merged.py +529 -0
  623. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/possible_fragment_spreads.py +44 -0
  624. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/provided_non_null_arguments.py +46 -0
  625. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/scalar_leafs.py +33 -0
  626. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_argument_names.py +32 -0
  627. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_fragment_names.py +28 -0
  628. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_input_field_names.py +33 -0
  629. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_operation_names.py +31 -0
  630. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_variable_names.py +27 -0
  631. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/variables_are_input_types.py +21 -0
  632. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/variables_in_allowed_position.py +53 -0
  633. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/validation.py +158 -0
  634. wandb/vendor/promise-2.3.0/conftest.py +30 -0
  635. wandb/vendor/promise-2.3.0/setup.py +64 -0
  636. wandb/vendor/promise-2.3.0/tests/__init__.py +0 -0
  637. wandb/vendor/promise-2.3.0/tests/conftest.py +8 -0
  638. wandb/vendor/promise-2.3.0/tests/test_awaitable.py +32 -0
  639. wandb/vendor/promise-2.3.0/tests/test_awaitable_35.py +47 -0
  640. wandb/vendor/promise-2.3.0/tests/test_benchmark.py +116 -0
  641. wandb/vendor/promise-2.3.0/tests/test_complex_threads.py +23 -0
  642. wandb/vendor/promise-2.3.0/tests/test_dataloader.py +452 -0
  643. wandb/vendor/promise-2.3.0/tests/test_dataloader_awaitable_35.py +99 -0
  644. wandb/vendor/promise-2.3.0/tests/test_dataloader_extra.py +65 -0
  645. wandb/vendor/promise-2.3.0/tests/test_extra.py +670 -0
  646. wandb/vendor/promise-2.3.0/tests/test_issues.py +132 -0
  647. wandb/vendor/promise-2.3.0/tests/test_promise_list.py +70 -0
  648. wandb/vendor/promise-2.3.0/tests/test_spec.py +584 -0
  649. wandb/vendor/promise-2.3.0/tests/test_thread_safety.py +115 -0
  650. wandb/vendor/promise-2.3.0/tests/utils.py +3 -0
  651. wandb/vendor/promise-2.3.0/wandb_promise/__init__.py +38 -0
  652. wandb/vendor/promise-2.3.0/wandb_promise/async_.py +135 -0
  653. wandb/vendor/promise-2.3.0/wandb_promise/compat.py +32 -0
  654. wandb/vendor/promise-2.3.0/wandb_promise/dataloader.py +326 -0
  655. wandb/vendor/promise-2.3.0/wandb_promise/iterate_promise.py +12 -0
  656. wandb/vendor/promise-2.3.0/wandb_promise/promise.py +848 -0
  657. wandb/vendor/promise-2.3.0/wandb_promise/promise_list.py +151 -0
  658. wandb/vendor/promise-2.3.0/wandb_promise/pyutils/__init__.py +0 -0
  659. wandb/vendor/promise-2.3.0/wandb_promise/pyutils/version.py +83 -0
  660. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/__init__.py +0 -0
  661. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/asyncio.py +22 -0
  662. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/gevent.py +21 -0
  663. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/immediate.py +27 -0
  664. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/thread.py +18 -0
  665. wandb/vendor/promise-2.3.0/wandb_promise/utils.py +56 -0
  666. wandb/vendor/pygments/__init__.py +90 -0
  667. wandb/vendor/pygments/cmdline.py +568 -0
  668. wandb/vendor/pygments/console.py +74 -0
  669. wandb/vendor/pygments/filter.py +74 -0
  670. wandb/vendor/pygments/filters/__init__.py +350 -0
  671. wandb/vendor/pygments/formatter.py +95 -0
  672. wandb/vendor/pygments/formatters/__init__.py +153 -0
  673. wandb/vendor/pygments/formatters/_mapping.py +85 -0
  674. wandb/vendor/pygments/formatters/bbcode.py +109 -0
  675. wandb/vendor/pygments/formatters/html.py +851 -0
  676. wandb/vendor/pygments/formatters/img.py +600 -0
  677. wandb/vendor/pygments/formatters/irc.py +182 -0
  678. wandb/vendor/pygments/formatters/latex.py +482 -0
  679. wandb/vendor/pygments/formatters/other.py +160 -0
  680. wandb/vendor/pygments/formatters/rtf.py +147 -0
  681. wandb/vendor/pygments/formatters/svg.py +153 -0
  682. wandb/vendor/pygments/formatters/terminal.py +136 -0
  683. wandb/vendor/pygments/formatters/terminal256.py +309 -0
  684. wandb/vendor/pygments/lexer.py +871 -0
  685. wandb/vendor/pygments/lexers/__init__.py +329 -0
  686. wandb/vendor/pygments/lexers/_asy_builtins.py +1645 -0
  687. wandb/vendor/pygments/lexers/_cl_builtins.py +232 -0
  688. wandb/vendor/pygments/lexers/_cocoa_builtins.py +72 -0
  689. wandb/vendor/pygments/lexers/_csound_builtins.py +1346 -0
  690. wandb/vendor/pygments/lexers/_lasso_builtins.py +5327 -0
  691. wandb/vendor/pygments/lexers/_lua_builtins.py +295 -0
  692. wandb/vendor/pygments/lexers/_mapping.py +500 -0
  693. wandb/vendor/pygments/lexers/_mql_builtins.py +1172 -0
  694. wandb/vendor/pygments/lexers/_openedge_builtins.py +2547 -0
  695. wandb/vendor/pygments/lexers/_php_builtins.py +4756 -0
  696. wandb/vendor/pygments/lexers/_postgres_builtins.py +621 -0
  697. wandb/vendor/pygments/lexers/_scilab_builtins.py +3094 -0
  698. wandb/vendor/pygments/lexers/_sourcemod_builtins.py +1163 -0
  699. wandb/vendor/pygments/lexers/_stan_builtins.py +532 -0
  700. wandb/vendor/pygments/lexers/_stata_builtins.py +419 -0
  701. wandb/vendor/pygments/lexers/_tsql_builtins.py +1004 -0
  702. wandb/vendor/pygments/lexers/_vim_builtins.py +1939 -0
  703. wandb/vendor/pygments/lexers/actionscript.py +240 -0
  704. wandb/vendor/pygments/lexers/agile.py +24 -0
  705. wandb/vendor/pygments/lexers/algebra.py +221 -0
  706. wandb/vendor/pygments/lexers/ambient.py +76 -0
  707. wandb/vendor/pygments/lexers/ampl.py +87 -0
  708. wandb/vendor/pygments/lexers/apl.py +101 -0
  709. wandb/vendor/pygments/lexers/archetype.py +318 -0
  710. wandb/vendor/pygments/lexers/asm.py +641 -0
  711. wandb/vendor/pygments/lexers/automation.py +374 -0
  712. wandb/vendor/pygments/lexers/basic.py +500 -0
  713. wandb/vendor/pygments/lexers/bibtex.py +160 -0
  714. wandb/vendor/pygments/lexers/business.py +612 -0
  715. wandb/vendor/pygments/lexers/c_cpp.py +252 -0
  716. wandb/vendor/pygments/lexers/c_like.py +541 -0
  717. wandb/vendor/pygments/lexers/capnproto.py +78 -0
  718. wandb/vendor/pygments/lexers/chapel.py +102 -0
  719. wandb/vendor/pygments/lexers/clean.py +288 -0
  720. wandb/vendor/pygments/lexers/compiled.py +34 -0
  721. wandb/vendor/pygments/lexers/configs.py +833 -0
  722. wandb/vendor/pygments/lexers/console.py +114 -0
  723. wandb/vendor/pygments/lexers/crystal.py +393 -0
  724. wandb/vendor/pygments/lexers/csound.py +366 -0
  725. wandb/vendor/pygments/lexers/css.py +689 -0
  726. wandb/vendor/pygments/lexers/d.py +251 -0
  727. wandb/vendor/pygments/lexers/dalvik.py +125 -0
  728. wandb/vendor/pygments/lexers/data.py +555 -0
  729. wandb/vendor/pygments/lexers/diff.py +165 -0
  730. wandb/vendor/pygments/lexers/dotnet.py +691 -0
  731. wandb/vendor/pygments/lexers/dsls.py +878 -0
  732. wandb/vendor/pygments/lexers/dylan.py +289 -0
  733. wandb/vendor/pygments/lexers/ecl.py +125 -0
  734. wandb/vendor/pygments/lexers/eiffel.py +65 -0
  735. wandb/vendor/pygments/lexers/elm.py +121 -0
  736. wandb/vendor/pygments/lexers/erlang.py +533 -0
  737. wandb/vendor/pygments/lexers/esoteric.py +277 -0
  738. wandb/vendor/pygments/lexers/ezhil.py +69 -0
  739. wandb/vendor/pygments/lexers/factor.py +344 -0
  740. wandb/vendor/pygments/lexers/fantom.py +250 -0
  741. wandb/vendor/pygments/lexers/felix.py +273 -0
  742. wandb/vendor/pygments/lexers/forth.py +177 -0
  743. wandb/vendor/pygments/lexers/fortran.py +205 -0
  744. wandb/vendor/pygments/lexers/foxpro.py +428 -0
  745. wandb/vendor/pygments/lexers/functional.py +21 -0
  746. wandb/vendor/pygments/lexers/go.py +101 -0
  747. wandb/vendor/pygments/lexers/grammar_notation.py +213 -0
  748. wandb/vendor/pygments/lexers/graph.py +80 -0
  749. wandb/vendor/pygments/lexers/graphics.py +553 -0
  750. wandb/vendor/pygments/lexers/haskell.py +843 -0
  751. wandb/vendor/pygments/lexers/haxe.py +936 -0
  752. wandb/vendor/pygments/lexers/hdl.py +382 -0
  753. wandb/vendor/pygments/lexers/hexdump.py +103 -0
  754. wandb/vendor/pygments/lexers/html.py +602 -0
  755. wandb/vendor/pygments/lexers/idl.py +270 -0
  756. wandb/vendor/pygments/lexers/igor.py +288 -0
  757. wandb/vendor/pygments/lexers/inferno.py +96 -0
  758. wandb/vendor/pygments/lexers/installers.py +322 -0
  759. wandb/vendor/pygments/lexers/int_fiction.py +1343 -0
  760. wandb/vendor/pygments/lexers/iolang.py +63 -0
  761. wandb/vendor/pygments/lexers/j.py +146 -0
  762. wandb/vendor/pygments/lexers/javascript.py +1525 -0
  763. wandb/vendor/pygments/lexers/julia.py +333 -0
  764. wandb/vendor/pygments/lexers/jvm.py +1573 -0
  765. wandb/vendor/pygments/lexers/lisp.py +2621 -0
  766. wandb/vendor/pygments/lexers/make.py +202 -0
  767. wandb/vendor/pygments/lexers/markup.py +595 -0
  768. wandb/vendor/pygments/lexers/math.py +21 -0
  769. wandb/vendor/pygments/lexers/matlab.py +663 -0
  770. wandb/vendor/pygments/lexers/ml.py +769 -0
  771. wandb/vendor/pygments/lexers/modeling.py +358 -0
  772. wandb/vendor/pygments/lexers/modula2.py +1561 -0
  773. wandb/vendor/pygments/lexers/monte.py +204 -0
  774. wandb/vendor/pygments/lexers/ncl.py +894 -0
  775. wandb/vendor/pygments/lexers/nimrod.py +159 -0
  776. wandb/vendor/pygments/lexers/nit.py +64 -0
  777. wandb/vendor/pygments/lexers/nix.py +136 -0
  778. wandb/vendor/pygments/lexers/oberon.py +105 -0
  779. wandb/vendor/pygments/lexers/objective.py +504 -0
  780. wandb/vendor/pygments/lexers/ooc.py +85 -0
  781. wandb/vendor/pygments/lexers/other.py +41 -0
  782. wandb/vendor/pygments/lexers/parasail.py +79 -0
  783. wandb/vendor/pygments/lexers/parsers.py +835 -0
  784. wandb/vendor/pygments/lexers/pascal.py +644 -0
  785. wandb/vendor/pygments/lexers/pawn.py +199 -0
  786. wandb/vendor/pygments/lexers/perl.py +620 -0
  787. wandb/vendor/pygments/lexers/php.py +267 -0
  788. wandb/vendor/pygments/lexers/praat.py +294 -0
  789. wandb/vendor/pygments/lexers/prolog.py +306 -0
  790. wandb/vendor/pygments/lexers/python.py +939 -0
  791. wandb/vendor/pygments/lexers/qvt.py +152 -0
  792. wandb/vendor/pygments/lexers/r.py +453 -0
  793. wandb/vendor/pygments/lexers/rdf.py +270 -0
  794. wandb/vendor/pygments/lexers/rebol.py +431 -0
  795. wandb/vendor/pygments/lexers/resource.py +85 -0
  796. wandb/vendor/pygments/lexers/rnc.py +67 -0
  797. wandb/vendor/pygments/lexers/roboconf.py +82 -0
  798. wandb/vendor/pygments/lexers/robotframework.py +560 -0
  799. wandb/vendor/pygments/lexers/ruby.py +519 -0
  800. wandb/vendor/pygments/lexers/rust.py +220 -0
  801. wandb/vendor/pygments/lexers/sas.py +228 -0
  802. wandb/vendor/pygments/lexers/scripting.py +1222 -0
  803. wandb/vendor/pygments/lexers/shell.py +794 -0
  804. wandb/vendor/pygments/lexers/smalltalk.py +195 -0
  805. wandb/vendor/pygments/lexers/smv.py +79 -0
  806. wandb/vendor/pygments/lexers/snobol.py +83 -0
  807. wandb/vendor/pygments/lexers/special.py +103 -0
  808. wandb/vendor/pygments/lexers/sql.py +681 -0
  809. wandb/vendor/pygments/lexers/stata.py +108 -0
  810. wandb/vendor/pygments/lexers/supercollider.py +90 -0
  811. wandb/vendor/pygments/lexers/tcl.py +145 -0
  812. wandb/vendor/pygments/lexers/templates.py +2283 -0
  813. wandb/vendor/pygments/lexers/testing.py +207 -0
  814. wandb/vendor/pygments/lexers/text.py +25 -0
  815. wandb/vendor/pygments/lexers/textedit.py +169 -0
  816. wandb/vendor/pygments/lexers/textfmts.py +297 -0
  817. wandb/vendor/pygments/lexers/theorem.py +458 -0
  818. wandb/vendor/pygments/lexers/trafficscript.py +54 -0
  819. wandb/vendor/pygments/lexers/typoscript.py +226 -0
  820. wandb/vendor/pygments/lexers/urbi.py +133 -0
  821. wandb/vendor/pygments/lexers/varnish.py +190 -0
  822. wandb/vendor/pygments/lexers/verification.py +111 -0
  823. wandb/vendor/pygments/lexers/web.py +24 -0
  824. wandb/vendor/pygments/lexers/webmisc.py +988 -0
  825. wandb/vendor/pygments/lexers/whiley.py +116 -0
  826. wandb/vendor/pygments/lexers/x10.py +69 -0
  827. wandb/vendor/pygments/modeline.py +44 -0
  828. wandb/vendor/pygments/plugin.py +68 -0
  829. wandb/vendor/pygments/regexopt.py +92 -0
  830. wandb/vendor/pygments/scanner.py +105 -0
  831. wandb/vendor/pygments/sphinxext.py +158 -0
  832. wandb/vendor/pygments/style.py +155 -0
  833. wandb/vendor/pygments/styles/__init__.py +80 -0
  834. wandb/vendor/pygments/styles/abap.py +29 -0
  835. wandb/vendor/pygments/styles/algol.py +63 -0
  836. wandb/vendor/pygments/styles/algol_nu.py +63 -0
  837. wandb/vendor/pygments/styles/arduino.py +98 -0
  838. wandb/vendor/pygments/styles/autumn.py +65 -0
  839. wandb/vendor/pygments/styles/borland.py +51 -0
  840. wandb/vendor/pygments/styles/bw.py +49 -0
  841. wandb/vendor/pygments/styles/colorful.py +81 -0
  842. wandb/vendor/pygments/styles/default.py +73 -0
  843. wandb/vendor/pygments/styles/emacs.py +72 -0
  844. wandb/vendor/pygments/styles/friendly.py +72 -0
  845. wandb/vendor/pygments/styles/fruity.py +42 -0
  846. wandb/vendor/pygments/styles/igor.py +29 -0
  847. wandb/vendor/pygments/styles/lovelace.py +97 -0
  848. wandb/vendor/pygments/styles/manni.py +75 -0
  849. wandb/vendor/pygments/styles/monokai.py +106 -0
  850. wandb/vendor/pygments/styles/murphy.py +80 -0
  851. wandb/vendor/pygments/styles/native.py +65 -0
  852. wandb/vendor/pygments/styles/paraiso_dark.py +125 -0
  853. wandb/vendor/pygments/styles/paraiso_light.py +125 -0
  854. wandb/vendor/pygments/styles/pastie.py +75 -0
  855. wandb/vendor/pygments/styles/perldoc.py +69 -0
  856. wandb/vendor/pygments/styles/rainbow_dash.py +89 -0
  857. wandb/vendor/pygments/styles/rrt.py +33 -0
  858. wandb/vendor/pygments/styles/sas.py +44 -0
  859. wandb/vendor/pygments/styles/stata.py +40 -0
  860. wandb/vendor/pygments/styles/tango.py +141 -0
  861. wandb/vendor/pygments/styles/trac.py +63 -0
  862. wandb/vendor/pygments/styles/vim.py +63 -0
  863. wandb/vendor/pygments/styles/vs.py +38 -0
  864. wandb/vendor/pygments/styles/xcode.py +51 -0
  865. wandb/vendor/pygments/token.py +213 -0
  866. wandb/vendor/pygments/unistring.py +217 -0
  867. wandb/vendor/pygments/util.py +388 -0
  868. wandb/vendor/watchdog_0_9_0/wandb_watchdog/__init__.py +17 -0
  869. wandb/vendor/watchdog_0_9_0/wandb_watchdog/events.py +615 -0
  870. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/__init__.py +98 -0
  871. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/api.py +369 -0
  872. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/fsevents.py +172 -0
  873. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/fsevents2.py +239 -0
  874. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/inotify.py +218 -0
  875. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/inotify_buffer.py +81 -0
  876. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/inotify_c.py +575 -0
  877. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/kqueue.py +730 -0
  878. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/polling.py +145 -0
  879. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/read_directory_changes.py +133 -0
  880. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/winapi.py +348 -0
  881. wandb/vendor/watchdog_0_9_0/wandb_watchdog/patterns.py +265 -0
  882. wandb/vendor/watchdog_0_9_0/wandb_watchdog/tricks/__init__.py +174 -0
  883. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/__init__.py +151 -0
  884. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/bricks.py +249 -0
  885. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/compat.py +29 -0
  886. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/decorators.py +198 -0
  887. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/delayed_queue.py +88 -0
  888. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/dirsnapshot.py +293 -0
  889. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/echo.py +157 -0
  890. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/event_backport.py +41 -0
  891. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/importlib2.py +40 -0
  892. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/platform.py +57 -0
  893. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/unicode_paths.py +64 -0
  894. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/win32stat.py +123 -0
  895. wandb/vendor/watchdog_0_9_0/wandb_watchdog/version.py +28 -0
  896. wandb/vendor/watchdog_0_9_0/wandb_watchdog/watchmedo.py +577 -0
  897. wandb/wandb_agent.py +580 -0
  898. wandb/wandb_controller.py +719 -0
  899. wandb/wandb_run.py +8 -0
  900. wandb-0.21.2.dist-info/METADATA +223 -0
  901. wandb-0.21.2.dist-info/RECORD +904 -0
  902. wandb-0.21.2.dist-info/WHEEL +4 -0
  903. wandb-0.21.2.dist-info/entry_points.txt +3 -0
  904. wandb-0.21.2.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1086 @@
1
+ """keras init."""
2
+
3
+ import logging
4
+ import operator
5
+ import os
6
+ import shutil
7
+ import sys
8
+ from itertools import chain
9
+
10
+ import numpy as np
11
+ import tensorflow as tf
12
+ import tensorflow.keras.backend as K # noqa: N812
13
+
14
+ import wandb
15
+ from wandb.proto.wandb_deprecated import Deprecated
16
+ from wandb.sdk.integration_utils.data_logging import ValidationDataLogger
17
+ from wandb.sdk.lib.deprecate import deprecate
18
+ from wandb.util import add_import_hook
19
+
20
+
21
+ def _check_keras_version():
22
+ from keras import __version__ as keras_version
23
+ from packaging.version import parse
24
+
25
+ if parse(keras_version) < parse("2.4.0"):
26
+ wandb.termwarn(
27
+ f"Keras version {keras_version} is not fully supported. Required keras >= 2.4.0"
28
+ )
29
+
30
+
31
+ def _can_compute_flops() -> bool:
32
+ """FLOPS computation is restricted to TF 2.x as it requires tf.compat.v1."""
33
+ from packaging.version import parse
34
+
35
+ if parse(tf.__version__) >= parse("2.0.0"):
36
+ return True
37
+
38
+ return False
39
+
40
+
41
+ if "keras" in sys.modules:
42
+ _check_keras_version()
43
+ else:
44
+ add_import_hook("keras", _check_keras_version)
45
+
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+
50
+ def is_dataset(data):
51
+ dataset_ops = wandb.util.get_module("tensorflow.python.data.ops.dataset_ops")
52
+ if dataset_ops and hasattr(dataset_ops, "DatasetV2"):
53
+ dataset_types = (dataset_ops.DatasetV2,)
54
+ if hasattr(dataset_ops, "DatasetV1"):
55
+ dataset_types = dataset_types + (dataset_ops.DatasetV1,)
56
+ return isinstance(data, dataset_types)
57
+ else:
58
+ return False
59
+
60
+
61
+ def is_generator_like(data):
62
+ # Checks if data is a generator, Sequence, or Iterator.
63
+
64
+ types = (tf.keras.utils.Sequence,)
65
+ iterator_ops = wandb.util.get_module("tensorflow.python.data.ops.iterator_ops")
66
+ if iterator_ops:
67
+ types = types + (iterator_ops.Iterator,)
68
+ # EagerIterator was in tensorflow < 2
69
+ if hasattr(iterator_ops, "EagerIterator"):
70
+ types = types + (iterator_ops.EagerIterator,)
71
+ elif hasattr(iterator_ops, "IteratorV2"):
72
+ types = types + (iterator_ops.IteratorV2,)
73
+ return hasattr(data, "next") or hasattr(data, "__next__") or isinstance(data, types)
74
+
75
+
76
+ def patch_tf_keras(): # noqa: C901
77
+ from packaging.version import parse
78
+ from tensorflow.python.eager import context
79
+
80
+ if parse("2.6.0") <= parse(tf.__version__) < parse("2.13.0"):
81
+ keras_engine = "keras.engine"
82
+ try:
83
+ from keras.engine import training
84
+ from keras.engine import training_arrays_v1 as training_arrays
85
+ from keras.engine import training_generator_v1 as training_generator
86
+ except (ImportError, AttributeError):
87
+ wandb.termerror("Unable to patch Tensorflow/Keras")
88
+ logger.exception("exception while trying to patch_tf_keras")
89
+ return
90
+ else:
91
+ keras_engine = "tensorflow.python.keras.engine"
92
+
93
+ from tensorflow.python.keras.engine import training
94
+
95
+ try:
96
+ from tensorflow.python.keras.engine import (
97
+ training_arrays_v1 as training_arrays,
98
+ )
99
+ from tensorflow.python.keras.engine import (
100
+ training_generator_v1 as training_generator,
101
+ )
102
+ except (ImportError, AttributeError):
103
+ try:
104
+ from tensorflow.python.keras.engine import (
105
+ training_arrays,
106
+ training_generator,
107
+ )
108
+ except (ImportError, AttributeError):
109
+ wandb.termerror("Unable to patch Tensorflow/Keras")
110
+ logger.exception("exception while trying to patch_tf_keras")
111
+ return
112
+
113
+ # Tensorflow 2.1
114
+ training_v2_1 = wandb.util.get_module("tensorflow.python.keras.engine.training_v2")
115
+ # Tensorflow 2.2
116
+ training_v2_2 = wandb.util.get_module(f"{keras_engine}.training_v1")
117
+
118
+ if training_v2_1:
119
+ old_v2 = training_v2_1.Loop.fit
120
+ elif training_v2_2:
121
+ old_v2 = training.Model.fit
122
+
123
+ old_arrays = training_arrays.fit_loop
124
+ old_generator = training_generator.fit_generator
125
+
126
+ def set_wandb_attrs(cbk, val_data):
127
+ if isinstance(cbk, WandbCallback):
128
+ if is_generator_like(val_data):
129
+ cbk.generator = val_data
130
+ elif is_dataset(val_data):
131
+ if context.executing_eagerly():
132
+ cbk.generator = iter(val_data)
133
+ else:
134
+ wandb.termwarn(
135
+ "Found a validation dataset in graph mode, can't patch Keras."
136
+ )
137
+ elif isinstance(val_data, tuple) and isinstance(val_data[0], tf.Tensor):
138
+ # Graph mode dataset generator
139
+ def gen():
140
+ while True:
141
+ yield K.get_session().run(val_data)
142
+
143
+ cbk.generator = gen()
144
+ else:
145
+ cbk.validation_data = val_data
146
+
147
+ def new_arrays(*args, **kwargs):
148
+ cbks = kwargs.get("callbacks", [])
149
+ val_inputs = kwargs.get("val_inputs")
150
+ val_targets = kwargs.get("val_targets")
151
+ # TODO: these could be generators, why index 0?
152
+ if val_inputs and val_targets:
153
+ for cbk in cbks:
154
+ set_wandb_attrs(cbk, (val_inputs[0], val_targets[0]))
155
+ return old_arrays(*args, **kwargs)
156
+
157
+ def new_generator(*args, **kwargs):
158
+ cbks = kwargs.get("callbacks", [])
159
+ val_data = kwargs.get("validation_data")
160
+ if val_data:
161
+ for cbk in cbks:
162
+ set_wandb_attrs(cbk, val_data)
163
+ return old_generator(*args, **kwargs)
164
+
165
+ def new_v2(*args, **kwargs):
166
+ cbks = kwargs.get("callbacks", [])
167
+ val_data = kwargs.get("validation_data")
168
+ if val_data:
169
+ for cbk in cbks:
170
+ set_wandb_attrs(cbk, val_data)
171
+ return old_v2(*args, **kwargs)
172
+
173
+ training_arrays.orig_fit_loop = old_arrays
174
+ training_arrays.fit_loop = new_arrays
175
+ training_generator.orig_fit_generator = old_generator
176
+ training_generator.fit_generator = new_generator
177
+ wandb.patched["keras"].append([f"{keras_engine}.training_arrays", "fit_loop"])
178
+ wandb.patched["keras"].append(
179
+ [f"{keras_engine}.training_generator", "fit_generator"]
180
+ )
181
+
182
+ if training_v2_1:
183
+ training_v2_1.Loop.fit = new_v2
184
+ wandb.patched["keras"].append(
185
+ ["tensorflow.python.keras.engine.training_v2.Loop", "fit"]
186
+ )
187
+ elif training_v2_2:
188
+ training.Model.fit = new_v2
189
+ wandb.patched["keras"].append([f"{keras_engine}.training.Model", "fit"])
190
+
191
+
192
+ def _array_has_dtype(array):
193
+ return hasattr(array, "dtype")
194
+
195
+
196
+ def _update_if_numeric(metrics, key, values):
197
+ if not _array_has_dtype(values):
198
+ _warn_not_logging(key)
199
+ return
200
+
201
+ if not is_numeric_array(values):
202
+ _warn_not_logging_non_numeric(key)
203
+ return
204
+
205
+ metrics[key] = wandb.Histogram(values)
206
+
207
+
208
+ def is_numeric_array(array):
209
+ return np.issubdtype(array.dtype, np.number)
210
+
211
+
212
+ def _warn_not_logging_non_numeric(name):
213
+ wandb.termwarn(
214
+ f"Non-numeric values found in layer: {name}, not logging this layer",
215
+ repeat=False,
216
+ )
217
+
218
+
219
+ def _warn_not_logging(name):
220
+ wandb.termwarn(
221
+ f"Layer {name} has undetermined datatype not logging this layer",
222
+ repeat=False,
223
+ )
224
+
225
+
226
+ tf_logger = tf.get_logger()
227
+
228
+ patch_tf_keras()
229
+
230
+
231
+ ### For gradient logging ###
232
+
233
+
234
+ def _get_custom_optimizer_parent_class():
235
+ from packaging.version import parse
236
+
237
+ if parse(tf.__version__) >= parse("2.9.0"):
238
+ custom_optimizer_parent_class = tf.keras.optimizers.legacy.Optimizer
239
+ else:
240
+ custom_optimizer_parent_class = tf.keras.optimizers.Optimizer
241
+
242
+ return custom_optimizer_parent_class
243
+
244
+
245
+ _custom_optimizer_parent_class = _get_custom_optimizer_parent_class()
246
+
247
+
248
+ class _CustomOptimizer(_custom_optimizer_parent_class):
249
+ def __init__(self):
250
+ super().__init__(name="CustomOptimizer")
251
+ self._resource_apply_dense = tf.function(self._resource_apply_dense)
252
+ self._resource_apply_sparse = tf.function(self._resource_apply_sparse)
253
+
254
+ def _resource_apply_dense(self, grad, var):
255
+ var.assign(grad)
256
+
257
+ # this needs to be implemented to prevent a NotImplementedError when
258
+ # using Lookup layers.
259
+ def _resource_apply_sparse(self, grad, var, indices):
260
+ pass
261
+
262
+ def get_config(self):
263
+ return super().get_config()
264
+
265
+
266
+ class _GradAccumulatorCallback(tf.keras.callbacks.Callback):
267
+ """Accumulates gradients during a fit() call when used in conjunction with the CustomOptimizer above."""
268
+
269
+ def set_model(self, model):
270
+ super().set_model(model)
271
+ self.og_weights = model.get_weights()
272
+ self.grads = [np.zeros(tuple(w.shape)) for w in model.trainable_weights]
273
+
274
+ def on_batch_end(self, batch, logs=None):
275
+ for g, w in zip(self.grads, self.model.trainable_weights):
276
+ g += w.numpy()
277
+ self.model.set_weights(self.og_weights)
278
+
279
+ def get_grads(self):
280
+ return [g.copy() for g in self.grads]
281
+
282
+
283
+ ###
284
+
285
+
286
+ class WandbCallback(tf.keras.callbacks.Callback):
287
+ """`WandbCallback` automatically integrates keras with wandb.
288
+
289
+ Example:
290
+ ```python
291
+ model.fit(
292
+ X_train,
293
+ y_train,
294
+ validation_data=(X_test, y_test),
295
+ callbacks=[WandbCallback()],
296
+ )
297
+ ```
298
+
299
+ `WandbCallback` will automatically log history data from any
300
+ metrics collected by keras: loss and anything passed into `keras_model.compile()`.
301
+
302
+ `WandbCallback` will set summary metrics for the run associated with the "best" training
303
+ step, where "best" is defined by the `monitor` and `mode` attributes. This defaults
304
+ to the epoch with the minimum `val_loss`. `WandbCallback` will by default save the model
305
+ associated with the best `epoch`.
306
+
307
+ `WandbCallback` can optionally log gradient and parameter histograms.
308
+
309
+ `WandbCallback` can optionally save training and validation data for wandb to visualize.
310
+
311
+ Args:
312
+ monitor: (str) name of metric to monitor. Defaults to `val_loss`.
313
+ mode: (str) one of {`auto`, `min`, `max`}.
314
+ `min` - save model when monitor is minimized
315
+ `max` - save model when monitor is maximized
316
+ `auto` - try to guess when to save the model (default).
317
+ save_model:
318
+ True - save a model when monitor beats all previous epochs
319
+ False - don't save models
320
+ save_graph: (boolean) if True save model graph to wandb (default to True).
321
+ save_weights_only: (boolean) if True, then only the model's weights will be
322
+ saved (`model.save_weights(filepath)`), else the full model
323
+ is saved (`model.save(filepath)`).
324
+ log_weights: (boolean) if True save histograms of the model's layer's weights.
325
+ log_gradients: (boolean) if True log histograms of the training gradients
326
+ training_data: (tuple) Same format `(X,y)` as passed to `model.fit`. This is needed
327
+ for calculating gradients - this is mandatory if `log_gradients` is `True`.
328
+ validation_data: (tuple) Same format `(X,y)` as passed to `model.fit`. A set of data
329
+ for wandb to visualize. If this is set, every epoch, wandb will
330
+ make a small number of predictions and save the results for later visualization. In case
331
+ you are working with image data, please also set `input_type` and `output_type` in order
332
+ to log correctly.
333
+ generator: (generator) a generator that returns validation data for wandb to visualize. This
334
+ generator should return tuples `(X,y)`. Either `validate_data` or generator should
335
+ be set for wandb to visualize specific data examples. In case you are working with image data,
336
+ please also set `input_type` and `output_type` in order to log correctly.
337
+ validation_steps: (int) if `validation_data` is a generator, how many
338
+ steps to run the generator for the full validation set.
339
+ labels: (list) If you are visualizing your data with wandb this list of labels
340
+ will convert numeric output to understandable string if you are building a
341
+ multiclass classifier. If you are making a binary classifier you can pass in
342
+ a list of two labels ["label for false", "label for true"]. If `validate_data`
343
+ and generator are both false, this won't do anything.
344
+ predictions: (int) the number of predictions to make for visualization each epoch, max
345
+ is 100.
346
+ input_type: (string) type of the model input to help visualization. can be one of:
347
+ (`image`, `images`, `segmentation_mask`, `auto`).
348
+ output_type: (string) type of the model output to help visualization. can be one of:
349
+ (`image`, `images`, `segmentation_mask`, `label`).
350
+ log_evaluation: (boolean) if True, save a Table containing validation data and the
351
+ model's predictions at each epoch. See `validation_indexes`,
352
+ `validation_row_processor`, and `output_row_processor` for additional details.
353
+ class_colors: ([float, float, float]) if the input or output is a segmentation mask,
354
+ an array containing an rgb tuple (range 0-1) for each class.
355
+ log_batch_frequency: (integer) if None, callback will log every epoch.
356
+ If set to integer, callback will log training metrics every `log_batch_frequency`
357
+ batches.
358
+ log_best_prefix: (string) if None, no extra summary metrics will be saved.
359
+ If set to a string, the monitored metric and epoch will be prepended with this value
360
+ and stored as summary metrics.
361
+ validation_indexes: ([wandb.data_types._TableLinkMixin]) an ordered list of index keys to associate
362
+ with each validation example. If log_evaluation is True and `validation_indexes` is provided,
363
+ then a Table of validation data will not be created and instead each prediction will
364
+ be associated with the row represented by the `TableLinkMixin`. The most common way to obtain
365
+ such keys are is use `Table.get_index()` which will return a list of row keys.
366
+ validation_row_processor: (Callable) a function to apply to the validation data, commonly used to visualize the data.
367
+ The function will receive an `ndx` (int) and a `row` (dict). If your model has a single input,
368
+ then `row["input"]` will be the input data for the row. Else, it will be keyed based on the name of the
369
+ input slot. If your fit function takes a single target, then `row["target"]` will be the target data for the row. Else,
370
+ it will be keyed based on the name of the output slots. For example, if your input data is a single ndarray,
371
+ but you wish to visualize the data as an Image, then you can provide `lambda ndx, row: {"img": wandb.Image(row["input"])}`
372
+ as the processor. Ignored if log_evaluation is False or `validation_indexes` are present.
373
+ output_row_processor: (Callable) same as `validation_row_processor`, but applied to the model's output. `row["output"]` will contain
374
+ the results of the model output.
375
+ infer_missing_processors: (bool) Determines if `validation_row_processor` and `output_row_processor`
376
+ should be inferred if missing. Defaults to True. If `labels` are provided, we will attempt to infer classification-type
377
+ processors where appropriate.
378
+ log_evaluation_frequency: (int) Determines the frequency which evaluation results will be logged. Default 0 (only at the end of training).
379
+ Set to 1 to log every epoch, 2 to log every other epoch, and so on. Has no effect when log_evaluation is False.
380
+ compute_flops: (bool) Compute the FLOPs of your Keras Sequential or Functional model in GigaFLOPs unit.
381
+ """
382
+
383
+ def __init__(
384
+ self,
385
+ monitor="val_loss",
386
+ verbose=0,
387
+ mode="auto",
388
+ save_weights_only=False,
389
+ log_weights=False,
390
+ log_gradients=False,
391
+ save_model=True,
392
+ training_data=None,
393
+ validation_data=None,
394
+ labels=None,
395
+ predictions=36,
396
+ generator=None,
397
+ input_type=None,
398
+ output_type=None,
399
+ log_evaluation=False,
400
+ validation_steps=None,
401
+ class_colors=None,
402
+ log_batch_frequency=None,
403
+ log_best_prefix="best_",
404
+ save_graph=True,
405
+ validation_indexes=None,
406
+ validation_row_processor=None,
407
+ prediction_row_processor=None,
408
+ infer_missing_processors=True,
409
+ log_evaluation_frequency=0,
410
+ compute_flops=False,
411
+ **kwargs,
412
+ ):
413
+ if wandb.run is None:
414
+ raise wandb.Error("You must call wandb.init() before WandbCallback()")
415
+
416
+ deprecate(
417
+ field_name=Deprecated.keras_callback,
418
+ warning_message=(
419
+ "WandbCallback is deprecated and will be removed in a future release. "
420
+ "Please use the WandbMetricsLogger, WandbModelCheckpoint, and WandbEvalCallback "
421
+ "callbacks instead. "
422
+ "See https://docs.wandb.ai/guides/integrations/keras for more information."
423
+ ),
424
+ )
425
+
426
+ with wandb.wandb_lib.telemetry.context(run=wandb.run) as tel:
427
+ tel.feature.keras = True
428
+ self.validation_data = None
429
+ # This is kept around for legacy reasons
430
+ if validation_data is not None:
431
+ if is_generator_like(validation_data):
432
+ generator = validation_data
433
+ else:
434
+ self.validation_data = validation_data
435
+ if labels is None:
436
+ labels = []
437
+ self.labels = labels
438
+ self.predictions = min(predictions, 100)
439
+
440
+ self.monitor = monitor
441
+ self.verbose = verbose
442
+ self.save_weights_only = save_weights_only
443
+ self.save_graph = save_graph
444
+
445
+ wandb.save("model-best.h5")
446
+ self.filepath = os.path.join(wandb.run.dir, "model-best.h5")
447
+ self.save_model = save_model
448
+ if save_model:
449
+ deprecate(
450
+ field_name=Deprecated.keras_callback__save_model,
451
+ warning_message=(
452
+ "The save_model argument by default saves the model in the HDF5 format that cannot save "
453
+ "custom objects like subclassed models and custom layers. This behavior will be deprecated "
454
+ "in a future release in favor of the SavedModel format. Meanwhile, the HDF5 model is saved "
455
+ "as W&B files and the SavedModel as W&B Artifacts."
456
+ ),
457
+ )
458
+
459
+ self.save_model_as_artifact = True
460
+ self.log_weights = log_weights
461
+ self.log_gradients = log_gradients
462
+ self.training_data = training_data
463
+ self.generator = generator
464
+ self._graph_rendered = False
465
+
466
+ data_type = kwargs.get("data_type", None)
467
+ if data_type is not None:
468
+ deprecate(
469
+ field_name=Deprecated.keras_callback__data_type,
470
+ warning_message=(
471
+ "The data_type argument of wandb.keras.WandbCallback is deprecated "
472
+ "and will be removed in a future release. Please use input_type instead.\n"
473
+ "Setting input_type = data_type."
474
+ ),
475
+ )
476
+ input_type = data_type
477
+ self.input_type = input_type
478
+ self.output_type = output_type
479
+ self.log_evaluation = log_evaluation
480
+ self.validation_steps = validation_steps
481
+ self.class_colors = np.array(class_colors) if class_colors is not None else None
482
+ self.log_batch_frequency = log_batch_frequency
483
+ self.log_best_prefix = log_best_prefix
484
+ self.compute_flops = compute_flops
485
+
486
+ self._prediction_batch_size = None
487
+
488
+ if self.log_gradients:
489
+ if int(tf.__version__.split(".")[0]) < 2:
490
+ raise Exception("Gradient logging requires tensorflow 2.0 or higher.")
491
+ if self.training_data is None:
492
+ raise ValueError(
493
+ "training_data argument is required for gradient logging."
494
+ )
495
+ if isinstance(self.training_data, (list, tuple)):
496
+ if len(self.training_data) != 2:
497
+ raise ValueError("training data must be a tuple of length two")
498
+ self._training_data_x, self._training_data_y = self.training_data
499
+ else:
500
+ self._training_data_x = (
501
+ self.training_data
502
+ ) # generator, tf.data.Dataset etc
503
+ self._training_data_y = None
504
+
505
+ # From Keras
506
+ if mode not in ["auto", "min", "max"]:
507
+ wandb.termwarn(
508
+ f"WandbCallback mode {mode} is unknown, fallback to auto mode."
509
+ )
510
+ mode = "auto"
511
+
512
+ if mode == "min":
513
+ self.monitor_op = operator.lt
514
+ self.best = float("inf")
515
+ elif mode == "max":
516
+ self.monitor_op = operator.gt
517
+ self.best = float("-inf")
518
+ else:
519
+ if "acc" in self.monitor or self.monitor.startswith("fmeasure"):
520
+ self.monitor_op = operator.gt
521
+ self.best = float("-inf")
522
+ else:
523
+ self.monitor_op = operator.lt
524
+ self.best = float("inf")
525
+ # Get the previous best metric for resumed runs
526
+ previous_best = wandb.run.summary.get(f"{self.log_best_prefix}{self.monitor}")
527
+ if previous_best is not None:
528
+ self.best = previous_best
529
+
530
+ self._validation_data_logger = None
531
+ self._validation_indexes = validation_indexes
532
+ self._validation_row_processor = validation_row_processor
533
+ self._prediction_row_processor = prediction_row_processor
534
+ self._infer_missing_processors = infer_missing_processors
535
+ self._log_evaluation_frequency = log_evaluation_frequency
536
+ self._model_trained_since_last_eval = False
537
+
538
+ def _build_grad_accumulator_model(self):
539
+ inputs = self.model.inputs
540
+ outputs = self.model(inputs)
541
+ grad_acc_model = tf.keras.models.Model(inputs, outputs)
542
+ grad_acc_model.compile(loss=self.model.loss, optimizer=_CustomOptimizer())
543
+
544
+ # make sure magic doesn't think this is a user model
545
+ grad_acc_model._wandb_internal_model = True
546
+
547
+ self._grad_accumulator_model = grad_acc_model
548
+ self._grad_accumulator_callback = _GradAccumulatorCallback()
549
+
550
+ def _implements_train_batch_hooks(self):
551
+ return self.log_batch_frequency is not None
552
+
553
+ def _implements_test_batch_hooks(self):
554
+ return self.log_batch_frequency is not None
555
+
556
+ def _implements_predict_batch_hooks(self):
557
+ return self.log_batch_frequency is not None
558
+
559
+ def set_params(self, params):
560
+ self.params = params
561
+
562
+ def set_model(self, model):
563
+ super().set_model(model)
564
+ if self.input_type == "auto" and len(model.inputs) == 1:
565
+ self.input_type = wandb.util.guess_data_type(
566
+ model.inputs[0].shape, risky=True
567
+ )
568
+ if self.input_type and self.output_type is None and len(model.outputs) == 1:
569
+ self.output_type = wandb.util.guess_data_type(model.outputs[0].shape)
570
+ if self.log_gradients:
571
+ self._build_grad_accumulator_model()
572
+
573
+ def _attempt_evaluation_log(self, commit=True):
574
+ if self.log_evaluation and self._validation_data_logger:
575
+ try:
576
+ if not self.model:
577
+ wandb.termwarn("WandbCallback unable to read model from trainer")
578
+ else:
579
+ self._validation_data_logger.log_predictions(
580
+ predictions=self._validation_data_logger.make_predictions(
581
+ self.model.predict
582
+ ),
583
+ commit=commit,
584
+ )
585
+ self._model_trained_since_last_eval = False
586
+ except Exception as e:
587
+ wandb.termwarn("Error during prediction logging for epoch: " + str(e))
588
+
589
+ def on_epoch_end(self, epoch, logs=None):
590
+ if logs is None:
591
+ logs = {}
592
+ if self.log_weights:
593
+ wandb.log(self._log_weights(), commit=False)
594
+
595
+ if self.log_gradients:
596
+ wandb.log(self._log_gradients(), commit=False)
597
+
598
+ if self.input_type in (
599
+ "image",
600
+ "images",
601
+ "segmentation_mask",
602
+ ) or self.output_type in ("image", "images", "segmentation_mask"):
603
+ if self.generator:
604
+ self.validation_data = next(self.generator)
605
+ if self.validation_data is None:
606
+ wandb.termwarn(
607
+ "No validation_data set, pass a generator to the callback."
608
+ )
609
+ elif self.validation_data and len(self.validation_data) > 0:
610
+ wandb.log(
611
+ {"examples": self._log_images(num_images=self.predictions)},
612
+ commit=False,
613
+ )
614
+
615
+ if (
616
+ self._log_evaluation_frequency > 0
617
+ and epoch % self._log_evaluation_frequency == 0
618
+ ):
619
+ self._attempt_evaluation_log(commit=False)
620
+
621
+ wandb.log({"epoch": epoch}, commit=False)
622
+ wandb.log(logs, commit=True)
623
+
624
+ self.current = logs.get(self.monitor)
625
+ if self.current and self.monitor_op(self.current, self.best):
626
+ if self.log_best_prefix:
627
+ wandb.run.summary[f"{self.log_best_prefix}{self.monitor}"] = (
628
+ self.current
629
+ )
630
+ wandb.run.summary["{}{}".format(self.log_best_prefix, "epoch")] = epoch
631
+ if self.verbose and not self.save_model:
632
+ wandb.termlog(
633
+ f"Epoch {epoch:05d}: {self.monitor} improved from {self.best:.5f} to {self.current:.5f}"
634
+ )
635
+ if self.save_model:
636
+ self._save_model(epoch)
637
+
638
+ if self.save_model and self.save_model_as_artifact:
639
+ self._save_model_as_artifact(epoch)
640
+
641
+ self.best = self.current
642
+
643
+ # This is what keras used pre tensorflow.keras
644
+ def on_batch_begin(self, batch, logs=None):
645
+ pass
646
+
647
+ # This is what keras used pre tensorflow.keras
648
+ def on_batch_end(self, batch, logs=None):
649
+ if self.save_graph and not self._graph_rendered:
650
+ # Couldn't do this in train_begin because keras may still not be built
651
+ wandb.run.summary["graph"] = wandb.Graph.from_keras(self.model)
652
+ self._graph_rendered = True
653
+
654
+ if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
655
+ wandb.log(logs, commit=True)
656
+
657
+ def on_train_batch_begin(self, batch, logs=None):
658
+ self._model_trained_since_last_eval = True
659
+
660
+ def on_train_batch_end(self, batch, logs=None):
661
+ if self.save_graph and not self._graph_rendered:
662
+ # Couldn't do this in train_begin because keras may still not be built
663
+ wandb.run.summary["graph"] = wandb.Graph.from_keras(self.model)
664
+ self._graph_rendered = True
665
+
666
+ if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
667
+ wandb.log(logs, commit=True)
668
+
669
+ def on_test_begin(self, logs=None):
670
+ pass
671
+
672
+ def on_test_end(self, logs=None):
673
+ pass
674
+
675
+ def on_test_batch_begin(self, batch, logs=None):
676
+ pass
677
+
678
+ def on_test_batch_end(self, batch, logs=None):
679
+ pass
680
+
681
+ def on_train_begin(self, logs=None):
682
+ if self.log_evaluation:
683
+ try:
684
+ validation_data = None
685
+ if self.validation_data:
686
+ validation_data = self.validation_data
687
+ elif self.generator:
688
+ if not self.validation_steps:
689
+ wandb.termwarn(
690
+ "WandbCallback is unable to log validation data. "
691
+ "When using a generator for validation_data, you must pass validation_steps"
692
+ )
693
+ else:
694
+ x = None
695
+ y_true = None
696
+ for _ in range(self.validation_steps):
697
+ bx, by_true = next(self.generator)
698
+ if x is None:
699
+ x, y_true = bx, by_true
700
+ else:
701
+ x, y_true = (
702
+ np.append(x, bx, axis=0),
703
+ np.append(y_true, by_true, axis=0),
704
+ )
705
+ validation_data = (x, y_true)
706
+ else:
707
+ wandb.termwarn(
708
+ "WandbCallback is unable to read validation_data from trainer "
709
+ "and therefore cannot log validation data. Ensure Keras is properly "
710
+ "patched by calling `from wandb.keras import WandbCallback` at the top of your script."
711
+ )
712
+ if validation_data:
713
+ self._validation_data_logger = ValidationDataLogger(
714
+ inputs=validation_data[0],
715
+ targets=validation_data[1],
716
+ indexes=self._validation_indexes,
717
+ validation_row_processor=self._validation_row_processor,
718
+ prediction_row_processor=self._prediction_row_processor,
719
+ class_labels=self.labels,
720
+ infer_missing_processors=self._infer_missing_processors,
721
+ )
722
+ except Exception as e:
723
+ wandb.termwarn(
724
+ "Error initializing ValidationDataLogger in WandbCallback. "
725
+ f"Skipping logging validation data. Error: {str(e)}"
726
+ )
727
+
728
+ if self.compute_flops and _can_compute_flops():
729
+ try:
730
+ wandb.summary["GFLOPs"] = self.get_flops()
731
+ except Exception:
732
+ logger.exception("Error computing FLOPs")
733
+ wandb.termwarn("Unable to compute FLOPs for this model.")
734
+
735
+ def on_train_end(self, logs=None):
736
+ if self._model_trained_since_last_eval:
737
+ self._attempt_evaluation_log()
738
+
739
+ def on_predict_begin(self, logs=None):
740
+ pass
741
+
742
+ def on_predict_end(self, logs=None):
743
+ pass
744
+
745
+ def on_predict_batch_begin(self, batch, logs=None):
746
+ pass
747
+
748
+ def on_predict_batch_end(self, batch, logs=None):
749
+ pass
750
+
751
+ def _logits_to_captions(self, logits):
752
+ if logits[0].shape[-1] == 1:
753
+ # Scalar output from the model
754
+ # TODO: handle validation_y
755
+ if len(self.labels) == 2:
756
+ # User has named true and false
757
+ captions = [
758
+ self.labels[1] if logits[0] > 0.5 else self.labels[0]
759
+ for logit in logits
760
+ ]
761
+ else:
762
+ if len(self.labels) != 0:
763
+ wandb.termwarn(
764
+ "keras model is producing a single output, "
765
+ 'so labels should be a length two array: ["False label", "True label"].'
766
+ )
767
+ captions = [logit[0] for logit in logits]
768
+ else:
769
+ # Vector output from the model
770
+ # TODO: handle validation_y
771
+ labels = np.argmax(np.stack(logits), axis=1)
772
+
773
+ if len(self.labels) > 0:
774
+ # User has named the categories in self.labels
775
+ captions = []
776
+ for label in labels:
777
+ try:
778
+ captions.append(self.labels[label])
779
+ except IndexError:
780
+ captions.append(label)
781
+ else:
782
+ captions = labels
783
+ return captions
784
+
785
+ def _masks_to_pixels(self, masks):
786
+ # if its a binary mask, just return it as grayscale instead of picking the argmax
787
+ if len(masks[0].shape) == 2 or masks[0].shape[-1] == 1:
788
+ return masks
789
+ class_colors = (
790
+ self.class_colors
791
+ if self.class_colors is not None
792
+ else np.array(wandb.util.class_colors(masks[0].shape[2]))
793
+ )
794
+ imgs = class_colors[np.argmax(masks, axis=-1)]
795
+ return imgs
796
+
797
+ def _log_images(self, num_images=36):
798
+ validation_X = self.validation_data[0] # noqa: N806
799
+ validation_y = self.validation_data[1]
800
+
801
+ validation_length = len(validation_X)
802
+
803
+ if validation_length > num_images:
804
+ # pick some data at random
805
+ indices = np.random.choice(validation_length, num_images, replace=False)
806
+ else:
807
+ indices = range(validation_length)
808
+
809
+ test_data = []
810
+ test_output = []
811
+ for i in indices:
812
+ test_example = validation_X[i]
813
+ test_data.append(test_example)
814
+ test_output.append(validation_y[i])
815
+
816
+ if self.model.stateful:
817
+ predictions = self.model.predict(np.stack(test_data), batch_size=1)
818
+ self.model.reset_states()
819
+ else:
820
+ predictions = self.model.predict(
821
+ np.stack(test_data), batch_size=self._prediction_batch_size
822
+ )
823
+ if len(predictions) != len(test_data):
824
+ self._prediction_batch_size = 1
825
+ predictions = self.model.predict(
826
+ np.stack(test_data), batch_size=self._prediction_batch_size
827
+ )
828
+
829
+ if self.input_type == "label":
830
+ if self.output_type in ("image", "images", "segmentation_mask"):
831
+ captions = self._logits_to_captions(test_data)
832
+ output_image_data = (
833
+ self._masks_to_pixels(predictions)
834
+ if self.output_type == "segmentation_mask"
835
+ else predictions
836
+ )
837
+ reference_image_data = (
838
+ self._masks_to_pixels(test_output)
839
+ if self.output_type == "segmentation_mask"
840
+ else test_output
841
+ )
842
+ output_images = [
843
+ wandb.Image(data, caption=captions[i], grouping=2)
844
+ for i, data in enumerate(output_image_data)
845
+ ]
846
+ reference_images = [
847
+ wandb.Image(data, caption=captions[i])
848
+ for i, data in enumerate(reference_image_data)
849
+ ]
850
+ return list(chain.from_iterable(zip(output_images, reference_images)))
851
+ elif self.input_type in ("image", "images", "segmentation_mask"):
852
+ input_image_data = (
853
+ self._masks_to_pixels(test_data)
854
+ if self.input_type == "segmentation_mask"
855
+ else test_data
856
+ )
857
+ if self.output_type == "label":
858
+ # we just use the predicted label as the caption for now
859
+ captions = self._logits_to_captions(predictions)
860
+ return [
861
+ wandb.Image(data, caption=captions[i])
862
+ for i, data in enumerate(test_data)
863
+ ]
864
+ elif self.output_type in ("image", "images", "segmentation_mask"):
865
+ output_image_data = (
866
+ self._masks_to_pixels(predictions)
867
+ if self.output_type == "segmentation_mask"
868
+ else predictions
869
+ )
870
+ reference_image_data = (
871
+ self._masks_to_pixels(test_output)
872
+ if self.output_type == "segmentation_mask"
873
+ else test_output
874
+ )
875
+ input_images = [
876
+ wandb.Image(data, grouping=3)
877
+ for i, data in enumerate(input_image_data)
878
+ ]
879
+ output_images = [
880
+ wandb.Image(data) for i, data in enumerate(output_image_data)
881
+ ]
882
+ reference_images = [
883
+ wandb.Image(data) for i, data in enumerate(reference_image_data)
884
+ ]
885
+ return list(
886
+ chain.from_iterable(
887
+ zip(input_images, output_images, reference_images)
888
+ )
889
+ )
890
+ else:
891
+ # unknown output, just log the input images
892
+ return [wandb.Image(img) for img in test_data]
893
+ elif self.output_type in ("image", "images", "segmentation_mask"):
894
+ # unknown input, just log the predicted and reference outputs without captions
895
+ output_image_data = (
896
+ self._masks_to_pixels(predictions)
897
+ if self.output_type == "segmentation_mask"
898
+ else predictions
899
+ )
900
+ reference_image_data = (
901
+ self._masks_to_pixels(test_output)
902
+ if self.output_type == "segmentation_mask"
903
+ else test_output
904
+ )
905
+ output_images = [
906
+ wandb.Image(data, grouping=2)
907
+ for i, data in enumerate(output_image_data)
908
+ ]
909
+ reference_images = [
910
+ wandb.Image(data) for i, data in enumerate(reference_image_data)
911
+ ]
912
+ return list(chain.from_iterable(zip(output_images, reference_images)))
913
+
914
+ def _log_weights(self):
915
+ metrics = {}
916
+ for layer in self.model.layers:
917
+ weights = layer.get_weights()
918
+ if len(weights) == 1:
919
+ _update_if_numeric(
920
+ metrics, "parameters/" + layer.name + ".weights", weights[0]
921
+ )
922
+ elif len(weights) == 2:
923
+ _update_if_numeric(
924
+ metrics, "parameters/" + layer.name + ".weights", weights[0]
925
+ )
926
+ _update_if_numeric(
927
+ metrics, "parameters/" + layer.name + ".bias", weights[1]
928
+ )
929
+ return metrics
930
+
931
+ def _log_gradients(self):
932
+ # Suppress callback warnings grad accumulator
933
+ og_level = tf_logger.level
934
+ tf_logger.setLevel("ERROR")
935
+
936
+ self._grad_accumulator_model.fit(
937
+ self._training_data_x,
938
+ self._training_data_y,
939
+ verbose=0,
940
+ callbacks=[self._grad_accumulator_callback],
941
+ )
942
+ tf_logger.setLevel(og_level)
943
+ weights = self.model.trainable_weights
944
+ grads = self._grad_accumulator_callback.grads
945
+ metrics = {}
946
+ for weight, grad in zip(weights, grads):
947
+ metrics["gradients/" + weight.name.split(":")[0] + ".gradient"] = (
948
+ wandb.Histogram(grad)
949
+ )
950
+ return metrics
951
+
952
+ def _log_dataframe(self):
953
+ x, y_true, y_pred = None, None, None
954
+
955
+ if self.validation_data:
956
+ x, y_true = self.validation_data[0], self.validation_data[1]
957
+ y_pred = self.model.predict(x)
958
+ elif self.generator:
959
+ if not self.validation_steps:
960
+ wandb.termwarn(
961
+ "when using a generator for validation data with dataframes, "
962
+ "you must pass validation_steps. skipping"
963
+ )
964
+ return None
965
+
966
+ for _ in range(self.validation_steps):
967
+ bx, by_true = next(self.generator)
968
+ by_pred = self.model.predict(bx)
969
+ if x is None:
970
+ x, y_true, y_pred = bx, by_true, by_pred
971
+ else:
972
+ x, y_true, y_pred = (
973
+ np.append(x, bx, axis=0),
974
+ np.append(y_true, by_true, axis=0),
975
+ np.append(y_pred, by_pred, axis=0),
976
+ )
977
+
978
+ if self.input_type in ("image", "images") and self.output_type == "label":
979
+ return wandb.image_categorizer_dataframe(
980
+ x=x, y_true=y_true, y_pred=y_pred, labels=self.labels
981
+ )
982
+ elif (
983
+ self.input_type in ("image", "images")
984
+ and self.output_type == "segmentation_mask"
985
+ ):
986
+ return wandb.image_segmentation_dataframe(
987
+ x=x,
988
+ y_true=y_true,
989
+ y_pred=y_pred,
990
+ labels=self.labels,
991
+ class_colors=self.class_colors,
992
+ )
993
+ else:
994
+ wandb.termwarn(
995
+ f"unknown dataframe type for input_type={self.input_type} and output_type={self.output_type}"
996
+ )
997
+ return None
998
+
999
+ def _save_model(self, epoch):
1000
+ if wandb.run.disabled:
1001
+ return
1002
+ if self.verbose > 0:
1003
+ wandb.termlog(
1004
+ f"Epoch {epoch:05d}: {self.monitor} improved from {self.best:.5f} to {self.current:.5f}, "
1005
+ f"saving model to {self.filepath}"
1006
+ )
1007
+
1008
+ try:
1009
+ if self.save_weights_only:
1010
+ self.model.save_weights(self.filepath, overwrite=True)
1011
+ else:
1012
+ self.model.save(self.filepath, overwrite=True)
1013
+ # Was getting `RuntimeError: Unable to create link` in TF 1.13.1
1014
+ # also saw `TypeError: can't pickle _thread.RLock objects`
1015
+ except (ImportError, RuntimeError, TypeError, AttributeError):
1016
+ logger.exception("Error saving model in the h5py format")
1017
+ wandb.termerror(
1018
+ "Can't save model in the h5py format. The model will be saved as "
1019
+ "as an W&B Artifact in the 'tf' format."
1020
+ )
1021
+
1022
+ def _save_model_as_artifact(self, epoch):
1023
+ if wandb.run.disabled:
1024
+ return
1025
+
1026
+ # Save the model in the SavedModel format.
1027
+ # TODO: Replace this manual artifact creation with the `log_model` method
1028
+ # after `log_model` is released from beta.
1029
+ self.model.save(self.filepath[:-3], overwrite=True, save_format="tf")
1030
+
1031
+ # Log the model as artifact.
1032
+ name = wandb.util.make_artifact_name_safe(f"model-{wandb.run.name}")
1033
+ model_artifact = wandb.Artifact(name, type="model")
1034
+ model_artifact.add_dir(self.filepath[:-3])
1035
+ wandb.run.log_artifact(model_artifact, aliases=["latest", f"epoch_{epoch}"])
1036
+
1037
+ # Remove the SavedModel from wandb dir as we don't want to log it to save memory.
1038
+ shutil.rmtree(self.filepath[:-3])
1039
+
1040
+ def get_flops(self) -> float:
1041
+ """Calculate FLOPS [GFLOPs] for a tf.keras.Model or tf.keras.Sequential model in inference mode.
1042
+
1043
+ It uses tf.compat.v1.profiler under the hood.
1044
+ """
1045
+ if not hasattr(self, "model"):
1046
+ raise wandb.Error("self.model must be set before using this method.")
1047
+
1048
+ if not isinstance(
1049
+ self.model, (tf.keras.models.Sequential, tf.keras.models.Model)
1050
+ ):
1051
+ raise TypeError(
1052
+ "Calculating FLOPS is only supported for "
1053
+ "`tf.keras.Model` and `tf.keras.Sequential` instances."
1054
+ )
1055
+
1056
+ from tensorflow.python.framework.convert_to_constants import (
1057
+ convert_variables_to_constants_v2_as_graph,
1058
+ )
1059
+
1060
+ # Compute FLOPs for one sample
1061
+ batch_size = 1
1062
+ inputs = [
1063
+ tf.TensorSpec([batch_size] + inp.shape[1:], inp.dtype)
1064
+ for inp in self.model.inputs
1065
+ ]
1066
+
1067
+ # convert tf.keras model into frozen graph to count FLOPs about operations used at inference
1068
+ real_model = tf.function(self.model).get_concrete_function(inputs)
1069
+ frozen_func, _ = convert_variables_to_constants_v2_as_graph(real_model)
1070
+
1071
+ # Calculate FLOPs with tf.profiler
1072
+ run_meta = tf.compat.v1.RunMetadata()
1073
+ opts = (
1074
+ tf.compat.v1.profiler.ProfileOptionBuilder(
1075
+ tf.compat.v1.profiler.ProfileOptionBuilder().float_operation()
1076
+ )
1077
+ .with_empty_output()
1078
+ .build()
1079
+ )
1080
+
1081
+ flops = tf.compat.v1.profiler.profile(
1082
+ graph=frozen_func.graph, run_meta=run_meta, cmd="scope", options=opts
1083
+ )
1084
+
1085
+ # convert to GFLOPs
1086
+ return (flops.total_float_ops / 1e9) / 2