wandb 0.19.1__py3-none-musllinux_1_2_aarch64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (822) hide show
  1. package_readme.md +97 -0
  2. wandb/__init__.py +246 -0
  3. wandb/__init__.pyi +1197 -0
  4. wandb/__main__.py +3 -0
  5. wandb/_globals.py +19 -0
  6. wandb/agents/__init__.py +0 -0
  7. wandb/agents/pyagent.py +363 -0
  8. wandb/analytics/__init__.py +3 -0
  9. wandb/analytics/sentry.py +263 -0
  10. wandb/apis/__init__.py +48 -0
  11. wandb/apis/attrs.py +51 -0
  12. wandb/apis/importers/__init__.py +1 -0
  13. wandb/apis/importers/internals/internal.py +385 -0
  14. wandb/apis/importers/internals/protocols.py +103 -0
  15. wandb/apis/importers/internals/util.py +78 -0
  16. wandb/apis/importers/mlflow.py +254 -0
  17. wandb/apis/importers/validation.py +108 -0
  18. wandb/apis/importers/wandb.py +1603 -0
  19. wandb/apis/internal.py +232 -0
  20. wandb/apis/normalize.py +73 -0
  21. wandb/apis/paginator.py +81 -0
  22. wandb/apis/public/__init__.py +34 -0
  23. wandb/apis/public/api.py +1387 -0
  24. wandb/apis/public/artifacts.py +1095 -0
  25. wandb/apis/public/const.py +4 -0
  26. wandb/apis/public/files.py +263 -0
  27. wandb/apis/public/history.py +149 -0
  28. wandb/apis/public/jobs.py +653 -0
  29. wandb/apis/public/projects.py +154 -0
  30. wandb/apis/public/query_generator.py +166 -0
  31. wandb/apis/public/reports.py +458 -0
  32. wandb/apis/public/runs.py +1012 -0
  33. wandb/apis/public/sweeps.py +240 -0
  34. wandb/apis/public/teams.py +198 -0
  35. wandb/apis/public/users.py +136 -0
  36. wandb/apis/public/utils.py +68 -0
  37. wandb/apis/reports/__init__.py +1 -0
  38. wandb/apis/reports/v1/__init__.py +8 -0
  39. wandb/apis/reports/v2/__init__.py +8 -0
  40. wandb/apis/workspaces/__init__.py +8 -0
  41. wandb/beta/workflows.py +288 -0
  42. wandb/bin/gpu_stats +0 -0
  43. wandb/bin/wandb-core +0 -0
  44. wandb/cli/__init__.py +0 -0
  45. wandb/cli/beta.py +178 -0
  46. wandb/cli/cli.py +2812 -0
  47. wandb/data_types.py +66 -0
  48. wandb/docker/__init__.py +343 -0
  49. wandb/docker/auth.py +435 -0
  50. wandb/docker/wandb-entrypoint.sh +33 -0
  51. wandb/docker/www_authenticate.py +94 -0
  52. wandb/env.py +513 -0
  53. wandb/errors/__init__.py +17 -0
  54. wandb/errors/errors.py +37 -0
  55. wandb/errors/links.py +73 -0
  56. wandb/errors/term.py +415 -0
  57. wandb/errors/util.py +57 -0
  58. wandb/errors/warnings.py +2 -0
  59. wandb/filesync/__init__.py +0 -0
  60. wandb/filesync/dir_watcher.py +403 -0
  61. wandb/filesync/stats.py +100 -0
  62. wandb/filesync/step_checksum.py +142 -0
  63. wandb/filesync/step_prepare.py +179 -0
  64. wandb/filesync/step_upload.py +287 -0
  65. wandb/filesync/upload_job.py +142 -0
  66. wandb/integration/__init__.py +0 -0
  67. wandb/integration/catboost/__init__.py +5 -0
  68. wandb/integration/catboost/catboost.py +178 -0
  69. wandb/integration/cohere/__init__.py +3 -0
  70. wandb/integration/cohere/cohere.py +21 -0
  71. wandb/integration/cohere/resolver.py +347 -0
  72. wandb/integration/diffusers/__init__.py +3 -0
  73. wandb/integration/diffusers/autologger.py +76 -0
  74. wandb/integration/diffusers/pipeline_resolver.py +50 -0
  75. wandb/integration/diffusers/resolvers/__init__.py +9 -0
  76. wandb/integration/diffusers/resolvers/multimodal.py +882 -0
  77. wandb/integration/diffusers/resolvers/utils.py +102 -0
  78. wandb/integration/fastai/__init__.py +245 -0
  79. wandb/integration/gym/__init__.py +99 -0
  80. wandb/integration/huggingface/__init__.py +3 -0
  81. wandb/integration/huggingface/huggingface.py +18 -0
  82. wandb/integration/huggingface/resolver.py +213 -0
  83. wandb/integration/keras/__init__.py +11 -0
  84. wandb/integration/keras/callbacks/__init__.py +5 -0
  85. wandb/integration/keras/callbacks/metrics_logger.py +129 -0
  86. wandb/integration/keras/callbacks/model_checkpoint.py +188 -0
  87. wandb/integration/keras/callbacks/tables_builder.py +228 -0
  88. wandb/integration/keras/keras.py +1089 -0
  89. wandb/integration/kfp/__init__.py +6 -0
  90. wandb/integration/kfp/helpers.py +28 -0
  91. wandb/integration/kfp/kfp_patch.py +334 -0
  92. wandb/integration/kfp/wandb_logging.py +182 -0
  93. wandb/integration/langchain/__init__.py +3 -0
  94. wandb/integration/langchain/wandb_tracer.py +48 -0
  95. wandb/integration/lightgbm/__init__.py +239 -0
  96. wandb/integration/lightning/__init__.py +0 -0
  97. wandb/integration/lightning/fabric/__init__.py +3 -0
  98. wandb/integration/lightning/fabric/logger.py +764 -0
  99. wandb/integration/metaflow/__init__.py +3 -0
  100. wandb/integration/metaflow/metaflow.py +383 -0
  101. wandb/integration/openai/__init__.py +3 -0
  102. wandb/integration/openai/fine_tuning.py +480 -0
  103. wandb/integration/openai/openai.py +22 -0
  104. wandb/integration/openai/resolver.py +240 -0
  105. wandb/integration/prodigy/__init__.py +3 -0
  106. wandb/integration/prodigy/prodigy.py +299 -0
  107. wandb/integration/sacred/__init__.py +117 -0
  108. wandb/integration/sagemaker/__init__.py +12 -0
  109. wandb/integration/sagemaker/auth.py +28 -0
  110. wandb/integration/sagemaker/config.py +49 -0
  111. wandb/integration/sagemaker/files.py +3 -0
  112. wandb/integration/sagemaker/resources.py +34 -0
  113. wandb/integration/sb3/__init__.py +3 -0
  114. wandb/integration/sb3/sb3.py +147 -0
  115. wandb/integration/sklearn/__init__.py +37 -0
  116. wandb/integration/sklearn/calculate/__init__.py +32 -0
  117. wandb/integration/sklearn/calculate/calibration_curves.py +125 -0
  118. wandb/integration/sklearn/calculate/class_proportions.py +68 -0
  119. wandb/integration/sklearn/calculate/confusion_matrix.py +93 -0
  120. wandb/integration/sklearn/calculate/decision_boundaries.py +40 -0
  121. wandb/integration/sklearn/calculate/elbow_curve.py +55 -0
  122. wandb/integration/sklearn/calculate/feature_importances.py +67 -0
  123. wandb/integration/sklearn/calculate/learning_curve.py +64 -0
  124. wandb/integration/sklearn/calculate/outlier_candidates.py +69 -0
  125. wandb/integration/sklearn/calculate/residuals.py +86 -0
  126. wandb/integration/sklearn/calculate/silhouette.py +118 -0
  127. wandb/integration/sklearn/calculate/summary_metrics.py +62 -0
  128. wandb/integration/sklearn/plot/__init__.py +35 -0
  129. wandb/integration/sklearn/plot/classifier.py +329 -0
  130. wandb/integration/sklearn/plot/clusterer.py +146 -0
  131. wandb/integration/sklearn/plot/regressor.py +121 -0
  132. wandb/integration/sklearn/plot/shared.py +91 -0
  133. wandb/integration/sklearn/utils.py +183 -0
  134. wandb/integration/tensorboard/__init__.py +10 -0
  135. wandb/integration/tensorboard/log.py +354 -0
  136. wandb/integration/tensorboard/monkeypatch.py +186 -0
  137. wandb/integration/tensorflow/__init__.py +5 -0
  138. wandb/integration/tensorflow/estimator_hook.py +54 -0
  139. wandb/integration/torch/__init__.py +0 -0
  140. wandb/integration/torch/wandb_torch.py +554 -0
  141. wandb/integration/ultralytics/__init__.py +11 -0
  142. wandb/integration/ultralytics/bbox_utils.py +215 -0
  143. wandb/integration/ultralytics/callback.py +524 -0
  144. wandb/integration/ultralytics/classification_utils.py +83 -0
  145. wandb/integration/ultralytics/mask_utils.py +202 -0
  146. wandb/integration/ultralytics/pose_utils.py +103 -0
  147. wandb/integration/xgboost/__init__.py +11 -0
  148. wandb/integration/xgboost/xgboost.py +189 -0
  149. wandb/integration/yolov8/__init__.py +0 -0
  150. wandb/integration/yolov8/yolov8.py +284 -0
  151. wandb/jupyter.py +513 -0
  152. wandb/mpmain/__init__.py +0 -0
  153. wandb/mpmain/__main__.py +1 -0
  154. wandb/old/__init__.py +0 -0
  155. wandb/old/core.py +53 -0
  156. wandb/old/settings.py +173 -0
  157. wandb/old/summary.py +440 -0
  158. wandb/plot/__init__.py +28 -0
  159. wandb/plot/bar.py +70 -0
  160. wandb/plot/confusion_matrix.py +181 -0
  161. wandb/plot/custom_chart.py +124 -0
  162. wandb/plot/histogram.py +65 -0
  163. wandb/plot/line.py +74 -0
  164. wandb/plot/line_series.py +176 -0
  165. wandb/plot/pr_curve.py +185 -0
  166. wandb/plot/roc_curve.py +163 -0
  167. wandb/plot/scatter.py +66 -0
  168. wandb/plot/utils.py +183 -0
  169. wandb/plot/viz.py +41 -0
  170. wandb/proto/__init__.py +0 -0
  171. wandb/proto/v3/__init__.py +0 -0
  172. wandb/proto/v3/wandb_base_pb2.py +55 -0
  173. wandb/proto/v3/wandb_internal_pb2.py +1658 -0
  174. wandb/proto/v3/wandb_server_pb2.py +228 -0
  175. wandb/proto/v3/wandb_settings_pb2.py +122 -0
  176. wandb/proto/v3/wandb_telemetry_pb2.py +106 -0
  177. wandb/proto/v4/__init__.py +0 -0
  178. wandb/proto/v4/wandb_base_pb2.py +30 -0
  179. wandb/proto/v4/wandb_internal_pb2.py +370 -0
  180. wandb/proto/v4/wandb_server_pb2.py +67 -0
  181. wandb/proto/v4/wandb_settings_pb2.py +47 -0
  182. wandb/proto/v4/wandb_telemetry_pb2.py +41 -0
  183. wandb/proto/v5/wandb_base_pb2.py +31 -0
  184. wandb/proto/v5/wandb_internal_pb2.py +371 -0
  185. wandb/proto/v5/wandb_server_pb2.py +68 -0
  186. wandb/proto/v5/wandb_settings_pb2.py +48 -0
  187. wandb/proto/v5/wandb_telemetry_pb2.py +42 -0
  188. wandb/proto/wandb_base_pb2.py +10 -0
  189. wandb/proto/wandb_deprecated.py +45 -0
  190. wandb/proto/wandb_generate_deprecated.py +30 -0
  191. wandb/proto/wandb_generate_proto.py +49 -0
  192. wandb/proto/wandb_internal_pb2.py +16 -0
  193. wandb/proto/wandb_server_pb2.py +10 -0
  194. wandb/proto/wandb_settings_pb2.py +10 -0
  195. wandb/proto/wandb_telemetry_pb2.py +10 -0
  196. wandb/py.typed +0 -0
  197. wandb/sdk/__init__.py +37 -0
  198. wandb/sdk/artifacts/__init__.py +0 -0
  199. wandb/sdk/artifacts/_validators.py +121 -0
  200. wandb/sdk/artifacts/artifact.py +2364 -0
  201. wandb/sdk/artifacts/artifact_download_logger.py +43 -0
  202. wandb/sdk/artifacts/artifact_file_cache.py +249 -0
  203. wandb/sdk/artifacts/artifact_instance_cache.py +17 -0
  204. wandb/sdk/artifacts/artifact_manifest.py +75 -0
  205. wandb/sdk/artifacts/artifact_manifest_entry.py +249 -0
  206. wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
  207. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +92 -0
  208. wandb/sdk/artifacts/artifact_saver.py +265 -0
  209. wandb/sdk/artifacts/artifact_state.py +11 -0
  210. wandb/sdk/artifacts/artifact_ttl.py +7 -0
  211. wandb/sdk/artifacts/exceptions.py +57 -0
  212. wandb/sdk/artifacts/staging.py +25 -0
  213. wandb/sdk/artifacts/storage_handler.py +62 -0
  214. wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
  215. wandb/sdk/artifacts/storage_handlers/azure_handler.py +213 -0
  216. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
  217. wandb/sdk/artifacts/storage_handlers/http_handler.py +114 -0
  218. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +139 -0
  219. wandb/sdk/artifacts/storage_handlers/multi_handler.py +56 -0
  220. wandb/sdk/artifacts/storage_handlers/s3_handler.py +298 -0
  221. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +72 -0
  222. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +135 -0
  223. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +74 -0
  224. wandb/sdk/artifacts/storage_layout.py +6 -0
  225. wandb/sdk/artifacts/storage_policies/__init__.py +4 -0
  226. wandb/sdk/artifacts/storage_policies/register.py +1 -0
  227. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +378 -0
  228. wandb/sdk/artifacts/storage_policy.py +72 -0
  229. wandb/sdk/backend/__init__.py +0 -0
  230. wandb/sdk/backend/backend.py +221 -0
  231. wandb/sdk/data_types/__init__.py +0 -0
  232. wandb/sdk/data_types/_dtypes.py +918 -0
  233. wandb/sdk/data_types/_private.py +10 -0
  234. wandb/sdk/data_types/audio.py +165 -0
  235. wandb/sdk/data_types/base_types/__init__.py +0 -0
  236. wandb/sdk/data_types/base_types/json_metadata.py +55 -0
  237. wandb/sdk/data_types/base_types/media.py +376 -0
  238. wandb/sdk/data_types/base_types/wb_value.py +282 -0
  239. wandb/sdk/data_types/bokeh.py +70 -0
  240. wandb/sdk/data_types/graph.py +405 -0
  241. wandb/sdk/data_types/helper_types/__init__.py +0 -0
  242. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +305 -0
  243. wandb/sdk/data_types/helper_types/classes.py +159 -0
  244. wandb/sdk/data_types/helper_types/image_mask.py +241 -0
  245. wandb/sdk/data_types/histogram.py +94 -0
  246. wandb/sdk/data_types/html.py +115 -0
  247. wandb/sdk/data_types/image.py +847 -0
  248. wandb/sdk/data_types/molecule.py +241 -0
  249. wandb/sdk/data_types/object_3d.py +470 -0
  250. wandb/sdk/data_types/plotly.py +82 -0
  251. wandb/sdk/data_types/saved_model.py +445 -0
  252. wandb/sdk/data_types/table.py +1204 -0
  253. wandb/sdk/data_types/trace_tree.py +438 -0
  254. wandb/sdk/data_types/utils.py +228 -0
  255. wandb/sdk/data_types/video.py +268 -0
  256. wandb/sdk/integration_utils/__init__.py +0 -0
  257. wandb/sdk/integration_utils/auto_logging.py +232 -0
  258. wandb/sdk/integration_utils/data_logging.py +475 -0
  259. wandb/sdk/interface/__init__.py +0 -0
  260. wandb/sdk/interface/constants.py +4 -0
  261. wandb/sdk/interface/interface.py +1010 -0
  262. wandb/sdk/interface/interface_queue.py +53 -0
  263. wandb/sdk/interface/interface_relay.py +53 -0
  264. wandb/sdk/interface/interface_shared.py +546 -0
  265. wandb/sdk/interface/interface_sock.py +61 -0
  266. wandb/sdk/interface/message_future.py +27 -0
  267. wandb/sdk/interface/message_future_poll.py +50 -0
  268. wandb/sdk/interface/router.py +115 -0
  269. wandb/sdk/interface/router_queue.py +41 -0
  270. wandb/sdk/interface/router_relay.py +37 -0
  271. wandb/sdk/interface/router_sock.py +36 -0
  272. wandb/sdk/interface/summary_record.py +67 -0
  273. wandb/sdk/internal/__init__.py +0 -0
  274. wandb/sdk/internal/context.py +89 -0
  275. wandb/sdk/internal/datastore.py +297 -0
  276. wandb/sdk/internal/file_pusher.py +181 -0
  277. wandb/sdk/internal/file_stream.py +695 -0
  278. wandb/sdk/internal/flow_control.py +263 -0
  279. wandb/sdk/internal/handler.py +905 -0
  280. wandb/sdk/internal/internal.py +403 -0
  281. wandb/sdk/internal/internal_api.py +4587 -0
  282. wandb/sdk/internal/internal_util.py +97 -0
  283. wandb/sdk/internal/job_builder.py +638 -0
  284. wandb/sdk/internal/profiler.py +78 -0
  285. wandb/sdk/internal/progress.py +79 -0
  286. wandb/sdk/internal/run.py +25 -0
  287. wandb/sdk/internal/sample.py +70 -0
  288. wandb/sdk/internal/sender.py +1696 -0
  289. wandb/sdk/internal/sender_config.py +197 -0
  290. wandb/sdk/internal/settings_static.py +97 -0
  291. wandb/sdk/internal/system/__init__.py +0 -0
  292. wandb/sdk/internal/system/assets/__init__.py +25 -0
  293. wandb/sdk/internal/system/assets/aggregators.py +37 -0
  294. wandb/sdk/internal/system/assets/asset_registry.py +20 -0
  295. wandb/sdk/internal/system/assets/cpu.py +163 -0
  296. wandb/sdk/internal/system/assets/disk.py +210 -0
  297. wandb/sdk/internal/system/assets/gpu.py +416 -0
  298. wandb/sdk/internal/system/assets/gpu_amd.py +233 -0
  299. wandb/sdk/internal/system/assets/interfaces.py +205 -0
  300. wandb/sdk/internal/system/assets/ipu.py +177 -0
  301. wandb/sdk/internal/system/assets/memory.py +166 -0
  302. wandb/sdk/internal/system/assets/network.py +125 -0
  303. wandb/sdk/internal/system/assets/open_metrics.py +293 -0
  304. wandb/sdk/internal/system/assets/tpu.py +154 -0
  305. wandb/sdk/internal/system/assets/trainium.py +393 -0
  306. wandb/sdk/internal/system/env_probe_helpers.py +13 -0
  307. wandb/sdk/internal/system/system_info.py +250 -0
  308. wandb/sdk/internal/system/system_monitor.py +222 -0
  309. wandb/sdk/internal/tb_watcher.py +519 -0
  310. wandb/sdk/internal/thread_local_settings.py +18 -0
  311. wandb/sdk/internal/writer.py +204 -0
  312. wandb/sdk/launch/__init__.py +15 -0
  313. wandb/sdk/launch/_launch.py +331 -0
  314. wandb/sdk/launch/_launch_add.py +255 -0
  315. wandb/sdk/launch/_project_spec.py +566 -0
  316. wandb/sdk/launch/agent/__init__.py +5 -0
  317. wandb/sdk/launch/agent/agent.py +924 -0
  318. wandb/sdk/launch/agent/config.py +296 -0
  319. wandb/sdk/launch/agent/job_status_tracker.py +53 -0
  320. wandb/sdk/launch/agent/run_queue_item_file_saver.py +39 -0
  321. wandb/sdk/launch/builder/__init__.py +0 -0
  322. wandb/sdk/launch/builder/abstract.py +156 -0
  323. wandb/sdk/launch/builder/build.py +297 -0
  324. wandb/sdk/launch/builder/context_manager.py +235 -0
  325. wandb/sdk/launch/builder/docker_builder.py +177 -0
  326. wandb/sdk/launch/builder/kaniko_builder.py +594 -0
  327. wandb/sdk/launch/builder/noop.py +58 -0
  328. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +188 -0
  329. wandb/sdk/launch/builder/templates/dockerfile.py +92 -0
  330. wandb/sdk/launch/create_job.py +528 -0
  331. wandb/sdk/launch/environment/abstract.py +29 -0
  332. wandb/sdk/launch/environment/aws_environment.py +322 -0
  333. wandb/sdk/launch/environment/azure_environment.py +105 -0
  334. wandb/sdk/launch/environment/gcp_environment.py +335 -0
  335. wandb/sdk/launch/environment/local_environment.py +65 -0
  336. wandb/sdk/launch/errors.py +13 -0
  337. wandb/sdk/launch/git_reference.py +109 -0
  338. wandb/sdk/launch/inputs/files.py +148 -0
  339. wandb/sdk/launch/inputs/internal.py +315 -0
  340. wandb/sdk/launch/inputs/manage.py +113 -0
  341. wandb/sdk/launch/inputs/schema.py +39 -0
  342. wandb/sdk/launch/loader.py +249 -0
  343. wandb/sdk/launch/registry/abstract.py +48 -0
  344. wandb/sdk/launch/registry/anon.py +29 -0
  345. wandb/sdk/launch/registry/azure_container_registry.py +124 -0
  346. wandb/sdk/launch/registry/elastic_container_registry.py +192 -0
  347. wandb/sdk/launch/registry/google_artifact_registry.py +219 -0
  348. wandb/sdk/launch/registry/local_registry.py +65 -0
  349. wandb/sdk/launch/runner/__init__.py +0 -0
  350. wandb/sdk/launch/runner/abstract.py +185 -0
  351. wandb/sdk/launch/runner/kubernetes_monitor.py +472 -0
  352. wandb/sdk/launch/runner/kubernetes_runner.py +963 -0
  353. wandb/sdk/launch/runner/local_container.py +301 -0
  354. wandb/sdk/launch/runner/local_process.py +78 -0
  355. wandb/sdk/launch/runner/sagemaker_runner.py +426 -0
  356. wandb/sdk/launch/runner/vertex_runner.py +230 -0
  357. wandb/sdk/launch/sweeps/__init__.py +37 -0
  358. wandb/sdk/launch/sweeps/scheduler.py +740 -0
  359. wandb/sdk/launch/sweeps/scheduler_sweep.py +90 -0
  360. wandb/sdk/launch/sweeps/utils.py +316 -0
  361. wandb/sdk/launch/utils.py +747 -0
  362. wandb/sdk/launch/wandb_reference.py +138 -0
  363. wandb/sdk/lib/__init__.py +5 -0
  364. wandb/sdk/lib/apikey.py +269 -0
  365. wandb/sdk/lib/capped_dict.py +26 -0
  366. wandb/sdk/lib/config_util.py +101 -0
  367. wandb/sdk/lib/credentials.py +141 -0
  368. wandb/sdk/lib/deprecate.py +42 -0
  369. wandb/sdk/lib/disabled.py +29 -0
  370. wandb/sdk/lib/exit_hooks.py +54 -0
  371. wandb/sdk/lib/file_stream_utils.py +118 -0
  372. wandb/sdk/lib/filenames.py +64 -0
  373. wandb/sdk/lib/filesystem.py +372 -0
  374. wandb/sdk/lib/fsm.py +180 -0
  375. wandb/sdk/lib/gitlib.py +239 -0
  376. wandb/sdk/lib/gql_request.py +65 -0
  377. wandb/sdk/lib/handler_util.py +21 -0
  378. wandb/sdk/lib/hashutil.py +84 -0
  379. wandb/sdk/lib/import_hooks.py +275 -0
  380. wandb/sdk/lib/ipython.py +126 -0
  381. wandb/sdk/lib/json_util.py +80 -0
  382. wandb/sdk/lib/lazyloader.py +63 -0
  383. wandb/sdk/lib/mailbox.py +456 -0
  384. wandb/sdk/lib/module.py +78 -0
  385. wandb/sdk/lib/paths.py +106 -0
  386. wandb/sdk/lib/preinit.py +42 -0
  387. wandb/sdk/lib/printer.py +548 -0
  388. wandb/sdk/lib/progress.py +279 -0
  389. wandb/sdk/lib/proto_util.py +90 -0
  390. wandb/sdk/lib/redirect.py +845 -0
  391. wandb/sdk/lib/retry.py +289 -0
  392. wandb/sdk/lib/run_moment.py +72 -0
  393. wandb/sdk/lib/runid.py +12 -0
  394. wandb/sdk/lib/server.py +38 -0
  395. wandb/sdk/lib/service_connection.py +216 -0
  396. wandb/sdk/lib/service_token.py +94 -0
  397. wandb/sdk/lib/sock_client.py +290 -0
  398. wandb/sdk/lib/sparkline.py +44 -0
  399. wandb/sdk/lib/telemetry.py +100 -0
  400. wandb/sdk/lib/timed_input.py +133 -0
  401. wandb/sdk/lib/timer.py +19 -0
  402. wandb/sdk/service/__init__.py +0 -0
  403. wandb/sdk/service/_startup_debug.py +22 -0
  404. wandb/sdk/service/port_file.py +53 -0
  405. wandb/sdk/service/server.py +107 -0
  406. wandb/sdk/service/server_sock.py +274 -0
  407. wandb/sdk/service/service.py +242 -0
  408. wandb/sdk/service/streams.py +425 -0
  409. wandb/sdk/verify/__init__.py +0 -0
  410. wandb/sdk/verify/verify.py +501 -0
  411. wandb/sdk/wandb_alerts.py +12 -0
  412. wandb/sdk/wandb_config.py +322 -0
  413. wandb/sdk/wandb_helper.py +54 -0
  414. wandb/sdk/wandb_init.py +1313 -0
  415. wandb/sdk/wandb_login.py +339 -0
  416. wandb/sdk/wandb_metric.py +110 -0
  417. wandb/sdk/wandb_require.py +94 -0
  418. wandb/sdk/wandb_require_helpers.py +44 -0
  419. wandb/sdk/wandb_run.py +4066 -0
  420. wandb/sdk/wandb_settings.py +1309 -0
  421. wandb/sdk/wandb_setup.py +402 -0
  422. wandb/sdk/wandb_summary.py +150 -0
  423. wandb/sdk/wandb_sweep.py +119 -0
  424. wandb/sdk/wandb_sync.py +82 -0
  425. wandb/sdk/wandb_watch.py +150 -0
  426. wandb/sklearn.py +35 -0
  427. wandb/sync/__init__.py +3 -0
  428. wandb/sync/sync.py +442 -0
  429. wandb/trigger.py +29 -0
  430. wandb/util.py +1955 -0
  431. wandb/vendor/__init__.py +0 -0
  432. wandb/vendor/gql-0.2.0/setup.py +40 -0
  433. wandb/vendor/gql-0.2.0/tests/__init__.py +0 -0
  434. wandb/vendor/gql-0.2.0/tests/starwars/__init__.py +0 -0
  435. wandb/vendor/gql-0.2.0/tests/starwars/fixtures.py +96 -0
  436. wandb/vendor/gql-0.2.0/tests/starwars/schema.py +146 -0
  437. wandb/vendor/gql-0.2.0/tests/starwars/test_dsl.py +293 -0
  438. wandb/vendor/gql-0.2.0/tests/starwars/test_query.py +355 -0
  439. wandb/vendor/gql-0.2.0/tests/starwars/test_validation.py +171 -0
  440. wandb/vendor/gql-0.2.0/tests/test_client.py +31 -0
  441. wandb/vendor/gql-0.2.0/tests/test_transport.py +89 -0
  442. wandb/vendor/gql-0.2.0/wandb_gql/__init__.py +4 -0
  443. wandb/vendor/gql-0.2.0/wandb_gql/client.py +75 -0
  444. wandb/vendor/gql-0.2.0/wandb_gql/dsl.py +152 -0
  445. wandb/vendor/gql-0.2.0/wandb_gql/gql.py +10 -0
  446. wandb/vendor/gql-0.2.0/wandb_gql/transport/__init__.py +0 -0
  447. wandb/vendor/gql-0.2.0/wandb_gql/transport/http.py +6 -0
  448. wandb/vendor/gql-0.2.0/wandb_gql/transport/local_schema.py +15 -0
  449. wandb/vendor/gql-0.2.0/wandb_gql/transport/requests.py +46 -0
  450. wandb/vendor/gql-0.2.0/wandb_gql/utils.py +21 -0
  451. wandb/vendor/graphql-core-1.1/setup.py +86 -0
  452. wandb/vendor/graphql-core-1.1/wandb_graphql/__init__.py +287 -0
  453. wandb/vendor/graphql-core-1.1/wandb_graphql/error/__init__.py +6 -0
  454. wandb/vendor/graphql-core-1.1/wandb_graphql/error/base.py +42 -0
  455. wandb/vendor/graphql-core-1.1/wandb_graphql/error/format_error.py +11 -0
  456. wandb/vendor/graphql-core-1.1/wandb_graphql/error/located_error.py +29 -0
  457. wandb/vendor/graphql-core-1.1/wandb_graphql/error/syntax_error.py +36 -0
  458. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__init__.py +26 -0
  459. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/base.py +311 -0
  460. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executor.py +398 -0
  461. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/__init__.py +0 -0
  462. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/asyncio.py +53 -0
  463. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/gevent.py +22 -0
  464. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/process.py +32 -0
  465. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/sync.py +7 -0
  466. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/thread.py +35 -0
  467. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/utils.py +6 -0
  468. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/__init__.py +0 -0
  469. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/executor.py +66 -0
  470. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/fragment.py +252 -0
  471. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/resolver.py +151 -0
  472. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/utils.py +7 -0
  473. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/middleware.py +57 -0
  474. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/values.py +145 -0
  475. wandb/vendor/graphql-core-1.1/wandb_graphql/graphql.py +60 -0
  476. wandb/vendor/graphql-core-1.1/wandb_graphql/language/__init__.py +0 -0
  477. wandb/vendor/graphql-core-1.1/wandb_graphql/language/ast.py +1349 -0
  478. wandb/vendor/graphql-core-1.1/wandb_graphql/language/base.py +19 -0
  479. wandb/vendor/graphql-core-1.1/wandb_graphql/language/lexer.py +435 -0
  480. wandb/vendor/graphql-core-1.1/wandb_graphql/language/location.py +30 -0
  481. wandb/vendor/graphql-core-1.1/wandb_graphql/language/parser.py +779 -0
  482. wandb/vendor/graphql-core-1.1/wandb_graphql/language/printer.py +193 -0
  483. wandb/vendor/graphql-core-1.1/wandb_graphql/language/source.py +18 -0
  484. wandb/vendor/graphql-core-1.1/wandb_graphql/language/visitor.py +222 -0
  485. wandb/vendor/graphql-core-1.1/wandb_graphql/language/visitor_meta.py +82 -0
  486. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__init__.py +0 -0
  487. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/cached_property.py +17 -0
  488. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/contain_subset.py +28 -0
  489. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/default_ordered_dict.py +40 -0
  490. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/ordereddict.py +8 -0
  491. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/pair_set.py +43 -0
  492. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/version.py +78 -0
  493. wandb/vendor/graphql-core-1.1/wandb_graphql/type/__init__.py +67 -0
  494. wandb/vendor/graphql-core-1.1/wandb_graphql/type/definition.py +619 -0
  495. wandb/vendor/graphql-core-1.1/wandb_graphql/type/directives.py +132 -0
  496. wandb/vendor/graphql-core-1.1/wandb_graphql/type/introspection.py +440 -0
  497. wandb/vendor/graphql-core-1.1/wandb_graphql/type/scalars.py +131 -0
  498. wandb/vendor/graphql-core-1.1/wandb_graphql/type/schema.py +100 -0
  499. wandb/vendor/graphql-core-1.1/wandb_graphql/type/typemap.py +145 -0
  500. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__init__.py +0 -0
  501. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/assert_valid_name.py +9 -0
  502. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_from_value.py +65 -0
  503. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_to_code.py +49 -0
  504. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_to_dict.py +24 -0
  505. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/base.py +75 -0
  506. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/build_ast_schema.py +291 -0
  507. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/build_client_schema.py +250 -0
  508. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/concat_ast.py +9 -0
  509. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/extend_schema.py +357 -0
  510. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/get_field_def.py +27 -0
  511. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/get_operation_ast.py +21 -0
  512. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/introspection_query.py +90 -0
  513. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/is_valid_literal_value.py +67 -0
  514. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/is_valid_value.py +66 -0
  515. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/quoted_or_list.py +21 -0
  516. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/schema_printer.py +168 -0
  517. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/suggestion_list.py +56 -0
  518. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_comparators.py +69 -0
  519. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_from_ast.py +21 -0
  520. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_info.py +149 -0
  521. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/value_from_ast.py +69 -0
  522. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/__init__.py +4 -0
  523. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__init__.py +79 -0
  524. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/arguments_of_correct_type.py +24 -0
  525. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/base.py +8 -0
  526. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/default_values_of_correct_type.py +44 -0
  527. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/fields_on_correct_type.py +113 -0
  528. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/fragments_on_composite_types.py +33 -0
  529. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_argument_names.py +70 -0
  530. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_directives.py +97 -0
  531. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_fragment_names.py +19 -0
  532. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_type_names.py +43 -0
  533. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/lone_anonymous_operation.py +23 -0
  534. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_fragment_cycles.py +59 -0
  535. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_undefined_variables.py +36 -0
  536. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_unused_fragments.py +38 -0
  537. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_unused_variables.py +37 -0
  538. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/overlapping_fields_can_be_merged.py +529 -0
  539. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/possible_fragment_spreads.py +44 -0
  540. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/provided_non_null_arguments.py +46 -0
  541. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/scalar_leafs.py +33 -0
  542. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_argument_names.py +32 -0
  543. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_fragment_names.py +28 -0
  544. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_input_field_names.py +33 -0
  545. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_operation_names.py +31 -0
  546. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_variable_names.py +27 -0
  547. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/variables_are_input_types.py +21 -0
  548. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/variables_in_allowed_position.py +53 -0
  549. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/validation.py +158 -0
  550. wandb/vendor/promise-2.3.0/conftest.py +30 -0
  551. wandb/vendor/promise-2.3.0/setup.py +64 -0
  552. wandb/vendor/promise-2.3.0/tests/__init__.py +0 -0
  553. wandb/vendor/promise-2.3.0/tests/conftest.py +8 -0
  554. wandb/vendor/promise-2.3.0/tests/test_awaitable.py +32 -0
  555. wandb/vendor/promise-2.3.0/tests/test_awaitable_35.py +47 -0
  556. wandb/vendor/promise-2.3.0/tests/test_benchmark.py +116 -0
  557. wandb/vendor/promise-2.3.0/tests/test_complex_threads.py +23 -0
  558. wandb/vendor/promise-2.3.0/tests/test_dataloader.py +452 -0
  559. wandb/vendor/promise-2.3.0/tests/test_dataloader_awaitable_35.py +99 -0
  560. wandb/vendor/promise-2.3.0/tests/test_dataloader_extra.py +65 -0
  561. wandb/vendor/promise-2.3.0/tests/test_extra.py +670 -0
  562. wandb/vendor/promise-2.3.0/tests/test_issues.py +132 -0
  563. wandb/vendor/promise-2.3.0/tests/test_promise_list.py +70 -0
  564. wandb/vendor/promise-2.3.0/tests/test_spec.py +584 -0
  565. wandb/vendor/promise-2.3.0/tests/test_thread_safety.py +115 -0
  566. wandb/vendor/promise-2.3.0/tests/utils.py +3 -0
  567. wandb/vendor/promise-2.3.0/wandb_promise/__init__.py +38 -0
  568. wandb/vendor/promise-2.3.0/wandb_promise/async_.py +135 -0
  569. wandb/vendor/promise-2.3.0/wandb_promise/compat.py +32 -0
  570. wandb/vendor/promise-2.3.0/wandb_promise/dataloader.py +326 -0
  571. wandb/vendor/promise-2.3.0/wandb_promise/iterate_promise.py +12 -0
  572. wandb/vendor/promise-2.3.0/wandb_promise/promise.py +848 -0
  573. wandb/vendor/promise-2.3.0/wandb_promise/promise_list.py +151 -0
  574. wandb/vendor/promise-2.3.0/wandb_promise/pyutils/__init__.py +0 -0
  575. wandb/vendor/promise-2.3.0/wandb_promise/pyutils/version.py +83 -0
  576. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/__init__.py +0 -0
  577. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/asyncio.py +22 -0
  578. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/gevent.py +21 -0
  579. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/immediate.py +27 -0
  580. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/thread.py +18 -0
  581. wandb/vendor/promise-2.3.0/wandb_promise/utils.py +56 -0
  582. wandb/vendor/pygments/__init__.py +90 -0
  583. wandb/vendor/pygments/cmdline.py +568 -0
  584. wandb/vendor/pygments/console.py +74 -0
  585. wandb/vendor/pygments/filter.py +74 -0
  586. wandb/vendor/pygments/filters/__init__.py +350 -0
  587. wandb/vendor/pygments/formatter.py +95 -0
  588. wandb/vendor/pygments/formatters/__init__.py +153 -0
  589. wandb/vendor/pygments/formatters/_mapping.py +85 -0
  590. wandb/vendor/pygments/formatters/bbcode.py +109 -0
  591. wandb/vendor/pygments/formatters/html.py +851 -0
  592. wandb/vendor/pygments/formatters/img.py +600 -0
  593. wandb/vendor/pygments/formatters/irc.py +182 -0
  594. wandb/vendor/pygments/formatters/latex.py +482 -0
  595. wandb/vendor/pygments/formatters/other.py +160 -0
  596. wandb/vendor/pygments/formatters/rtf.py +147 -0
  597. wandb/vendor/pygments/formatters/svg.py +153 -0
  598. wandb/vendor/pygments/formatters/terminal.py +136 -0
  599. wandb/vendor/pygments/formatters/terminal256.py +309 -0
  600. wandb/vendor/pygments/lexer.py +871 -0
  601. wandb/vendor/pygments/lexers/__init__.py +329 -0
  602. wandb/vendor/pygments/lexers/_asy_builtins.py +1645 -0
  603. wandb/vendor/pygments/lexers/_cl_builtins.py +232 -0
  604. wandb/vendor/pygments/lexers/_cocoa_builtins.py +72 -0
  605. wandb/vendor/pygments/lexers/_csound_builtins.py +1346 -0
  606. wandb/vendor/pygments/lexers/_lasso_builtins.py +5327 -0
  607. wandb/vendor/pygments/lexers/_lua_builtins.py +295 -0
  608. wandb/vendor/pygments/lexers/_mapping.py +500 -0
  609. wandb/vendor/pygments/lexers/_mql_builtins.py +1172 -0
  610. wandb/vendor/pygments/lexers/_openedge_builtins.py +2547 -0
  611. wandb/vendor/pygments/lexers/_php_builtins.py +4756 -0
  612. wandb/vendor/pygments/lexers/_postgres_builtins.py +621 -0
  613. wandb/vendor/pygments/lexers/_scilab_builtins.py +3094 -0
  614. wandb/vendor/pygments/lexers/_sourcemod_builtins.py +1163 -0
  615. wandb/vendor/pygments/lexers/_stan_builtins.py +532 -0
  616. wandb/vendor/pygments/lexers/_stata_builtins.py +419 -0
  617. wandb/vendor/pygments/lexers/_tsql_builtins.py +1004 -0
  618. wandb/vendor/pygments/lexers/_vim_builtins.py +1939 -0
  619. wandb/vendor/pygments/lexers/actionscript.py +240 -0
  620. wandb/vendor/pygments/lexers/agile.py +24 -0
  621. wandb/vendor/pygments/lexers/algebra.py +221 -0
  622. wandb/vendor/pygments/lexers/ambient.py +76 -0
  623. wandb/vendor/pygments/lexers/ampl.py +87 -0
  624. wandb/vendor/pygments/lexers/apl.py +101 -0
  625. wandb/vendor/pygments/lexers/archetype.py +318 -0
  626. wandb/vendor/pygments/lexers/asm.py +641 -0
  627. wandb/vendor/pygments/lexers/automation.py +374 -0
  628. wandb/vendor/pygments/lexers/basic.py +500 -0
  629. wandb/vendor/pygments/lexers/bibtex.py +160 -0
  630. wandb/vendor/pygments/lexers/business.py +612 -0
  631. wandb/vendor/pygments/lexers/c_cpp.py +252 -0
  632. wandb/vendor/pygments/lexers/c_like.py +541 -0
  633. wandb/vendor/pygments/lexers/capnproto.py +78 -0
  634. wandb/vendor/pygments/lexers/chapel.py +102 -0
  635. wandb/vendor/pygments/lexers/clean.py +288 -0
  636. wandb/vendor/pygments/lexers/compiled.py +34 -0
  637. wandb/vendor/pygments/lexers/configs.py +833 -0
  638. wandb/vendor/pygments/lexers/console.py +114 -0
  639. wandb/vendor/pygments/lexers/crystal.py +393 -0
  640. wandb/vendor/pygments/lexers/csound.py +366 -0
  641. wandb/vendor/pygments/lexers/css.py +689 -0
  642. wandb/vendor/pygments/lexers/d.py +251 -0
  643. wandb/vendor/pygments/lexers/dalvik.py +125 -0
  644. wandb/vendor/pygments/lexers/data.py +555 -0
  645. wandb/vendor/pygments/lexers/diff.py +165 -0
  646. wandb/vendor/pygments/lexers/dotnet.py +691 -0
  647. wandb/vendor/pygments/lexers/dsls.py +878 -0
  648. wandb/vendor/pygments/lexers/dylan.py +289 -0
  649. wandb/vendor/pygments/lexers/ecl.py +125 -0
  650. wandb/vendor/pygments/lexers/eiffel.py +65 -0
  651. wandb/vendor/pygments/lexers/elm.py +121 -0
  652. wandb/vendor/pygments/lexers/erlang.py +533 -0
  653. wandb/vendor/pygments/lexers/esoteric.py +277 -0
  654. wandb/vendor/pygments/lexers/ezhil.py +69 -0
  655. wandb/vendor/pygments/lexers/factor.py +344 -0
  656. wandb/vendor/pygments/lexers/fantom.py +250 -0
  657. wandb/vendor/pygments/lexers/felix.py +273 -0
  658. wandb/vendor/pygments/lexers/forth.py +177 -0
  659. wandb/vendor/pygments/lexers/fortran.py +205 -0
  660. wandb/vendor/pygments/lexers/foxpro.py +428 -0
  661. wandb/vendor/pygments/lexers/functional.py +21 -0
  662. wandb/vendor/pygments/lexers/go.py +101 -0
  663. wandb/vendor/pygments/lexers/grammar_notation.py +213 -0
  664. wandb/vendor/pygments/lexers/graph.py +80 -0
  665. wandb/vendor/pygments/lexers/graphics.py +553 -0
  666. wandb/vendor/pygments/lexers/haskell.py +843 -0
  667. wandb/vendor/pygments/lexers/haxe.py +936 -0
  668. wandb/vendor/pygments/lexers/hdl.py +382 -0
  669. wandb/vendor/pygments/lexers/hexdump.py +103 -0
  670. wandb/vendor/pygments/lexers/html.py +602 -0
  671. wandb/vendor/pygments/lexers/idl.py +270 -0
  672. wandb/vendor/pygments/lexers/igor.py +288 -0
  673. wandb/vendor/pygments/lexers/inferno.py +96 -0
  674. wandb/vendor/pygments/lexers/installers.py +322 -0
  675. wandb/vendor/pygments/lexers/int_fiction.py +1343 -0
  676. wandb/vendor/pygments/lexers/iolang.py +63 -0
  677. wandb/vendor/pygments/lexers/j.py +146 -0
  678. wandb/vendor/pygments/lexers/javascript.py +1525 -0
  679. wandb/vendor/pygments/lexers/julia.py +333 -0
  680. wandb/vendor/pygments/lexers/jvm.py +1573 -0
  681. wandb/vendor/pygments/lexers/lisp.py +2621 -0
  682. wandb/vendor/pygments/lexers/make.py +202 -0
  683. wandb/vendor/pygments/lexers/markup.py +595 -0
  684. wandb/vendor/pygments/lexers/math.py +21 -0
  685. wandb/vendor/pygments/lexers/matlab.py +663 -0
  686. wandb/vendor/pygments/lexers/ml.py +769 -0
  687. wandb/vendor/pygments/lexers/modeling.py +358 -0
  688. wandb/vendor/pygments/lexers/modula2.py +1561 -0
  689. wandb/vendor/pygments/lexers/monte.py +204 -0
  690. wandb/vendor/pygments/lexers/ncl.py +894 -0
  691. wandb/vendor/pygments/lexers/nimrod.py +159 -0
  692. wandb/vendor/pygments/lexers/nit.py +64 -0
  693. wandb/vendor/pygments/lexers/nix.py +136 -0
  694. wandb/vendor/pygments/lexers/oberon.py +105 -0
  695. wandb/vendor/pygments/lexers/objective.py +504 -0
  696. wandb/vendor/pygments/lexers/ooc.py +85 -0
  697. wandb/vendor/pygments/lexers/other.py +41 -0
  698. wandb/vendor/pygments/lexers/parasail.py +79 -0
  699. wandb/vendor/pygments/lexers/parsers.py +835 -0
  700. wandb/vendor/pygments/lexers/pascal.py +644 -0
  701. wandb/vendor/pygments/lexers/pawn.py +199 -0
  702. wandb/vendor/pygments/lexers/perl.py +620 -0
  703. wandb/vendor/pygments/lexers/php.py +267 -0
  704. wandb/vendor/pygments/lexers/praat.py +294 -0
  705. wandb/vendor/pygments/lexers/prolog.py +306 -0
  706. wandb/vendor/pygments/lexers/python.py +939 -0
  707. wandb/vendor/pygments/lexers/qvt.py +152 -0
  708. wandb/vendor/pygments/lexers/r.py +453 -0
  709. wandb/vendor/pygments/lexers/rdf.py +270 -0
  710. wandb/vendor/pygments/lexers/rebol.py +431 -0
  711. wandb/vendor/pygments/lexers/resource.py +85 -0
  712. wandb/vendor/pygments/lexers/rnc.py +67 -0
  713. wandb/vendor/pygments/lexers/roboconf.py +82 -0
  714. wandb/vendor/pygments/lexers/robotframework.py +560 -0
  715. wandb/vendor/pygments/lexers/ruby.py +519 -0
  716. wandb/vendor/pygments/lexers/rust.py +220 -0
  717. wandb/vendor/pygments/lexers/sas.py +228 -0
  718. wandb/vendor/pygments/lexers/scripting.py +1222 -0
  719. wandb/vendor/pygments/lexers/shell.py +794 -0
  720. wandb/vendor/pygments/lexers/smalltalk.py +195 -0
  721. wandb/vendor/pygments/lexers/smv.py +79 -0
  722. wandb/vendor/pygments/lexers/snobol.py +83 -0
  723. wandb/vendor/pygments/lexers/special.py +103 -0
  724. wandb/vendor/pygments/lexers/sql.py +681 -0
  725. wandb/vendor/pygments/lexers/stata.py +108 -0
  726. wandb/vendor/pygments/lexers/supercollider.py +90 -0
  727. wandb/vendor/pygments/lexers/tcl.py +145 -0
  728. wandb/vendor/pygments/lexers/templates.py +2283 -0
  729. wandb/vendor/pygments/lexers/testing.py +207 -0
  730. wandb/vendor/pygments/lexers/text.py +25 -0
  731. wandb/vendor/pygments/lexers/textedit.py +169 -0
  732. wandb/vendor/pygments/lexers/textfmts.py +297 -0
  733. wandb/vendor/pygments/lexers/theorem.py +458 -0
  734. wandb/vendor/pygments/lexers/trafficscript.py +54 -0
  735. wandb/vendor/pygments/lexers/typoscript.py +226 -0
  736. wandb/vendor/pygments/lexers/urbi.py +133 -0
  737. wandb/vendor/pygments/lexers/varnish.py +190 -0
  738. wandb/vendor/pygments/lexers/verification.py +111 -0
  739. wandb/vendor/pygments/lexers/web.py +24 -0
  740. wandb/vendor/pygments/lexers/webmisc.py +988 -0
  741. wandb/vendor/pygments/lexers/whiley.py +116 -0
  742. wandb/vendor/pygments/lexers/x10.py +69 -0
  743. wandb/vendor/pygments/modeline.py +44 -0
  744. wandb/vendor/pygments/plugin.py +68 -0
  745. wandb/vendor/pygments/regexopt.py +92 -0
  746. wandb/vendor/pygments/scanner.py +105 -0
  747. wandb/vendor/pygments/sphinxext.py +158 -0
  748. wandb/vendor/pygments/style.py +155 -0
  749. wandb/vendor/pygments/styles/__init__.py +80 -0
  750. wandb/vendor/pygments/styles/abap.py +29 -0
  751. wandb/vendor/pygments/styles/algol.py +63 -0
  752. wandb/vendor/pygments/styles/algol_nu.py +63 -0
  753. wandb/vendor/pygments/styles/arduino.py +98 -0
  754. wandb/vendor/pygments/styles/autumn.py +65 -0
  755. wandb/vendor/pygments/styles/borland.py +51 -0
  756. wandb/vendor/pygments/styles/bw.py +49 -0
  757. wandb/vendor/pygments/styles/colorful.py +81 -0
  758. wandb/vendor/pygments/styles/default.py +73 -0
  759. wandb/vendor/pygments/styles/emacs.py +72 -0
  760. wandb/vendor/pygments/styles/friendly.py +72 -0
  761. wandb/vendor/pygments/styles/fruity.py +42 -0
  762. wandb/vendor/pygments/styles/igor.py +29 -0
  763. wandb/vendor/pygments/styles/lovelace.py +97 -0
  764. wandb/vendor/pygments/styles/manni.py +75 -0
  765. wandb/vendor/pygments/styles/monokai.py +106 -0
  766. wandb/vendor/pygments/styles/murphy.py +80 -0
  767. wandb/vendor/pygments/styles/native.py +65 -0
  768. wandb/vendor/pygments/styles/paraiso_dark.py +125 -0
  769. wandb/vendor/pygments/styles/paraiso_light.py +125 -0
  770. wandb/vendor/pygments/styles/pastie.py +75 -0
  771. wandb/vendor/pygments/styles/perldoc.py +69 -0
  772. wandb/vendor/pygments/styles/rainbow_dash.py +89 -0
  773. wandb/vendor/pygments/styles/rrt.py +33 -0
  774. wandb/vendor/pygments/styles/sas.py +44 -0
  775. wandb/vendor/pygments/styles/stata.py +40 -0
  776. wandb/vendor/pygments/styles/tango.py +141 -0
  777. wandb/vendor/pygments/styles/trac.py +63 -0
  778. wandb/vendor/pygments/styles/vim.py +63 -0
  779. wandb/vendor/pygments/styles/vs.py +38 -0
  780. wandb/vendor/pygments/styles/xcode.py +51 -0
  781. wandb/vendor/pygments/token.py +213 -0
  782. wandb/vendor/pygments/unistring.py +217 -0
  783. wandb/vendor/pygments/util.py +388 -0
  784. wandb/vendor/pynvml/__init__.py +0 -0
  785. wandb/vendor/pynvml/pynvml.py +4779 -0
  786. wandb/vendor/watchdog_0_9_0/wandb_watchdog/__init__.py +17 -0
  787. wandb/vendor/watchdog_0_9_0/wandb_watchdog/events.py +615 -0
  788. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/__init__.py +98 -0
  789. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/api.py +369 -0
  790. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/fsevents.py +172 -0
  791. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/fsevents2.py +239 -0
  792. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/inotify.py +218 -0
  793. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/inotify_buffer.py +81 -0
  794. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/inotify_c.py +575 -0
  795. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/kqueue.py +730 -0
  796. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/polling.py +145 -0
  797. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/read_directory_changes.py +133 -0
  798. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/winapi.py +348 -0
  799. wandb/vendor/watchdog_0_9_0/wandb_watchdog/patterns.py +265 -0
  800. wandb/vendor/watchdog_0_9_0/wandb_watchdog/tricks/__init__.py +174 -0
  801. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/__init__.py +151 -0
  802. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/bricks.py +249 -0
  803. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/compat.py +29 -0
  804. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/decorators.py +198 -0
  805. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/delayed_queue.py +88 -0
  806. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/dirsnapshot.py +293 -0
  807. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/echo.py +157 -0
  808. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/event_backport.py +41 -0
  809. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/importlib2.py +40 -0
  810. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/platform.py +57 -0
  811. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/unicode_paths.py +64 -0
  812. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/win32stat.py +123 -0
  813. wandb/vendor/watchdog_0_9_0/wandb_watchdog/version.py +28 -0
  814. wandb/vendor/watchdog_0_9_0/wandb_watchdog/watchmedo.py +577 -0
  815. wandb/wandb_agent.py +588 -0
  816. wandb/wandb_controller.py +719 -0
  817. wandb/wandb_run.py +9 -0
  818. wandb-0.19.1.dist-info/METADATA +223 -0
  819. wandb-0.19.1.dist-info/RECORD +822 -0
  820. wandb-0.19.1.dist-info/WHEEL +5 -0
  821. wandb-0.19.1.dist-info/entry_points.txt +3 -0
  822. wandb-0.19.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,4587 @@
1
+ import ast
2
+ import base64
3
+ import datetime
4
+ import functools
5
+ import http.client
6
+ import json
7
+ import logging
8
+ import os
9
+ import re
10
+ import socket
11
+ import sys
12
+ import threading
13
+ from copy import deepcopy
14
+ from pathlib import Path
15
+ from typing import (
16
+ IO,
17
+ TYPE_CHECKING,
18
+ Any,
19
+ Callable,
20
+ Dict,
21
+ Iterable,
22
+ List,
23
+ Literal,
24
+ Mapping,
25
+ MutableMapping,
26
+ NamedTuple,
27
+ Optional,
28
+ Sequence,
29
+ TextIO,
30
+ Tuple,
31
+ Union,
32
+ )
33
+
34
+ import click
35
+ import requests
36
+ import yaml
37
+ from wandb_gql import Client, gql
38
+ from wandb_gql.client import RetryError
39
+
40
+ import wandb
41
+ from wandb import env, util
42
+ from wandb.apis.normalize import normalize_exceptions, parse_backend_error_messages
43
+ from wandb.errors import AuthenticationError, CommError, UnsupportedError, UsageError
44
+ from wandb.integration.sagemaker import parse_sm_secrets
45
+ from wandb.old.settings import Settings
46
+ from wandb.sdk.artifacts._validators import is_artifact_registry_project
47
+ from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
48
+ from wandb.sdk.lib.gql_request import GraphQLSession
49
+ from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64
50
+
51
+ from ..lib import credentials, retry
52
+ from ..lib.filenames import DIFF_FNAME, METADATA_FNAME
53
+ from ..lib.gitlib import GitRepo
54
+ from . import context
55
+ from .progress import Progress
56
+
57
+ logger = logging.getLogger(__name__)
58
+
59
+ LAUNCH_DEFAULT_PROJECT = "model-registry"
60
+
61
+ if TYPE_CHECKING:
62
+ from typing import Literal, TypedDict
63
+
64
+ from .progress import ProgressFn
65
+
66
+ class CreateArtifactFileSpecInput(TypedDict, total=False):
67
+ """Corresponds to `type CreateArtifactFileSpecInput` in schema.graphql."""
68
+
69
+ artifactID: str # noqa: N815
70
+ name: str
71
+ md5: str
72
+ mimetype: Optional[str]
73
+ artifactManifestID: Optional[str] # noqa: N815
74
+ uploadPartsInput: Optional[List[Dict[str, object]]] # noqa: N815
75
+
76
+ class CreateArtifactFilesResponseFile(TypedDict):
77
+ id: str
78
+ name: str
79
+ displayName: str # noqa: N815
80
+ uploadUrl: Optional[str] # noqa: N815
81
+ uploadHeaders: Sequence[str] # noqa: N815
82
+ uploadMultipartUrls: "UploadPartsResponse" # noqa: N815
83
+ storagePath: str # noqa: N815
84
+ artifact: "CreateArtifactFilesResponseFileNode"
85
+
86
+ class CreateArtifactFilesResponseFileNode(TypedDict):
87
+ id: str
88
+
89
+ class UploadPartsResponse(TypedDict):
90
+ uploadUrlParts: List["UploadUrlParts"] # noqa: N815
91
+ uploadID: str # noqa: N815
92
+
93
+ class UploadUrlParts(TypedDict):
94
+ partNumber: int # noqa: N815
95
+ uploadUrl: str # noqa: N815
96
+
97
+ class CompleteMultipartUploadArtifactInput(TypedDict):
98
+ """Corresponds to `type CompleteMultipartUploadArtifactInput` in schema.graphql."""
99
+
100
+ completeMultipartAction: str # noqa: N815
101
+ completedParts: Dict[int, str] # noqa: N815
102
+ artifactID: str # noqa: N815
103
+ storagePath: str # noqa: N815
104
+ uploadID: str # noqa: N815
105
+ md5: str
106
+
107
+ class CompleteMultipartUploadArtifactResponse(TypedDict):
108
+ digest: str
109
+
110
+ class DefaultSettings(TypedDict):
111
+ section: str
112
+ git_remote: str
113
+ ignore_globs: Optional[List[str]]
114
+ base_url: Optional[str]
115
+ root_dir: Optional[str]
116
+ api_key: Optional[str]
117
+ entity: Optional[str]
118
+ project: Optional[str]
119
+ _extra_http_headers: Optional[Mapping[str, str]]
120
+ _proxies: Optional[Mapping[str, str]]
121
+
122
+ _Response = MutableMapping
123
+ SweepState = Literal["RUNNING", "PAUSED", "CANCELED", "FINISHED"]
124
+ Number = Union[int, float]
125
+
126
+ # class _MappingSupportsCopy(Protocol):
127
+ # def copy(self) -> "_MappingSupportsCopy": ...
128
+ # def keys(self) -> Iterable: ...
129
+ # def __getitem__(self, name: str) -> Any: ...
130
+
131
+ httpclient_logger = logging.getLogger("http.client")
132
+ if os.environ.get("WANDB_DEBUG"):
133
+ httpclient_logger.setLevel(logging.DEBUG)
134
+
135
+
136
+ def check_httpclient_logger_handler() -> None:
137
+ # Only enable http.client logging if WANDB_DEBUG is set
138
+ if not os.environ.get("WANDB_DEBUG"):
139
+ return
140
+ if httpclient_logger.handlers:
141
+ return
142
+
143
+ # Enable HTTPConnection debug logging to the logging framework
144
+ level = logging.DEBUG
145
+
146
+ def httpclient_log(*args: Any) -> None:
147
+ httpclient_logger.log(level, " ".join(args))
148
+
149
+ # mask the print() built-in in the http.client module to use logging instead
150
+ http.client.print = httpclient_log # type: ignore[attr-defined]
151
+ # enable debugging
152
+ http.client.HTTPConnection.debuglevel = 1
153
+
154
+ root_logger = logging.getLogger("wandb")
155
+ if root_logger.handlers:
156
+ httpclient_logger.addHandler(root_logger.handlers[0])
157
+
158
+
159
+ class _ThreadLocalData(threading.local):
160
+ context: Optional[context.Context]
161
+
162
+ def __init__(self) -> None:
163
+ self.context = None
164
+
165
+
166
+ class _OrgNames(NamedTuple):
167
+ entity_name: str
168
+ display_name: str
169
+
170
+
171
+ def _match_org_with_fetched_org_entities(
172
+ organization: str, orgs: Sequence[_OrgNames]
173
+ ) -> str:
174
+ """Match the organization provided in the path with the org entity or org name of the input entity.
175
+
176
+ Args:
177
+ organization: The organization name to match
178
+ orgs: List of tuples containing (org_entity_name, org_display_name)
179
+
180
+ Returns:
181
+ str: The matched org entity name
182
+
183
+ Raises:
184
+ ValueError: If no matching organization is found or if multiple orgs exist without a match
185
+ """
186
+ for org_names in orgs:
187
+ if organization in org_names:
188
+ wandb.termwarn(
189
+ "Registries can be linked/fetched using a shorthand form without specifying the organization name. "
190
+ "Try using shorthand path format: <my_registry_name>/<artifact_name> or "
191
+ "just <my_registry_name> if fetching just the project."
192
+ )
193
+ return org_names.entity_name
194
+
195
+ if len(orgs) == 1:
196
+ raise ValueError(
197
+ f"Expecting the organization name or entity name to match {orgs[0].display_name!r} "
198
+ f"and cannot be linked/fetched with {organization!r}. "
199
+ "Please update the target path with the correct organization name."
200
+ )
201
+
202
+ raise ValueError(
203
+ "Personal entity belongs to multiple organizations "
204
+ f"and cannot be linked/fetched with {organization!r}. "
205
+ "Please update the target path with the correct organization name "
206
+ "or use a team entity in the entity settings."
207
+ )
208
+
209
+
210
+ class Api:
211
+ """W&B Internal Api wrapper.
212
+
213
+ Note:
214
+ Settings are automatically overridden by looking for
215
+ a `wandb/settings` file in the current working directory or its parent
216
+ directory. If none can be found, we look in the current user's home
217
+ directory.
218
+
219
+ Args:
220
+ default_settings(dict, optional): If you aren't using a settings
221
+ file, or you wish to override the section to use in the settings file
222
+ Override the settings here.
223
+ """
224
+
225
+ HTTP_TIMEOUT = env.get_http_timeout(20)
226
+ FILE_PUSHER_TIMEOUT = env.get_file_pusher_timeout()
227
+ _global_context: context.Context
228
+ _local_data: _ThreadLocalData
229
+
230
+ def __init__(
231
+ self,
232
+ default_settings: Optional[
233
+ Union[
234
+ "wandb.sdk.wandb_settings.Settings",
235
+ "wandb.sdk.internal.settings_static.SettingsStatic",
236
+ Settings,
237
+ dict,
238
+ ]
239
+ ] = None,
240
+ load_settings: bool = True,
241
+ retry_timedelta: datetime.timedelta = datetime.timedelta( # noqa: B008 # okay because it's immutable
242
+ days=7
243
+ ),
244
+ environ: MutableMapping = os.environ,
245
+ retry_callback: Optional[Callable[[int, str], Any]] = None,
246
+ ) -> None:
247
+ self._environ = environ
248
+ self._global_context = context.Context()
249
+ self._local_data = _ThreadLocalData()
250
+ self.default_settings: DefaultSettings = {
251
+ "section": "default",
252
+ "git_remote": "origin",
253
+ "ignore_globs": [],
254
+ "base_url": "https://api.wandb.ai",
255
+ "root_dir": None,
256
+ "api_key": None,
257
+ "entity": None,
258
+ "project": None,
259
+ "_extra_http_headers": None,
260
+ "_proxies": None,
261
+ }
262
+ self.retry_timedelta = retry_timedelta
263
+ # todo: Old Settings do not follow the SupportsKeysAndGetItem Protocol
264
+ default_settings = default_settings or {}
265
+ self.default_settings.update(default_settings) # type: ignore
266
+ self.retry_uploads = 10
267
+ self._settings = Settings(
268
+ load_settings=load_settings,
269
+ root_dir=self.default_settings.get("root_dir"),
270
+ )
271
+ self.git = GitRepo(remote=self.settings("git_remote"))
272
+ # Mutable settings set by the _file_stream_api
273
+ self.dynamic_settings = {
274
+ "system_sample_seconds": 2,
275
+ "system_samples": 15,
276
+ "heartbeat_seconds": 30,
277
+ }
278
+
279
+ # todo: remove these hacky hacks after settings refactor is complete
280
+ # keeping this code here to limit scope and so that it is easy to remove later
281
+ self._extra_http_headers = self.settings("_extra_http_headers") or json.loads(
282
+ self._environ.get("WANDB__EXTRA_HTTP_HEADERS", "{}")
283
+ )
284
+ self._extra_http_headers.update(_thread_local_api_settings.headers or {})
285
+
286
+ auth = None
287
+ if self.access_token is not None:
288
+ self._extra_http_headers["Authorization"] = f"Bearer {self.access_token}"
289
+ elif _thread_local_api_settings.cookies is None:
290
+ auth = ("api", self.api_key or "")
291
+
292
+ proxies = self.settings("_proxies") or json.loads(
293
+ self._environ.get("WANDB__PROXIES", "{}")
294
+ )
295
+
296
+ self.client = Client(
297
+ transport=GraphQLSession(
298
+ headers={
299
+ "User-Agent": self.user_agent,
300
+ "X-WANDB-USERNAME": env.get_username(env=self._environ),
301
+ "X-WANDB-USER-EMAIL": env.get_user_email(env=self._environ),
302
+ **self._extra_http_headers,
303
+ },
304
+ use_json=True,
305
+ # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
306
+ # https://bugs.python.org/issue22889
307
+ timeout=self.HTTP_TIMEOUT,
308
+ auth=auth,
309
+ url=f"{self.settings('base_url')}/graphql",
310
+ cookies=_thread_local_api_settings.cookies,
311
+ proxies=proxies,
312
+ )
313
+ )
314
+
315
+ self.retry_callback = retry_callback
316
+ self._retry_gql = retry.Retry(
317
+ self.execute,
318
+ retry_timedelta=retry_timedelta,
319
+ check_retry_fn=util.no_retry_auth,
320
+ retryable_exceptions=(RetryError, requests.RequestException),
321
+ retry_callback=retry_callback,
322
+ )
323
+ self._current_run_id: Optional[str] = None
324
+ self._file_stream_api = None
325
+ self._upload_file_session = requests.Session()
326
+ if self.FILE_PUSHER_TIMEOUT:
327
+ self._upload_file_session.put = functools.partial( # type: ignore
328
+ self._upload_file_session.put,
329
+ timeout=self.FILE_PUSHER_TIMEOUT,
330
+ )
331
+ if proxies:
332
+ self._upload_file_session.proxies.update(proxies)
333
+ # This Retry class is initialized once for each Api instance, so this
334
+ # defaults to retrying 1 million times per process or 7 days
335
+ self.upload_file_retry = normalize_exceptions(
336
+ retry.retriable(retry_timedelta=retry_timedelta)(self.upload_file)
337
+ )
338
+ self.upload_multipart_file_chunk_retry = normalize_exceptions(
339
+ retry.retriable(retry_timedelta=retry_timedelta)(
340
+ self.upload_multipart_file_chunk
341
+ )
342
+ )
343
+ self._client_id_mapping: Dict[str, str] = {}
344
+ # Large file uploads to azure can optionally use their SDK
345
+ self._azure_blob_module = util.get_module("azure.storage.blob")
346
+
347
+ self.query_types: Optional[List[str]] = None
348
+ self.mutation_types: Optional[List[str]] = None
349
+ self.server_info_types: Optional[List[str]] = None
350
+ self.server_use_artifact_input_info: Optional[List[str]] = None
351
+ self.server_create_artifact_input_info: Optional[List[str]] = None
352
+ self.server_artifact_fields_info: Optional[List[str]] = None
353
+ self.server_organization_type_fields_info: Optional[List[str]] = None
354
+ self.server_supports_enabling_artifact_usage_tracking: Optional[bool] = None
355
+ self._max_cli_version: Optional[str] = None
356
+ self._server_settings_type: Optional[List[str]] = None
357
+ self.fail_run_queue_item_input_info: Optional[List[str]] = None
358
+ self.create_launch_agent_input_info: Optional[List[str]] = None
359
+ self.server_create_run_queue_supports_drc: Optional[bool] = None
360
+ self.server_create_run_queue_supports_priority: Optional[bool] = None
361
+ self.server_supports_template_variables: Optional[bool] = None
362
+ self.server_push_to_run_queue_supports_priority: Optional[bool] = None
363
+
364
+ def gql(self, *args: Any, **kwargs: Any) -> Any:
365
+ ret = self._retry_gql(
366
+ *args,
367
+ retry_cancel_event=self.context.cancel_event,
368
+ **kwargs,
369
+ )
370
+ return ret
371
+
372
+ def set_local_context(self, api_context: Optional[context.Context]) -> None:
373
+ self._local_data.context = api_context
374
+
375
+ def clear_local_context(self) -> None:
376
+ self._local_data.context = None
377
+
378
+ @property
379
+ def context(self) -> context.Context:
380
+ return self._local_data.context or self._global_context
381
+
382
+ def reauth(self) -> None:
383
+ """Ensure the current api key is set in the transport."""
384
+ self.client.transport.session.auth = ("api", self.api_key or "")
385
+
386
+ def relocate(self) -> None:
387
+ """Ensure the current api points to the right server."""
388
+ self.client.transport.url = "{}/graphql".format(self.settings("base_url"))
389
+
390
+ def execute(self, *args: Any, **kwargs: Any) -> "_Response":
391
+ """Wrapper around execute that logs in cases of failure."""
392
+ try:
393
+ return self.client.execute(*args, **kwargs) # type: ignore
394
+ except requests.exceptions.HTTPError as err:
395
+ response = err.response
396
+ assert response is not None
397
+ logger.error(f"{response.status_code} response executing GraphQL.")
398
+ logger.error(response.text)
399
+ for error in parse_backend_error_messages(response):
400
+ wandb.termerror(f"Error while calling W&B API: {error} ({response})")
401
+ raise
402
+
403
+ def disabled(self) -> Union[str, bool]:
404
+ return self._settings.get(Settings.DEFAULT_SECTION, "disabled", fallback=False) # type: ignore
405
+
406
+ def set_current_run_id(self, run_id: str) -> None:
407
+ self._current_run_id = run_id
408
+
409
+ @property
410
+ def current_run_id(self) -> Optional[str]:
411
+ return self._current_run_id
412
+
413
+ @property
414
+ def user_agent(self) -> str:
415
+ return f"W&B Internal Client {wandb.__version__}"
416
+
417
+ @property
418
+ def api_key(self) -> Optional[str]:
419
+ if _thread_local_api_settings.api_key:
420
+ return _thread_local_api_settings.api_key
421
+ auth = requests.utils.get_netrc_auth(self.api_url)
422
+ key = None
423
+ if auth:
424
+ key = auth[-1]
425
+
426
+ # Environment should take precedence
427
+ env_key: Optional[str] = self._environ.get(env.API_KEY)
428
+ sagemaker_key: Optional[str] = parse_sm_secrets().get(env.API_KEY)
429
+ default_key: Optional[str] = self.default_settings.get("api_key")
430
+ return env_key or key or sagemaker_key or default_key
431
+
432
+ @property
433
+ def access_token(self) -> Optional[str]:
434
+ """Retrieves an access token for authentication.
435
+
436
+ This function attempts to exchange an identity token for a temporary
437
+ access token from the server, and save it to the credentials file.
438
+ It uses the path to the identity token as defined in the environment
439
+ variables. If the environment variable is not set, it returns None.
440
+
441
+ Returns:
442
+ Optional[str]: The access token if available, otherwise None if
443
+ no identity token is supplied.
444
+ Raises:
445
+ AuthenticationError: If the path to the identity token is not found.
446
+ """
447
+ token_file_str = self._environ.get(env.IDENTITY_TOKEN_FILE)
448
+ if not token_file_str:
449
+ return None
450
+
451
+ token_file = Path(token_file_str)
452
+ if not token_file.exists():
453
+ raise AuthenticationError(f"Identity token file not found: {token_file}")
454
+
455
+ base_url = self.settings("base_url")
456
+ credentials_file = env.get_credentials_file(
457
+ str(credentials.DEFAULT_WANDB_CREDENTIALS_FILE), self._environ
458
+ )
459
+ return credentials.access_token(base_url, token_file, credentials_file)
460
+
461
+ @property
462
+ def api_url(self) -> str:
463
+ return self.settings("base_url") # type: ignore
464
+
465
+ @property
466
+ def app_url(self) -> str:
467
+ return wandb.util.app_url(self.api_url)
468
+
469
+ @property
470
+ def default_entity(self) -> str:
471
+ return self.viewer().get("entity") # type: ignore
472
+
473
+ def settings(self, key: Optional[str] = None, section: Optional[str] = None) -> Any:
474
+ """The settings overridden from the wandb/settings file.
475
+
476
+ Args:
477
+ key (str, optional): If provided only this setting is returned
478
+ section (str, optional): If provided this section of the setting file is
479
+ used, defaults to "default"
480
+
481
+ Returns:
482
+ A dict with the current settings
483
+
484
+ {
485
+ "entity": "models",
486
+ "base_url": "https://api.wandb.ai",
487
+ "project": None
488
+ }
489
+ """
490
+ result = self.default_settings.copy()
491
+ result.update(self._settings.items(section=section)) # type: ignore
492
+ result.update(
493
+ {
494
+ "entity": env.get_entity(
495
+ self._settings.get(
496
+ Settings.DEFAULT_SECTION,
497
+ "entity",
498
+ fallback=result.get("entity"),
499
+ ),
500
+ env=self._environ,
501
+ ),
502
+ "project": env.get_project(
503
+ self._settings.get(
504
+ Settings.DEFAULT_SECTION,
505
+ "project",
506
+ fallback=result.get("project"),
507
+ ),
508
+ env=self._environ,
509
+ ),
510
+ "base_url": env.get_base_url(
511
+ self._settings.get(
512
+ Settings.DEFAULT_SECTION,
513
+ "base_url",
514
+ fallback=result.get("base_url"),
515
+ ),
516
+ env=self._environ,
517
+ ),
518
+ "ignore_globs": env.get_ignore(
519
+ self._settings.get(
520
+ Settings.DEFAULT_SECTION,
521
+ "ignore_globs",
522
+ fallback=result.get("ignore_globs"),
523
+ ),
524
+ env=self._environ,
525
+ ),
526
+ }
527
+ )
528
+
529
+ return result if key is None else result[key] # type: ignore
530
+
531
+ def clear_setting(
532
+ self, key: str, globally: bool = False, persist: bool = False
533
+ ) -> None:
534
+ self._settings.clear(
535
+ Settings.DEFAULT_SECTION, key, globally=globally, persist=persist
536
+ )
537
+
538
+ def set_setting(
539
+ self, key: str, value: Any, globally: bool = False, persist: bool = False
540
+ ) -> None:
541
+ self._settings.set(
542
+ Settings.DEFAULT_SECTION, key, value, globally=globally, persist=persist
543
+ )
544
+ if key == "entity":
545
+ env.set_entity(value, env=self._environ)
546
+ elif key == "project":
547
+ env.set_project(value, env=self._environ)
548
+ elif key == "base_url":
549
+ self.relocate()
550
+
551
+ def parse_slug(
552
+ self, slug: str, project: Optional[str] = None, run: Optional[str] = None
553
+ ) -> Tuple[str, str]:
554
+ """Parse a slug into a project and run.
555
+
556
+ Args:
557
+ slug (str): The slug to parse
558
+ project (str, optional): The project to use, if not provided it will be
559
+ inferred from the slug
560
+ run (str, optional): The run to use, if not provided it will be inferred
561
+ from the slug
562
+
563
+ Returns:
564
+ A dict with the project and run
565
+ """
566
+ if slug and "/" in slug:
567
+ parts = slug.split("/")
568
+ project = parts[0]
569
+ run = parts[1]
570
+ else:
571
+ project = project or self.settings().get("project")
572
+ if project is None:
573
+ raise CommError("No default project configured.")
574
+ run = run or slug or self.current_run_id or env.get_run(env=self._environ)
575
+ assert run, "run must be specified"
576
+ return project, run
577
+
578
+ @normalize_exceptions
579
+ def server_info_introspection(self) -> Tuple[List[str], List[str], List[str]]:
580
+ query_string = """
581
+ query ProbeServerCapabilities {
582
+ QueryType: __type(name: "Query") {
583
+ ...fieldData
584
+ }
585
+ MutationType: __type(name: "Mutation") {
586
+ ...fieldData
587
+ }
588
+ ServerInfoType: __type(name: "ServerInfo") {
589
+ ...fieldData
590
+ }
591
+ }
592
+
593
+ fragment fieldData on __Type {
594
+ fields {
595
+ name
596
+ }
597
+ }
598
+ """
599
+ if (
600
+ self.query_types is None
601
+ or self.mutation_types is None
602
+ or self.server_info_types is None
603
+ ):
604
+ query = gql(query_string)
605
+ res = self.gql(query)
606
+
607
+ self.query_types = [
608
+ field.get("name", "")
609
+ for field in res.get("QueryType", {}).get("fields", [{}])
610
+ ]
611
+ self.mutation_types = [
612
+ field.get("name", "")
613
+ for field in res.get("MutationType", {}).get("fields", [{}])
614
+ ]
615
+ self.server_info_types = [
616
+ field.get("name", "")
617
+ for field in res.get("ServerInfoType", {}).get("fields", [{}])
618
+ ]
619
+ return self.query_types, self.server_info_types, self.mutation_types
620
+
621
+ @normalize_exceptions
622
+ def server_settings_introspection(self) -> None:
623
+ query_string = """
624
+ query ProbeServerSettings {
625
+ ServerSettingsType: __type(name: "ServerSettings") {
626
+ ...fieldData
627
+ }
628
+ }
629
+
630
+ fragment fieldData on __Type {
631
+ fields {
632
+ name
633
+ }
634
+ }
635
+ """
636
+ if self._server_settings_type is None:
637
+ query = gql(query_string)
638
+ res = self.gql(query)
639
+ self._server_settings_type = (
640
+ [
641
+ field.get("name", "")
642
+ for field in res.get("ServerSettingsType", {}).get("fields", [{}])
643
+ ]
644
+ if res
645
+ else []
646
+ )
647
+
648
+ def server_use_artifact_input_introspection(self) -> List:
649
+ query_string = """
650
+ query ProbeServerUseArtifactInput {
651
+ UseArtifactInputInfoType: __type(name: "UseArtifactInput") {
652
+ name
653
+ inputFields {
654
+ name
655
+ }
656
+ }
657
+ }
658
+ """
659
+
660
+ if self.server_use_artifact_input_info is None:
661
+ query = gql(query_string)
662
+ res = self.gql(query)
663
+ self.server_use_artifact_input_info = [
664
+ field.get("name", "")
665
+ for field in res.get("UseArtifactInputInfoType", {}).get(
666
+ "inputFields", [{}]
667
+ )
668
+ ]
669
+ return self.server_use_artifact_input_info
670
+
671
+ @normalize_exceptions
672
+ def launch_agent_introspection(self) -> Optional[str]:
673
+ query = gql(
674
+ """
675
+ query LaunchAgentIntrospection {
676
+ LaunchAgentType: __type(name: "LaunchAgent") {
677
+ name
678
+ }
679
+ }
680
+ """
681
+ )
682
+
683
+ res = self.gql(query)
684
+ return res.get("LaunchAgentType") or None
685
+
686
+ @normalize_exceptions
687
+ def create_run_queue_introspection(self) -> Tuple[bool, bool, bool]:
688
+ _, _, mutations = self.server_info_introspection()
689
+ query_string = """
690
+ query ProbeCreateRunQueueInput {
691
+ CreateRunQueueInputType: __type(name: "CreateRunQueueInput") {
692
+ name
693
+ inputFields {
694
+ name
695
+ }
696
+ }
697
+ }
698
+ """
699
+ if (
700
+ self.server_create_run_queue_supports_drc is None
701
+ or self.server_create_run_queue_supports_priority is None
702
+ ):
703
+ query = gql(query_string)
704
+ res = self.gql(query)
705
+ if res is None:
706
+ raise CommError("Could not get CreateRunQueue input from GQL.")
707
+ self.server_create_run_queue_supports_drc = "defaultResourceConfigID" in [
708
+ x["name"]
709
+ for x in (
710
+ res.get("CreateRunQueueInputType", {}).get("inputFields", [{}])
711
+ )
712
+ ]
713
+ self.server_create_run_queue_supports_priority = "prioritizationMode" in [
714
+ x["name"]
715
+ for x in (
716
+ res.get("CreateRunQueueInputType", {}).get("inputFields", [{}])
717
+ )
718
+ ]
719
+ return (
720
+ "createRunQueue" in mutations,
721
+ self.server_create_run_queue_supports_drc,
722
+ self.server_create_run_queue_supports_priority,
723
+ )
724
+
725
+ @normalize_exceptions
726
+ def upsert_run_queue_introspection(self) -> bool:
727
+ _, _, mutations = self.server_info_introspection()
728
+ return "upsertRunQueue" in mutations
729
+
730
+ @normalize_exceptions
731
+ def push_to_run_queue_introspection(self) -> Tuple[bool, bool]:
732
+ query_string = """
733
+ query ProbePushToRunQueueInput {
734
+ PushToRunQueueInputType: __type(name: "PushToRunQueueInput") {
735
+ name
736
+ inputFields {
737
+ name
738
+ }
739
+ }
740
+ }
741
+ """
742
+
743
+ if (
744
+ self.server_supports_template_variables is None
745
+ or self.server_push_to_run_queue_supports_priority is None
746
+ ):
747
+ query = gql(query_string)
748
+ res = self.gql(query)
749
+ self.server_supports_template_variables = "templateVariableValues" in [
750
+ x["name"]
751
+ for x in (
752
+ res.get("PushToRunQueueInputType", {}).get("inputFields", [{}])
753
+ )
754
+ ]
755
+ self.server_push_to_run_queue_supports_priority = "priority" in [
756
+ x["name"]
757
+ for x in (
758
+ res.get("PushToRunQueueInputType", {}).get("inputFields", [{}])
759
+ )
760
+ ]
761
+
762
+ return (
763
+ self.server_supports_template_variables,
764
+ self.server_push_to_run_queue_supports_priority,
765
+ )
766
+
767
+ @normalize_exceptions
768
+ def create_default_resource_config_introspection(self) -> bool:
769
+ _, _, mutations = self.server_info_introspection()
770
+ return "createDefaultResourceConfig" in mutations
771
+
772
+ @normalize_exceptions
773
+ def fail_run_queue_item_introspection(self) -> bool:
774
+ _, _, mutations = self.server_info_introspection()
775
+ return "failRunQueueItem" in mutations
776
+
777
+ @normalize_exceptions
778
+ def fail_run_queue_item_fields_introspection(self) -> List:
779
+ if self.fail_run_queue_item_input_info:
780
+ return self.fail_run_queue_item_input_info
781
+ query_string = """
782
+ query ProbeServerFailRunQueueItemInput {
783
+ FailRunQueueItemInputInfoType: __type(name:"FailRunQueueItemInput") {
784
+ inputFields{
785
+ name
786
+ }
787
+ }
788
+ }
789
+ """
790
+
791
+ query = gql(query_string)
792
+ res = self.gql(query)
793
+
794
+ self.fail_run_queue_item_input_info = [
795
+ field.get("name", "")
796
+ for field in res.get("FailRunQueueItemInputInfoType", {}).get(
797
+ "inputFields", [{}]
798
+ )
799
+ ]
800
+ return self.fail_run_queue_item_input_info
801
+
802
+ @normalize_exceptions
803
+ def fail_run_queue_item(
804
+ self,
805
+ run_queue_item_id: str,
806
+ message: str,
807
+ stage: str,
808
+ file_paths: Optional[List[str]] = None,
809
+ ) -> bool:
810
+ if not self.fail_run_queue_item_introspection():
811
+ return False
812
+ variable_values: Dict[str, Union[str, Optional[List[str]]]] = {
813
+ "runQueueItemId": run_queue_item_id,
814
+ }
815
+ if "message" in self.fail_run_queue_item_fields_introspection():
816
+ variable_values.update({"message": message, "stage": stage})
817
+ if file_paths is not None:
818
+ variable_values["filePaths"] = file_paths
819
+ mutation_string = """
820
+ mutation failRunQueueItem($runQueueItemId: ID!, $message: String!, $stage: String!, $filePaths: [String!]) {
821
+ failRunQueueItem(
822
+ input: {
823
+ runQueueItemId: $runQueueItemId
824
+ message: $message
825
+ stage: $stage
826
+ filePaths: $filePaths
827
+ }
828
+ ) {
829
+ success
830
+ }
831
+ }
832
+ """
833
+ else:
834
+ mutation_string = """
835
+ mutation failRunQueueItem($runQueueItemId: ID!) {
836
+ failRunQueueItem(
837
+ input: {
838
+ runQueueItemId: $runQueueItemId
839
+ }
840
+ ) {
841
+ success
842
+ }
843
+ }
844
+ """
845
+
846
+ mutation = gql(mutation_string)
847
+ response = self.gql(mutation, variable_values=variable_values)
848
+ result: bool = response["failRunQueueItem"]["success"]
849
+ return result
850
+
851
+ @normalize_exceptions
852
+ def update_run_queue_item_warning_introspection(self) -> bool:
853
+ _, _, mutations = self.server_info_introspection()
854
+ return "updateRunQueueItemWarning" in mutations
855
+
856
+ @normalize_exceptions
857
+ def update_run_queue_item_warning(
858
+ self,
859
+ run_queue_item_id: str,
860
+ message: str,
861
+ stage: str,
862
+ file_paths: Optional[List[str]] = None,
863
+ ) -> bool:
864
+ if not self.update_run_queue_item_warning_introspection():
865
+ return False
866
+ mutation = gql(
867
+ """
868
+ mutation updateRunQueueItemWarning($runQueueItemId: ID!, $message: String!, $stage: String!, $filePaths: [String!]) {
869
+ updateRunQueueItemWarning(
870
+ input: {
871
+ runQueueItemId: $runQueueItemId
872
+ message: $message
873
+ stage: $stage
874
+ filePaths: $filePaths
875
+ }
876
+ ) {
877
+ success
878
+ }
879
+ }
880
+ """
881
+ )
882
+ response = self.gql(
883
+ mutation,
884
+ variable_values={
885
+ "runQueueItemId": run_queue_item_id,
886
+ "message": message,
887
+ "stage": stage,
888
+ "filePaths": file_paths,
889
+ },
890
+ )
891
+ result: bool = response["updateRunQueueItemWarning"]["success"]
892
+ return result
893
+
894
+ @normalize_exceptions
895
+ def viewer(self) -> Dict[str, Any]:
896
+ query = gql(
897
+ """
898
+ query Viewer{
899
+ viewer {
900
+ id
901
+ entity
902
+ username
903
+ flags
904
+ teams {
905
+ edges {
906
+ node {
907
+ name
908
+ }
909
+ }
910
+ }
911
+ }
912
+ }
913
+ """
914
+ )
915
+ res = self.gql(query)
916
+ return res.get("viewer") or {}
917
+
918
+ @normalize_exceptions
919
+ def max_cli_version(self) -> Optional[str]:
920
+ if self._max_cli_version is not None:
921
+ return self._max_cli_version
922
+
923
+ query_types, server_info_types, _ = self.server_info_introspection()
924
+ cli_version_exists = (
925
+ "serverInfo" in query_types and "cliVersionInfo" in server_info_types
926
+ )
927
+ if not cli_version_exists:
928
+ return None
929
+
930
+ _, server_info = self.viewer_server_info()
931
+ self._max_cli_version = server_info.get("cliVersionInfo", {}).get(
932
+ "max_cli_version"
933
+ )
934
+ return self._max_cli_version
935
+
936
+ @normalize_exceptions
937
+ def viewer_server_info(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
938
+ local_query = """
939
+ latestLocalVersionInfo {
940
+ outOfDate
941
+ latestVersionString
942
+ versionOnThisInstanceString
943
+ }
944
+ """
945
+ cli_query = """
946
+ serverInfo {
947
+ cliVersionInfo
948
+ _LOCAL_QUERY_
949
+ }
950
+ """
951
+ query_template = """
952
+ query Viewer{
953
+ viewer {
954
+ id
955
+ entity
956
+ username
957
+ email
958
+ flags
959
+ teams {
960
+ edges {
961
+ node {
962
+ name
963
+ }
964
+ }
965
+ }
966
+ }
967
+ _CLI_QUERY_
968
+ }
969
+ """
970
+ query_types, server_info_types, _ = self.server_info_introspection()
971
+
972
+ cli_version_exists = (
973
+ "serverInfo" in query_types and "cliVersionInfo" in server_info_types
974
+ )
975
+
976
+ local_version_exists = (
977
+ "serverInfo" in query_types
978
+ and "latestLocalVersionInfo" in server_info_types
979
+ )
980
+
981
+ cli_query_string = "" if not cli_version_exists else cli_query
982
+ local_query_string = "" if not local_version_exists else local_query
983
+
984
+ query_string = query_template.replace("_CLI_QUERY_", cli_query_string).replace(
985
+ "_LOCAL_QUERY_", local_query_string
986
+ )
987
+ query = gql(query_string)
988
+ res = self.gql(query)
989
+ return res.get("viewer") or {}, res.get("serverInfo") or {}
990
+
991
+ @normalize_exceptions
992
+ def list_projects(self, entity: Optional[str] = None) -> List[Dict[str, str]]:
993
+ """List projects in W&B scoped by entity.
994
+
995
+ Args:
996
+ entity (str, optional): The entity to scope this project to.
997
+
998
+ Returns:
999
+ [{"id","name","description"}]
1000
+ """
1001
+ query = gql(
1002
+ """
1003
+ query EntityProjects($entity: String) {
1004
+ models(first: 10, entityName: $entity) {
1005
+ edges {
1006
+ node {
1007
+ id
1008
+ name
1009
+ description
1010
+ }
1011
+ }
1012
+ }
1013
+ }
1014
+ """
1015
+ )
1016
+ project_list: List[Dict[str, str]] = self._flatten_edges(
1017
+ self.gql(
1018
+ query, variable_values={"entity": entity or self.settings("entity")}
1019
+ )["models"]
1020
+ )
1021
+ return project_list
1022
+
1023
+ @normalize_exceptions
1024
+ def project(self, project: str, entity: Optional[str] = None) -> "_Response":
1025
+ """Retrieve project.
1026
+
1027
+ Args:
1028
+ project (str): The project to get details for
1029
+ entity (str, optional): The entity to scope this project to.
1030
+
1031
+ Returns:
1032
+ [{"id","name","repo","dockerImage","description"}]
1033
+ """
1034
+ query = gql(
1035
+ """
1036
+ query ProjectDetails($entity: String, $project: String) {
1037
+ model(name: $project, entityName: $entity) {
1038
+ id
1039
+ name
1040
+ repo
1041
+ dockerImage
1042
+ description
1043
+ }
1044
+ }
1045
+ """
1046
+ )
1047
+ response: _Response = self.gql(
1048
+ query, variable_values={"entity": entity, "project": project}
1049
+ )["model"]
1050
+ return response
1051
+
1052
+ @normalize_exceptions
1053
+ def sweep(
1054
+ self,
1055
+ sweep: str,
1056
+ specs: str,
1057
+ project: Optional[str] = None,
1058
+ entity: Optional[str] = None,
1059
+ ) -> Dict[str, Any]:
1060
+ """Retrieve sweep.
1061
+
1062
+ Args:
1063
+ sweep (str): The sweep to get details for
1064
+ specs (str): history specs
1065
+ project (str, optional): The project to scope this sweep to.
1066
+ entity (str, optional): The entity to scope this sweep to.
1067
+
1068
+ Returns:
1069
+ [{"id","name","repo","dockerImage","description"}]
1070
+ """
1071
+ query = gql(
1072
+ """
1073
+ query SweepWithRuns($entity: String, $project: String, $sweep: String!, $specs: [JSONString!]!) {
1074
+ project(name: $project, entityName: $entity) {
1075
+ sweep(sweepName: $sweep) {
1076
+ id
1077
+ name
1078
+ method
1079
+ state
1080
+ description
1081
+ config
1082
+ createdAt
1083
+ heartbeatAt
1084
+ updatedAt
1085
+ earlyStopJobRunning
1086
+ bestLoss
1087
+ controller
1088
+ scheduler
1089
+ runs {
1090
+ edges {
1091
+ node {
1092
+ name
1093
+ state
1094
+ config
1095
+ exitcode
1096
+ heartbeatAt
1097
+ shouldStop
1098
+ failed
1099
+ stopped
1100
+ running
1101
+ summaryMetrics
1102
+ sampledHistory(specs: $specs)
1103
+ }
1104
+ }
1105
+ }
1106
+ }
1107
+ }
1108
+ }
1109
+ """
1110
+ )
1111
+ entity = entity or self.settings("entity")
1112
+ project = project or self.settings("project")
1113
+ response = self.gql(
1114
+ query,
1115
+ variable_values={
1116
+ "entity": entity,
1117
+ "project": project,
1118
+ "sweep": sweep,
1119
+ "specs": specs,
1120
+ },
1121
+ )
1122
+ if response["project"] is None or response["project"]["sweep"] is None:
1123
+ raise ValueError(f"Sweep {entity}/{project}/{sweep} not found")
1124
+ data: Dict[str, Any] = response["project"]["sweep"]
1125
+ if data:
1126
+ data["runs"] = self._flatten_edges(data["runs"])
1127
+ return data
1128
+
1129
+ @normalize_exceptions
1130
+ def list_runs(
1131
+ self, project: str, entity: Optional[str] = None
1132
+ ) -> List[Dict[str, str]]:
1133
+ """List runs in W&B scoped by project.
1134
+
1135
+ Args:
1136
+ project (str): The project to scope the runs to
1137
+ entity (str, optional): The entity to scope this project to. Defaults to public models
1138
+
1139
+ Returns:
1140
+ [{"id","name","description"}]
1141
+ """
1142
+ query = gql(
1143
+ """
1144
+ query ProjectRuns($model: String!, $entity: String) {
1145
+ model(name: $model, entityName: $entity) {
1146
+ buckets(first: 10) {
1147
+ edges {
1148
+ node {
1149
+ id
1150
+ name
1151
+ displayName
1152
+ description
1153
+ }
1154
+ }
1155
+ }
1156
+ }
1157
+ }
1158
+ """
1159
+ )
1160
+ return self._flatten_edges(
1161
+ self.gql(
1162
+ query,
1163
+ variable_values={
1164
+ "entity": entity or self.settings("entity"),
1165
+ "model": project or self.settings("project"),
1166
+ },
1167
+ )["model"]["buckets"]
1168
+ )
1169
+
1170
+ @normalize_exceptions
1171
+ def run_config(
1172
+ self, project: str, run: Optional[str] = None, entity: Optional[str] = None
1173
+ ) -> Tuple[str, Dict[str, Any], Optional[str], Dict[str, Any]]:
1174
+ """Get the relevant configs for a run.
1175
+
1176
+ Args:
1177
+ project (str): The project to download, (can include bucket)
1178
+ run (str, optional): The run to download
1179
+ entity (str, optional): The entity to scope this project to.
1180
+ """
1181
+ check_httpclient_logger_handler()
1182
+
1183
+ query = gql(
1184
+ """
1185
+ query RunConfigs(
1186
+ $name: String!,
1187
+ $entity: String,
1188
+ $run: String!,
1189
+ $pattern: String!,
1190
+ $includeConfig: Boolean!,
1191
+ ) {
1192
+ model(name: $name, entityName: $entity) {
1193
+ bucket(name: $run) {
1194
+ config @include(if: $includeConfig)
1195
+ commit @include(if: $includeConfig)
1196
+ files(pattern: $pattern) {
1197
+ pageInfo {
1198
+ hasNextPage
1199
+ endCursor
1200
+ }
1201
+ edges {
1202
+ node {
1203
+ name
1204
+ directUrl
1205
+ }
1206
+ }
1207
+ }
1208
+ }
1209
+ }
1210
+ }
1211
+ """
1212
+ )
1213
+
1214
+ variable_values = {
1215
+ "name": project,
1216
+ "run": run,
1217
+ "entity": entity,
1218
+ "includeConfig": True,
1219
+ }
1220
+
1221
+ commit: str = ""
1222
+ config: Dict[str, Any] = {}
1223
+ patch: Optional[str] = None
1224
+ metadata: Dict[str, Any] = {}
1225
+
1226
+ # If we use the `names` parameter on the `files` node, then the server
1227
+ # will helpfully give us and 'open' file handle to the files that don't
1228
+ # exist. This is so that we can upload data to it. However, in this
1229
+ # case, we just want to download that file and not upload to it, so
1230
+ # let's instead query for the files that do exist using `pattern`
1231
+ # (with no wildcards).
1232
+ #
1233
+ # Unfortunately we're unable to construct a single pattern that matches
1234
+ # our 2 files, we would need something like regex for that.
1235
+ for filename in [DIFF_FNAME, METADATA_FNAME]:
1236
+ variable_values["pattern"] = filename
1237
+ response = self.gql(query, variable_values=variable_values)
1238
+ if response["model"] is None:
1239
+ raise CommError(f"Run {entity}/{project}/{run} not found")
1240
+ run_obj: Dict = response["model"]["bucket"]
1241
+ # we only need to fetch this config once
1242
+ if variable_values["includeConfig"]:
1243
+ commit = run_obj["commit"]
1244
+ config = json.loads(run_obj["config"] or "{}")
1245
+ variable_values["includeConfig"] = False
1246
+ if run_obj["files"] is not None:
1247
+ for file_edge in run_obj["files"]["edges"]:
1248
+ name = file_edge["node"]["name"]
1249
+ url = file_edge["node"]["directUrl"]
1250
+ res = requests.get(url)
1251
+ res.raise_for_status()
1252
+ if name == METADATA_FNAME:
1253
+ metadata = res.json()
1254
+ elif name == DIFF_FNAME:
1255
+ patch = res.text
1256
+
1257
+ return commit, config, patch, metadata
1258
+
1259
+ @normalize_exceptions
1260
+ def run_resume_status(
1261
+ self, entity: str, project_name: str, name: str
1262
+ ) -> Optional[Dict[str, Any]]:
1263
+ """Check if a run exists and get resume information.
1264
+
1265
+ Args:
1266
+ entity (str): The entity to scope this project to.
1267
+ project_name (str): The project to download, (can include bucket)
1268
+ name (str): The run to download
1269
+ """
1270
+ # Pulling wandbConfig.start_time is required so that we can determine if a run has actually started
1271
+ query = gql(
1272
+ """
1273
+ query RunResumeStatus($project: String, $entity: String, $name: String!) {
1274
+ model(name: $project, entityName: $entity) {
1275
+ id
1276
+ name
1277
+ entity {
1278
+ id
1279
+ name
1280
+ }
1281
+
1282
+ bucket(name: $name, missingOk: true) {
1283
+ id
1284
+ name
1285
+ summaryMetrics
1286
+ displayName
1287
+ logLineCount
1288
+ historyLineCount
1289
+ eventsLineCount
1290
+ historyTail
1291
+ eventsTail
1292
+ config
1293
+ tags
1294
+ wandbConfig(keys: ["t"])
1295
+ }
1296
+ }
1297
+ }
1298
+ """
1299
+ )
1300
+
1301
+ response = self.gql(
1302
+ query,
1303
+ variable_values={
1304
+ "entity": entity,
1305
+ "project": project_name,
1306
+ "name": name,
1307
+ },
1308
+ )
1309
+
1310
+ if "model" not in response or "bucket" not in (response["model"] or {}):
1311
+ return None
1312
+
1313
+ project = response["model"]
1314
+ self.set_setting("project", project_name)
1315
+ if "entity" in project:
1316
+ self.set_setting("entity", project["entity"]["name"])
1317
+
1318
+ result: Dict[str, Any] = project["bucket"]
1319
+
1320
+ return result
1321
+
1322
+ @normalize_exceptions
1323
+ def check_stop_requested(
1324
+ self, project_name: str, entity_name: str, run_id: str
1325
+ ) -> bool:
1326
+ query = gql(
1327
+ """
1328
+ query RunStoppedStatus($projectName: String, $entityName: String, $runId: String!) {
1329
+ project(name:$projectName, entityName:$entityName) {
1330
+ run(name:$runId) {
1331
+ stopped
1332
+ }
1333
+ }
1334
+ }
1335
+ """
1336
+ )
1337
+
1338
+ response = self.gql(
1339
+ query,
1340
+ variable_values={
1341
+ "projectName": project_name,
1342
+ "entityName": entity_name,
1343
+ "runId": run_id,
1344
+ },
1345
+ )
1346
+
1347
+ project = response.get("project", None)
1348
+ if not project:
1349
+ return False
1350
+ run = project.get("run", None)
1351
+ if not run:
1352
+ return False
1353
+
1354
+ status: bool = run["stopped"]
1355
+ return status
1356
+
1357
+ def format_project(self, project: str) -> str:
1358
+ return re.sub(r"\W+", "-", project.lower()).strip("-_")
1359
+
1360
+ @normalize_exceptions
1361
+ def upsert_project(
1362
+ self,
1363
+ project: str,
1364
+ id: Optional[str] = None,
1365
+ description: Optional[str] = None,
1366
+ entity: Optional[str] = None,
1367
+ ) -> Dict[str, Any]:
1368
+ """Create a new project.
1369
+
1370
+ Args:
1371
+ project (str): The project to create
1372
+ description (str, optional): A description of this project
1373
+ entity (str, optional): The entity to scope this project to.
1374
+ """
1375
+ mutation = gql(
1376
+ """
1377
+ mutation UpsertModel($name: String!, $id: String, $entity: String!, $description: String, $repo: String) {
1378
+ upsertModel(input: { id: $id, name: $name, entityName: $entity, description: $description, repo: $repo }) {
1379
+ model {
1380
+ name
1381
+ description
1382
+ }
1383
+ }
1384
+ }
1385
+ """
1386
+ )
1387
+ response = self.gql(
1388
+ mutation,
1389
+ variable_values={
1390
+ "name": self.format_project(project),
1391
+ "entity": entity or self.settings("entity"),
1392
+ "description": description,
1393
+ "id": id,
1394
+ },
1395
+ )
1396
+ # TODO(jhr): Commenting out 'repo' field for cling, add back
1397
+ # 'description': description, 'repo': self.git.remote_url, 'id': id})
1398
+ result: Dict[str, Any] = response["upsertModel"]["model"]
1399
+ return result
1400
+
1401
+ @normalize_exceptions
1402
+ def entity_is_team(self, entity: str) -> bool:
1403
+ query = gql(
1404
+ """
1405
+ query EntityIsTeam($entity: String!) {
1406
+ entity(name: $entity) {
1407
+ id
1408
+ isTeam
1409
+ }
1410
+ }
1411
+ """
1412
+ )
1413
+ variable_values = {
1414
+ "entity": entity,
1415
+ }
1416
+
1417
+ res = self.gql(query, variable_values)
1418
+ if res.get("entity") is None:
1419
+ raise Exception(
1420
+ f"Error fetching entity {entity} "
1421
+ "check that you have access to this entity"
1422
+ )
1423
+
1424
+ is_team: bool = res["entity"]["isTeam"]
1425
+ return is_team
1426
+
1427
+ @normalize_exceptions
1428
+ def get_project_run_queues(self, entity: str, project: str) -> List[Dict[str, str]]:
1429
+ query = gql(
1430
+ """
1431
+ query ProjectRunQueues($entity: String!, $projectName: String!){
1432
+ project(entityName: $entity, name: $projectName) {
1433
+ runQueues {
1434
+ id
1435
+ name
1436
+ createdBy
1437
+ access
1438
+ }
1439
+ }
1440
+ }
1441
+ """
1442
+ )
1443
+ variable_values = {
1444
+ "projectName": project,
1445
+ "entity": entity,
1446
+ }
1447
+
1448
+ res = self.gql(query, variable_values)
1449
+ if res.get("project") is None:
1450
+ # circular dependency: (LAUNCH_DEFAULT_PROJECT = model-registry)
1451
+ if project == "model-registry":
1452
+ msg = (
1453
+ f"Error fetching run queues for {entity} "
1454
+ "check that you have access to this entity and project"
1455
+ )
1456
+ else:
1457
+ msg = (
1458
+ f"Error fetching run queues for {entity}/{project} "
1459
+ "check that you have access to this entity and project"
1460
+ )
1461
+
1462
+ raise Exception(msg)
1463
+
1464
+ project_run_queues: List[Dict[str, str]] = res["project"]["runQueues"]
1465
+ return project_run_queues
1466
+
1467
+ @normalize_exceptions
1468
+ def create_default_resource_config(
1469
+ self,
1470
+ entity: str,
1471
+ resource: str,
1472
+ config: str,
1473
+ template_variables: Optional[Dict[str, Union[float, int, str]]],
1474
+ ) -> Optional[Dict[str, Any]]:
1475
+ if not self.create_default_resource_config_introspection():
1476
+ raise Exception()
1477
+ supports_template_vars, _ = self.push_to_run_queue_introspection()
1478
+
1479
+ mutation_params = """
1480
+ $entityName: String!,
1481
+ $resource: String!,
1482
+ $config: JSONString!
1483
+ """
1484
+ mutation_inputs = """
1485
+ entityName: $entityName,
1486
+ resource: $resource,
1487
+ config: $config
1488
+ """
1489
+
1490
+ if supports_template_vars:
1491
+ mutation_params += ", $templateVariables: JSONString"
1492
+ mutation_inputs += ", templateVariables: $templateVariables"
1493
+ else:
1494
+ if template_variables is not None:
1495
+ raise UnsupportedError(
1496
+ "server does not support template variables, please update server instance to >=0.46"
1497
+ )
1498
+
1499
+ variable_values = {
1500
+ "entityName": entity,
1501
+ "resource": resource,
1502
+ "config": config,
1503
+ }
1504
+ if supports_template_vars:
1505
+ if template_variables is not None:
1506
+ variable_values["templateVariables"] = json.dumps(template_variables)
1507
+ else:
1508
+ variable_values["templateVariables"] = "{}"
1509
+
1510
+ query = gql(
1511
+ f"""
1512
+ mutation createDefaultResourceConfig(
1513
+ {mutation_params}
1514
+ ) {{
1515
+ createDefaultResourceConfig(
1516
+ input: {{
1517
+ {mutation_inputs}
1518
+ }}
1519
+ ) {{
1520
+ defaultResourceConfigID
1521
+ success
1522
+ }}
1523
+ }}
1524
+ """
1525
+ )
1526
+
1527
+ result: Optional[Dict[str, Any]] = self.gql(query, variable_values)[
1528
+ "createDefaultResourceConfig"
1529
+ ]
1530
+ return result
1531
+
1532
+ @normalize_exceptions
1533
+ def create_run_queue(
1534
+ self,
1535
+ entity: str,
1536
+ project: str,
1537
+ queue_name: str,
1538
+ access: str,
1539
+ prioritization_mode: Optional[str] = None,
1540
+ config_id: Optional[str] = None,
1541
+ ) -> Optional[Dict[str, Any]]:
1542
+ (
1543
+ create_run_queue,
1544
+ supports_drc,
1545
+ supports_prioritization,
1546
+ ) = self.create_run_queue_introspection()
1547
+ if not create_run_queue:
1548
+ raise UnsupportedError(
1549
+ "run queue creation is not supported by this version of "
1550
+ "wandb server. Consider updating to the latest version."
1551
+ )
1552
+ if not supports_drc and config_id is not None:
1553
+ raise UnsupportedError(
1554
+ "default resource configurations are not supported by this version "
1555
+ "of wandb server. Consider updating to the latest version."
1556
+ )
1557
+ if not supports_prioritization and prioritization_mode is not None:
1558
+ raise UnsupportedError(
1559
+ "launch prioritization is not supported by this version of "
1560
+ "wandb server. Consider updating to the latest version."
1561
+ )
1562
+
1563
+ if supports_prioritization:
1564
+ query = gql(
1565
+ """
1566
+ mutation createRunQueue(
1567
+ $entity: String!,
1568
+ $project: String!,
1569
+ $queueName: String!,
1570
+ $access: RunQueueAccessType!,
1571
+ $prioritizationMode: RunQueuePrioritizationMode,
1572
+ $defaultResourceConfigID: ID,
1573
+ ) {
1574
+ createRunQueue(
1575
+ input: {
1576
+ entityName: $entity,
1577
+ projectName: $project,
1578
+ queueName: $queueName,
1579
+ access: $access,
1580
+ prioritizationMode: $prioritizationMode
1581
+ defaultResourceConfigID: $defaultResourceConfigID
1582
+ }
1583
+ ) {
1584
+ success
1585
+ queueID
1586
+ }
1587
+ }
1588
+ """
1589
+ )
1590
+ variable_values = {
1591
+ "entity": entity,
1592
+ "project": project,
1593
+ "queueName": queue_name,
1594
+ "access": access,
1595
+ "prioritizationMode": prioritization_mode,
1596
+ "defaultResourceConfigID": config_id,
1597
+ }
1598
+ else:
1599
+ query = gql(
1600
+ """
1601
+ mutation createRunQueue(
1602
+ $entity: String!,
1603
+ $project: String!,
1604
+ $queueName: String!,
1605
+ $access: RunQueueAccessType!,
1606
+ $defaultResourceConfigID: ID,
1607
+ ) {
1608
+ createRunQueue(
1609
+ input: {
1610
+ entityName: $entity,
1611
+ projectName: $project,
1612
+ queueName: $queueName,
1613
+ access: $access,
1614
+ defaultResourceConfigID: $defaultResourceConfigID
1615
+ }
1616
+ ) {
1617
+ success
1618
+ queueID
1619
+ }
1620
+ }
1621
+ """
1622
+ )
1623
+ variable_values = {
1624
+ "entity": entity,
1625
+ "project": project,
1626
+ "queueName": queue_name,
1627
+ "access": access,
1628
+ "defaultResourceConfigID": config_id,
1629
+ }
1630
+
1631
+ result: Optional[Dict[str, Any]] = self.gql(query, variable_values)[
1632
+ "createRunQueue"
1633
+ ]
1634
+ return result
1635
+
1636
+ @normalize_exceptions
1637
+ def upsert_run_queue(
1638
+ self,
1639
+ queue_name: str,
1640
+ entity: str,
1641
+ resource_type: str,
1642
+ resource_config: dict,
1643
+ project: str = LAUNCH_DEFAULT_PROJECT,
1644
+ prioritization_mode: Optional[str] = None,
1645
+ template_variables: Optional[dict] = None,
1646
+ external_links: Optional[dict] = None,
1647
+ ) -> Optional[Dict[str, Any]]:
1648
+ if not self.upsert_run_queue_introspection():
1649
+ raise UnsupportedError(
1650
+ "upserting run queues is not supported by this version of "
1651
+ "wandb server. Consider updating to the latest version."
1652
+ )
1653
+ query = gql(
1654
+ """
1655
+ mutation upsertRunQueue(
1656
+ $entityName: String!
1657
+ $projectName: String!
1658
+ $queueName: String!
1659
+ $resourceType: String!
1660
+ $resourceConfig: JSONString!
1661
+ $templateVariables: JSONString
1662
+ $prioritizationMode: RunQueuePrioritizationMode
1663
+ $externalLinks: JSONString
1664
+ $clientMutationId: String
1665
+ ) {
1666
+ upsertRunQueue(
1667
+ input: {
1668
+ entityName: $entityName
1669
+ projectName: $projectName
1670
+ queueName: $queueName
1671
+ resourceType: $resourceType
1672
+ resourceConfig: $resourceConfig
1673
+ templateVariables: $templateVariables
1674
+ prioritizationMode: $prioritizationMode
1675
+ externalLinks: $externalLinks
1676
+ clientMutationId: $clientMutationId
1677
+ }
1678
+ ) {
1679
+ success
1680
+ configSchemaValidationErrors
1681
+ }
1682
+ }
1683
+ """
1684
+ )
1685
+ variable_values = {
1686
+ "entityName": entity,
1687
+ "projectName": project,
1688
+ "queueName": queue_name,
1689
+ "resourceType": resource_type,
1690
+ "resourceConfig": json.dumps(resource_config),
1691
+ "templateVariables": (
1692
+ json.dumps(template_variables) if template_variables else None
1693
+ ),
1694
+ "prioritizationMode": prioritization_mode,
1695
+ "externalLinks": json.dumps(external_links) if external_links else None,
1696
+ }
1697
+ result: Dict[str, Any] = self.gql(query, variable_values)
1698
+ return result["upsertRunQueue"]
1699
+
1700
+ @normalize_exceptions
1701
+ def push_to_run_queue_by_name(
1702
+ self,
1703
+ entity: str,
1704
+ project: str,
1705
+ queue_name: str,
1706
+ run_spec: str,
1707
+ template_variables: Optional[Dict[str, Union[int, float, str]]],
1708
+ priority: Optional[int] = None,
1709
+ ) -> Optional[Dict[str, Any]]:
1710
+ self.push_to_run_queue_introspection()
1711
+ """Queryless mutation, should be used before legacy fallback method."""
1712
+
1713
+ mutation_params = """
1714
+ $entityName: String!,
1715
+ $projectName: String!,
1716
+ $queueName: String!,
1717
+ $runSpec: JSONString!
1718
+ """
1719
+
1720
+ mutation_input = """
1721
+ entityName: $entityName,
1722
+ projectName: $projectName,
1723
+ queueName: $queueName,
1724
+ runSpec: $runSpec
1725
+ """
1726
+
1727
+ variables: Dict[str, Any] = {
1728
+ "entityName": entity,
1729
+ "projectName": project,
1730
+ "queueName": queue_name,
1731
+ "runSpec": run_spec,
1732
+ }
1733
+ if self.server_push_to_run_queue_supports_priority:
1734
+ if priority is not None:
1735
+ variables["priority"] = priority
1736
+ mutation_params += ", $priority: Int"
1737
+ mutation_input += ", priority: $priority"
1738
+ else:
1739
+ if priority is not None:
1740
+ raise UnsupportedError(
1741
+ "server does not support priority, please update server instance to >=0.46"
1742
+ )
1743
+
1744
+ if self.server_supports_template_variables:
1745
+ if template_variables is not None:
1746
+ variables.update(
1747
+ {"templateVariableValues": json.dumps(template_variables)}
1748
+ )
1749
+ mutation_params += ", $templateVariableValues: JSONString"
1750
+ mutation_input += ", templateVariableValues: $templateVariableValues"
1751
+ else:
1752
+ if template_variables is not None:
1753
+ raise UnsupportedError(
1754
+ "server does not support template variables, please update server instance to >=0.46"
1755
+ )
1756
+
1757
+ mutation = gql(
1758
+ f"""
1759
+ mutation pushToRunQueueByName(
1760
+ {mutation_params}
1761
+ ) {{
1762
+ pushToRunQueueByName(
1763
+ input: {{
1764
+ {mutation_input}
1765
+ }}
1766
+ ) {{
1767
+ runQueueItemId
1768
+ runSpec
1769
+ }}
1770
+ }}
1771
+ """
1772
+ )
1773
+
1774
+ try:
1775
+ result: Optional[Dict[str, Any]] = self.gql(
1776
+ mutation, variables, check_retry_fn=util.no_retry_4xx
1777
+ ).get("pushToRunQueueByName")
1778
+ if not result:
1779
+ return None
1780
+
1781
+ if result.get("runSpec"):
1782
+ run_spec = json.loads(str(result["runSpec"]))
1783
+ result["runSpec"] = run_spec
1784
+
1785
+ return result
1786
+ except Exception as e:
1787
+ if (
1788
+ 'Cannot query field "runSpec" on type "PushToRunQueueByNamePayload"'
1789
+ not in str(e)
1790
+ ):
1791
+ return None
1792
+
1793
+ mutation_no_runspec = gql(
1794
+ """
1795
+ mutation pushToRunQueueByName(
1796
+ $entityName: String!,
1797
+ $projectName: String!,
1798
+ $queueName: String!,
1799
+ $runSpec: JSONString!,
1800
+ ) {
1801
+ pushToRunQueueByName(
1802
+ input: {
1803
+ entityName: $entityName,
1804
+ projectName: $projectName,
1805
+ queueName: $queueName,
1806
+ runSpec: $runSpec
1807
+ }
1808
+ ) {
1809
+ runQueueItemId
1810
+ }
1811
+ }
1812
+ """
1813
+ )
1814
+
1815
+ try:
1816
+ result = self.gql(
1817
+ mutation_no_runspec, variables, check_retry_fn=util.no_retry_4xx
1818
+ ).get("pushToRunQueueByName")
1819
+ except Exception:
1820
+ result = None
1821
+
1822
+ return result
1823
+
1824
+ @normalize_exceptions
1825
+ def push_to_run_queue(
1826
+ self,
1827
+ queue_name: str,
1828
+ launch_spec: Dict[str, str],
1829
+ template_variables: Optional[dict],
1830
+ project_queue: str,
1831
+ priority: Optional[int] = None,
1832
+ ) -> Optional[Dict[str, Any]]:
1833
+ self.push_to_run_queue_introspection()
1834
+ entity = launch_spec.get("queue_entity") or launch_spec["entity"]
1835
+ run_spec = json.dumps(launch_spec)
1836
+
1837
+ push_result = self.push_to_run_queue_by_name(
1838
+ entity, project_queue, queue_name, run_spec, template_variables, priority
1839
+ )
1840
+
1841
+ if push_result:
1842
+ return push_result
1843
+
1844
+ if priority is not None:
1845
+ # Cannot proceed with legacy method if priority is set
1846
+ return None
1847
+
1848
+ """ Legacy Method """
1849
+ queues_found = self.get_project_run_queues(entity, project_queue)
1850
+ matching_queues = [
1851
+ q
1852
+ for q in queues_found
1853
+ if q["name"] == queue_name
1854
+ # ensure user has access to queue
1855
+ and (
1856
+ # TODO: User created queues in the UI have USER access
1857
+ q["access"] in ["PROJECT", "USER"]
1858
+ or q["createdBy"] == self.default_entity
1859
+ )
1860
+ ]
1861
+ if not matching_queues:
1862
+ # in the case of a missing default queue. create it
1863
+ if queue_name == "default":
1864
+ wandb.termlog(
1865
+ f"No default queue existing for entity: {entity} in project: {project_queue}, creating one."
1866
+ )
1867
+ res = self.create_run_queue(
1868
+ launch_spec["entity"],
1869
+ project_queue,
1870
+ queue_name,
1871
+ access="PROJECT",
1872
+ )
1873
+
1874
+ if res is None or res.get("queueID") is None:
1875
+ wandb.termerror(
1876
+ f"Unable to create default queue for entity: {entity} on project: {project_queue}. Run could not be added to a queue"
1877
+ )
1878
+ return None
1879
+ queue_id = res["queueID"]
1880
+
1881
+ else:
1882
+ if project_queue == "model-registry":
1883
+ _msg = f"Unable to push to run queue {queue_name}. Queue not found."
1884
+ else:
1885
+ _msg = f"Unable to push to run queue {project_queue}/{queue_name}. Queue not found."
1886
+ wandb.termwarn(_msg)
1887
+ return None
1888
+ elif len(matching_queues) > 1:
1889
+ wandb.termerror(
1890
+ f"Unable to push to run queue {queue_name}. More than one queue found with this name."
1891
+ )
1892
+ return None
1893
+ else:
1894
+ queue_id = matching_queues[0]["id"]
1895
+ spec_json = json.dumps(launch_spec)
1896
+ variables = {"queueID": queue_id, "runSpec": spec_json}
1897
+
1898
+ mutation_params = """
1899
+ $queueID: ID!,
1900
+ $runSpec: JSONString!
1901
+ """
1902
+ mutation_input = """
1903
+ queueID: $queueID,
1904
+ runSpec: $runSpec
1905
+ """
1906
+ if self.server_supports_template_variables:
1907
+ if template_variables is not None:
1908
+ mutation_params += ", $templateVariableValues: JSONString"
1909
+ mutation_input += ", templateVariableValues: $templateVariableValues"
1910
+ variables.update(
1911
+ {"templateVariableValues": json.dumps(template_variables)}
1912
+ )
1913
+ else:
1914
+ if template_variables is not None:
1915
+ raise UnsupportedError(
1916
+ "server does not support template variables, please update server instance to >=0.46"
1917
+ )
1918
+
1919
+ mutation = gql(
1920
+ f"""
1921
+ mutation pushToRunQueue(
1922
+ {mutation_params}
1923
+ ) {{
1924
+ pushToRunQueue(
1925
+ input: {{{mutation_input}}}
1926
+ ) {{
1927
+ runQueueItemId
1928
+ }}
1929
+ }}
1930
+ """
1931
+ )
1932
+
1933
+ response = self.gql(mutation, variable_values=variables)
1934
+ if not response.get("pushToRunQueue"):
1935
+ raise CommError(f"Error pushing run queue item to queue {queue_name}.")
1936
+
1937
+ result: Optional[Dict[str, Any]] = response["pushToRunQueue"]
1938
+ return result
1939
+
1940
+ @normalize_exceptions
1941
+ def pop_from_run_queue(
1942
+ self,
1943
+ queue_name: str,
1944
+ entity: Optional[str] = None,
1945
+ project: Optional[str] = None,
1946
+ agent_id: Optional[str] = None,
1947
+ ) -> Optional[Dict[str, Any]]:
1948
+ mutation = gql(
1949
+ """
1950
+ mutation popFromRunQueue($entity: String!, $project: String!, $queueName: String!, $launchAgentId: ID) {
1951
+ popFromRunQueue(input: {
1952
+ entityName: $entity,
1953
+ projectName: $project,
1954
+ queueName: $queueName,
1955
+ launchAgentId: $launchAgentId
1956
+ }) {
1957
+ runQueueItemId
1958
+ runSpec
1959
+ }
1960
+ }
1961
+ """
1962
+ )
1963
+ response = self.gql(
1964
+ mutation,
1965
+ variable_values={
1966
+ "entity": entity,
1967
+ "project": project,
1968
+ "queueName": queue_name,
1969
+ "launchAgentId": agent_id,
1970
+ },
1971
+ )
1972
+ result: Optional[Dict[str, Any]] = response["popFromRunQueue"]
1973
+ return result
1974
+
1975
+ @normalize_exceptions
1976
+ def ack_run_queue_item(self, item_id: str, run_id: Optional[str] = None) -> bool:
1977
+ mutation = gql(
1978
+ """
1979
+ mutation ackRunQueueItem($itemId: ID!, $runId: String!) {
1980
+ ackRunQueueItem(input: { runQueueItemId: $itemId, runName: $runId }) {
1981
+ success
1982
+ }
1983
+ }
1984
+ """
1985
+ )
1986
+ response = self.gql(
1987
+ mutation, variable_values={"itemId": item_id, "runId": str(run_id)}
1988
+ )
1989
+ if not response["ackRunQueueItem"]["success"]:
1990
+ raise CommError(
1991
+ "Error acking run queue item. Item may have already been acknowledged by another process"
1992
+ )
1993
+ result: bool = response["ackRunQueueItem"]["success"]
1994
+ return result
1995
+
1996
+ @normalize_exceptions
1997
+ def create_launch_agent_fields_introspection(self) -> List:
1998
+ if self.create_launch_agent_input_info:
1999
+ return self.create_launch_agent_input_info
2000
+ query_string = """
2001
+ query ProbeServerCreateLaunchAgentInput {
2002
+ CreateLaunchAgentInputInfoType: __type(name:"CreateLaunchAgentInput") {
2003
+ inputFields{
2004
+ name
2005
+ }
2006
+ }
2007
+ }
2008
+ """
2009
+
2010
+ query = gql(query_string)
2011
+ res = self.gql(query)
2012
+
2013
+ self.create_launch_agent_input_info = [
2014
+ field.get("name", "")
2015
+ for field in res.get("CreateLaunchAgentInputInfoType", {}).get(
2016
+ "inputFields", [{}]
2017
+ )
2018
+ ]
2019
+ return self.create_launch_agent_input_info
2020
+
2021
+ @normalize_exceptions
2022
+ def create_launch_agent(
2023
+ self,
2024
+ entity: str,
2025
+ project: str,
2026
+ queues: List[str],
2027
+ agent_config: Dict[str, Any],
2028
+ version: str,
2029
+ gorilla_agent_support: bool,
2030
+ ) -> dict:
2031
+ project_queues = self.get_project_run_queues(entity, project)
2032
+ if not project_queues:
2033
+ # create default queue if it doesn't already exist
2034
+ default = self.create_run_queue(
2035
+ entity, project, "default", access="PROJECT"
2036
+ )
2037
+ if default is None or default.get("queueID") is None:
2038
+ raise CommError(
2039
+ "Unable to create default queue for {}/{}. No queues for agent to poll".format(
2040
+ entity, project
2041
+ )
2042
+ )
2043
+ project_queues = [{"id": default["queueID"], "name": "default"}]
2044
+ polling_queue_ids = [
2045
+ q["id"] for q in project_queues if q["name"] in queues
2046
+ ] # filter to poll specified queues
2047
+ if len(polling_queue_ids) != len(queues):
2048
+ raise CommError(
2049
+ f"Could not start launch agent: Not all of requested queues ({', '.join(queues)}) found. "
2050
+ f"Available queues for this project: {','.join([q['name'] for q in project_queues])}"
2051
+ )
2052
+
2053
+ if not gorilla_agent_support:
2054
+ # if gorilla doesn't support launch agents, return a client-generated id
2055
+ return {
2056
+ "success": True,
2057
+ "launchAgentId": None,
2058
+ }
2059
+
2060
+ hostname = socket.gethostname()
2061
+
2062
+ variable_values = {
2063
+ "entity": entity,
2064
+ "project": project,
2065
+ "queues": polling_queue_ids,
2066
+ "hostname": hostname,
2067
+ }
2068
+
2069
+ mutation_params = """
2070
+ $entity: String!,
2071
+ $project: String!,
2072
+ $queues: [ID!]!,
2073
+ $hostname: String!
2074
+ """
2075
+
2076
+ mutation_input = """
2077
+ entityName: $entity,
2078
+ projectName: $project,
2079
+ runQueues: $queues,
2080
+ hostname: $hostname
2081
+ """
2082
+
2083
+ if "agentConfig" in self.create_launch_agent_fields_introspection():
2084
+ variable_values["agentConfig"] = json.dumps(agent_config)
2085
+ mutation_params += ", $agentConfig: JSONString"
2086
+ mutation_input += ", agentConfig: $agentConfig"
2087
+ if "version" in self.create_launch_agent_fields_introspection():
2088
+ variable_values["version"] = version
2089
+ mutation_params += ", $version: String"
2090
+ mutation_input += ", version: $version"
2091
+
2092
+ mutation = gql(
2093
+ f"""
2094
+ mutation createLaunchAgent(
2095
+ {mutation_params}
2096
+ ) {{
2097
+ createLaunchAgent(
2098
+ input: {{
2099
+ {mutation_input}
2100
+ }}
2101
+ ) {{
2102
+ launchAgentId
2103
+ }}
2104
+ }}
2105
+ """
2106
+ )
2107
+ result: dict = self.gql(mutation, variable_values)["createLaunchAgent"]
2108
+ return result
2109
+
2110
+ @normalize_exceptions
2111
+ def update_launch_agent_status(
2112
+ self,
2113
+ agent_id: str,
2114
+ status: str,
2115
+ gorilla_agent_support: bool,
2116
+ ) -> dict:
2117
+ if not gorilla_agent_support:
2118
+ # if gorilla doesn't support launch agents, this is a no-op
2119
+ return {
2120
+ "success": True,
2121
+ }
2122
+
2123
+ mutation = gql(
2124
+ """
2125
+ mutation updateLaunchAgent($agentId: ID!, $agentStatus: String){
2126
+ updateLaunchAgent(
2127
+ input: {
2128
+ launchAgentId: $agentId
2129
+ agentStatus: $agentStatus
2130
+ }
2131
+ ) {
2132
+ success
2133
+ }
2134
+ }
2135
+ """
2136
+ )
2137
+ variable_values = {
2138
+ "agentId": agent_id,
2139
+ "agentStatus": status,
2140
+ }
2141
+ result: dict = self.gql(mutation, variable_values)["updateLaunchAgent"]
2142
+ return result
2143
+
2144
+ @normalize_exceptions
2145
+ def get_launch_agent(self, agent_id: str, gorilla_agent_support: bool) -> dict:
2146
+ if not gorilla_agent_support:
2147
+ return {
2148
+ "id": None,
2149
+ "name": "",
2150
+ "stopPolling": False,
2151
+ }
2152
+ query = gql(
2153
+ """
2154
+ query LaunchAgent($agentId: ID!) {
2155
+ launchAgent(id: $agentId) {
2156
+ id
2157
+ name
2158
+ runQueues
2159
+ hostname
2160
+ agentStatus
2161
+ stopPolling
2162
+ heartbeatAt
2163
+ }
2164
+ }
2165
+ """
2166
+ )
2167
+ variable_values = {
2168
+ "agentId": agent_id,
2169
+ }
2170
+ result: dict = self.gql(query, variable_values)["launchAgent"]
2171
+ return result
2172
+
2173
+ @normalize_exceptions
2174
+ def upsert_run(
2175
+ self,
2176
+ id: Optional[str] = None,
2177
+ name: Optional[str] = None,
2178
+ project: Optional[str] = None,
2179
+ host: Optional[str] = None,
2180
+ group: Optional[str] = None,
2181
+ tags: Optional[List[str]] = None,
2182
+ config: Optional[dict] = None,
2183
+ description: Optional[str] = None,
2184
+ entity: Optional[str] = None,
2185
+ state: Optional[str] = None,
2186
+ display_name: Optional[str] = None,
2187
+ notes: Optional[str] = None,
2188
+ repo: Optional[str] = None,
2189
+ job_type: Optional[str] = None,
2190
+ program_path: Optional[str] = None,
2191
+ commit: Optional[str] = None,
2192
+ sweep_name: Optional[str] = None,
2193
+ summary_metrics: Optional[str] = None,
2194
+ num_retries: Optional[int] = None,
2195
+ ) -> Tuple[dict, bool, Optional[List]]:
2196
+ """Update a run.
2197
+
2198
+ Args:
2199
+ id (str, optional): The existing run to update
2200
+ name (str, optional): The name of the run to create
2201
+ group (str, optional): Name of the group this run is a part of
2202
+ project (str, optional): The name of the project
2203
+ host (str, optional): The name of the host
2204
+ tags (list, optional): A list of tags to apply to the run
2205
+ config (dict, optional): The latest config params
2206
+ description (str, optional): A description of this project
2207
+ entity (str, optional): The entity to scope this project to.
2208
+ display_name (str, optional): The display name of this project
2209
+ notes (str, optional): Notes about this run
2210
+ repo (str, optional): Url of the program's repository.
2211
+ state (str, optional): State of the program.
2212
+ job_type (str, optional): Type of job, e.g 'train'.
2213
+ program_path (str, optional): Path to the program.
2214
+ commit (str, optional): The Git SHA to associate the run with
2215
+ sweep_name (str, optional): The name of the sweep this run is a part of
2216
+ summary_metrics (str, optional): The JSON summary metrics
2217
+ num_retries (int, optional): Number of retries
2218
+ """
2219
+ query_string = """
2220
+ mutation UpsertBucket(
2221
+ $id: String,
2222
+ $name: String,
2223
+ $project: String,
2224
+ $entity: String,
2225
+ $groupName: String,
2226
+ $description: String,
2227
+ $displayName: String,
2228
+ $notes: String,
2229
+ $commit: String,
2230
+ $config: JSONString,
2231
+ $host: String,
2232
+ $debug: Boolean,
2233
+ $program: String,
2234
+ $repo: String,
2235
+ $jobType: String,
2236
+ $state: String,
2237
+ $sweep: String,
2238
+ $tags: [String!],
2239
+ $summaryMetrics: JSONString,
2240
+ ) {
2241
+ upsertBucket(input: {
2242
+ id: $id,
2243
+ name: $name,
2244
+ groupName: $groupName,
2245
+ modelName: $project,
2246
+ entityName: $entity,
2247
+ description: $description,
2248
+ displayName: $displayName,
2249
+ notes: $notes,
2250
+ config: $config,
2251
+ commit: $commit,
2252
+ host: $host,
2253
+ debug: $debug,
2254
+ jobProgram: $program,
2255
+ jobRepo: $repo,
2256
+ jobType: $jobType,
2257
+ state: $state,
2258
+ sweep: $sweep,
2259
+ tags: $tags,
2260
+ summaryMetrics: $summaryMetrics,
2261
+ }) {
2262
+ bucket {
2263
+ id
2264
+ name
2265
+ displayName
2266
+ description
2267
+ config
2268
+ sweepName
2269
+ project {
2270
+ id
2271
+ name
2272
+ entity {
2273
+ id
2274
+ name
2275
+ }
2276
+ }
2277
+ historyLineCount
2278
+ }
2279
+ inserted
2280
+ _Server_Settings_
2281
+ }
2282
+ }
2283
+ """
2284
+ self.server_settings_introspection()
2285
+
2286
+ server_settings_string = (
2287
+ """
2288
+ serverSettings {
2289
+ serverMessages{
2290
+ utfText
2291
+ plainText
2292
+ htmlText
2293
+ messageType
2294
+ messageLevel
2295
+ }
2296
+ }
2297
+ """
2298
+ if self._server_settings_type
2299
+ else ""
2300
+ )
2301
+
2302
+ query_string = query_string.replace("_Server_Settings_", server_settings_string)
2303
+ mutation = gql(query_string)
2304
+ config_str = json.dumps(config) if config else None
2305
+ if not description or description.isspace():
2306
+ description = None
2307
+
2308
+ kwargs = {}
2309
+ if num_retries is not None:
2310
+ kwargs["num_retries"] = num_retries
2311
+
2312
+ variable_values = {
2313
+ "id": id,
2314
+ "entity": entity or self.settings("entity"),
2315
+ "name": name,
2316
+ "project": project or util.auto_project_name(program_path),
2317
+ "groupName": group,
2318
+ "tags": tags,
2319
+ "description": description,
2320
+ "config": config_str,
2321
+ "commit": commit,
2322
+ "displayName": display_name,
2323
+ "notes": notes,
2324
+ "host": None if self.settings().get("anonymous") == "true" else host,
2325
+ "debug": env.is_debug(env=self._environ),
2326
+ "repo": repo,
2327
+ "program": program_path,
2328
+ "jobType": job_type,
2329
+ "state": state,
2330
+ "sweep": sweep_name,
2331
+ "summaryMetrics": summary_metrics,
2332
+ }
2333
+
2334
+ # retry conflict errors for 2 minutes, default to no_auth_retry
2335
+ check_retry_fn = util.make_check_retry_fn(
2336
+ check_fn=util.check_retry_conflict_or_gone,
2337
+ check_timedelta=datetime.timedelta(minutes=2),
2338
+ fallback_retry_fn=util.no_retry_auth,
2339
+ )
2340
+
2341
+ response = self.gql(
2342
+ mutation,
2343
+ variable_values=variable_values,
2344
+ check_retry_fn=check_retry_fn,
2345
+ **kwargs,
2346
+ )
2347
+
2348
+ run_obj: Dict[str, Dict[str, Dict[str, str]]] = response["upsertBucket"][
2349
+ "bucket"
2350
+ ]
2351
+ project_obj: Dict[str, Dict[str, str]] = run_obj.get("project", {})
2352
+ if project_obj:
2353
+ self.set_setting("project", project_obj["name"])
2354
+ entity_obj = project_obj.get("entity", {})
2355
+ if entity_obj:
2356
+ self.set_setting("entity", entity_obj["name"])
2357
+
2358
+ server_messages = None
2359
+ if self._server_settings_type:
2360
+ server_messages = (
2361
+ response["upsertBucket"]
2362
+ .get("serverSettings", {})
2363
+ .get("serverMessages", [])
2364
+ )
2365
+
2366
+ return (
2367
+ response["upsertBucket"]["bucket"],
2368
+ response["upsertBucket"]["inserted"],
2369
+ server_messages,
2370
+ )
2371
+
2372
+ @normalize_exceptions
2373
+ def rewind_run(
2374
+ self,
2375
+ run_name: str,
2376
+ metric_name: str,
2377
+ metric_value: float,
2378
+ program_path: Optional[str] = None,
2379
+ entity: Optional[str] = None,
2380
+ project: Optional[str] = None,
2381
+ num_retries: Optional[int] = None,
2382
+ ) -> dict:
2383
+ """Rewinds a run to a previous state.
2384
+
2385
+ Args:
2386
+ run_name (str): The name of the run to rewind
2387
+ metric_name (str): The name of the metric to rewind to
2388
+ metric_value (float): The value of the metric to rewind to
2389
+ program_path (str, optional): Path to the program
2390
+ entity (str, optional): The entity to scope this project to
2391
+ project (str, optional): The name of the project
2392
+ num_retries (int, optional): Number of retries
2393
+
2394
+ Returns:
2395
+ A dict with the rewound run
2396
+
2397
+ {
2398
+ "id": "run_id",
2399
+ "name": "run_name",
2400
+ "displayName": "run_display_name",
2401
+ "description": "run_description",
2402
+ "config": "stringified_run_config_json",
2403
+ "sweepName": "run_sweep_name",
2404
+ "project": {
2405
+ "id": "project_id",
2406
+ "name": "project_name",
2407
+ "entity": {
2408
+ "id": "entity_id",
2409
+ "name": "entity_name"
2410
+ }
2411
+ },
2412
+ "historyLineCount": 100,
2413
+ }
2414
+ """
2415
+ query_string = """
2416
+ mutation RewindRun($runName: String!, $entity: String, $project: String, $metricName: String!, $metricValue: Float!) {
2417
+ rewindRun(input: {runName: $runName, entityName: $entity, projectName: $project, metricName: $metricName, metricValue: $metricValue}) {
2418
+ rewoundRun {
2419
+ id
2420
+ name
2421
+ displayName
2422
+ description
2423
+ config
2424
+ sweepName
2425
+ project {
2426
+ id
2427
+ name
2428
+ entity {
2429
+ id
2430
+ name
2431
+ }
2432
+ }
2433
+ historyLineCount
2434
+ }
2435
+ }
2436
+ }
2437
+ """
2438
+
2439
+ mutation = gql(query_string)
2440
+
2441
+ kwargs = {}
2442
+ if num_retries is not None:
2443
+ kwargs["num_retries"] = num_retries
2444
+
2445
+ variable_values = {
2446
+ "runName": run_name,
2447
+ "entity": entity or self.settings("entity"),
2448
+ "project": project or util.auto_project_name(program_path),
2449
+ "metricName": metric_name,
2450
+ "metricValue": metric_value,
2451
+ }
2452
+
2453
+ # retry conflict errors for 2 minutes, default to no_auth_retry
2454
+ check_retry_fn = util.make_check_retry_fn(
2455
+ check_fn=util.check_retry_conflict_or_gone,
2456
+ check_timedelta=datetime.timedelta(minutes=2),
2457
+ fallback_retry_fn=util.no_retry_auth,
2458
+ )
2459
+
2460
+ response = self.gql(
2461
+ mutation,
2462
+ variable_values=variable_values,
2463
+ check_retry_fn=check_retry_fn,
2464
+ **kwargs,
2465
+ )
2466
+
2467
+ run_obj: Dict[str, Dict[str, Dict[str, str]]] = response.get(
2468
+ "rewindRun", {}
2469
+ ).get("rewoundRun", {})
2470
+ project_obj: Dict[str, Dict[str, str]] = run_obj.get("project", {})
2471
+ if project_obj:
2472
+ self.set_setting("project", project_obj["name"])
2473
+ entity_obj = project_obj.get("entity", {})
2474
+ if entity_obj:
2475
+ self.set_setting("entity", entity_obj["name"])
2476
+
2477
+ return run_obj
2478
+
2479
+ @normalize_exceptions
2480
+ def get_run_info(
2481
+ self,
2482
+ entity: str,
2483
+ project: str,
2484
+ name: str,
2485
+ ) -> dict:
2486
+ query = gql(
2487
+ """
2488
+ query RunInfo($project: String!, $entity: String!, $name: String!) {
2489
+ project(name: $project, entityName: $entity) {
2490
+ run(name: $name) {
2491
+ runInfo {
2492
+ program
2493
+ args
2494
+ os
2495
+ python
2496
+ colab
2497
+ executable
2498
+ codeSaved
2499
+ cpuCount
2500
+ gpuCount
2501
+ gpu
2502
+ git {
2503
+ remote
2504
+ commit
2505
+ }
2506
+ }
2507
+ }
2508
+ }
2509
+ }
2510
+ """
2511
+ )
2512
+ variable_values = {"project": project, "entity": entity, "name": name}
2513
+ res = self.gql(query, variable_values)
2514
+ if res.get("project") is None:
2515
+ raise CommError(
2516
+ "Error fetching run info for {}/{}/{}. Check that this project exists and you have access to this entity and project".format(
2517
+ entity, project, name
2518
+ )
2519
+ )
2520
+ elif res["project"].get("run") is None:
2521
+ raise CommError(
2522
+ "Error fetching run info for {}/{}/{}. Check that this run id exists".format(
2523
+ entity, project, name
2524
+ )
2525
+ )
2526
+ run_info: dict = res["project"]["run"]["runInfo"]
2527
+ return run_info
2528
+
2529
+ @normalize_exceptions
2530
+ def get_run_state(self, entity: str, project: str, name: str) -> str:
2531
+ query = gql(
2532
+ """
2533
+ query RunState(
2534
+ $project: String!,
2535
+ $entity: String!,
2536
+ $name: String!) {
2537
+ project(name: $project, entityName: $entity) {
2538
+ run(name: $name) {
2539
+ state
2540
+ }
2541
+ }
2542
+ }
2543
+ """
2544
+ )
2545
+ variable_values = {
2546
+ "project": project,
2547
+ "entity": entity,
2548
+ "name": name,
2549
+ }
2550
+ res = self.gql(query, variable_values)
2551
+ if res.get("project") is None or res["project"].get("run") is None:
2552
+ raise CommError(f"Error fetching run state for {entity}/{project}/{name}.")
2553
+ run_state: str = res["project"]["run"]["state"]
2554
+ return run_state
2555
+
2556
+ @normalize_exceptions
2557
+ def create_run_files_introspection(self) -> bool:
2558
+ _, _, mutations = self.server_info_introspection()
2559
+ return "createRunFiles" in mutations
2560
+
2561
+ @normalize_exceptions
2562
+ def upload_urls(
2563
+ self,
2564
+ project: str,
2565
+ files: Union[List[str], Dict[str, IO]],
2566
+ run: Optional[str] = None,
2567
+ entity: Optional[str] = None,
2568
+ description: Optional[str] = None,
2569
+ ) -> Tuple[str, List[str], Dict[str, Dict[str, Any]]]:
2570
+ """Generate temporary resumable upload urls.
2571
+
2572
+ Args:
2573
+ project (str): The project to download
2574
+ files (list or dict): The filenames to upload
2575
+ run (str, optional): The run to upload to
2576
+ entity (str, optional): The entity to scope this project to.
2577
+ description (str, optional): description
2578
+
2579
+ Returns:
2580
+ (run_id, upload_headers, file_info)
2581
+ run_id: id of run we uploaded files to
2582
+ upload_headers: A list of headers to use when uploading files.
2583
+ file_info: A dict of filenames and urls.
2584
+ {
2585
+ "run_id": "run_id",
2586
+ "upload_headers": [""],
2587
+ "file_info": [
2588
+ { "weights.h5": { "uploadUrl": "https://weights.url" } },
2589
+ { "model.json": { "uploadUrl": "https://model.json" } }
2590
+ ]
2591
+ }
2592
+ """
2593
+ run_name = run or self.current_run_id
2594
+ assert run_name, "run must be specified"
2595
+ entity = entity or self.settings("entity")
2596
+ assert entity, "entity must be specified"
2597
+
2598
+ has_create_run_files_mutation = self.create_run_files_introspection()
2599
+ if not has_create_run_files_mutation:
2600
+ return self.legacy_upload_urls(project, files, run, entity, description)
2601
+
2602
+ query = gql(
2603
+ """
2604
+ mutation CreateRunFiles($entity: String!, $project: String!, $run: String!, $files: [String!]!) {
2605
+ createRunFiles(input: {entityName: $entity, projectName: $project, runName: $run, files: $files}) {
2606
+ runID
2607
+ uploadHeaders
2608
+ files {
2609
+ name
2610
+ uploadUrl
2611
+ }
2612
+ }
2613
+ }
2614
+ """
2615
+ )
2616
+
2617
+ query_result = self.gql(
2618
+ query,
2619
+ variable_values={
2620
+ "project": project,
2621
+ "run": run_name,
2622
+ "entity": entity,
2623
+ "files": [file for file in files],
2624
+ },
2625
+ )
2626
+
2627
+ result = query_result["createRunFiles"]
2628
+ run_id = result["runID"]
2629
+ if not run_id:
2630
+ raise CommError(
2631
+ f"Error uploading files to {entity}/{project}/{run_name}. Check that this project exists and you have access to this entity and project"
2632
+ )
2633
+ file_name_urls = {file["name"]: file for file in result["files"]}
2634
+ return run_id, result["uploadHeaders"], file_name_urls
2635
+
2636
+ def legacy_upload_urls(
2637
+ self,
2638
+ project: str,
2639
+ files: Union[List[str], Dict[str, IO]],
2640
+ run: Optional[str] = None,
2641
+ entity: Optional[str] = None,
2642
+ description: Optional[str] = None,
2643
+ ) -> Tuple[str, List[str], Dict[str, Dict[str, Any]]]:
2644
+ """Generate temporary resumable upload urls.
2645
+
2646
+ A new mutation createRunFiles was introduced after 0.15.4.
2647
+ This function is used to support older versions.
2648
+ """
2649
+ query = gql(
2650
+ """
2651
+ query RunUploadUrls($name: String!, $files: [String]!, $entity: String, $run: String!, $description: String) {
2652
+ model(name: $name, entityName: $entity) {
2653
+ bucket(name: $run, desc: $description) {
2654
+ id
2655
+ files(names: $files) {
2656
+ uploadHeaders
2657
+ edges {
2658
+ node {
2659
+ name
2660
+ url(upload: true)
2661
+ updatedAt
2662
+ }
2663
+ }
2664
+ }
2665
+ }
2666
+ }
2667
+ }
2668
+ """
2669
+ )
2670
+ run_id = run or self.current_run_id
2671
+ assert run_id, "run must be specified"
2672
+ entity = entity or self.settings("entity")
2673
+ query_result = self.gql(
2674
+ query,
2675
+ variable_values={
2676
+ "name": project,
2677
+ "run": run_id,
2678
+ "entity": entity,
2679
+ "files": [file for file in files],
2680
+ "description": description,
2681
+ },
2682
+ )
2683
+
2684
+ run_obj = query_result["model"]["bucket"]
2685
+ if run_obj:
2686
+ for file_node in run_obj["files"]["edges"]:
2687
+ file = file_node["node"]
2688
+ # we previously used "url" field but now use "uploadUrl"
2689
+ # replace the "url" field with "uploadUrl for downstream compatibility
2690
+ if "url" in file and "uploadUrl" not in file:
2691
+ file["uploadUrl"] = file.pop("url")
2692
+
2693
+ result = {
2694
+ file["name"]: file for file in self._flatten_edges(run_obj["files"])
2695
+ }
2696
+ return run_obj["id"], run_obj["files"]["uploadHeaders"], result
2697
+ else:
2698
+ raise CommError(f"Run does not exist {entity}/{project}/{run_id}.")
2699
+
2700
+ @normalize_exceptions
2701
+ def download_urls(
2702
+ self,
2703
+ project: str,
2704
+ run: Optional[str] = None,
2705
+ entity: Optional[str] = None,
2706
+ ) -> Dict[str, Dict[str, str]]:
2707
+ """Generate download urls.
2708
+
2709
+ Args:
2710
+ project (str): The project to download
2711
+ run (str): The run to upload to
2712
+ entity (str, optional): The entity to scope this project to. Defaults to wandb models
2713
+
2714
+ Returns:
2715
+ A dict of extensions and urls
2716
+
2717
+ {
2718
+ 'weights.h5': { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' },
2719
+ 'model.json': { "url": "https://model.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }
2720
+ }
2721
+ """
2722
+ query = gql(
2723
+ """
2724
+ query RunDownloadUrls($name: String!, $entity: String, $run: String!) {
2725
+ model(name: $name, entityName: $entity) {
2726
+ bucket(name: $run) {
2727
+ files {
2728
+ edges {
2729
+ node {
2730
+ name
2731
+ url
2732
+ md5
2733
+ updatedAt
2734
+ }
2735
+ }
2736
+ }
2737
+ }
2738
+ }
2739
+ }
2740
+ """
2741
+ )
2742
+ run = run or self.current_run_id
2743
+ assert run, "run must be specified"
2744
+ entity = entity or self.settings("entity")
2745
+ query_result = self.gql(
2746
+ query,
2747
+ variable_values={
2748
+ "name": project,
2749
+ "run": run,
2750
+ "entity": entity,
2751
+ },
2752
+ )
2753
+ if query_result["model"] is None:
2754
+ raise CommError(f"Run does not exist {entity}/{project}/{run}.")
2755
+ files = self._flatten_edges(query_result["model"]["bucket"]["files"])
2756
+ return {file["name"]: file for file in files if file}
2757
+
2758
+ @normalize_exceptions
2759
+ def download_url(
2760
+ self,
2761
+ project: str,
2762
+ file_name: str,
2763
+ run: Optional[str] = None,
2764
+ entity: Optional[str] = None,
2765
+ ) -> Optional[Dict[str, str]]:
2766
+ """Generate download urls.
2767
+
2768
+ Args:
2769
+ project (str): The project to download
2770
+ file_name (str): The name of the file to download
2771
+ run (str): The run to upload to
2772
+ entity (str, optional): The entity to scope this project to. Defaults to wandb models
2773
+
2774
+ Returns:
2775
+ A dict of extensions and urls
2776
+
2777
+ { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }
2778
+
2779
+ """
2780
+ query = gql(
2781
+ """
2782
+ query RunDownloadUrl($name: String!, $fileName: String!, $entity: String, $run: String!) {
2783
+ model(name: $name, entityName: $entity) {
2784
+ bucket(name: $run) {
2785
+ files(names: [$fileName]) {
2786
+ edges {
2787
+ node {
2788
+ name
2789
+ url
2790
+ md5
2791
+ updatedAt
2792
+ }
2793
+ }
2794
+ }
2795
+ }
2796
+ }
2797
+ }
2798
+ """
2799
+ )
2800
+ run = run or self.current_run_id
2801
+ assert run, "run must be specified"
2802
+ query_result = self.gql(
2803
+ query,
2804
+ variable_values={
2805
+ "name": project,
2806
+ "run": run,
2807
+ "fileName": file_name,
2808
+ "entity": entity or self.settings("entity"),
2809
+ },
2810
+ )
2811
+ if query_result["model"]:
2812
+ files = self._flatten_edges(query_result["model"]["bucket"]["files"])
2813
+ return files[0] if len(files) > 0 and files[0].get("updatedAt") else None
2814
+ else:
2815
+ return None
2816
+
2817
+ @normalize_exceptions
2818
+ def download_file(self, url: str) -> Tuple[int, requests.Response]:
2819
+ """Initiate a streaming download.
2820
+
2821
+ Args:
2822
+ url (str): The url to download
2823
+
2824
+ Returns:
2825
+ A tuple of the content length and the streaming response
2826
+ """
2827
+ check_httpclient_logger_handler()
2828
+
2829
+ http_headers = _thread_local_api_settings.headers or {}
2830
+
2831
+ auth = None
2832
+ if self.access_token is not None:
2833
+ http_headers["Authorization"] = f"Bearer {self.access_token}"
2834
+ elif _thread_local_api_settings.cookies is None:
2835
+ auth = ("api", self.api_key or "")
2836
+
2837
+ response = requests.get(
2838
+ url,
2839
+ auth=auth,
2840
+ cookies=_thread_local_api_settings.cookies or {},
2841
+ headers=http_headers,
2842
+ stream=True,
2843
+ )
2844
+ response.raise_for_status()
2845
+ return int(response.headers.get("content-length", 0)), response
2846
+
2847
+ @normalize_exceptions
2848
+ def download_write_file(
2849
+ self,
2850
+ metadata: Dict[str, str],
2851
+ out_dir: Optional[str] = None,
2852
+ ) -> Tuple[str, Optional[requests.Response]]:
2853
+ """Download a file from a run and write it to wandb/.
2854
+
2855
+ Args:
2856
+ metadata (obj): The metadata object for the file to download. Comes from Api.download_urls().
2857
+ out_dir (str, optional): The directory to write the file to. Defaults to wandb/
2858
+
2859
+ Returns:
2860
+ A tuple of the file's local path and the streaming response. The streaming response is None if the file
2861
+ already existed and was up-to-date.
2862
+ """
2863
+ filename = metadata["name"]
2864
+ path = os.path.join(out_dir or self.settings("wandb_dir"), filename)
2865
+ if self.file_current(filename, B64MD5(metadata["md5"])):
2866
+ return path, None
2867
+
2868
+ size, response = self.download_file(metadata["url"])
2869
+
2870
+ with util.fsync_open(path, "wb") as file:
2871
+ for data in response.iter_content(chunk_size=1024):
2872
+ file.write(data)
2873
+
2874
+ return path, response
2875
+
2876
+ def upload_file_azure(
2877
+ self, url: str, file: Any, extra_headers: Dict[str, str]
2878
+ ) -> None:
2879
+ """Upload a file to azure."""
2880
+ from azure.core.exceptions import AzureError # type: ignore
2881
+
2882
+ # Configure the client without retries so our existing logic can handle them
2883
+ client = self._azure_blob_module.BlobClient.from_blob_url(
2884
+ url, retry_policy=self._azure_blob_module.LinearRetry(retry_total=0)
2885
+ )
2886
+ try:
2887
+ if extra_headers.get("Content-MD5") is not None:
2888
+ md5: Optional[bytes] = base64.b64decode(extra_headers["Content-MD5"])
2889
+ else:
2890
+ md5 = None
2891
+ content_settings = self._azure_blob_module.ContentSettings(
2892
+ content_md5=md5,
2893
+ content_type=extra_headers.get("Content-Type"),
2894
+ )
2895
+ client.upload_blob(
2896
+ file,
2897
+ max_concurrency=4,
2898
+ length=len(file),
2899
+ overwrite=True,
2900
+ content_settings=content_settings,
2901
+ )
2902
+ except AzureError as e:
2903
+ if hasattr(e, "response"):
2904
+ response = requests.models.Response()
2905
+ response.status_code = e.response.status_code
2906
+ response.headers = e.response.headers
2907
+ raise requests.exceptions.RequestException(e.message, response=response)
2908
+ else:
2909
+ raise requests.exceptions.ConnectionError(e.message)
2910
+
2911
+ def upload_multipart_file_chunk(
2912
+ self,
2913
+ url: str,
2914
+ upload_chunk: bytes,
2915
+ extra_headers: Optional[Dict[str, str]] = None,
2916
+ ) -> Optional[requests.Response]:
2917
+ """Upload a file chunk to S3 with failure resumption.
2918
+
2919
+ Args:
2920
+ url: The url to download
2921
+ upload_chunk: The path to the file you want to upload
2922
+ extra_headers: A dictionary of extra headers to send with the request
2923
+
2924
+ Returns:
2925
+ The `requests` library response object
2926
+ """
2927
+ check_httpclient_logger_handler()
2928
+ try:
2929
+ if env.is_debug(env=self._environ):
2930
+ logger.debug("upload_file: %s", url)
2931
+ response = self._upload_file_session.put(
2932
+ url, data=upload_chunk, headers=extra_headers
2933
+ )
2934
+ if env.is_debug(env=self._environ):
2935
+ logger.debug("upload_file: %s complete", url)
2936
+ response.raise_for_status()
2937
+ except requests.exceptions.RequestException as e:
2938
+ logger.error(f"upload_file exception {url}: {e}")
2939
+ request_headers = e.request.headers if e.request is not None else ""
2940
+ logger.error(f"upload_file request headers: {request_headers!r}")
2941
+ response_content = e.response.content if e.response is not None else ""
2942
+ logger.error(f"upload_file response body: {response_content!r}")
2943
+ status_code = e.response.status_code if e.response is not None else 0
2944
+ # S3 reports retryable request timeouts out-of-band
2945
+ is_aws_retryable = status_code == 400 and "RequestTimeout" in str(
2946
+ response_content
2947
+ )
2948
+ # Retry errors from cloud storage or local network issues
2949
+ if (
2950
+ status_code in (308, 408, 409, 429, 500, 502, 503, 504)
2951
+ or isinstance(
2952
+ e,
2953
+ (requests.exceptions.Timeout, requests.exceptions.ConnectionError),
2954
+ )
2955
+ or is_aws_retryable
2956
+ ):
2957
+ _e = retry.TransientError(exc=e)
2958
+ raise _e.with_traceback(sys.exc_info()[2])
2959
+ else:
2960
+ wandb._sentry.reraise(e)
2961
+ return response
2962
+
2963
+ def upload_file(
2964
+ self,
2965
+ url: str,
2966
+ file: IO[bytes],
2967
+ callback: Optional["ProgressFn"] = None,
2968
+ extra_headers: Optional[Dict[str, str]] = None,
2969
+ ) -> Optional[requests.Response]:
2970
+ """Upload a file to W&B with failure resumption.
2971
+
2972
+ Args:
2973
+ url: The url to download
2974
+ file: The path to the file you want to upload
2975
+ callback: A callback which is passed the number of
2976
+ bytes uploaded since the last time it was called, used to report progress
2977
+ extra_headers: A dictionary of extra headers to send with the request
2978
+
2979
+ Returns:
2980
+ The `requests` library response object
2981
+ """
2982
+ check_httpclient_logger_handler()
2983
+ extra_headers = extra_headers.copy() if extra_headers else {}
2984
+ response: Optional[requests.Response] = None
2985
+ progress = Progress(file, callback=callback)
2986
+ try:
2987
+ if "x-ms-blob-type" in extra_headers and self._azure_blob_module:
2988
+ self.upload_file_azure(url, progress, extra_headers)
2989
+ else:
2990
+ if "x-ms-blob-type" in extra_headers:
2991
+ wandb.termwarn(
2992
+ "Azure uploads over 256MB require the azure SDK, install with pip install wandb[azure]",
2993
+ repeat=False,
2994
+ )
2995
+ if env.is_debug(env=self._environ):
2996
+ logger.debug("upload_file: %s", url)
2997
+ response = self._upload_file_session.put(
2998
+ url, data=progress, headers=extra_headers
2999
+ )
3000
+ if env.is_debug(env=self._environ):
3001
+ logger.debug("upload_file: %s complete", url)
3002
+ response.raise_for_status()
3003
+ except requests.exceptions.RequestException as e:
3004
+ logger.error(f"upload_file exception {url}: {e}")
3005
+ request_headers = e.request.headers if e.request is not None else ""
3006
+ logger.error(f"upload_file request headers: {request_headers}")
3007
+ response_content = e.response.content if e.response is not None else ""
3008
+ logger.error(f"upload_file response body: {response_content!r}")
3009
+ status_code = e.response.status_code if e.response is not None else 0
3010
+ # S3 reports retryable request timeouts out-of-band
3011
+ is_aws_retryable = (
3012
+ "x-amz-meta-md5" in extra_headers
3013
+ and status_code == 400
3014
+ and "RequestTimeout" in str(response_content)
3015
+ )
3016
+ # We need to rewind the file for the next retry (the file passed in is `seek`'ed to 0)
3017
+ progress.rewind()
3018
+ # Retry errors from cloud storage or local network issues
3019
+ if (
3020
+ status_code in (308, 408, 409, 429, 500, 502, 503, 504)
3021
+ or isinstance(
3022
+ e,
3023
+ (requests.exceptions.Timeout, requests.exceptions.ConnectionError),
3024
+ )
3025
+ or is_aws_retryable
3026
+ ):
3027
+ _e = retry.TransientError(exc=e)
3028
+ raise _e.with_traceback(sys.exc_info()[2])
3029
+ else:
3030
+ wandb._sentry.reraise(e)
3031
+
3032
+ return response
3033
+
3034
+ @normalize_exceptions
3035
+ def register_agent(
3036
+ self,
3037
+ host: str,
3038
+ sweep_id: Optional[str] = None,
3039
+ project_name: Optional[str] = None,
3040
+ entity: Optional[str] = None,
3041
+ ) -> dict:
3042
+ """Register a new agent.
3043
+
3044
+ Args:
3045
+ host (str): hostname
3046
+ sweep_id (str): sweep id
3047
+ project_name: (str): model that contains sweep
3048
+ entity: (str): entity that contains sweep
3049
+ """
3050
+ mutation = gql(
3051
+ """
3052
+ mutation CreateAgent(
3053
+ $host: String!
3054
+ $projectName: String,
3055
+ $entityName: String,
3056
+ $sweep: String!
3057
+ ) {
3058
+ createAgent(input: {
3059
+ host: $host,
3060
+ projectName: $projectName,
3061
+ entityName: $entityName,
3062
+ sweep: $sweep,
3063
+ }) {
3064
+ agent {
3065
+ id
3066
+ }
3067
+ }
3068
+ }
3069
+ """
3070
+ )
3071
+ if entity is None:
3072
+ entity = self.settings("entity")
3073
+ if project_name is None:
3074
+ project_name = self.settings("project")
3075
+
3076
+ response = self.gql(
3077
+ mutation,
3078
+ variable_values={
3079
+ "host": host,
3080
+ "entityName": entity,
3081
+ "projectName": project_name,
3082
+ "sweep": sweep_id,
3083
+ },
3084
+ check_retry_fn=util.no_retry_4xx,
3085
+ )
3086
+ result: dict = response["createAgent"]["agent"]
3087
+ return result
3088
+
3089
+ def agent_heartbeat(
3090
+ self, agent_id: str, metrics: dict, run_states: dict
3091
+ ) -> List[Dict[str, Any]]:
3092
+ """Notify server about agent state, receive commands.
3093
+
3094
+ Args:
3095
+ agent_id (str): agent_id
3096
+ metrics (dict): system metrics
3097
+ run_states (dict): run_id: state mapping
3098
+ Returns:
3099
+ List of commands to execute.
3100
+ """
3101
+ mutation = gql(
3102
+ """
3103
+ mutation Heartbeat(
3104
+ $id: ID!,
3105
+ $metrics: JSONString,
3106
+ $runState: JSONString
3107
+ ) {
3108
+ agentHeartbeat(input: {
3109
+ id: $id,
3110
+ metrics: $metrics,
3111
+ runState: $runState
3112
+ }) {
3113
+ agent {
3114
+ id
3115
+ }
3116
+ commands
3117
+ }
3118
+ }
3119
+ """
3120
+ )
3121
+
3122
+ if agent_id is None:
3123
+ raise ValueError("Cannot call heartbeat with an unregistered agent.")
3124
+
3125
+ try:
3126
+ response = self.gql(
3127
+ mutation,
3128
+ variable_values={
3129
+ "id": agent_id,
3130
+ "metrics": json.dumps(metrics),
3131
+ "runState": json.dumps(run_states),
3132
+ },
3133
+ timeout=60,
3134
+ )
3135
+ except Exception as e:
3136
+ # GQL raises exceptions with stringified python dictionaries :/
3137
+ message = ast.literal_eval(e.args[0])["message"]
3138
+ logger.error("Error communicating with W&B: %s", message)
3139
+ return []
3140
+ else:
3141
+ result: List[Dict[str, Any]] = json.loads(
3142
+ response["agentHeartbeat"]["commands"]
3143
+ )
3144
+ return result
3145
+
3146
+ @staticmethod
3147
+ def _validate_config_and_fill_distribution(config: dict) -> dict:
3148
+ # verify that parameters are well specified.
3149
+ # TODO(dag): deprecate this in favor of jsonschema validation once
3150
+ # apiVersion 2 is released and local controller is integrated with
3151
+ # wandb/client.
3152
+
3153
+ # avoid modifying the original config dict in
3154
+ # case it is reused outside the calling func
3155
+ config = deepcopy(config)
3156
+
3157
+ # explicitly cast to dict in case config was passed as a sweepconfig
3158
+ # sweepconfig does not serialize cleanly to yaml and breaks graphql,
3159
+ # but it is a subclass of dict, so this conversion is clean
3160
+ config = dict(config)
3161
+
3162
+ if "parameters" not in config:
3163
+ # still shows an anaconda warning, but doesn't error
3164
+ return config
3165
+
3166
+ for parameter_name in config["parameters"]:
3167
+ parameter = config["parameters"][parameter_name]
3168
+ if "min" in parameter and "max" in parameter:
3169
+ if "distribution" not in parameter:
3170
+ if isinstance(parameter["min"], int) and isinstance(
3171
+ parameter["max"], int
3172
+ ):
3173
+ parameter["distribution"] = "int_uniform"
3174
+ elif isinstance(parameter["min"], float) and isinstance(
3175
+ parameter["max"], float
3176
+ ):
3177
+ parameter["distribution"] = "uniform"
3178
+ else:
3179
+ raise ValueError(
3180
+ "Parameter {} is ambiguous, please specify bounds as both floats (for a float_"
3181
+ "uniform distribution) or ints (for an int_uniform distribution).".format(
3182
+ parameter_name
3183
+ )
3184
+ )
3185
+ return config
3186
+
3187
+ @normalize_exceptions
3188
+ def upsert_sweep(
3189
+ self,
3190
+ config: dict,
3191
+ controller: Optional[str] = None,
3192
+ launch_scheduler: Optional[str] = None,
3193
+ scheduler: Optional[str] = None,
3194
+ obj_id: Optional[str] = None,
3195
+ project: Optional[str] = None,
3196
+ entity: Optional[str] = None,
3197
+ state: Optional[str] = None,
3198
+ prior_runs: Optional[List[str]] = None,
3199
+ template_variable_values: Optional[Dict[str, Any]] = None,
3200
+ ) -> Tuple[str, List[str]]:
3201
+ """Upsert a sweep object.
3202
+
3203
+ Args:
3204
+ config (dict): sweep config (will be converted to yaml)
3205
+ controller (str): controller to use
3206
+ launch_scheduler (str): launch scheduler to use
3207
+ scheduler (str): scheduler to use
3208
+ obj_id (str): object id
3209
+ project (str): project to use
3210
+ entity (str): entity to use
3211
+ state (str): state
3212
+ prior_runs (list): IDs of existing runs to add to the sweep
3213
+ template_variable_values (dict): template variable values
3214
+ """
3215
+ project_query = """
3216
+ project {
3217
+ id
3218
+ name
3219
+ entity {
3220
+ id
3221
+ name
3222
+ }
3223
+ }
3224
+ """
3225
+ mutation_str = """
3226
+ mutation UpsertSweep(
3227
+ $id: ID,
3228
+ $config: String,
3229
+ $description: String,
3230
+ $entityName: String,
3231
+ $projectName: String,
3232
+ $controller: JSONString,
3233
+ $scheduler: JSONString,
3234
+ $state: String,
3235
+ $priorRunsFilters: JSONString,
3236
+ ) {
3237
+ upsertSweep(input: {
3238
+ id: $id,
3239
+ config: $config,
3240
+ description: $description,
3241
+ entityName: $entityName,
3242
+ projectName: $projectName,
3243
+ controller: $controller,
3244
+ scheduler: $scheduler,
3245
+ state: $state,
3246
+ priorRunsFilters: $priorRunsFilters,
3247
+ }) {
3248
+ sweep {
3249
+ name
3250
+ _PROJECT_QUERY_
3251
+ }
3252
+ configValidationWarnings
3253
+ }
3254
+ }
3255
+ """
3256
+ # TODO(jhr): we need protocol versioning to know schema is not supported
3257
+ # for now we will just try both new and old query
3258
+ mutation_5 = gql(
3259
+ mutation_str.replace(
3260
+ "$controller: JSONString,",
3261
+ "$controller: JSONString,$launchScheduler: JSONString, $templateVariableValues: JSONString,",
3262
+ )
3263
+ .replace(
3264
+ "controller: $controller,",
3265
+ "controller: $controller,launchScheduler: $launchScheduler,templateVariableValues: $templateVariableValues,",
3266
+ )
3267
+ .replace("_PROJECT_QUERY_", project_query)
3268
+ )
3269
+ # launchScheduler was introduced in core v0.14.0
3270
+ mutation_4 = gql(
3271
+ mutation_str.replace(
3272
+ "$controller: JSONString,",
3273
+ "$controller: JSONString,$launchScheduler: JSONString,",
3274
+ )
3275
+ .replace(
3276
+ "controller: $controller,",
3277
+ "controller: $controller,launchScheduler: $launchScheduler",
3278
+ )
3279
+ .replace("_PROJECT_QUERY_", project_query)
3280
+ )
3281
+
3282
+ # mutation 3 maps to backend that can support CLI version of at least 0.10.31
3283
+ mutation_3 = gql(mutation_str.replace("_PROJECT_QUERY_", project_query))
3284
+ mutation_2 = gql(
3285
+ mutation_str.replace("_PROJECT_QUERY_", project_query).replace(
3286
+ "configValidationWarnings", ""
3287
+ )
3288
+ )
3289
+ mutation_1 = gql(
3290
+ mutation_str.replace("_PROJECT_QUERY_", "").replace(
3291
+ "configValidationWarnings", ""
3292
+ )
3293
+ )
3294
+
3295
+ # TODO(dag): replace this with a query for protocol versioning
3296
+ mutations = [mutation_5, mutation_4, mutation_3, mutation_2, mutation_1]
3297
+
3298
+ config = self._validate_config_and_fill_distribution(config)
3299
+
3300
+ # Silly, but attr-dicts like EasyDicts don't serialize correctly to yaml.
3301
+ # This sanitizes them with a round trip pass through json to get a regular dict.
3302
+ config_str = yaml.dump(
3303
+ json.loads(json.dumps(config)), Dumper=util.NonOctalStringDumper
3304
+ )
3305
+ filters = None
3306
+ if prior_runs:
3307
+ filters = json.dumps({"$or": [{"name": r} for r in prior_runs]})
3308
+
3309
+ err: Optional[Exception] = None
3310
+ for mutation in mutations:
3311
+ try:
3312
+ variables = {
3313
+ "id": obj_id,
3314
+ "config": config_str,
3315
+ "description": config.get("description"),
3316
+ "entityName": entity or self.settings("entity"),
3317
+ "projectName": project or self.settings("project"),
3318
+ "controller": controller,
3319
+ "launchScheduler": launch_scheduler,
3320
+ "templateVariableValues": json.dumps(template_variable_values),
3321
+ "scheduler": scheduler,
3322
+ "priorRunsFilters": filters,
3323
+ }
3324
+ if state:
3325
+ variables["state"] = state
3326
+
3327
+ response = self.gql(
3328
+ mutation,
3329
+ variable_values=variables,
3330
+ check_retry_fn=util.no_retry_4xx,
3331
+ )
3332
+ except UsageError as e:
3333
+ raise e
3334
+ except Exception as e:
3335
+ # graphql schema exception is generic
3336
+ err = e
3337
+ continue
3338
+ err = None
3339
+ break
3340
+ if err:
3341
+ raise err
3342
+
3343
+ sweep: Dict[str, Dict[str, Dict]] = response["upsertSweep"]["sweep"]
3344
+ project_obj: Dict[str, Dict] = sweep.get("project", {})
3345
+ if project_obj:
3346
+ self.set_setting("project", project_obj["name"])
3347
+ entity_obj: dict = project_obj.get("entity", {})
3348
+ if entity_obj:
3349
+ self.set_setting("entity", entity_obj["name"])
3350
+
3351
+ warnings = response["upsertSweep"].get("configValidationWarnings", [])
3352
+ return response["upsertSweep"]["sweep"]["name"], warnings
3353
+
3354
+ @normalize_exceptions
3355
+ def create_anonymous_api_key(self) -> str:
3356
+ """Create a new API key belonging to a new anonymous user."""
3357
+ mutation = gql(
3358
+ """
3359
+ mutation CreateAnonymousApiKey {
3360
+ createAnonymousEntity(input: {}) {
3361
+ apiKey {
3362
+ name
3363
+ }
3364
+ }
3365
+ }
3366
+ """
3367
+ )
3368
+
3369
+ response = self.gql(mutation, variable_values={})
3370
+ key: str = str(response["createAnonymousEntity"]["apiKey"]["name"])
3371
+ return key
3372
+
3373
+ @staticmethod
3374
+ def file_current(fname: str, md5: B64MD5) -> bool:
3375
+ """Checksum a file and compare the md5 with the known md5."""
3376
+ return os.path.isfile(fname) and md5_file_b64(fname) == md5
3377
+
3378
+ @normalize_exceptions
3379
+ def pull(
3380
+ self, project: str, run: Optional[str] = None, entity: Optional[str] = None
3381
+ ) -> "List[requests.Response]":
3382
+ """Download files from W&B.
3383
+
3384
+ Args:
3385
+ project (str): The project to download
3386
+ run (str, optional): The run to upload to
3387
+ entity (str, optional): The entity to scope this project to. Defaults to wandb models
3388
+
3389
+ Returns:
3390
+ The `requests` library response object
3391
+ """
3392
+ project, run = self.parse_slug(project, run=run)
3393
+ urls = self.download_urls(project, run, entity)
3394
+ responses = []
3395
+ for filename in urls:
3396
+ _, response = self.download_write_file(urls[filename])
3397
+ if response:
3398
+ responses.append(response)
3399
+
3400
+ return responses
3401
+
3402
+ def get_project(self) -> str:
3403
+ project: str = self.default_settings.get("project") or self.settings("project")
3404
+ return project
3405
+
3406
+ @normalize_exceptions
3407
+ def push(
3408
+ self,
3409
+ files: Union[List[str], Dict[str, IO]],
3410
+ run: Optional[str] = None,
3411
+ entity: Optional[str] = None,
3412
+ project: Optional[str] = None,
3413
+ description: Optional[str] = None,
3414
+ force: bool = True,
3415
+ progress: Union[TextIO, Literal[False]] = False,
3416
+ ) -> "List[Optional[requests.Response]]":
3417
+ """Uploads multiple files to W&B.
3418
+
3419
+ Args:
3420
+ files (list or dict): The filenames to upload, when dict the values are open files
3421
+ run (str, optional): The run to upload to
3422
+ entity (str, optional): The entity to scope this project to. Defaults to wandb models
3423
+ project (str, optional): The name of the project to upload to. Defaults to the one in settings.
3424
+ description (str, optional): The description of the changes
3425
+ force (bool, optional): Whether to prevent push if git has uncommitted changes
3426
+ progress (callable, or stream): If callable, will be called with (chunk_bytes,
3427
+ total_bytes) as argument. If TextIO, renders a progress bar to it.
3428
+
3429
+ Returns:
3430
+ A list of `requests.Response` objects
3431
+ """
3432
+ if project is None:
3433
+ project = self.get_project()
3434
+ if project is None:
3435
+ raise CommError("No project configured.")
3436
+ if run is None:
3437
+ run = self.current_run_id
3438
+
3439
+ # TODO(adrian): we use a retriable version of self.upload_file() so
3440
+ # will never retry self.upload_urls() here. Instead, maybe we should
3441
+ # make push itself retriable.
3442
+ _, upload_headers, result = self.upload_urls(
3443
+ project,
3444
+ files,
3445
+ run,
3446
+ entity,
3447
+ )
3448
+ extra_headers = {}
3449
+ for upload_header in upload_headers:
3450
+ key, val = upload_header.split(":", 1)
3451
+ extra_headers[key] = val
3452
+ responses = []
3453
+ for file_name, file_info in result.items():
3454
+ file_url = file_info["uploadUrl"]
3455
+
3456
+ # If the upload URL is relative, fill it in with the base URL,
3457
+ # since it's a proxied file store like the on-prem VM.
3458
+ if file_url.startswith("/"):
3459
+ file_url = f"{self.api_url}{file_url}"
3460
+
3461
+ try:
3462
+ # To handle Windows paths
3463
+ # TODO: this doesn't handle absolute paths...
3464
+ normal_name = os.path.join(*file_name.split("/"))
3465
+ open_file = (
3466
+ files[file_name]
3467
+ if isinstance(files, dict)
3468
+ else open(normal_name, "rb")
3469
+ )
3470
+ except OSError:
3471
+ print(f"{file_name} does not exist")
3472
+ continue
3473
+ if progress is False:
3474
+ responses.append(
3475
+ self.upload_file_retry(
3476
+ file_info["uploadUrl"], open_file, extra_headers=extra_headers
3477
+ )
3478
+ )
3479
+ else:
3480
+ if callable(progress):
3481
+ responses.append( # type: ignore
3482
+ self.upload_file_retry(
3483
+ file_url, open_file, progress, extra_headers=extra_headers
3484
+ )
3485
+ )
3486
+ else:
3487
+ length = os.fstat(open_file.fileno()).st_size
3488
+ with click.progressbar( # type: ignore
3489
+ file=progress,
3490
+ length=length,
3491
+ label=f"Uploading file: {file_name}",
3492
+ fill_char=click.style("&", fg="green"),
3493
+ ) as bar:
3494
+ responses.append(
3495
+ self.upload_file_retry(
3496
+ file_url,
3497
+ open_file,
3498
+ lambda bites, _: bar.update(bites),
3499
+ extra_headers=extra_headers,
3500
+ )
3501
+ )
3502
+ open_file.close()
3503
+ return responses
3504
+
3505
+ def link_artifact(
3506
+ self,
3507
+ client_id: str,
3508
+ server_id: str,
3509
+ portfolio_name: str,
3510
+ entity: str,
3511
+ project: str,
3512
+ aliases: Sequence[str],
3513
+ organization: str,
3514
+ ) -> Dict[str, Any]:
3515
+ template = """
3516
+ mutation LinkArtifact(
3517
+ $artifactPortfolioName: String!,
3518
+ $entityName: String!,
3519
+ $projectName: String!,
3520
+ $aliases: [ArtifactAliasInput!],
3521
+ ID_TYPE
3522
+ ) {
3523
+ linkArtifact(input: {
3524
+ artifactPortfolioName: $artifactPortfolioName,
3525
+ entityName: $entityName,
3526
+ projectName: $projectName,
3527
+ aliases: $aliases,
3528
+ ID_VALUE
3529
+ }) {
3530
+ versionIndex
3531
+ }
3532
+ }
3533
+ """
3534
+
3535
+ org_entity = ""
3536
+ if is_artifact_registry_project(project):
3537
+ try:
3538
+ org_entity = self._resolve_org_entity_name(
3539
+ entity=entity, organization=organization
3540
+ )
3541
+ except ValueError as e:
3542
+ wandb.termerror(str(e))
3543
+ raise
3544
+
3545
+ def replace(a: str, b: str) -> None:
3546
+ nonlocal template
3547
+ template = template.replace(a, b)
3548
+
3549
+ if server_id:
3550
+ replace("ID_TYPE", "$artifactID: ID")
3551
+ replace("ID_VALUE", "artifactID: $artifactID")
3552
+ elif client_id:
3553
+ replace("ID_TYPE", "$clientID: ID")
3554
+ replace("ID_VALUE", "clientID: $clientID")
3555
+
3556
+ variable_values = {
3557
+ "clientID": client_id,
3558
+ "artifactID": server_id,
3559
+ "artifactPortfolioName": portfolio_name,
3560
+ "entityName": org_entity or entity,
3561
+ "projectName": project,
3562
+ "aliases": [
3563
+ {"alias": alias, "artifactCollectionName": portfolio_name}
3564
+ for alias in aliases
3565
+ ],
3566
+ }
3567
+
3568
+ mutation = gql(template)
3569
+ response = self.gql(mutation, variable_values=variable_values)
3570
+ link_artifact: Dict[str, Any] = response["linkArtifact"]
3571
+ return link_artifact
3572
+
3573
+ def _resolve_org_entity_name(self, entity: str, organization: str = "") -> str:
3574
+ # resolveOrgEntityName fetches the portfolio's org entity's name.
3575
+ #
3576
+ # The organization parameter may be empty, an org's display name, or an org entity name.
3577
+ #
3578
+ # If the server doesn't support fetching the org name of a portfolio, then this returns
3579
+ # the organization parameter, or an error if it is empty. Otherwise, this returns the
3580
+ # fetched value after validating that the given organization, if not empty, matches
3581
+ # either the org's display or entity name.
3582
+
3583
+ if not entity:
3584
+ raise ValueError("Entity name is required to resolve org entity name.")
3585
+
3586
+ org_fields = self.server_organization_type_introspection()
3587
+ can_shorthand_org_entity = "orgEntity" in org_fields
3588
+ if not organization and not can_shorthand_org_entity:
3589
+ raise ValueError(
3590
+ "Fetching Registry artifacts without inputting an organization "
3591
+ "is unavailable for your server version. "
3592
+ "Please upgrade your server to 0.50.0 or later."
3593
+ )
3594
+ if not can_shorthand_org_entity:
3595
+ # Server doesn't support fetching org entity to validate,
3596
+ # assume org entity is correctly inputted
3597
+ return organization
3598
+
3599
+ orgs_from_entity = self._fetch_orgs_and_org_entities_from_entity(entity)
3600
+ if organization:
3601
+ return _match_org_with_fetched_org_entities(organization, orgs_from_entity)
3602
+
3603
+ # If no input organization provided, error if entity belongs to multiple orgs because we
3604
+ # cannot determine which one to use.
3605
+ if len(orgs_from_entity) > 1:
3606
+ raise ValueError(
3607
+ f"Personal entity {entity!r} belongs to multiple organizations "
3608
+ "and cannot be used without specifying the organization name. "
3609
+ "Please specify the organization in the Registry path or use a team entity in the entity settings."
3610
+ )
3611
+ return orgs_from_entity[0].entity_name
3612
+
3613
+ def _fetch_orgs_and_org_entities_from_entity(self, entity: str) -> List[_OrgNames]:
3614
+ """Fetches organization entity names and display names for a given entity.
3615
+
3616
+ Args:
3617
+ entity (str): Entity name to lookup. Can be either a personal or team entity.
3618
+
3619
+ Returns:
3620
+ List[_OrgNames]: List of _OrgNames tuples. (_OrgNames(entity_name, display_name))
3621
+
3622
+ Raises:
3623
+ ValueError: If entity is not found, has no organizations, or other validation errors.
3624
+ """
3625
+ query = gql(
3626
+ """
3627
+ query FetchOrgEntityFromEntity($entityName: String!) {
3628
+ entity(name: $entityName) {
3629
+ organization {
3630
+ name
3631
+ orgEntity {
3632
+ name
3633
+ }
3634
+ }
3635
+ user {
3636
+ organizations {
3637
+ name
3638
+ orgEntity {
3639
+ name
3640
+ }
3641
+ }
3642
+ }
3643
+ }
3644
+ }
3645
+ """
3646
+ )
3647
+ response = self.gql(
3648
+ query,
3649
+ variable_values={
3650
+ "entityName": entity,
3651
+ },
3652
+ )
3653
+
3654
+ # Parse organization from response
3655
+ entity_resp = response["entity"]["organization"]
3656
+ user_resp = response["entity"]["user"]
3657
+ # Check for organization under team/org entity type
3658
+ if entity_resp:
3659
+ org_name = entity_resp.get("name")
3660
+ org_entity_name = entity_resp.get("orgEntity") and entity_resp[
3661
+ "orgEntity"
3662
+ ].get("name")
3663
+ if not org_name or not org_entity_name:
3664
+ raise ValueError(
3665
+ f"Unable to find an organization under entity {entity!r}."
3666
+ )
3667
+ return [_OrgNames(entity_name=org_entity_name, display_name=org_name)]
3668
+ # Check for organization under personal entity type, where a user can belong to multiple orgs
3669
+ elif user_resp:
3670
+ orgs = user_resp.get("organizations", [])
3671
+ org_entities_return = [
3672
+ _OrgNames(
3673
+ entity_name=org["orgEntity"]["name"], display_name=org["name"]
3674
+ )
3675
+ for org in orgs
3676
+ if org.get("orgEntity") and org.get("name")
3677
+ ]
3678
+ if not org_entities_return:
3679
+ raise ValueError(
3680
+ f"Unable to resolve an organization associated with personal entity: {entity!r}. "
3681
+ "This could be because its a personal entity that doesn't belong to any organizations. "
3682
+ "Please specify the organization in the Registry path or use a team entity in the entity settings."
3683
+ )
3684
+ return org_entities_return
3685
+ else:
3686
+ raise ValueError(f"Unable to find an organization under entity {entity!r}.")
3687
+
3688
+ def use_artifact(
3689
+ self,
3690
+ artifact_id: str,
3691
+ entity_name: Optional[str] = None,
3692
+ project_name: Optional[str] = None,
3693
+ run_name: Optional[str] = None,
3694
+ use_as: Optional[str] = None,
3695
+ ) -> Optional[Dict[str, Any]]:
3696
+ query_template = """
3697
+ mutation UseArtifact(
3698
+ $entityName: String!,
3699
+ $projectName: String!,
3700
+ $runName: String!,
3701
+ $artifactID: ID!,
3702
+ _USED_AS_TYPE_
3703
+ ) {
3704
+ useArtifact(input: {
3705
+ entityName: $entityName,
3706
+ projectName: $projectName,
3707
+ runName: $runName,
3708
+ artifactID: $artifactID,
3709
+ _USED_AS_VALUE_
3710
+ }) {
3711
+ artifact {
3712
+ id
3713
+ digest
3714
+ description
3715
+ state
3716
+ createdAt
3717
+ metadata
3718
+ }
3719
+ }
3720
+ }
3721
+ """
3722
+
3723
+ artifact_types = self.server_use_artifact_input_introspection()
3724
+ if "usedAs" in artifact_types:
3725
+ query_template = query_template.replace(
3726
+ "_USED_AS_TYPE_", "$usedAs: String"
3727
+ ).replace("_USED_AS_VALUE_", "usedAs: $usedAs")
3728
+ else:
3729
+ query_template = query_template.replace("_USED_AS_TYPE_", "").replace(
3730
+ "_USED_AS_VALUE_", ""
3731
+ )
3732
+
3733
+ query = gql(query_template)
3734
+
3735
+ entity_name = entity_name or self.settings("entity")
3736
+ project_name = project_name or self.settings("project")
3737
+ run_name = run_name or self.current_run_id
3738
+
3739
+ response = self.gql(
3740
+ query,
3741
+ variable_values={
3742
+ "entityName": entity_name,
3743
+ "projectName": project_name,
3744
+ "runName": run_name,
3745
+ "artifactID": artifact_id,
3746
+ "usedAs": use_as,
3747
+ },
3748
+ )
3749
+
3750
+ if response["useArtifact"]["artifact"]:
3751
+ artifact: Dict[str, Any] = response["useArtifact"]["artifact"]
3752
+ return artifact
3753
+ return None
3754
+
3755
+ # Fetch fields available in backend of Organization type
3756
+ def server_organization_type_introspection(self) -> List[str]:
3757
+ query_string = """
3758
+ query ProbeServerOrganization {
3759
+ OrganizationInfoType: __type(name:"Organization") {
3760
+ fields {
3761
+ name
3762
+ }
3763
+ }
3764
+ }
3765
+ """
3766
+
3767
+ if self.server_organization_type_fields_info is None:
3768
+ query = gql(query_string)
3769
+ res = self.gql(query)
3770
+ input_fields = res.get("OrganizationInfoType", {}).get("fields", [{}])
3771
+ self.server_organization_type_fields_info = [
3772
+ field["name"] for field in input_fields if "name" in field
3773
+ ]
3774
+
3775
+ return self.server_organization_type_fields_info
3776
+
3777
+ # Fetch input arguments for the "artifact" endpoint on the "Project" type
3778
+ def server_project_type_introspection(self) -> bool:
3779
+ if self.server_supports_enabling_artifact_usage_tracking is not None:
3780
+ return self.server_supports_enabling_artifact_usage_tracking
3781
+
3782
+ query_string = """
3783
+ query ProbeServerProjectInfo {
3784
+ ProjectInfoType: __type(name:"Project") {
3785
+ fields {
3786
+ name
3787
+ args {
3788
+ name
3789
+ }
3790
+ }
3791
+ }
3792
+ }
3793
+ """
3794
+
3795
+ query = gql(query_string)
3796
+ res = self.gql(query)
3797
+ input_fields = res.get("ProjectInfoType", {}).get("fields", [{}])
3798
+ artifact_args: List[Dict[str, str]] = next(
3799
+ (
3800
+ field.get("args", [])
3801
+ for field in input_fields
3802
+ if field.get("name") == "artifact"
3803
+ ),
3804
+ [],
3805
+ )
3806
+ self.server_supports_enabling_artifact_usage_tracking = any(
3807
+ arg.get("name") == "enableTracking" for arg in artifact_args
3808
+ )
3809
+
3810
+ return self.server_supports_enabling_artifact_usage_tracking
3811
+
3812
+ def create_artifact_type(
3813
+ self,
3814
+ artifact_type_name: str,
3815
+ entity_name: Optional[str] = None,
3816
+ project_name: Optional[str] = None,
3817
+ description: Optional[str] = None,
3818
+ ) -> Optional[str]:
3819
+ mutation = gql(
3820
+ """
3821
+ mutation CreateArtifactType(
3822
+ $entityName: String!,
3823
+ $projectName: String!,
3824
+ $artifactTypeName: String!,
3825
+ $description: String
3826
+ ) {
3827
+ createArtifactType(input: {
3828
+ entityName: $entityName,
3829
+ projectName: $projectName,
3830
+ name: $artifactTypeName,
3831
+ description: $description
3832
+ }) {
3833
+ artifactType {
3834
+ id
3835
+ }
3836
+ }
3837
+ }
3838
+ """
3839
+ )
3840
+ entity_name = entity_name or self.settings("entity")
3841
+ project_name = project_name or self.settings("project")
3842
+ response = self.gql(
3843
+ mutation,
3844
+ variable_values={
3845
+ "entityName": entity_name,
3846
+ "projectName": project_name,
3847
+ "artifactTypeName": artifact_type_name,
3848
+ "description": description,
3849
+ },
3850
+ )
3851
+ _id: Optional[str] = response["createArtifactType"]["artifactType"]["id"]
3852
+ return _id
3853
+
3854
+ def server_artifact_introspection(self) -> List[str]:
3855
+ query_string = """
3856
+ query ProbeServerArtifact {
3857
+ ArtifactInfoType: __type(name:"Artifact") {
3858
+ fields {
3859
+ name
3860
+ }
3861
+ }
3862
+ }
3863
+ """
3864
+
3865
+ if self.server_artifact_fields_info is None:
3866
+ query = gql(query_string)
3867
+ res = self.gql(query)
3868
+ input_fields = res.get("ArtifactInfoType", {}).get("fields", [{}])
3869
+ self.server_artifact_fields_info = [
3870
+ field["name"] for field in input_fields if "name" in field
3871
+ ]
3872
+
3873
+ return self.server_artifact_fields_info
3874
+
3875
+ def server_create_artifact_introspection(self) -> List[str]:
3876
+ query_string = """
3877
+ query ProbeServerCreateArtifactInput {
3878
+ CreateArtifactInputInfoType: __type(name:"CreateArtifactInput") {
3879
+ inputFields{
3880
+ name
3881
+ }
3882
+ }
3883
+ }
3884
+ """
3885
+
3886
+ if self.server_create_artifact_input_info is None:
3887
+ query = gql(query_string)
3888
+ res = self.gql(query)
3889
+ input_fields = res.get("CreateArtifactInputInfoType", {}).get(
3890
+ "inputFields", [{}]
3891
+ )
3892
+ self.server_create_artifact_input_info = [
3893
+ field["name"] for field in input_fields if "name" in field
3894
+ ]
3895
+
3896
+ return self.server_create_artifact_input_info
3897
+
3898
+ def _get_create_artifact_mutation(
3899
+ self,
3900
+ fields: List,
3901
+ history_step: Optional[int],
3902
+ distributed_id: Optional[str],
3903
+ ) -> str:
3904
+ types = ""
3905
+ values = ""
3906
+
3907
+ if "historyStep" in fields and history_step not in [0, None]:
3908
+ types += "$historyStep: Int64!,"
3909
+ values += "historyStep: $historyStep,"
3910
+
3911
+ if distributed_id:
3912
+ types += "$distributedID: String,"
3913
+ values += "distributedID: $distributedID,"
3914
+
3915
+ if "clientID" in fields:
3916
+ types += "$clientID: ID,"
3917
+ values += "clientID: $clientID,"
3918
+
3919
+ if "sequenceClientID" in fields:
3920
+ types += "$sequenceClientID: ID,"
3921
+ values += "sequenceClientID: $sequenceClientID,"
3922
+
3923
+ if "enableDigestDeduplication" in fields:
3924
+ values += "enableDigestDeduplication: true,"
3925
+
3926
+ if "ttlDurationSeconds" in fields:
3927
+ types += "$ttlDurationSeconds: Int64,"
3928
+ values += "ttlDurationSeconds: $ttlDurationSeconds,"
3929
+
3930
+ if "tags" in fields:
3931
+ types += "$tags: [TagInput!],"
3932
+ values += "tags: $tags,"
3933
+
3934
+ query_template = """
3935
+ mutation CreateArtifact(
3936
+ $artifactTypeName: String!,
3937
+ $artifactCollectionNames: [String!],
3938
+ $entityName: String!,
3939
+ $projectName: String!,
3940
+ $runName: String,
3941
+ $description: String,
3942
+ $digest: String!,
3943
+ $aliases: [ArtifactAliasInput!],
3944
+ $metadata: JSONString,
3945
+ _CREATE_ARTIFACT_ADDITIONAL_TYPE_
3946
+ ) {
3947
+ createArtifact(input: {
3948
+ artifactTypeName: $artifactTypeName,
3949
+ artifactCollectionNames: $artifactCollectionNames,
3950
+ entityName: $entityName,
3951
+ projectName: $projectName,
3952
+ runName: $runName,
3953
+ description: $description,
3954
+ digest: $digest,
3955
+ digestAlgorithm: MANIFEST_MD5,
3956
+ aliases: $aliases,
3957
+ metadata: $metadata,
3958
+ _CREATE_ARTIFACT_ADDITIONAL_VALUE_
3959
+ }) {
3960
+ artifact {
3961
+ id
3962
+ state
3963
+ artifactSequence {
3964
+ id
3965
+ latestArtifact {
3966
+ id
3967
+ versionIndex
3968
+ }
3969
+ }
3970
+ }
3971
+ }
3972
+ }
3973
+ """
3974
+
3975
+ return query_template.replace(
3976
+ "_CREATE_ARTIFACT_ADDITIONAL_TYPE_", types
3977
+ ).replace("_CREATE_ARTIFACT_ADDITIONAL_VALUE_", values)
3978
+
3979
+ def create_artifact(
3980
+ self,
3981
+ artifact_type_name: str,
3982
+ artifact_collection_name: str,
3983
+ digest: str,
3984
+ client_id: Optional[str] = None,
3985
+ sequence_client_id: Optional[str] = None,
3986
+ entity_name: Optional[str] = None,
3987
+ project_name: Optional[str] = None,
3988
+ run_name: Optional[str] = None,
3989
+ description: Optional[str] = None,
3990
+ metadata: Optional[Dict] = None,
3991
+ ttl_duration_seconds: Optional[int] = None,
3992
+ aliases: Optional[List[Dict[str, str]]] = None,
3993
+ tags: Optional[List[Dict[str, str]]] = None,
3994
+ distributed_id: Optional[str] = None,
3995
+ is_user_created: Optional[bool] = False,
3996
+ history_step: Optional[int] = None,
3997
+ ) -> Tuple[Dict, Dict]:
3998
+ fields = self.server_create_artifact_introspection()
3999
+ artifact_fields = self.server_artifact_introspection()
4000
+ if ("ttlIsInherited" not in artifact_fields) and ttl_duration_seconds:
4001
+ wandb.termwarn(
4002
+ "Server not compatible with setting Artifact TTLs, please upgrade the server to use Artifact TTL"
4003
+ )
4004
+ # ttlDurationSeconds is only usable if ttlIsInherited is also present
4005
+ ttl_duration_seconds = None
4006
+ if ("tags" not in artifact_fields) and tags:
4007
+ wandb.termwarn(
4008
+ "Server not compatible with Artifact tags. "
4009
+ "To use Artifact tags, please upgrade the server to v0.85 or higher."
4010
+ )
4011
+
4012
+ query_template = self._get_create_artifact_mutation(
4013
+ fields, history_step, distributed_id
4014
+ )
4015
+
4016
+ entity_name = entity_name or self.settings("entity")
4017
+ project_name = project_name or self.settings("project")
4018
+ if not is_user_created:
4019
+ run_name = run_name or self.current_run_id
4020
+
4021
+ mutation = gql(query_template)
4022
+ response = self.gql(
4023
+ mutation,
4024
+ variable_values={
4025
+ "entityName": entity_name,
4026
+ "projectName": project_name,
4027
+ "runName": run_name,
4028
+ "artifactTypeName": artifact_type_name,
4029
+ "artifactCollectionNames": [artifact_collection_name],
4030
+ "clientID": client_id,
4031
+ "sequenceClientID": sequence_client_id,
4032
+ "digest": digest,
4033
+ "description": description,
4034
+ "aliases": list(aliases or []),
4035
+ "tags": list(tags or []),
4036
+ "metadata": json.dumps(util.make_safe_for_json(metadata))
4037
+ if metadata
4038
+ else None,
4039
+ "ttlDurationSeconds": ttl_duration_seconds,
4040
+ "distributedID": distributed_id,
4041
+ "historyStep": history_step,
4042
+ },
4043
+ )
4044
+ av = response["createArtifact"]["artifact"]
4045
+ latest = response["createArtifact"]["artifact"]["artifactSequence"].get(
4046
+ "latestArtifact"
4047
+ )
4048
+ return av, latest
4049
+
4050
+ def commit_artifact(self, artifact_id: str) -> "_Response":
4051
+ mutation = gql(
4052
+ """
4053
+ mutation CommitArtifact(
4054
+ $artifactID: ID!,
4055
+ ) {
4056
+ commitArtifact(input: {
4057
+ artifactID: $artifactID,
4058
+ }) {
4059
+ artifact {
4060
+ id
4061
+ digest
4062
+ }
4063
+ }
4064
+ }
4065
+ """
4066
+ )
4067
+
4068
+ response: _Response = self.gql(
4069
+ mutation,
4070
+ variable_values={"artifactID": artifact_id},
4071
+ timeout=60,
4072
+ )
4073
+ return response
4074
+
4075
+ def complete_multipart_upload_artifact(
4076
+ self,
4077
+ artifact_id: str,
4078
+ storage_path: str,
4079
+ completed_parts: List[Dict[str, Any]],
4080
+ upload_id: Optional[str],
4081
+ complete_multipart_action: str = "Complete",
4082
+ ) -> Optional[str]:
4083
+ mutation = gql(
4084
+ """
4085
+ mutation CompleteMultipartUploadArtifact(
4086
+ $completeMultipartAction: CompleteMultipartAction!,
4087
+ $completedParts: [UploadPartsInput!]!,
4088
+ $artifactID: ID!
4089
+ $storagePath: String!
4090
+ $uploadID: String!
4091
+ ) {
4092
+ completeMultipartUploadArtifact(
4093
+ input: {
4094
+ completeMultipartAction: $completeMultipartAction,
4095
+ completedParts: $completedParts,
4096
+ artifactID: $artifactID,
4097
+ storagePath: $storagePath
4098
+ uploadID: $uploadID
4099
+ }
4100
+ ) {
4101
+ digest
4102
+ }
4103
+ }
4104
+ """
4105
+ )
4106
+ response = self.gql(
4107
+ mutation,
4108
+ variable_values={
4109
+ "completeMultipartAction": complete_multipart_action,
4110
+ "artifactID": artifact_id,
4111
+ "storagePath": storage_path,
4112
+ "completedParts": completed_parts,
4113
+ "uploadID": upload_id,
4114
+ },
4115
+ )
4116
+ digest: Optional[str] = response["completeMultipartUploadArtifact"]["digest"]
4117
+ return digest
4118
+
4119
+ def create_artifact_manifest(
4120
+ self,
4121
+ name: str,
4122
+ digest: str,
4123
+ artifact_id: Optional[str],
4124
+ base_artifact_id: Optional[str] = None,
4125
+ entity: Optional[str] = None,
4126
+ project: Optional[str] = None,
4127
+ run: Optional[str] = None,
4128
+ include_upload: bool = True,
4129
+ type: str = "FULL",
4130
+ ) -> Tuple[str, Dict[str, Any]]:
4131
+ mutation = gql(
4132
+ """
4133
+ mutation CreateArtifactManifest(
4134
+ $name: String!,
4135
+ $digest: String!,
4136
+ $artifactID: ID!,
4137
+ $baseArtifactID: ID,
4138
+ $entityName: String!,
4139
+ $projectName: String!,
4140
+ $runName: String!,
4141
+ $includeUpload: Boolean!,
4142
+ {}
4143
+ ) {{
4144
+ createArtifactManifest(input: {{
4145
+ name: $name,
4146
+ digest: $digest,
4147
+ artifactID: $artifactID,
4148
+ baseArtifactID: $baseArtifactID,
4149
+ entityName: $entityName,
4150
+ projectName: $projectName,
4151
+ runName: $runName,
4152
+ {}
4153
+ }}) {{
4154
+ artifactManifest {{
4155
+ id
4156
+ file {{
4157
+ id
4158
+ name
4159
+ displayName
4160
+ uploadUrl @include(if: $includeUpload)
4161
+ uploadHeaders @include(if: $includeUpload)
4162
+ }}
4163
+ }}
4164
+ }}
4165
+ }}
4166
+ """.format(
4167
+ "$type: ArtifactManifestType = FULL" if type != "FULL" else "",
4168
+ "type: $type" if type != "FULL" else "",
4169
+ )
4170
+ )
4171
+
4172
+ entity_name = entity or self.settings("entity")
4173
+ project_name = project or self.settings("project")
4174
+ run_name = run or self.current_run_id
4175
+
4176
+ response = self.gql(
4177
+ mutation,
4178
+ variable_values={
4179
+ "name": name,
4180
+ "digest": digest,
4181
+ "artifactID": artifact_id,
4182
+ "baseArtifactID": base_artifact_id,
4183
+ "entityName": entity_name,
4184
+ "projectName": project_name,
4185
+ "runName": run_name,
4186
+ "includeUpload": include_upload,
4187
+ "type": type,
4188
+ },
4189
+ )
4190
+ return (
4191
+ response["createArtifactManifest"]["artifactManifest"]["id"],
4192
+ response["createArtifactManifest"]["artifactManifest"]["file"],
4193
+ )
4194
+
4195
+ def update_artifact_manifest(
4196
+ self,
4197
+ artifact_manifest_id: str,
4198
+ base_artifact_id: Optional[str] = None,
4199
+ digest: Optional[str] = None,
4200
+ include_upload: Optional[bool] = True,
4201
+ ) -> Tuple[str, Dict[str, Any]]:
4202
+ mutation = gql(
4203
+ """
4204
+ mutation UpdateArtifactManifest(
4205
+ $artifactManifestID: ID!,
4206
+ $digest: String,
4207
+ $baseArtifactID: ID,
4208
+ $includeUpload: Boolean!,
4209
+ ) {
4210
+ updateArtifactManifest(input: {
4211
+ artifactManifestID: $artifactManifestID,
4212
+ digest: $digest,
4213
+ baseArtifactID: $baseArtifactID,
4214
+ }) {
4215
+ artifactManifest {
4216
+ id
4217
+ file {
4218
+ id
4219
+ name
4220
+ displayName
4221
+ uploadUrl @include(if: $includeUpload)
4222
+ uploadHeaders @include(if: $includeUpload)
4223
+ }
4224
+ }
4225
+ }
4226
+ }
4227
+ """
4228
+ )
4229
+
4230
+ response = self.gql(
4231
+ mutation,
4232
+ variable_values={
4233
+ "artifactManifestID": artifact_manifest_id,
4234
+ "digest": digest,
4235
+ "baseArtifactID": base_artifact_id,
4236
+ "includeUpload": include_upload,
4237
+ },
4238
+ )
4239
+
4240
+ return (
4241
+ response["updateArtifactManifest"]["artifactManifest"]["id"],
4242
+ response["updateArtifactManifest"]["artifactManifest"]["file"],
4243
+ )
4244
+
4245
+ def update_artifact_metadata(
4246
+ self, artifact_id: str, metadata: Dict[str, Any]
4247
+ ) -> Dict[str, Any]:
4248
+ """Set the metadata of the given artifact version."""
4249
+ mutation = gql(
4250
+ """
4251
+ mutation UpdateArtifact(
4252
+ $artifactID: ID!,
4253
+ $metadata: JSONString,
4254
+ ) {
4255
+ updateArtifact(input: {
4256
+ artifactID: $artifactID,
4257
+ metadata: $metadata,
4258
+ }) {
4259
+ artifact {
4260
+ id
4261
+ }
4262
+ }
4263
+ }
4264
+ """
4265
+ )
4266
+ response = self.gql(
4267
+ mutation,
4268
+ variable_values={
4269
+ "artifactID": artifact_id,
4270
+ "metadata": json.dumps(metadata),
4271
+ },
4272
+ )
4273
+ return response["updateArtifact"]["artifact"]
4274
+
4275
+ def _resolve_client_id(
4276
+ self,
4277
+ client_id: str,
4278
+ ) -> Optional[str]:
4279
+ if client_id in self._client_id_mapping:
4280
+ return self._client_id_mapping[client_id]
4281
+
4282
+ query = gql(
4283
+ """
4284
+ query ClientIDMapping($clientID: ID!) {
4285
+ clientIDMapping(clientID: $clientID) {
4286
+ serverID
4287
+ }
4288
+ }
4289
+ """
4290
+ )
4291
+ response = self.gql(
4292
+ query,
4293
+ variable_values={
4294
+ "clientID": client_id,
4295
+ },
4296
+ )
4297
+ server_id = None
4298
+ if response is not None:
4299
+ client_id_mapping = response.get("clientIDMapping")
4300
+ if client_id_mapping is not None:
4301
+ server_id = client_id_mapping.get("serverID")
4302
+ if server_id is not None:
4303
+ self._client_id_mapping[client_id] = server_id
4304
+ return server_id
4305
+
4306
+ def server_create_artifact_file_spec_input_introspection(self) -> List:
4307
+ query_string = """
4308
+ query ProbeServerCreateArtifactFileSpecInput {
4309
+ CreateArtifactFileSpecInputInfoType: __type(name:"CreateArtifactFileSpecInput") {
4310
+ inputFields{
4311
+ name
4312
+ }
4313
+ }
4314
+ }
4315
+ """
4316
+
4317
+ query = gql(query_string)
4318
+ res = self.gql(query)
4319
+ create_artifact_file_spec_input_info = [
4320
+ field.get("name", "")
4321
+ for field in res.get("CreateArtifactFileSpecInputInfoType", {}).get(
4322
+ "inputFields", [{}]
4323
+ )
4324
+ ]
4325
+ return create_artifact_file_spec_input_info
4326
+
4327
+ @normalize_exceptions
4328
+ def create_artifact_files(
4329
+ self, artifact_files: Iterable["CreateArtifactFileSpecInput"]
4330
+ ) -> Mapping[str, "CreateArtifactFilesResponseFile"]:
4331
+ query_template = """
4332
+ mutation CreateArtifactFiles(
4333
+ $storageLayout: ArtifactStorageLayout!
4334
+ $artifactFiles: [CreateArtifactFileSpecInput!]!
4335
+ ) {
4336
+ createArtifactFiles(input: {
4337
+ artifactFiles: $artifactFiles,
4338
+ storageLayout: $storageLayout,
4339
+ }) {
4340
+ files {
4341
+ edges {
4342
+ node {
4343
+ id
4344
+ name
4345
+ displayName
4346
+ uploadUrl
4347
+ uploadHeaders
4348
+ _MULTIPART_UPLOAD_FIELDS_
4349
+ artifact {
4350
+ id
4351
+ }
4352
+ }
4353
+ }
4354
+ }
4355
+ }
4356
+ }
4357
+ """
4358
+ multipart_upload_url_query = """
4359
+ storagePath
4360
+ uploadMultipartUrls {
4361
+ uploadID
4362
+ uploadUrlParts {
4363
+ partNumber
4364
+ uploadUrl
4365
+ }
4366
+ }
4367
+ """
4368
+
4369
+ # TODO: we should use constants here from interface/artifacts.py
4370
+ # but probably don't want the dependency. We're going to remove
4371
+ # this setting in a future release, so I'm just hard-coding the strings.
4372
+ storage_layout = "V2"
4373
+ if env.get_use_v1_artifacts():
4374
+ storage_layout = "V1"
4375
+
4376
+ create_artifact_file_spec_input_fields = (
4377
+ self.server_create_artifact_file_spec_input_introspection()
4378
+ )
4379
+ if "uploadPartsInput" in create_artifact_file_spec_input_fields:
4380
+ query_template = query_template.replace(
4381
+ "_MULTIPART_UPLOAD_FIELDS_", multipart_upload_url_query
4382
+ )
4383
+ else:
4384
+ query_template = query_template.replace("_MULTIPART_UPLOAD_FIELDS_", "")
4385
+
4386
+ mutation = gql(query_template)
4387
+ response = self.gql(
4388
+ mutation,
4389
+ variable_values={
4390
+ "storageLayout": storage_layout,
4391
+ "artifactFiles": [af for af in artifact_files],
4392
+ },
4393
+ )
4394
+
4395
+ result = {}
4396
+ for edge in response["createArtifactFiles"]["files"]["edges"]:
4397
+ node = edge["node"]
4398
+ result[node["displayName"]] = node
4399
+ return result
4400
+
4401
+ @normalize_exceptions
4402
+ def notify_scriptable_run_alert(
4403
+ self,
4404
+ title: str,
4405
+ text: str,
4406
+ level: Optional[str] = None,
4407
+ wait_duration: Optional["Number"] = None,
4408
+ ) -> bool:
4409
+ mutation = gql(
4410
+ """
4411
+ mutation NotifyScriptableRunAlert(
4412
+ $entityName: String!,
4413
+ $projectName: String!,
4414
+ $runName: String!,
4415
+ $title: String!,
4416
+ $text: String!,
4417
+ $severity: AlertSeverity = INFO,
4418
+ $waitDuration: Duration
4419
+ ) {
4420
+ notifyScriptableRunAlert(input: {
4421
+ entityName: $entityName,
4422
+ projectName: $projectName,
4423
+ runName: $runName,
4424
+ title: $title,
4425
+ text: $text,
4426
+ severity: $severity,
4427
+ waitDuration: $waitDuration
4428
+ }) {
4429
+ success
4430
+ }
4431
+ }
4432
+ """
4433
+ )
4434
+
4435
+ response = self.gql(
4436
+ mutation,
4437
+ variable_values={
4438
+ "entityName": self.settings("entity"),
4439
+ "projectName": self.settings("project"),
4440
+ "runName": self.current_run_id,
4441
+ "title": title,
4442
+ "text": text,
4443
+ "severity": level,
4444
+ "waitDuration": wait_duration,
4445
+ },
4446
+ )
4447
+ success: bool = response["notifyScriptableRunAlert"]["success"]
4448
+ return success
4449
+
4450
+ def get_sweep_state(
4451
+ self, sweep: str, entity: Optional[str] = None, project: Optional[str] = None
4452
+ ) -> "SweepState":
4453
+ state: SweepState = self.sweep(
4454
+ sweep=sweep, entity=entity, project=project, specs="{}"
4455
+ )["state"]
4456
+ return state
4457
+
4458
+ def set_sweep_state(
4459
+ self,
4460
+ sweep: str,
4461
+ state: "SweepState",
4462
+ entity: Optional[str] = None,
4463
+ project: Optional[str] = None,
4464
+ ) -> None:
4465
+ assert state in ("RUNNING", "PAUSED", "CANCELED", "FINISHED")
4466
+ s = self.sweep(sweep=sweep, entity=entity, project=project, specs="{}")
4467
+ curr_state = s["state"].upper()
4468
+ if state == "PAUSED" and curr_state not in ("PAUSED", "RUNNING"):
4469
+ raise Exception("Cannot pause {} sweep.".format(curr_state.lower()))
4470
+ elif state != "RUNNING" and curr_state not in ("RUNNING", "PAUSED", "PENDING"):
4471
+ raise Exception("Sweep already {}.".format(curr_state.lower()))
4472
+ sweep_id = s["id"]
4473
+ mutation = gql(
4474
+ """
4475
+ mutation UpsertSweep(
4476
+ $id: ID,
4477
+ $state: String,
4478
+ $entityName: String,
4479
+ $projectName: String
4480
+ ) {
4481
+ upsertSweep(input: {
4482
+ id: $id,
4483
+ state: $state,
4484
+ entityName: $entityName,
4485
+ projectName: $projectName
4486
+ }){
4487
+ sweep {
4488
+ name
4489
+ }
4490
+ }
4491
+ }
4492
+ """
4493
+ )
4494
+ self.gql(
4495
+ mutation,
4496
+ variable_values={
4497
+ "id": sweep_id,
4498
+ "state": state,
4499
+ "entityName": entity or self.settings("entity"),
4500
+ "projectName": project or self.settings("project"),
4501
+ },
4502
+ )
4503
+
4504
+ def stop_sweep(
4505
+ self,
4506
+ sweep: str,
4507
+ entity: Optional[str] = None,
4508
+ project: Optional[str] = None,
4509
+ ) -> None:
4510
+ """Finish the sweep to stop running new runs and let currently running runs finish."""
4511
+ self.set_sweep_state(
4512
+ sweep=sweep, state="FINISHED", entity=entity, project=project
4513
+ )
4514
+
4515
+ def cancel_sweep(
4516
+ self,
4517
+ sweep: str,
4518
+ entity: Optional[str] = None,
4519
+ project: Optional[str] = None,
4520
+ ) -> None:
4521
+ """Cancel the sweep to kill all running runs and stop running new runs."""
4522
+ self.set_sweep_state(
4523
+ sweep=sweep, state="CANCELED", entity=entity, project=project
4524
+ )
4525
+
4526
+ def pause_sweep(
4527
+ self,
4528
+ sweep: str,
4529
+ entity: Optional[str] = None,
4530
+ project: Optional[str] = None,
4531
+ ) -> None:
4532
+ """Pause the sweep to temporarily stop running new runs."""
4533
+ self.set_sweep_state(
4534
+ sweep=sweep, state="PAUSED", entity=entity, project=project
4535
+ )
4536
+
4537
+ def resume_sweep(
4538
+ self,
4539
+ sweep: str,
4540
+ entity: Optional[str] = None,
4541
+ project: Optional[str] = None,
4542
+ ) -> None:
4543
+ """Resume the sweep to continue running new runs."""
4544
+ self.set_sweep_state(
4545
+ sweep=sweep, state="RUNNING", entity=entity, project=project
4546
+ )
4547
+
4548
+ def _status_request(self, url: str, length: int) -> requests.Response:
4549
+ """Ask google how much we've uploaded."""
4550
+ check_httpclient_logger_handler()
4551
+ return requests.put(
4552
+ url=url,
4553
+ headers={"Content-Length": "0", "Content-Range": f"bytes */{length}"},
4554
+ )
4555
+
4556
+ def _flatten_edges(self, response: "_Response") -> List[Dict]:
4557
+ """Return an array from the nested graphql relay structure."""
4558
+ return [node["node"] for node in response["edges"]]
4559
+
4560
+ @normalize_exceptions
4561
+ def stop_run(
4562
+ self,
4563
+ run_id: str,
4564
+ ) -> bool:
4565
+ mutation = gql(
4566
+ """
4567
+ mutation stopRun($id: ID!) {
4568
+ stopRun(input: {
4569
+ id: $id
4570
+ }) {
4571
+ clientMutationId
4572
+ success
4573
+ }
4574
+ }
4575
+ """
4576
+ )
4577
+
4578
+ response = self.gql(
4579
+ mutation,
4580
+ variable_values={
4581
+ "id": run_id,
4582
+ },
4583
+ )
4584
+
4585
+ success: bool = response["stopRun"].get("success")
4586
+
4587
+ return success