wandb 0.17.0rc1__py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (841) hide show
  1. package_readme.md +95 -0
  2. wandb/__init__.py +253 -0
  3. wandb/__main__.py +3 -0
  4. wandb/_globals.py +19 -0
  5. wandb/agents/__init__.py +0 -0
  6. wandb/agents/pyagent.py +364 -0
  7. wandb/analytics/__init__.py +3 -0
  8. wandb/analytics/sentry.py +266 -0
  9. wandb/apis/__init__.py +48 -0
  10. wandb/apis/attrs.py +40 -0
  11. wandb/apis/importers/__init__.py +1 -0
  12. wandb/apis/importers/internals/internal.py +386 -0
  13. wandb/apis/importers/internals/protocols.py +99 -0
  14. wandb/apis/importers/internals/util.py +78 -0
  15. wandb/apis/importers/mlflow.py +254 -0
  16. wandb/apis/importers/validation.py +108 -0
  17. wandb/apis/importers/wandb.py +1598 -0
  18. wandb/apis/internal.py +228 -0
  19. wandb/apis/normalize.py +89 -0
  20. wandb/apis/paginator.py +81 -0
  21. wandb/apis/public/__init__.py +34 -0
  22. wandb/apis/public/api.py +1045 -0
  23. wandb/apis/public/artifacts.py +851 -0
  24. wandb/apis/public/const.py +4 -0
  25. wandb/apis/public/files.py +195 -0
  26. wandb/apis/public/history.py +149 -0
  27. wandb/apis/public/jobs.py +634 -0
  28. wandb/apis/public/projects.py +162 -0
  29. wandb/apis/public/query_generator.py +166 -0
  30. wandb/apis/public/reports.py +469 -0
  31. wandb/apis/public/runs.py +803 -0
  32. wandb/apis/public/sweeps.py +240 -0
  33. wandb/apis/public/teams.py +198 -0
  34. wandb/apis/public/users.py +136 -0
  35. wandb/apis/reports/__init__.py +7 -0
  36. wandb/apis/reports/v1/__init__.py +30 -0
  37. wandb/apis/reports/v1/_blocks.py +1406 -0
  38. wandb/apis/reports/v1/_helpers.py +70 -0
  39. wandb/apis/reports/v1/_panels.py +1282 -0
  40. wandb/apis/reports/v1/_templates.py +478 -0
  41. wandb/apis/reports/v1/blocks.py +27 -0
  42. wandb/apis/reports/v1/helpers.py +2 -0
  43. wandb/apis/reports/v1/mutations.py +66 -0
  44. wandb/apis/reports/v1/panels.py +17 -0
  45. wandb/apis/reports/v1/report.py +268 -0
  46. wandb/apis/reports/v1/runset.py +144 -0
  47. wandb/apis/reports/v1/templates.py +7 -0
  48. wandb/apis/reports/v1/util.py +406 -0
  49. wandb/apis/reports/v1/validators.py +131 -0
  50. wandb/apis/reports/v2/__init__.py +20 -0
  51. wandb/apis/reports/v2/blocks.py +25 -0
  52. wandb/apis/reports/v2/expr_parsing.py +257 -0
  53. wandb/apis/reports/v2/gql.py +68 -0
  54. wandb/apis/reports/v2/interface.py +1911 -0
  55. wandb/apis/reports/v2/internal.py +867 -0
  56. wandb/apis/reports/v2/metrics.py +6 -0
  57. wandb/apis/reports/v2/panels.py +15 -0
  58. wandb/beta/workflows.py +283 -0
  59. wandb/bin/wandb-core +0 -0
  60. wandb/catboost/__init__.py +9 -0
  61. wandb/cli/__init__.py +0 -0
  62. wandb/cli/cli.py +2890 -0
  63. wandb/data_types.py +2068 -0
  64. wandb/docker/__init__.py +342 -0
  65. wandb/docker/auth.py +436 -0
  66. wandb/docker/wandb-entrypoint.sh +33 -0
  67. wandb/docker/www_authenticate.py +94 -0
  68. wandb/env.py +496 -0
  69. wandb/errors/__init__.py +46 -0
  70. wandb/errors/term.py +95 -0
  71. wandb/errors/util.py +57 -0
  72. wandb/fastai/__init__.py +9 -0
  73. wandb/filesync/__init__.py +0 -0
  74. wandb/filesync/dir_watcher.py +403 -0
  75. wandb/filesync/stats.py +100 -0
  76. wandb/filesync/step_checksum.py +145 -0
  77. wandb/filesync/step_prepare.py +199 -0
  78. wandb/filesync/step_upload.py +392 -0
  79. wandb/filesync/upload_job.py +218 -0
  80. wandb/integration/__init__.py +0 -0
  81. wandb/integration/catboost/__init__.py +5 -0
  82. wandb/integration/catboost/catboost.py +178 -0
  83. wandb/integration/cohere/__init__.py +3 -0
  84. wandb/integration/cohere/cohere.py +21 -0
  85. wandb/integration/cohere/resolver.py +347 -0
  86. wandb/integration/diffusers/__init__.py +3 -0
  87. wandb/integration/diffusers/autologger.py +76 -0
  88. wandb/integration/diffusers/pipeline_resolver.py +50 -0
  89. wandb/integration/diffusers/resolvers/__init__.py +9 -0
  90. wandb/integration/diffusers/resolvers/multimodal.py +882 -0
  91. wandb/integration/diffusers/resolvers/utils.py +102 -0
  92. wandb/integration/fastai/__init__.py +249 -0
  93. wandb/integration/gym/__init__.py +85 -0
  94. wandb/integration/huggingface/__init__.py +3 -0
  95. wandb/integration/huggingface/huggingface.py +18 -0
  96. wandb/integration/huggingface/resolver.py +213 -0
  97. wandb/integration/keras/__init__.py +14 -0
  98. wandb/integration/keras/callbacks/__init__.py +5 -0
  99. wandb/integration/keras/callbacks/metrics_logger.py +130 -0
  100. wandb/integration/keras/callbacks/model_checkpoint.py +200 -0
  101. wandb/integration/keras/callbacks/tables_builder.py +226 -0
  102. wandb/integration/keras/keras.py +1080 -0
  103. wandb/integration/kfp/__init__.py +6 -0
  104. wandb/integration/kfp/helpers.py +28 -0
  105. wandb/integration/kfp/kfp_patch.py +324 -0
  106. wandb/integration/kfp/wandb_logging.py +182 -0
  107. wandb/integration/langchain/__init__.py +3 -0
  108. wandb/integration/langchain/wandb_tracer.py +48 -0
  109. wandb/integration/lightgbm/__init__.py +239 -0
  110. wandb/integration/lightning/__init__.py +0 -0
  111. wandb/integration/lightning/fabric/__init__.py +3 -0
  112. wandb/integration/lightning/fabric/logger.py +762 -0
  113. wandb/integration/magic.py +556 -0
  114. wandb/integration/metaflow/__init__.py +3 -0
  115. wandb/integration/metaflow/metaflow.py +383 -0
  116. wandb/integration/openai/__init__.py +3 -0
  117. wandb/integration/openai/fine_tuning.py +454 -0
  118. wandb/integration/openai/openai.py +22 -0
  119. wandb/integration/openai/resolver.py +240 -0
  120. wandb/integration/prodigy/__init__.py +3 -0
  121. wandb/integration/prodigy/prodigy.py +299 -0
  122. wandb/integration/sacred/__init__.py +117 -0
  123. wandb/integration/sagemaker/__init__.py +12 -0
  124. wandb/integration/sagemaker/auth.py +28 -0
  125. wandb/integration/sagemaker/config.py +49 -0
  126. wandb/integration/sagemaker/files.py +3 -0
  127. wandb/integration/sagemaker/resources.py +34 -0
  128. wandb/integration/sb3/__init__.py +3 -0
  129. wandb/integration/sb3/sb3.py +153 -0
  130. wandb/integration/tensorboard/__init__.py +10 -0
  131. wandb/integration/tensorboard/log.py +358 -0
  132. wandb/integration/tensorboard/monkeypatch.py +185 -0
  133. wandb/integration/tensorflow/__init__.py +5 -0
  134. wandb/integration/tensorflow/estimator_hook.py +54 -0
  135. wandb/integration/torch/__init__.py +0 -0
  136. wandb/integration/ultralytics/__init__.py +11 -0
  137. wandb/integration/ultralytics/bbox_utils.py +208 -0
  138. wandb/integration/ultralytics/callback.py +524 -0
  139. wandb/integration/ultralytics/classification_utils.py +83 -0
  140. wandb/integration/ultralytics/mask_utils.py +202 -0
  141. wandb/integration/ultralytics/pose_utils.py +104 -0
  142. wandb/integration/xgboost/__init__.py +11 -0
  143. wandb/integration/xgboost/xgboost.py +189 -0
  144. wandb/integration/yolov8/__init__.py +0 -0
  145. wandb/integration/yolov8/yolov8.py +284 -0
  146. wandb/jupyter.py +501 -0
  147. wandb/keras/__init__.py +19 -0
  148. wandb/lightgbm/__init__.py +9 -0
  149. wandb/magic.py +3 -0
  150. wandb/mpmain/__init__.py +0 -0
  151. wandb/mpmain/__main__.py +1 -0
  152. wandb/old/__init__.py +0 -0
  153. wandb/old/core.py +131 -0
  154. wandb/old/settings.py +173 -0
  155. wandb/old/summary.py +435 -0
  156. wandb/plot/__init__.py +19 -0
  157. wandb/plot/bar.py +42 -0
  158. wandb/plot/confusion_matrix.py +99 -0
  159. wandb/plot/histogram.py +36 -0
  160. wandb/plot/line.py +40 -0
  161. wandb/plot/line_series.py +88 -0
  162. wandb/plot/pr_curve.py +135 -0
  163. wandb/plot/roc_curve.py +117 -0
  164. wandb/plot/scatter.py +32 -0
  165. wandb/plots/__init__.py +6 -0
  166. wandb/plots/explain_text.py +36 -0
  167. wandb/plots/heatmap.py +81 -0
  168. wandb/plots/named_entity.py +43 -0
  169. wandb/plots/part_of_speech.py +50 -0
  170. wandb/plots/plot_definitions.py +768 -0
  171. wandb/plots/precision_recall.py +121 -0
  172. wandb/plots/roc.py +103 -0
  173. wandb/plots/utils.py +195 -0
  174. wandb/proto/__init__.py +0 -0
  175. wandb/proto/v3/__init__.py +0 -0
  176. wandb/proto/v3/wandb_base_pb2.py +54 -0
  177. wandb/proto/v3/wandb_internal_pb2.py +1586 -0
  178. wandb/proto/v3/wandb_server_pb2.py +207 -0
  179. wandb/proto/v3/wandb_settings_pb2.py +111 -0
  180. wandb/proto/v3/wandb_telemetry_pb2.py +105 -0
  181. wandb/proto/v4/__init__.py +0 -0
  182. wandb/proto/v4/wandb_base_pb2.py +29 -0
  183. wandb/proto/v4/wandb_internal_pb2.py +354 -0
  184. wandb/proto/v4/wandb_server_pb2.py +62 -0
  185. wandb/proto/v4/wandb_settings_pb2.py +44 -0
  186. wandb/proto/v4/wandb_telemetry_pb2.py +40 -0
  187. wandb/proto/wandb_base_pb2.py +8 -0
  188. wandb/proto/wandb_deprecated.py +37 -0
  189. wandb/proto/wandb_internal_codegen.py +83 -0
  190. wandb/proto/wandb_internal_pb2.py +8 -0
  191. wandb/proto/wandb_server_pb2.py +8 -0
  192. wandb/proto/wandb_settings_pb2.py +8 -0
  193. wandb/proto/wandb_telemetry_pb2.py +8 -0
  194. wandb/py.typed +0 -0
  195. wandb/sacred/__init__.py +3 -0
  196. wandb/sdk/__init__.py +37 -0
  197. wandb/sdk/artifacts/__init__.py +0 -0
  198. wandb/sdk/artifacts/artifact.py +2317 -0
  199. wandb/sdk/artifacts/artifact_download_logger.py +43 -0
  200. wandb/sdk/artifacts/artifact_file_cache.py +229 -0
  201. wandb/sdk/artifacts/artifact_instance_cache.py +15 -0
  202. wandb/sdk/artifacts/artifact_manifest.py +72 -0
  203. wandb/sdk/artifacts/artifact_manifest_entry.py +205 -0
  204. wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
  205. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +90 -0
  206. wandb/sdk/artifacts/artifact_saver.py +270 -0
  207. wandb/sdk/artifacts/artifact_state.py +11 -0
  208. wandb/sdk/artifacts/artifact_ttl.py +7 -0
  209. wandb/sdk/artifacts/exceptions.py +56 -0
  210. wandb/sdk/artifacts/staging.py +25 -0
  211. wandb/sdk/artifacts/storage_handler.py +60 -0
  212. wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
  213. wandb/sdk/artifacts/storage_handlers/azure_handler.py +194 -0
  214. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +195 -0
  215. wandb/sdk/artifacts/storage_handlers/http_handler.py +113 -0
  216. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +135 -0
  217. wandb/sdk/artifacts/storage_handlers/multi_handler.py +54 -0
  218. wandb/sdk/artifacts/storage_handlers/s3_handler.py +300 -0
  219. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +68 -0
  220. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +133 -0
  221. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
  222. wandb/sdk/artifacts/storage_layout.py +6 -0
  223. wandb/sdk/artifacts/storage_policies/__init__.py +4 -0
  224. wandb/sdk/artifacts/storage_policies/register.py +1 -0
  225. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +411 -0
  226. wandb/sdk/artifacts/storage_policy.py +83 -0
  227. wandb/sdk/backend/__init__.py +0 -0
  228. wandb/sdk/backend/backend.py +240 -0
  229. wandb/sdk/data_types/__init__.py +0 -0
  230. wandb/sdk/data_types/_dtypes.py +911 -0
  231. wandb/sdk/data_types/_private.py +10 -0
  232. wandb/sdk/data_types/base_types/__init__.py +0 -0
  233. wandb/sdk/data_types/base_types/json_metadata.py +55 -0
  234. wandb/sdk/data_types/base_types/media.py +313 -0
  235. wandb/sdk/data_types/base_types/wb_value.py +274 -0
  236. wandb/sdk/data_types/helper_types/__init__.py +0 -0
  237. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +293 -0
  238. wandb/sdk/data_types/helper_types/classes.py +159 -0
  239. wandb/sdk/data_types/helper_types/image_mask.py +233 -0
  240. wandb/sdk/data_types/histogram.py +96 -0
  241. wandb/sdk/data_types/html.py +115 -0
  242. wandb/sdk/data_types/image.py +687 -0
  243. wandb/sdk/data_types/molecule.py +241 -0
  244. wandb/sdk/data_types/object_3d.py +363 -0
  245. wandb/sdk/data_types/plotly.py +82 -0
  246. wandb/sdk/data_types/saved_model.py +444 -0
  247. wandb/sdk/data_types/trace_tree.py +438 -0
  248. wandb/sdk/data_types/utils.py +180 -0
  249. wandb/sdk/data_types/video.py +245 -0
  250. wandb/sdk/integration_utils/__init__.py +0 -0
  251. wandb/sdk/integration_utils/auto_logging.py +239 -0
  252. wandb/sdk/integration_utils/data_logging.py +475 -0
  253. wandb/sdk/interface/__init__.py +0 -0
  254. wandb/sdk/interface/constants.py +4 -0
  255. wandb/sdk/interface/interface.py +949 -0
  256. wandb/sdk/interface/interface_queue.py +59 -0
  257. wandb/sdk/interface/interface_relay.py +53 -0
  258. wandb/sdk/interface/interface_shared.py +550 -0
  259. wandb/sdk/interface/interface_sock.py +61 -0
  260. wandb/sdk/interface/message_future.py +27 -0
  261. wandb/sdk/interface/message_future_poll.py +50 -0
  262. wandb/sdk/interface/router.py +118 -0
  263. wandb/sdk/interface/router_queue.py +44 -0
  264. wandb/sdk/interface/router_relay.py +39 -0
  265. wandb/sdk/interface/router_sock.py +36 -0
  266. wandb/sdk/interface/summary_record.py +67 -0
  267. wandb/sdk/internal/__init__.py +0 -0
  268. wandb/sdk/internal/context.py +89 -0
  269. wandb/sdk/internal/datastore.py +297 -0
  270. wandb/sdk/internal/file_pusher.py +184 -0
  271. wandb/sdk/internal/file_stream.py +708 -0
  272. wandb/sdk/internal/flow_control.py +263 -0
  273. wandb/sdk/internal/handler.py +909 -0
  274. wandb/sdk/internal/internal.py +417 -0
  275. wandb/sdk/internal/internal_api.py +4202 -0
  276. wandb/sdk/internal/internal_util.py +100 -0
  277. wandb/sdk/internal/job_builder.py +554 -0
  278. wandb/sdk/internal/profiler.py +78 -0
  279. wandb/sdk/internal/progress.py +111 -0
  280. wandb/sdk/internal/run.py +25 -0
  281. wandb/sdk/internal/sample.py +70 -0
  282. wandb/sdk/internal/sender.py +1642 -0
  283. wandb/sdk/internal/sender_config.py +197 -0
  284. wandb/sdk/internal/settings_static.py +83 -0
  285. wandb/sdk/internal/system/__init__.py +0 -0
  286. wandb/sdk/internal/system/assets/__init__.py +27 -0
  287. wandb/sdk/internal/system/assets/aggregators.py +37 -0
  288. wandb/sdk/internal/system/assets/asset_registry.py +20 -0
  289. wandb/sdk/internal/system/assets/cpu.py +163 -0
  290. wandb/sdk/internal/system/assets/disk.py +210 -0
  291. wandb/sdk/internal/system/assets/gpu.py +414 -0
  292. wandb/sdk/internal/system/assets/gpu_amd.py +230 -0
  293. wandb/sdk/internal/system/assets/gpu_apple.py +177 -0
  294. wandb/sdk/internal/system/assets/interfaces.py +207 -0
  295. wandb/sdk/internal/system/assets/ipu.py +177 -0
  296. wandb/sdk/internal/system/assets/memory.py +166 -0
  297. wandb/sdk/internal/system/assets/network.py +125 -0
  298. wandb/sdk/internal/system/assets/open_metrics.py +299 -0
  299. wandb/sdk/internal/system/assets/tpu.py +154 -0
  300. wandb/sdk/internal/system/assets/trainium.py +398 -0
  301. wandb/sdk/internal/system/env_probe_helpers.py +13 -0
  302. wandb/sdk/internal/system/system_info.py +247 -0
  303. wandb/sdk/internal/system/system_monitor.py +229 -0
  304. wandb/sdk/internal/tb_watcher.py +518 -0
  305. wandb/sdk/internal/thread_local_settings.py +18 -0
  306. wandb/sdk/internal/update.py +113 -0
  307. wandb/sdk/internal/writer.py +206 -0
  308. wandb/sdk/launch/__init__.py +6 -0
  309. wandb/sdk/launch/_launch.py +348 -0
  310. wandb/sdk/launch/_launch_add.py +257 -0
  311. wandb/sdk/launch/_project_spec.py +578 -0
  312. wandb/sdk/launch/agent/__init__.py +5 -0
  313. wandb/sdk/launch/agent/agent.py +878 -0
  314. wandb/sdk/launch/agent/config.py +299 -0
  315. wandb/sdk/launch/agent/job_status_tracker.py +53 -0
  316. wandb/sdk/launch/agent/run_queue_item_file_saver.py +47 -0
  317. wandb/sdk/launch/builder/__init__.py +0 -0
  318. wandb/sdk/launch/builder/abstract.py +89 -0
  319. wandb/sdk/launch/builder/build.py +706 -0
  320. wandb/sdk/launch/builder/docker_builder.py +193 -0
  321. wandb/sdk/launch/builder/kaniko_builder.py +579 -0
  322. wandb/sdk/launch/builder/noop.py +58 -0
  323. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +187 -0
  324. wandb/sdk/launch/create_job.py +520 -0
  325. wandb/sdk/launch/environment/abstract.py +29 -0
  326. wandb/sdk/launch/environment/aws_environment.py +297 -0
  327. wandb/sdk/launch/environment/azure_environment.py +105 -0
  328. wandb/sdk/launch/environment/gcp_environment.py +335 -0
  329. wandb/sdk/launch/environment/local_environment.py +66 -0
  330. wandb/sdk/launch/errors.py +19 -0
  331. wandb/sdk/launch/git_reference.py +109 -0
  332. wandb/sdk/launch/loader.py +249 -0
  333. wandb/sdk/launch/registry/abstract.py +48 -0
  334. wandb/sdk/launch/registry/anon.py +29 -0
  335. wandb/sdk/launch/registry/azure_container_registry.py +124 -0
  336. wandb/sdk/launch/registry/elastic_container_registry.py +192 -0
  337. wandb/sdk/launch/registry/google_artifact_registry.py +219 -0
  338. wandb/sdk/launch/registry/local_registry.py +67 -0
  339. wandb/sdk/launch/runner/__init__.py +0 -0
  340. wandb/sdk/launch/runner/abstract.py +195 -0
  341. wandb/sdk/launch/runner/kubernetes_monitor.py +441 -0
  342. wandb/sdk/launch/runner/kubernetes_runner.py +893 -0
  343. wandb/sdk/launch/runner/local_container.py +299 -0
  344. wandb/sdk/launch/runner/local_process.py +99 -0
  345. wandb/sdk/launch/runner/sagemaker_runner.py +422 -0
  346. wandb/sdk/launch/runner/vertex_runner.py +229 -0
  347. wandb/sdk/launch/sweeps/__init__.py +39 -0
  348. wandb/sdk/launch/sweeps/scheduler.py +738 -0
  349. wandb/sdk/launch/sweeps/scheduler_sweep.py +91 -0
  350. wandb/sdk/launch/sweeps/utils.py +316 -0
  351. wandb/sdk/launch/utils.py +866 -0
  352. wandb/sdk/launch/wandb_reference.py +138 -0
  353. wandb/sdk/lib/__init__.py +5 -0
  354. wandb/sdk/lib/_settings_toposort_generate.py +159 -0
  355. wandb/sdk/lib/_settings_toposort_generated.py +244 -0
  356. wandb/sdk/lib/_wburls_generate.py +25 -0
  357. wandb/sdk/lib/_wburls_generated.py +22 -0
  358. wandb/sdk/lib/apikey.py +258 -0
  359. wandb/sdk/lib/capped_dict.py +26 -0
  360. wandb/sdk/lib/config_util.py +101 -0
  361. wandb/sdk/lib/console.py +39 -0
  362. wandb/sdk/lib/deprecate.py +42 -0
  363. wandb/sdk/lib/disabled.py +190 -0
  364. wandb/sdk/lib/exit_hooks.py +54 -0
  365. wandb/sdk/lib/file_stream_utils.py +118 -0
  366. wandb/sdk/lib/filenames.py +64 -0
  367. wandb/sdk/lib/filesystem.py +372 -0
  368. wandb/sdk/lib/fsm.py +174 -0
  369. wandb/sdk/lib/gitlib.py +239 -0
  370. wandb/sdk/lib/gql_request.py +65 -0
  371. wandb/sdk/lib/handler_util.py +21 -0
  372. wandb/sdk/lib/hashutil.py +62 -0
  373. wandb/sdk/lib/import_hooks.py +275 -0
  374. wandb/sdk/lib/ipython.py +146 -0
  375. wandb/sdk/lib/json_util.py +80 -0
  376. wandb/sdk/lib/lazyloader.py +63 -0
  377. wandb/sdk/lib/mailbox.py +460 -0
  378. wandb/sdk/lib/module.py +69 -0
  379. wandb/sdk/lib/paths.py +106 -0
  380. wandb/sdk/lib/preinit.py +42 -0
  381. wandb/sdk/lib/printer.py +313 -0
  382. wandb/sdk/lib/proto_util.py +69 -0
  383. wandb/sdk/lib/redirect.py +840 -0
  384. wandb/sdk/lib/reporting.py +99 -0
  385. wandb/sdk/lib/retry.py +289 -0
  386. wandb/sdk/lib/run_moment.py +78 -0
  387. wandb/sdk/lib/runid.py +12 -0
  388. wandb/sdk/lib/server.py +52 -0
  389. wandb/sdk/lib/sock_client.py +291 -0
  390. wandb/sdk/lib/sparkline.py +45 -0
  391. wandb/sdk/lib/telemetry.py +100 -0
  392. wandb/sdk/lib/timed_input.py +133 -0
  393. wandb/sdk/lib/timer.py +19 -0
  394. wandb/sdk/lib/tracelog.py +255 -0
  395. wandb/sdk/lib/wburls.py +46 -0
  396. wandb/sdk/service/__init__.py +0 -0
  397. wandb/sdk/service/_startup_debug.py +22 -0
  398. wandb/sdk/service/port_file.py +53 -0
  399. wandb/sdk/service/server.py +119 -0
  400. wandb/sdk/service/server_sock.py +276 -0
  401. wandb/sdk/service/service.py +266 -0
  402. wandb/sdk/service/service_base.py +50 -0
  403. wandb/sdk/service/service_sock.py +70 -0
  404. wandb/sdk/service/streams.py +426 -0
  405. wandb/sdk/verify/__init__.py +0 -0
  406. wandb/sdk/verify/verify.py +501 -0
  407. wandb/sdk/wandb_alerts.py +12 -0
  408. wandb/sdk/wandb_config.py +319 -0
  409. wandb/sdk/wandb_helper.py +54 -0
  410. wandb/sdk/wandb_init.py +1220 -0
  411. wandb/sdk/wandb_login.py +339 -0
  412. wandb/sdk/wandb_manager.py +222 -0
  413. wandb/sdk/wandb_metric.py +110 -0
  414. wandb/sdk/wandb_require.py +92 -0
  415. wandb/sdk/wandb_require_helpers.py +44 -0
  416. wandb/sdk/wandb_run.py +4191 -0
  417. wandb/sdk/wandb_settings.py +1973 -0
  418. wandb/sdk/wandb_setup.py +335 -0
  419. wandb/sdk/wandb_summary.py +150 -0
  420. wandb/sdk/wandb_sweep.py +114 -0
  421. wandb/sdk/wandb_sync.py +75 -0
  422. wandb/sdk/wandb_watch.py +128 -0
  423. wandb/sklearn/__init__.py +37 -0
  424. wandb/sklearn/calculate/__init__.py +32 -0
  425. wandb/sklearn/calculate/calibration_curves.py +125 -0
  426. wandb/sklearn/calculate/class_proportions.py +68 -0
  427. wandb/sklearn/calculate/confusion_matrix.py +92 -0
  428. wandb/sklearn/calculate/decision_boundaries.py +40 -0
  429. wandb/sklearn/calculate/elbow_curve.py +55 -0
  430. wandb/sklearn/calculate/feature_importances.py +67 -0
  431. wandb/sklearn/calculate/learning_curve.py +64 -0
  432. wandb/sklearn/calculate/outlier_candidates.py +69 -0
  433. wandb/sklearn/calculate/residuals.py +86 -0
  434. wandb/sklearn/calculate/silhouette.py +118 -0
  435. wandb/sklearn/calculate/summary_metrics.py +62 -0
  436. wandb/sklearn/plot/__init__.py +35 -0
  437. wandb/sklearn/plot/classifier.py +331 -0
  438. wandb/sklearn/plot/clusterer.py +142 -0
  439. wandb/sklearn/plot/regressor.py +121 -0
  440. wandb/sklearn/plot/shared.py +91 -0
  441. wandb/sklearn/utils.py +183 -0
  442. wandb/sync/__init__.py +3 -0
  443. wandb/sync/sync.py +443 -0
  444. wandb/testing/relay.py +859 -0
  445. wandb/trigger.py +29 -0
  446. wandb/util.py +1925 -0
  447. wandb/vendor/__init__.py +0 -0
  448. wandb/vendor/gql-0.2.0/setup.py +40 -0
  449. wandb/vendor/gql-0.2.0/tests/__init__.py +0 -0
  450. wandb/vendor/gql-0.2.0/tests/starwars/__init__.py +0 -0
  451. wandb/vendor/gql-0.2.0/tests/starwars/fixtures.py +96 -0
  452. wandb/vendor/gql-0.2.0/tests/starwars/schema.py +146 -0
  453. wandb/vendor/gql-0.2.0/tests/starwars/test_dsl.py +293 -0
  454. wandb/vendor/gql-0.2.0/tests/starwars/test_query.py +355 -0
  455. wandb/vendor/gql-0.2.0/tests/starwars/test_validation.py +171 -0
  456. wandb/vendor/gql-0.2.0/tests/test_client.py +31 -0
  457. wandb/vendor/gql-0.2.0/tests/test_transport.py +89 -0
  458. wandb/vendor/gql-0.2.0/wandb_gql/__init__.py +4 -0
  459. wandb/vendor/gql-0.2.0/wandb_gql/client.py +75 -0
  460. wandb/vendor/gql-0.2.0/wandb_gql/dsl.py +152 -0
  461. wandb/vendor/gql-0.2.0/wandb_gql/gql.py +10 -0
  462. wandb/vendor/gql-0.2.0/wandb_gql/transport/__init__.py +0 -0
  463. wandb/vendor/gql-0.2.0/wandb_gql/transport/http.py +6 -0
  464. wandb/vendor/gql-0.2.0/wandb_gql/transport/local_schema.py +15 -0
  465. wandb/vendor/gql-0.2.0/wandb_gql/transport/requests.py +46 -0
  466. wandb/vendor/gql-0.2.0/wandb_gql/utils.py +21 -0
  467. wandb/vendor/graphql-core-1.1/setup.py +86 -0
  468. wandb/vendor/graphql-core-1.1/wandb_graphql/__init__.py +287 -0
  469. wandb/vendor/graphql-core-1.1/wandb_graphql/error/__init__.py +6 -0
  470. wandb/vendor/graphql-core-1.1/wandb_graphql/error/base.py +42 -0
  471. wandb/vendor/graphql-core-1.1/wandb_graphql/error/format_error.py +11 -0
  472. wandb/vendor/graphql-core-1.1/wandb_graphql/error/located_error.py +29 -0
  473. wandb/vendor/graphql-core-1.1/wandb_graphql/error/syntax_error.py +36 -0
  474. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__init__.py +26 -0
  475. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/base.py +311 -0
  476. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executor.py +398 -0
  477. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/__init__.py +0 -0
  478. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/asyncio.py +53 -0
  479. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/gevent.py +22 -0
  480. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/process.py +32 -0
  481. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/sync.py +7 -0
  482. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/thread.py +35 -0
  483. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/utils.py +6 -0
  484. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/__init__.py +0 -0
  485. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/executor.py +66 -0
  486. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/fragment.py +252 -0
  487. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/resolver.py +151 -0
  488. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/utils.py +7 -0
  489. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/middleware.py +57 -0
  490. wandb/vendor/graphql-core-1.1/wandb_graphql/execution/values.py +145 -0
  491. wandb/vendor/graphql-core-1.1/wandb_graphql/graphql.py +60 -0
  492. wandb/vendor/graphql-core-1.1/wandb_graphql/language/__init__.py +0 -0
  493. wandb/vendor/graphql-core-1.1/wandb_graphql/language/ast.py +1349 -0
  494. wandb/vendor/graphql-core-1.1/wandb_graphql/language/base.py +19 -0
  495. wandb/vendor/graphql-core-1.1/wandb_graphql/language/lexer.py +435 -0
  496. wandb/vendor/graphql-core-1.1/wandb_graphql/language/location.py +30 -0
  497. wandb/vendor/graphql-core-1.1/wandb_graphql/language/parser.py +779 -0
  498. wandb/vendor/graphql-core-1.1/wandb_graphql/language/printer.py +193 -0
  499. wandb/vendor/graphql-core-1.1/wandb_graphql/language/source.py +18 -0
  500. wandb/vendor/graphql-core-1.1/wandb_graphql/language/visitor.py +222 -0
  501. wandb/vendor/graphql-core-1.1/wandb_graphql/language/visitor_meta.py +82 -0
  502. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__init__.py +0 -0
  503. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/cached_property.py +17 -0
  504. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/contain_subset.py +28 -0
  505. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/default_ordered_dict.py +40 -0
  506. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/ordereddict.py +8 -0
  507. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/pair_set.py +43 -0
  508. wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/version.py +78 -0
  509. wandb/vendor/graphql-core-1.1/wandb_graphql/type/__init__.py +67 -0
  510. wandb/vendor/graphql-core-1.1/wandb_graphql/type/definition.py +619 -0
  511. wandb/vendor/graphql-core-1.1/wandb_graphql/type/directives.py +132 -0
  512. wandb/vendor/graphql-core-1.1/wandb_graphql/type/introspection.py +440 -0
  513. wandb/vendor/graphql-core-1.1/wandb_graphql/type/scalars.py +131 -0
  514. wandb/vendor/graphql-core-1.1/wandb_graphql/type/schema.py +100 -0
  515. wandb/vendor/graphql-core-1.1/wandb_graphql/type/typemap.py +145 -0
  516. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__init__.py +0 -0
  517. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/assert_valid_name.py +9 -0
  518. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_from_value.py +65 -0
  519. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_to_code.py +49 -0
  520. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_to_dict.py +24 -0
  521. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/base.py +75 -0
  522. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/build_ast_schema.py +291 -0
  523. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/build_client_schema.py +250 -0
  524. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/concat_ast.py +9 -0
  525. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/extend_schema.py +357 -0
  526. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/get_field_def.py +27 -0
  527. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/get_operation_ast.py +21 -0
  528. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/introspection_query.py +90 -0
  529. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/is_valid_literal_value.py +67 -0
  530. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/is_valid_value.py +66 -0
  531. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/quoted_or_list.py +21 -0
  532. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/schema_printer.py +168 -0
  533. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/suggestion_list.py +56 -0
  534. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_comparators.py +69 -0
  535. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_from_ast.py +21 -0
  536. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_info.py +149 -0
  537. wandb/vendor/graphql-core-1.1/wandb_graphql/utils/value_from_ast.py +69 -0
  538. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/__init__.py +4 -0
  539. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__init__.py +79 -0
  540. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/arguments_of_correct_type.py +24 -0
  541. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/base.py +8 -0
  542. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/default_values_of_correct_type.py +44 -0
  543. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/fields_on_correct_type.py +113 -0
  544. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/fragments_on_composite_types.py +33 -0
  545. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_argument_names.py +70 -0
  546. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_directives.py +97 -0
  547. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_fragment_names.py +19 -0
  548. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_type_names.py +43 -0
  549. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/lone_anonymous_operation.py +23 -0
  550. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_fragment_cycles.py +59 -0
  551. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_undefined_variables.py +36 -0
  552. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_unused_fragments.py +38 -0
  553. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_unused_variables.py +37 -0
  554. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/overlapping_fields_can_be_merged.py +529 -0
  555. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/possible_fragment_spreads.py +44 -0
  556. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/provided_non_null_arguments.py +46 -0
  557. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/scalar_leafs.py +33 -0
  558. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_argument_names.py +32 -0
  559. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_fragment_names.py +28 -0
  560. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_input_field_names.py +33 -0
  561. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_operation_names.py +31 -0
  562. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_variable_names.py +27 -0
  563. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/variables_are_input_types.py +21 -0
  564. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/variables_in_allowed_position.py +53 -0
  565. wandb/vendor/graphql-core-1.1/wandb_graphql/validation/validation.py +158 -0
  566. wandb/vendor/promise-2.3.0/conftest.py +30 -0
  567. wandb/vendor/promise-2.3.0/setup.py +64 -0
  568. wandb/vendor/promise-2.3.0/tests/__init__.py +0 -0
  569. wandb/vendor/promise-2.3.0/tests/conftest.py +8 -0
  570. wandb/vendor/promise-2.3.0/tests/test_awaitable.py +32 -0
  571. wandb/vendor/promise-2.3.0/tests/test_awaitable_35.py +47 -0
  572. wandb/vendor/promise-2.3.0/tests/test_benchmark.py +116 -0
  573. wandb/vendor/promise-2.3.0/tests/test_complex_threads.py +23 -0
  574. wandb/vendor/promise-2.3.0/tests/test_dataloader.py +452 -0
  575. wandb/vendor/promise-2.3.0/tests/test_dataloader_awaitable_35.py +99 -0
  576. wandb/vendor/promise-2.3.0/tests/test_dataloader_extra.py +65 -0
  577. wandb/vendor/promise-2.3.0/tests/test_extra.py +670 -0
  578. wandb/vendor/promise-2.3.0/tests/test_issues.py +132 -0
  579. wandb/vendor/promise-2.3.0/tests/test_promise_list.py +70 -0
  580. wandb/vendor/promise-2.3.0/tests/test_spec.py +584 -0
  581. wandb/vendor/promise-2.3.0/tests/test_thread_safety.py +115 -0
  582. wandb/vendor/promise-2.3.0/tests/utils.py +3 -0
  583. wandb/vendor/promise-2.3.0/wandb_promise/__init__.py +38 -0
  584. wandb/vendor/promise-2.3.0/wandb_promise/async_.py +135 -0
  585. wandb/vendor/promise-2.3.0/wandb_promise/compat.py +32 -0
  586. wandb/vendor/promise-2.3.0/wandb_promise/dataloader.py +326 -0
  587. wandb/vendor/promise-2.3.0/wandb_promise/iterate_promise.py +12 -0
  588. wandb/vendor/promise-2.3.0/wandb_promise/promise.py +848 -0
  589. wandb/vendor/promise-2.3.0/wandb_promise/promise_list.py +151 -0
  590. wandb/vendor/promise-2.3.0/wandb_promise/pyutils/__init__.py +0 -0
  591. wandb/vendor/promise-2.3.0/wandb_promise/pyutils/version.py +83 -0
  592. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/__init__.py +0 -0
  593. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/asyncio.py +22 -0
  594. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/gevent.py +21 -0
  595. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/immediate.py +27 -0
  596. wandb/vendor/promise-2.3.0/wandb_promise/schedulers/thread.py +18 -0
  597. wandb/vendor/promise-2.3.0/wandb_promise/utils.py +56 -0
  598. wandb/vendor/pygments/__init__.py +90 -0
  599. wandb/vendor/pygments/cmdline.py +568 -0
  600. wandb/vendor/pygments/console.py +74 -0
  601. wandb/vendor/pygments/filter.py +74 -0
  602. wandb/vendor/pygments/filters/__init__.py +350 -0
  603. wandb/vendor/pygments/formatter.py +95 -0
  604. wandb/vendor/pygments/formatters/__init__.py +153 -0
  605. wandb/vendor/pygments/formatters/_mapping.py +85 -0
  606. wandb/vendor/pygments/formatters/bbcode.py +109 -0
  607. wandb/vendor/pygments/formatters/html.py +851 -0
  608. wandb/vendor/pygments/formatters/img.py +600 -0
  609. wandb/vendor/pygments/formatters/irc.py +182 -0
  610. wandb/vendor/pygments/formatters/latex.py +482 -0
  611. wandb/vendor/pygments/formatters/other.py +160 -0
  612. wandb/vendor/pygments/formatters/rtf.py +147 -0
  613. wandb/vendor/pygments/formatters/svg.py +153 -0
  614. wandb/vendor/pygments/formatters/terminal.py +136 -0
  615. wandb/vendor/pygments/formatters/terminal256.py +309 -0
  616. wandb/vendor/pygments/lexer.py +871 -0
  617. wandb/vendor/pygments/lexers/__init__.py +329 -0
  618. wandb/vendor/pygments/lexers/_asy_builtins.py +1645 -0
  619. wandb/vendor/pygments/lexers/_cl_builtins.py +232 -0
  620. wandb/vendor/pygments/lexers/_cocoa_builtins.py +72 -0
  621. wandb/vendor/pygments/lexers/_csound_builtins.py +1346 -0
  622. wandb/vendor/pygments/lexers/_lasso_builtins.py +5327 -0
  623. wandb/vendor/pygments/lexers/_lua_builtins.py +295 -0
  624. wandb/vendor/pygments/lexers/_mapping.py +500 -0
  625. wandb/vendor/pygments/lexers/_mql_builtins.py +1172 -0
  626. wandb/vendor/pygments/lexers/_openedge_builtins.py +2547 -0
  627. wandb/vendor/pygments/lexers/_php_builtins.py +4756 -0
  628. wandb/vendor/pygments/lexers/_postgres_builtins.py +621 -0
  629. wandb/vendor/pygments/lexers/_scilab_builtins.py +3094 -0
  630. wandb/vendor/pygments/lexers/_sourcemod_builtins.py +1163 -0
  631. wandb/vendor/pygments/lexers/_stan_builtins.py +532 -0
  632. wandb/vendor/pygments/lexers/_stata_builtins.py +419 -0
  633. wandb/vendor/pygments/lexers/_tsql_builtins.py +1004 -0
  634. wandb/vendor/pygments/lexers/_vim_builtins.py +1939 -0
  635. wandb/vendor/pygments/lexers/actionscript.py +240 -0
  636. wandb/vendor/pygments/lexers/agile.py +24 -0
  637. wandb/vendor/pygments/lexers/algebra.py +221 -0
  638. wandb/vendor/pygments/lexers/ambient.py +76 -0
  639. wandb/vendor/pygments/lexers/ampl.py +87 -0
  640. wandb/vendor/pygments/lexers/apl.py +101 -0
  641. wandb/vendor/pygments/lexers/archetype.py +318 -0
  642. wandb/vendor/pygments/lexers/asm.py +641 -0
  643. wandb/vendor/pygments/lexers/automation.py +374 -0
  644. wandb/vendor/pygments/lexers/basic.py +500 -0
  645. wandb/vendor/pygments/lexers/bibtex.py +160 -0
  646. wandb/vendor/pygments/lexers/business.py +612 -0
  647. wandb/vendor/pygments/lexers/c_cpp.py +252 -0
  648. wandb/vendor/pygments/lexers/c_like.py +541 -0
  649. wandb/vendor/pygments/lexers/capnproto.py +78 -0
  650. wandb/vendor/pygments/lexers/chapel.py +102 -0
  651. wandb/vendor/pygments/lexers/clean.py +288 -0
  652. wandb/vendor/pygments/lexers/compiled.py +34 -0
  653. wandb/vendor/pygments/lexers/configs.py +833 -0
  654. wandb/vendor/pygments/lexers/console.py +114 -0
  655. wandb/vendor/pygments/lexers/crystal.py +393 -0
  656. wandb/vendor/pygments/lexers/csound.py +366 -0
  657. wandb/vendor/pygments/lexers/css.py +689 -0
  658. wandb/vendor/pygments/lexers/d.py +251 -0
  659. wandb/vendor/pygments/lexers/dalvik.py +125 -0
  660. wandb/vendor/pygments/lexers/data.py +555 -0
  661. wandb/vendor/pygments/lexers/diff.py +165 -0
  662. wandb/vendor/pygments/lexers/dotnet.py +691 -0
  663. wandb/vendor/pygments/lexers/dsls.py +878 -0
  664. wandb/vendor/pygments/lexers/dylan.py +289 -0
  665. wandb/vendor/pygments/lexers/ecl.py +125 -0
  666. wandb/vendor/pygments/lexers/eiffel.py +65 -0
  667. wandb/vendor/pygments/lexers/elm.py +121 -0
  668. wandb/vendor/pygments/lexers/erlang.py +533 -0
  669. wandb/vendor/pygments/lexers/esoteric.py +277 -0
  670. wandb/vendor/pygments/lexers/ezhil.py +69 -0
  671. wandb/vendor/pygments/lexers/factor.py +344 -0
  672. wandb/vendor/pygments/lexers/fantom.py +250 -0
  673. wandb/vendor/pygments/lexers/felix.py +273 -0
  674. wandb/vendor/pygments/lexers/forth.py +177 -0
  675. wandb/vendor/pygments/lexers/fortran.py +205 -0
  676. wandb/vendor/pygments/lexers/foxpro.py +428 -0
  677. wandb/vendor/pygments/lexers/functional.py +21 -0
  678. wandb/vendor/pygments/lexers/go.py +101 -0
  679. wandb/vendor/pygments/lexers/grammar_notation.py +213 -0
  680. wandb/vendor/pygments/lexers/graph.py +80 -0
  681. wandb/vendor/pygments/lexers/graphics.py +553 -0
  682. wandb/vendor/pygments/lexers/haskell.py +843 -0
  683. wandb/vendor/pygments/lexers/haxe.py +936 -0
  684. wandb/vendor/pygments/lexers/hdl.py +382 -0
  685. wandb/vendor/pygments/lexers/hexdump.py +103 -0
  686. wandb/vendor/pygments/lexers/html.py +602 -0
  687. wandb/vendor/pygments/lexers/idl.py +270 -0
  688. wandb/vendor/pygments/lexers/igor.py +288 -0
  689. wandb/vendor/pygments/lexers/inferno.py +96 -0
  690. wandb/vendor/pygments/lexers/installers.py +322 -0
  691. wandb/vendor/pygments/lexers/int_fiction.py +1343 -0
  692. wandb/vendor/pygments/lexers/iolang.py +63 -0
  693. wandb/vendor/pygments/lexers/j.py +146 -0
  694. wandb/vendor/pygments/lexers/javascript.py +1525 -0
  695. wandb/vendor/pygments/lexers/julia.py +333 -0
  696. wandb/vendor/pygments/lexers/jvm.py +1573 -0
  697. wandb/vendor/pygments/lexers/lisp.py +2621 -0
  698. wandb/vendor/pygments/lexers/make.py +202 -0
  699. wandb/vendor/pygments/lexers/markup.py +595 -0
  700. wandb/vendor/pygments/lexers/math.py +21 -0
  701. wandb/vendor/pygments/lexers/matlab.py +663 -0
  702. wandb/vendor/pygments/lexers/ml.py +769 -0
  703. wandb/vendor/pygments/lexers/modeling.py +358 -0
  704. wandb/vendor/pygments/lexers/modula2.py +1561 -0
  705. wandb/vendor/pygments/lexers/monte.py +204 -0
  706. wandb/vendor/pygments/lexers/ncl.py +894 -0
  707. wandb/vendor/pygments/lexers/nimrod.py +159 -0
  708. wandb/vendor/pygments/lexers/nit.py +64 -0
  709. wandb/vendor/pygments/lexers/nix.py +136 -0
  710. wandb/vendor/pygments/lexers/oberon.py +105 -0
  711. wandb/vendor/pygments/lexers/objective.py +504 -0
  712. wandb/vendor/pygments/lexers/ooc.py +85 -0
  713. wandb/vendor/pygments/lexers/other.py +41 -0
  714. wandb/vendor/pygments/lexers/parasail.py +79 -0
  715. wandb/vendor/pygments/lexers/parsers.py +835 -0
  716. wandb/vendor/pygments/lexers/pascal.py +644 -0
  717. wandb/vendor/pygments/lexers/pawn.py +199 -0
  718. wandb/vendor/pygments/lexers/perl.py +620 -0
  719. wandb/vendor/pygments/lexers/php.py +267 -0
  720. wandb/vendor/pygments/lexers/praat.py +294 -0
  721. wandb/vendor/pygments/lexers/prolog.py +306 -0
  722. wandb/vendor/pygments/lexers/python.py +939 -0
  723. wandb/vendor/pygments/lexers/qvt.py +152 -0
  724. wandb/vendor/pygments/lexers/r.py +453 -0
  725. wandb/vendor/pygments/lexers/rdf.py +270 -0
  726. wandb/vendor/pygments/lexers/rebol.py +431 -0
  727. wandb/vendor/pygments/lexers/resource.py +85 -0
  728. wandb/vendor/pygments/lexers/rnc.py +67 -0
  729. wandb/vendor/pygments/lexers/roboconf.py +82 -0
  730. wandb/vendor/pygments/lexers/robotframework.py +560 -0
  731. wandb/vendor/pygments/lexers/ruby.py +519 -0
  732. wandb/vendor/pygments/lexers/rust.py +220 -0
  733. wandb/vendor/pygments/lexers/sas.py +228 -0
  734. wandb/vendor/pygments/lexers/scripting.py +1222 -0
  735. wandb/vendor/pygments/lexers/shell.py +794 -0
  736. wandb/vendor/pygments/lexers/smalltalk.py +195 -0
  737. wandb/vendor/pygments/lexers/smv.py +79 -0
  738. wandb/vendor/pygments/lexers/snobol.py +83 -0
  739. wandb/vendor/pygments/lexers/special.py +103 -0
  740. wandb/vendor/pygments/lexers/sql.py +681 -0
  741. wandb/vendor/pygments/lexers/stata.py +108 -0
  742. wandb/vendor/pygments/lexers/supercollider.py +90 -0
  743. wandb/vendor/pygments/lexers/tcl.py +145 -0
  744. wandb/vendor/pygments/lexers/templates.py +2283 -0
  745. wandb/vendor/pygments/lexers/testing.py +207 -0
  746. wandb/vendor/pygments/lexers/text.py +25 -0
  747. wandb/vendor/pygments/lexers/textedit.py +169 -0
  748. wandb/vendor/pygments/lexers/textfmts.py +297 -0
  749. wandb/vendor/pygments/lexers/theorem.py +458 -0
  750. wandb/vendor/pygments/lexers/trafficscript.py +54 -0
  751. wandb/vendor/pygments/lexers/typoscript.py +226 -0
  752. wandb/vendor/pygments/lexers/urbi.py +133 -0
  753. wandb/vendor/pygments/lexers/varnish.py +190 -0
  754. wandb/vendor/pygments/lexers/verification.py +111 -0
  755. wandb/vendor/pygments/lexers/web.py +24 -0
  756. wandb/vendor/pygments/lexers/webmisc.py +988 -0
  757. wandb/vendor/pygments/lexers/whiley.py +116 -0
  758. wandb/vendor/pygments/lexers/x10.py +69 -0
  759. wandb/vendor/pygments/modeline.py +44 -0
  760. wandb/vendor/pygments/plugin.py +68 -0
  761. wandb/vendor/pygments/regexopt.py +92 -0
  762. wandb/vendor/pygments/scanner.py +105 -0
  763. wandb/vendor/pygments/sphinxext.py +158 -0
  764. wandb/vendor/pygments/style.py +155 -0
  765. wandb/vendor/pygments/styles/__init__.py +80 -0
  766. wandb/vendor/pygments/styles/abap.py +29 -0
  767. wandb/vendor/pygments/styles/algol.py +63 -0
  768. wandb/vendor/pygments/styles/algol_nu.py +63 -0
  769. wandb/vendor/pygments/styles/arduino.py +98 -0
  770. wandb/vendor/pygments/styles/autumn.py +65 -0
  771. wandb/vendor/pygments/styles/borland.py +51 -0
  772. wandb/vendor/pygments/styles/bw.py +49 -0
  773. wandb/vendor/pygments/styles/colorful.py +81 -0
  774. wandb/vendor/pygments/styles/default.py +73 -0
  775. wandb/vendor/pygments/styles/emacs.py +72 -0
  776. wandb/vendor/pygments/styles/friendly.py +72 -0
  777. wandb/vendor/pygments/styles/fruity.py +42 -0
  778. wandb/vendor/pygments/styles/igor.py +29 -0
  779. wandb/vendor/pygments/styles/lovelace.py +97 -0
  780. wandb/vendor/pygments/styles/manni.py +75 -0
  781. wandb/vendor/pygments/styles/monokai.py +106 -0
  782. wandb/vendor/pygments/styles/murphy.py +80 -0
  783. wandb/vendor/pygments/styles/native.py +65 -0
  784. wandb/vendor/pygments/styles/paraiso_dark.py +125 -0
  785. wandb/vendor/pygments/styles/paraiso_light.py +125 -0
  786. wandb/vendor/pygments/styles/pastie.py +75 -0
  787. wandb/vendor/pygments/styles/perldoc.py +69 -0
  788. wandb/vendor/pygments/styles/rainbow_dash.py +89 -0
  789. wandb/vendor/pygments/styles/rrt.py +33 -0
  790. wandb/vendor/pygments/styles/sas.py +44 -0
  791. wandb/vendor/pygments/styles/stata.py +40 -0
  792. wandb/vendor/pygments/styles/tango.py +141 -0
  793. wandb/vendor/pygments/styles/trac.py +63 -0
  794. wandb/vendor/pygments/styles/vim.py +63 -0
  795. wandb/vendor/pygments/styles/vs.py +38 -0
  796. wandb/vendor/pygments/styles/xcode.py +51 -0
  797. wandb/vendor/pygments/token.py +213 -0
  798. wandb/vendor/pygments/unistring.py +217 -0
  799. wandb/vendor/pygments/util.py +388 -0
  800. wandb/vendor/pynvml/__init__.py +0 -0
  801. wandb/vendor/pynvml/pynvml.py +4779 -0
  802. wandb/vendor/watchdog_0_9_0/wandb_watchdog/__init__.py +17 -0
  803. wandb/vendor/watchdog_0_9_0/wandb_watchdog/events.py +615 -0
  804. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/__init__.py +98 -0
  805. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/api.py +369 -0
  806. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/fsevents.py +172 -0
  807. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/fsevents2.py +239 -0
  808. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/inotify.py +218 -0
  809. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/inotify_buffer.py +81 -0
  810. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/inotify_c.py +575 -0
  811. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/kqueue.py +730 -0
  812. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/polling.py +145 -0
  813. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/read_directory_changes.py +133 -0
  814. wandb/vendor/watchdog_0_9_0/wandb_watchdog/observers/winapi.py +348 -0
  815. wandb/vendor/watchdog_0_9_0/wandb_watchdog/patterns.py +265 -0
  816. wandb/vendor/watchdog_0_9_0/wandb_watchdog/tricks/__init__.py +174 -0
  817. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/__init__.py +151 -0
  818. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/bricks.py +249 -0
  819. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/compat.py +29 -0
  820. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/decorators.py +198 -0
  821. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/delayed_queue.py +88 -0
  822. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/dirsnapshot.py +293 -0
  823. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/echo.py +157 -0
  824. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/event_backport.py +41 -0
  825. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/importlib2.py +40 -0
  826. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/platform.py +57 -0
  827. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/unicode_paths.py +64 -0
  828. wandb/vendor/watchdog_0_9_0/wandb_watchdog/utils/win32stat.py +123 -0
  829. wandb/vendor/watchdog_0_9_0/wandb_watchdog/version.py +28 -0
  830. wandb/vendor/watchdog_0_9_0/wandb_watchdog/watchmedo.py +577 -0
  831. wandb/viz.py +123 -0
  832. wandb/wandb_agent.py +586 -0
  833. wandb/wandb_controller.py +720 -0
  834. wandb/wandb_run.py +9 -0
  835. wandb/wandb_torch.py +550 -0
  836. wandb/xgboost/__init__.py +9 -0
  837. wandb-0.17.0rc1.dist-info/METADATA +219 -0
  838. wandb-0.17.0rc1.dist-info/RECORD +841 -0
  839. wandb-0.17.0rc1.dist-info/WHEEL +6 -0
  840. wandb-0.17.0rc1.dist-info/entry_points.txt +3 -0
  841. wandb-0.17.0rc1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,4202 @@
1
+ import ast
2
+ import asyncio
3
+ import base64
4
+ import datetime
5
+ import functools
6
+ import http.client
7
+ import json
8
+ import logging
9
+ import os
10
+ import re
11
+ import socket
12
+ import sys
13
+ import threading
14
+ from copy import deepcopy
15
+ from typing import (
16
+ IO,
17
+ TYPE_CHECKING,
18
+ Any,
19
+ Callable,
20
+ Dict,
21
+ Iterable,
22
+ List,
23
+ Mapping,
24
+ MutableMapping,
25
+ Optional,
26
+ Sequence,
27
+ TextIO,
28
+ Tuple,
29
+ Union,
30
+ )
31
+
32
+ import click
33
+ import requests
34
+ import yaml
35
+ from wandb_gql import Client, gql
36
+ from wandb_gql.client import RetryError
37
+
38
+ import wandb
39
+ from wandb import env, util
40
+ from wandb.apis.normalize import normalize_exceptions, parse_backend_error_messages
41
+ from wandb.errors import CommError, UnsupportedError, UsageError
42
+ from wandb.integration.sagemaker import parse_sm_secrets
43
+ from wandb.old.settings import Settings
44
+ from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
45
+ from wandb.sdk.lib.gql_request import GraphQLSession
46
+ from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64
47
+
48
+ from ..lib import retry
49
+ from ..lib.filenames import DIFF_FNAME, METADATA_FNAME
50
+ from ..lib.gitlib import GitRepo
51
+ from . import context
52
+ from .progress import AsyncProgress, Progress
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+ if TYPE_CHECKING:
57
+ if sys.version_info >= (3, 8):
58
+ from typing import Literal, TypedDict
59
+ else:
60
+ from typing_extensions import Literal, TypedDict
61
+
62
+ from .progress import ProgressFn
63
+
64
+ class CreateArtifactFileSpecInput(TypedDict, total=False):
65
+ """Corresponds to `type CreateArtifactFileSpecInput` in schema.graphql."""
66
+
67
+ artifactID: str # noqa: N815
68
+ name: str
69
+ md5: str
70
+ mimetype: Optional[str]
71
+ artifactManifestID: Optional[str] # noqa: N815
72
+ uploadPartsInput: Optional[List[Dict[str, object]]] # noqa: N815
73
+
74
+ class CreateArtifactFilesResponseFile(TypedDict):
75
+ id: str
76
+ name: str
77
+ displayName: str # noqa: N815
78
+ uploadUrl: Optional[str] # noqa: N815
79
+ uploadHeaders: Sequence[str] # noqa: N815
80
+ uploadMultipartUrls: "UploadPartsResponse" # noqa: N815
81
+ storagePath: str # noqa: N815
82
+ artifact: "CreateArtifactFilesResponseFileNode"
83
+
84
+ class CreateArtifactFilesResponseFileNode(TypedDict):
85
+ id: str
86
+
87
+ class UploadPartsResponse(TypedDict):
88
+ uploadUrlParts: List["UploadUrlParts"] # noqa: N815
89
+ uploadID: str # noqa: N815
90
+
91
+ class UploadUrlParts(TypedDict):
92
+ partNumber: int # noqa: N815
93
+ uploadUrl: str # noqa: N815
94
+
95
+ class CompleteMultipartUploadArtifactInput(TypedDict):
96
+ """Corresponds to `type CompleteMultipartUploadArtifactInput` in schema.graphql."""
97
+
98
+ completeMultipartAction: str # noqa: N815
99
+ completedParts: Dict[int, str] # noqa: N815
100
+ artifactID: str # noqa: N815
101
+ storagePath: str # noqa: N815
102
+ uploadID: str # noqa: N815
103
+ md5: str
104
+
105
+ class CompleteMultipartUploadArtifactResponse(TypedDict):
106
+ digest: str
107
+
108
+ class DefaultSettings(TypedDict):
109
+ section: str
110
+ git_remote: str
111
+ ignore_globs: Optional[List[str]]
112
+ base_url: Optional[str]
113
+ root_dir: Optional[str]
114
+ api_key: Optional[str]
115
+ entity: Optional[str]
116
+ project: Optional[str]
117
+ _extra_http_headers: Optional[Mapping[str, str]]
118
+ _proxies: Optional[Mapping[str, str]]
119
+
120
+ _Response = MutableMapping
121
+ SweepState = Literal["RUNNING", "PAUSED", "CANCELED", "FINISHED"]
122
+ Number = Union[int, float]
123
+
124
+ # This funny if/else construction is the simplest thing I've found that
125
+ # works at runtime, satisfies Mypy, and gives autocomplete in VSCode:
126
+ if TYPE_CHECKING:
127
+ import httpx
128
+ else:
129
+ httpx = util.get_module("httpx")
130
+
131
+ # class _MappingSupportsCopy(Protocol):
132
+ # def copy(self) -> "_MappingSupportsCopy": ...
133
+ # def keys(self) -> Iterable: ...
134
+ # def __getitem__(self, name: str) -> Any: ...
135
+
136
+ httpclient_logger = logging.getLogger("http.client")
137
+ if os.environ.get("WANDB_DEBUG"):
138
+ httpclient_logger.setLevel(logging.DEBUG)
139
+
140
+
141
+ def check_httpclient_logger_handler() -> None:
142
+ # Only enable http.client logging if WANDB_DEBUG is set
143
+ if not os.environ.get("WANDB_DEBUG"):
144
+ return
145
+ if httpclient_logger.handlers:
146
+ return
147
+
148
+ # Enable HTTPConnection debug logging to the logging framework
149
+ level = logging.DEBUG
150
+
151
+ def httpclient_log(*args: Any) -> None:
152
+ httpclient_logger.log(level, " ".join(args))
153
+
154
+ # mask the print() built-in in the http.client module to use logging instead
155
+ http.client.print = httpclient_log # type: ignore[attr-defined]
156
+ # enable debugging
157
+ http.client.HTTPConnection.debuglevel = 1
158
+
159
+ root_logger = logging.getLogger("wandb")
160
+ if root_logger.handlers:
161
+ httpclient_logger.addHandler(root_logger.handlers[0])
162
+
163
+
164
+ def check_httpx_exc_retriable(exc: Exception) -> bool:
165
+ retriable_codes = (308, 408, 409, 429, 500, 502, 503, 504)
166
+ return (
167
+ isinstance(exc, (httpx.TimeoutException, httpx.NetworkError))
168
+ or (
169
+ isinstance(exc, httpx.HTTPStatusError)
170
+ and exc.response.status_code in retriable_codes
171
+ )
172
+ or (
173
+ isinstance(exc, httpx.HTTPStatusError)
174
+ and exc.response.status_code == 400
175
+ and "x-amz-meta-md5" in exc.request.headers
176
+ and "RequestTimeout" in str(exc.response.content)
177
+ )
178
+ )
179
+
180
+
181
+ class _ThreadLocalData(threading.local):
182
+ context: Optional[context.Context]
183
+
184
+ def __init__(self) -> None:
185
+ self.context = None
186
+
187
+
188
+ class Api:
189
+ """W&B Internal Api wrapper.
190
+
191
+ Note:
192
+ Settings are automatically overridden by looking for
193
+ a `wandb/settings` file in the current working directory or its parent
194
+ directory. If none can be found, we look in the current user's home
195
+ directory.
196
+
197
+ Arguments:
198
+ default_settings(dict, optional): If you aren't using a settings
199
+ file, or you wish to override the section to use in the settings file
200
+ Override the settings here.
201
+ """
202
+
203
+ HTTP_TIMEOUT = env.get_http_timeout(20)
204
+ FILE_PUSHER_TIMEOUT = env.get_file_pusher_timeout()
205
+ _global_context: context.Context
206
+ _local_data: _ThreadLocalData
207
+
208
+ def __init__(
209
+ self,
210
+ default_settings: Optional[
211
+ Union[
212
+ "wandb.sdk.wandb_settings.Settings",
213
+ "wandb.sdk.internal.settings_static.SettingsStatic",
214
+ Settings,
215
+ dict,
216
+ ]
217
+ ] = None,
218
+ load_settings: bool = True,
219
+ retry_timedelta: datetime.timedelta = datetime.timedelta( # noqa: B008 # okay because it's immutable
220
+ days=7
221
+ ),
222
+ environ: MutableMapping = os.environ,
223
+ retry_callback: Optional[Callable[[int, str], Any]] = None,
224
+ ) -> None:
225
+ self._environ = environ
226
+ self._global_context = context.Context()
227
+ self._local_data = _ThreadLocalData()
228
+ self.default_settings: DefaultSettings = {
229
+ "section": "default",
230
+ "git_remote": "origin",
231
+ "ignore_globs": [],
232
+ "base_url": "https://api.wandb.ai",
233
+ "root_dir": None,
234
+ "api_key": None,
235
+ "entity": None,
236
+ "project": None,
237
+ "_extra_http_headers": None,
238
+ "_proxies": None,
239
+ }
240
+ self.retry_timedelta = retry_timedelta
241
+ # todo: Old Settings do not follow the SupportsKeysAndGetItem Protocol
242
+ default_settings = default_settings or {}
243
+ self.default_settings.update(default_settings) # type: ignore
244
+ self.retry_uploads = 10
245
+ self._settings = Settings(
246
+ load_settings=load_settings,
247
+ root_dir=self.default_settings.get("root_dir"),
248
+ )
249
+ self.git = GitRepo(remote=self.settings("git_remote"))
250
+ # Mutable settings set by the _file_stream_api
251
+ self.dynamic_settings = {
252
+ "system_sample_seconds": 2,
253
+ "system_samples": 15,
254
+ "heartbeat_seconds": 30,
255
+ }
256
+
257
+ # todo: remove these hacky hacks after settings refactor is complete
258
+ # keeping this code here to limit scope and so that it is easy to remove later
259
+ extra_http_headers = self.settings("_extra_http_headers") or json.loads(
260
+ self._environ.get("WANDB__EXTRA_HTTP_HEADERS", "{}")
261
+ )
262
+ proxies = self.settings("_proxies") or json.loads(
263
+ self._environ.get("WANDB__PROXIES", "{}")
264
+ )
265
+
266
+ auth = None
267
+ if _thread_local_api_settings.cookies is None:
268
+ auth = ("api", self.api_key or "")
269
+ extra_http_headers.update(_thread_local_api_settings.headers or {})
270
+ self.client = Client(
271
+ transport=GraphQLSession(
272
+ headers={
273
+ "User-Agent": self.user_agent,
274
+ "X-WANDB-USERNAME": env.get_username(env=self._environ),
275
+ "X-WANDB-USER-EMAIL": env.get_user_email(env=self._environ),
276
+ **extra_http_headers,
277
+ },
278
+ use_json=True,
279
+ # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
280
+ # https://bugs.python.org/issue22889
281
+ timeout=self.HTTP_TIMEOUT,
282
+ auth=auth,
283
+ url=f"{self.settings('base_url')}/graphql",
284
+ cookies=_thread_local_api_settings.cookies,
285
+ proxies=proxies,
286
+ )
287
+ )
288
+
289
+ # httpx is an optional dependency, so we lazily instantiate the client
290
+ # only when we need it
291
+ self._async_httpx_client: Optional[httpx.AsyncClient] = None
292
+
293
+ self.retry_callback = retry_callback
294
+ self._retry_gql = retry.Retry(
295
+ self.execute,
296
+ retry_timedelta=retry_timedelta,
297
+ check_retry_fn=util.no_retry_auth,
298
+ retryable_exceptions=(RetryError, requests.RequestException),
299
+ retry_callback=retry_callback,
300
+ )
301
+ self._current_run_id: Optional[str] = None
302
+ self._file_stream_api = None
303
+ self._upload_file_session = requests.Session()
304
+ if self.FILE_PUSHER_TIMEOUT:
305
+ self._upload_file_session.put = functools.partial( # type: ignore
306
+ self._upload_file_session.put,
307
+ timeout=self.FILE_PUSHER_TIMEOUT,
308
+ )
309
+ if proxies:
310
+ self._upload_file_session.proxies.update(proxies)
311
+ # This Retry class is initialized once for each Api instance, so this
312
+ # defaults to retrying 1 million times per process or 7 days
313
+ self.upload_file_retry = normalize_exceptions(
314
+ retry.retriable(retry_timedelta=retry_timedelta)(self.upload_file)
315
+ )
316
+ self.upload_multipart_file_chunk_retry = normalize_exceptions(
317
+ retry.retriable(retry_timedelta=retry_timedelta)(
318
+ self.upload_multipart_file_chunk
319
+ )
320
+ )
321
+ self._client_id_mapping: Dict[str, str] = {}
322
+ # Large file uploads to azure can optionally use their SDK
323
+ self._azure_blob_module = util.get_module("azure.storage.blob")
324
+
325
+ self.query_types: Optional[List[str]] = None
326
+ self.mutation_types: Optional[List[str]] = None
327
+ self.server_info_types: Optional[List[str]] = None
328
+ self.server_use_artifact_input_info: Optional[List[str]] = None
329
+ self.server_create_artifact_input_info: Optional[List[str]] = None
330
+ self.server_artifact_fields_info: Optional[List[str]] = None
331
+ self._max_cli_version: Optional[str] = None
332
+ self._server_settings_type: Optional[List[str]] = None
333
+ self.fail_run_queue_item_input_info: Optional[List[str]] = None
334
+ self.create_launch_agent_input_info: Optional[List[str]] = None
335
+ self.server_create_run_queue_supports_drc: Optional[bool] = None
336
+ self.server_create_run_queue_supports_priority: Optional[bool] = None
337
+ self.server_supports_template_variables: Optional[bool] = None
338
+ self.server_push_to_run_queue_supports_priority: Optional[bool] = None
339
+
340
+ def gql(self, *args: Any, **kwargs: Any) -> Any:
341
+ ret = self._retry_gql(
342
+ *args,
343
+ retry_cancel_event=self.context.cancel_event,
344
+ **kwargs,
345
+ )
346
+ return ret
347
+
348
+ def set_local_context(self, api_context: Optional[context.Context]) -> None:
349
+ self._local_data.context = api_context
350
+
351
+ def clear_local_context(self) -> None:
352
+ self._local_data.context = None
353
+
354
+ @property
355
+ def context(self) -> context.Context:
356
+ return self._local_data.context or self._global_context
357
+
358
+ def reauth(self) -> None:
359
+ """Ensure the current api key is set in the transport."""
360
+ self.client.transport.session.auth = ("api", self.api_key or "")
361
+
362
+ def relocate(self) -> None:
363
+ """Ensure the current api points to the right server."""
364
+ self.client.transport.url = "%s/graphql" % self.settings("base_url")
365
+
366
+ def execute(self, *args: Any, **kwargs: Any) -> "_Response":
367
+ """Wrapper around execute that logs in cases of failure."""
368
+ try:
369
+ return self.client.execute(*args, **kwargs) # type: ignore
370
+ except requests.exceptions.HTTPError as err:
371
+ response = err.response
372
+ assert response is not None
373
+ logger.error(f"{response.status_code} response executing GraphQL.")
374
+ logger.error(response.text)
375
+ for error in parse_backend_error_messages(response):
376
+ wandb.termerror(f"Error while calling W&B API: {error} ({response})")
377
+ raise
378
+
379
+ def disabled(self) -> Union[str, bool]:
380
+ return self._settings.get(Settings.DEFAULT_SECTION, "disabled", fallback=False) # type: ignore
381
+
382
+ def set_current_run_id(self, run_id: str) -> None:
383
+ self._current_run_id = run_id
384
+
385
+ @property
386
+ def current_run_id(self) -> Optional[str]:
387
+ return self._current_run_id
388
+
389
+ @property
390
+ def user_agent(self) -> str:
391
+ return f"W&B Internal Client {wandb.__version__}"
392
+
393
+ @property
394
+ def api_key(self) -> Optional[str]:
395
+ if _thread_local_api_settings.api_key:
396
+ return _thread_local_api_settings.api_key
397
+ auth = requests.utils.get_netrc_auth(self.api_url)
398
+ key = None
399
+ if auth:
400
+ key = auth[-1]
401
+
402
+ # Environment should take precedence
403
+ env_key: Optional[str] = self._environ.get(env.API_KEY)
404
+ sagemaker_key: Optional[str] = parse_sm_secrets().get(env.API_KEY)
405
+ default_key: Optional[str] = self.default_settings.get("api_key")
406
+ return env_key or key or sagemaker_key or default_key
407
+
408
+ @property
409
+ def api_url(self) -> str:
410
+ return self.settings("base_url") # type: ignore
411
+
412
+ @property
413
+ def app_url(self) -> str:
414
+ return wandb.util.app_url(self.api_url)
415
+
416
+ @property
417
+ def default_entity(self) -> str:
418
+ return self.viewer().get("entity") # type: ignore
419
+
420
+ def settings(self, key: Optional[str] = None, section: Optional[str] = None) -> Any:
421
+ """The settings overridden from the wandb/settings file.
422
+
423
+ Arguments:
424
+ key (str, optional): If provided only this setting is returned
425
+ section (str, optional): If provided this section of the setting file is
426
+ used, defaults to "default"
427
+
428
+ Returns:
429
+ A dict with the current settings
430
+
431
+ {
432
+ "entity": "models",
433
+ "base_url": "https://api.wandb.ai",
434
+ "project": None
435
+ }
436
+ """
437
+ result = self.default_settings.copy()
438
+ result.update(self._settings.items(section=section)) # type: ignore
439
+ result.update(
440
+ {
441
+ "entity": env.get_entity(
442
+ self._settings.get(
443
+ Settings.DEFAULT_SECTION,
444
+ "entity",
445
+ fallback=result.get("entity"),
446
+ ),
447
+ env=self._environ,
448
+ ),
449
+ "project": env.get_project(
450
+ self._settings.get(
451
+ Settings.DEFAULT_SECTION,
452
+ "project",
453
+ fallback=result.get("project"),
454
+ ),
455
+ env=self._environ,
456
+ ),
457
+ "base_url": env.get_base_url(
458
+ self._settings.get(
459
+ Settings.DEFAULT_SECTION,
460
+ "base_url",
461
+ fallback=result.get("base_url"),
462
+ ),
463
+ env=self._environ,
464
+ ),
465
+ "ignore_globs": env.get_ignore(
466
+ self._settings.get(
467
+ Settings.DEFAULT_SECTION,
468
+ "ignore_globs",
469
+ fallback=result.get("ignore_globs"),
470
+ ),
471
+ env=self._environ,
472
+ ),
473
+ }
474
+ )
475
+
476
+ return result if key is None else result[key] # type: ignore
477
+
478
+ def clear_setting(
479
+ self, key: str, globally: bool = False, persist: bool = False
480
+ ) -> None:
481
+ self._settings.clear(
482
+ Settings.DEFAULT_SECTION, key, globally=globally, persist=persist
483
+ )
484
+
485
+ def set_setting(
486
+ self, key: str, value: Any, globally: bool = False, persist: bool = False
487
+ ) -> None:
488
+ self._settings.set(
489
+ Settings.DEFAULT_SECTION, key, value, globally=globally, persist=persist
490
+ )
491
+ if key == "entity":
492
+ env.set_entity(value, env=self._environ)
493
+ elif key == "project":
494
+ env.set_project(value, env=self._environ)
495
+ elif key == "base_url":
496
+ self.relocate()
497
+
498
+ def parse_slug(
499
+ self, slug: str, project: Optional[str] = None, run: Optional[str] = None
500
+ ) -> Tuple[str, str]:
501
+ """Parse a slug into a project and run.
502
+
503
+ Arguments:
504
+ slug (str): The slug to parse
505
+ project (str, optional): The project to use, if not provided it will be
506
+ inferred from the slug
507
+ run (str, optional): The run to use, if not provided it will be inferred
508
+ from the slug
509
+
510
+ Returns:
511
+ A dict with the project and run
512
+ """
513
+ if slug and "/" in slug:
514
+ parts = slug.split("/")
515
+ project = parts[0]
516
+ run = parts[1]
517
+ else:
518
+ project = project or self.settings().get("project")
519
+ if project is None:
520
+ raise CommError("No default project configured.")
521
+ run = run or slug or self.current_run_id or env.get_run(env=self._environ)
522
+ assert run, "run must be specified"
523
+ return project, run
524
+
525
+ @normalize_exceptions
526
+ def server_info_introspection(self) -> Tuple[List[str], List[str], List[str]]:
527
+ query_string = """
528
+ query ProbeServerCapabilities {
529
+ QueryType: __type(name: "Query") {
530
+ ...fieldData
531
+ }
532
+ MutationType: __type(name: "Mutation") {
533
+ ...fieldData
534
+ }
535
+ ServerInfoType: __type(name: "ServerInfo") {
536
+ ...fieldData
537
+ }
538
+ }
539
+
540
+ fragment fieldData on __Type {
541
+ fields {
542
+ name
543
+ }
544
+ }
545
+ """
546
+ if (
547
+ self.query_types is None
548
+ or self.mutation_types is None
549
+ or self.server_info_types is None
550
+ ):
551
+ query = gql(query_string)
552
+ res = self.gql(query)
553
+
554
+ self.query_types = [
555
+ field.get("name", "")
556
+ for field in res.get("QueryType", {}).get("fields", [{}])
557
+ ]
558
+ self.mutation_types = [
559
+ field.get("name", "")
560
+ for field in res.get("MutationType", {}).get("fields", [{}])
561
+ ]
562
+ self.server_info_types = [
563
+ field.get("name", "")
564
+ for field in res.get("ServerInfoType", {}).get("fields", [{}])
565
+ ]
566
+ return self.query_types, self.server_info_types, self.mutation_types
567
+
568
+ @normalize_exceptions
569
+ def server_settings_introspection(self) -> None:
570
+ query_string = """
571
+ query ProbeServerSettings {
572
+ ServerSettingsType: __type(name: "ServerSettings") {
573
+ ...fieldData
574
+ }
575
+ }
576
+
577
+ fragment fieldData on __Type {
578
+ fields {
579
+ name
580
+ }
581
+ }
582
+ """
583
+ if self._server_settings_type is None:
584
+ query = gql(query_string)
585
+ res = self.gql(query)
586
+ self._server_settings_type = (
587
+ [
588
+ field.get("name", "")
589
+ for field in res.get("ServerSettingsType", {}).get("fields", [{}])
590
+ ]
591
+ if res
592
+ else []
593
+ )
594
+
595
+ def server_use_artifact_input_introspection(self) -> List:
596
+ query_string = """
597
+ query ProbeServerUseArtifactInput {
598
+ UseArtifactInputInfoType: __type(name: "UseArtifactInput") {
599
+ name
600
+ inputFields {
601
+ name
602
+ }
603
+ }
604
+ }
605
+ """
606
+
607
+ if self.server_use_artifact_input_info is None:
608
+ query = gql(query_string)
609
+ res = self.gql(query)
610
+ self.server_use_artifact_input_info = [
611
+ field.get("name", "")
612
+ for field in res.get("UseArtifactInputInfoType", {}).get(
613
+ "inputFields", [{}]
614
+ )
615
+ ]
616
+ return self.server_use_artifact_input_info
617
+
618
+ @normalize_exceptions
619
+ def launch_agent_introspection(self) -> Optional[str]:
620
+ query = gql(
621
+ """
622
+ query LaunchAgentIntrospection {
623
+ LaunchAgentType: __type(name: "LaunchAgent") {
624
+ name
625
+ }
626
+ }
627
+ """
628
+ )
629
+
630
+ res = self.gql(query)
631
+ return res.get("LaunchAgentType") or None
632
+
633
+ @normalize_exceptions
634
+ def create_run_queue_introspection(self) -> Tuple[bool, bool, bool]:
635
+ _, _, mutations = self.server_info_introspection()
636
+ query_string = """
637
+ query ProbeCreateRunQueueInput {
638
+ CreateRunQueueInputType: __type(name: "CreateRunQueueInput") {
639
+ name
640
+ inputFields {
641
+ name
642
+ }
643
+ }
644
+ }
645
+ """
646
+ if (
647
+ self.server_create_run_queue_supports_drc is None
648
+ or self.server_create_run_queue_supports_priority is None
649
+ ):
650
+ query = gql(query_string)
651
+ res = self.gql(query)
652
+ if res is None:
653
+ raise CommError("Could not get CreateRunQueue input from GQL.")
654
+ self.server_create_run_queue_supports_drc = "defaultResourceConfigID" in [
655
+ x["name"]
656
+ for x in (
657
+ res.get("CreateRunQueueInputType", {}).get("inputFields", [{}])
658
+ )
659
+ ]
660
+ self.server_create_run_queue_supports_priority = "prioritizationMode" in [
661
+ x["name"]
662
+ for x in (
663
+ res.get("CreateRunQueueInputType", {}).get("inputFields", [{}])
664
+ )
665
+ ]
666
+ return (
667
+ "createRunQueue" in mutations,
668
+ self.server_create_run_queue_supports_drc,
669
+ self.server_create_run_queue_supports_priority,
670
+ )
671
+
672
+ @normalize_exceptions
673
+ def push_to_run_queue_introspection(self) -> Tuple[bool, bool]:
674
+ query_string = """
675
+ query ProbePushToRunQueueInput {
676
+ PushToRunQueueInputType: __type(name: "PushToRunQueueInput") {
677
+ name
678
+ inputFields {
679
+ name
680
+ }
681
+ }
682
+ }
683
+ """
684
+
685
+ if (
686
+ self.server_supports_template_variables is None
687
+ or self.server_push_to_run_queue_supports_priority is None
688
+ ):
689
+ query = gql(query_string)
690
+ res = self.gql(query)
691
+ self.server_supports_template_variables = "templateVariableValues" in [
692
+ x["name"]
693
+ for x in (
694
+ res.get("PushToRunQueueInputType", {}).get("inputFields", [{}])
695
+ )
696
+ ]
697
+ self.server_push_to_run_queue_supports_priority = "priority" in [
698
+ x["name"]
699
+ for x in (
700
+ res.get("PushToRunQueueInputType", {}).get("inputFields", [{}])
701
+ )
702
+ ]
703
+
704
+ return (
705
+ self.server_supports_template_variables,
706
+ self.server_push_to_run_queue_supports_priority,
707
+ )
708
+
709
+ @normalize_exceptions
710
+ def create_default_resource_config_introspection(self) -> bool:
711
+ _, _, mutations = self.server_info_introspection()
712
+ return "createDefaultResourceConfig" in mutations
713
+
714
+ @normalize_exceptions
715
+ def fail_run_queue_item_introspection(self) -> bool:
716
+ _, _, mutations = self.server_info_introspection()
717
+ return "failRunQueueItem" in mutations
718
+
719
+ @normalize_exceptions
720
+ def fail_run_queue_item_fields_introspection(self) -> List:
721
+ if self.fail_run_queue_item_input_info:
722
+ return self.fail_run_queue_item_input_info
723
+ query_string = """
724
+ query ProbeServerFailRunQueueItemInput {
725
+ FailRunQueueItemInputInfoType: __type(name:"FailRunQueueItemInput") {
726
+ inputFields{
727
+ name
728
+ }
729
+ }
730
+ }
731
+ """
732
+
733
+ query = gql(query_string)
734
+ res = self.gql(query)
735
+
736
+ self.fail_run_queue_item_input_info = [
737
+ field.get("name", "")
738
+ for field in res.get("FailRunQueueItemInputInfoType", {}).get(
739
+ "inputFields", [{}]
740
+ )
741
+ ]
742
+ return self.fail_run_queue_item_input_info
743
+
744
+ @normalize_exceptions
745
+ def fail_run_queue_item(
746
+ self,
747
+ run_queue_item_id: str,
748
+ message: str,
749
+ stage: str,
750
+ file_paths: Optional[List[str]] = None,
751
+ ) -> bool:
752
+ if not self.fail_run_queue_item_introspection():
753
+ return False
754
+ variable_values: Dict[str, Union[str, Optional[List[str]]]] = {
755
+ "runQueueItemId": run_queue_item_id,
756
+ }
757
+ if "message" in self.fail_run_queue_item_fields_introspection():
758
+ variable_values.update({"message": message, "stage": stage})
759
+ if file_paths is not None:
760
+ variable_values["filePaths"] = file_paths
761
+ mutation_string = """
762
+ mutation failRunQueueItem($runQueueItemId: ID!, $message: String!, $stage: String!, $filePaths: [String!]) {
763
+ failRunQueueItem(
764
+ input: {
765
+ runQueueItemId: $runQueueItemId
766
+ message: $message
767
+ stage: $stage
768
+ filePaths: $filePaths
769
+ }
770
+ ) {
771
+ success
772
+ }
773
+ }
774
+ """
775
+ else:
776
+ mutation_string = """
777
+ mutation failRunQueueItem($runQueueItemId: ID!) {
778
+ failRunQueueItem(
779
+ input: {
780
+ runQueueItemId: $runQueueItemId
781
+ }
782
+ ) {
783
+ success
784
+ }
785
+ }
786
+ """
787
+
788
+ mutation = gql(mutation_string)
789
+ response = self.gql(mutation, variable_values=variable_values)
790
+ result: bool = response["failRunQueueItem"]["success"]
791
+ return result
792
+
793
+ @normalize_exceptions
794
+ def update_run_queue_item_warning_introspection(self) -> bool:
795
+ _, _, mutations = self.server_info_introspection()
796
+ return "updateRunQueueItemWarning" in mutations
797
+
798
+ @normalize_exceptions
799
+ def update_run_queue_item_warning(
800
+ self,
801
+ run_queue_item_id: str,
802
+ message: str,
803
+ stage: str,
804
+ file_paths: Optional[List[str]] = None,
805
+ ) -> bool:
806
+ if not self.update_run_queue_item_warning_introspection():
807
+ return False
808
+ mutation = gql(
809
+ """
810
+ mutation updateRunQueueItemWarning($runQueueItemId: ID!, $message: String!, $stage: String!, $filePaths: [String!]) {
811
+ updateRunQueueItemWarning(
812
+ input: {
813
+ runQueueItemId: $runQueueItemId
814
+ message: $message
815
+ stage: $stage
816
+ filePaths: $filePaths
817
+ }
818
+ ) {
819
+ success
820
+ }
821
+ }
822
+ """
823
+ )
824
+ response = self.gql(
825
+ mutation,
826
+ variable_values={
827
+ "runQueueItemId": run_queue_item_id,
828
+ "message": message,
829
+ "stage": stage,
830
+ "filePaths": file_paths,
831
+ },
832
+ )
833
+ result: bool = response["updateRunQueueItemWarning"]["success"]
834
+ return result
835
+
836
+ @normalize_exceptions
837
+ def viewer(self) -> Dict[str, Any]:
838
+ query = gql(
839
+ """
840
+ query Viewer{
841
+ viewer {
842
+ id
843
+ entity
844
+ username
845
+ flags
846
+ teams {
847
+ edges {
848
+ node {
849
+ name
850
+ }
851
+ }
852
+ }
853
+ }
854
+ }
855
+ """
856
+ )
857
+ res = self.gql(query)
858
+ return res.get("viewer") or {}
859
+
860
+ @normalize_exceptions
861
+ def max_cli_version(self) -> Optional[str]:
862
+ if self._max_cli_version is not None:
863
+ return self._max_cli_version
864
+
865
+ query_types, server_info_types, _ = self.server_info_introspection()
866
+ cli_version_exists = (
867
+ "serverInfo" in query_types and "cliVersionInfo" in server_info_types
868
+ )
869
+ if not cli_version_exists:
870
+ return None
871
+
872
+ _, server_info = self.viewer_server_info()
873
+ self._max_cli_version = server_info.get("cliVersionInfo", {}).get(
874
+ "max_cli_version"
875
+ )
876
+ return self._max_cli_version
877
+
878
+ @normalize_exceptions
879
+ def viewer_server_info(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
880
+ local_query = """
881
+ latestLocalVersionInfo {
882
+ outOfDate
883
+ latestVersionString
884
+ versionOnThisInstanceString
885
+ }
886
+ """
887
+ cli_query = """
888
+ serverInfo {
889
+ cliVersionInfo
890
+ _LOCAL_QUERY_
891
+ }
892
+ """
893
+ query_template = """
894
+ query Viewer{
895
+ viewer {
896
+ id
897
+ entity
898
+ username
899
+ email
900
+ flags
901
+ teams {
902
+ edges {
903
+ node {
904
+ name
905
+ }
906
+ }
907
+ }
908
+ }
909
+ _CLI_QUERY_
910
+ }
911
+ """
912
+ query_types, server_info_types, _ = self.server_info_introspection()
913
+
914
+ cli_version_exists = (
915
+ "serverInfo" in query_types and "cliVersionInfo" in server_info_types
916
+ )
917
+
918
+ local_version_exists = (
919
+ "serverInfo" in query_types
920
+ and "latestLocalVersionInfo" in server_info_types
921
+ )
922
+
923
+ cli_query_string = "" if not cli_version_exists else cli_query
924
+ local_query_string = "" if not local_version_exists else local_query
925
+
926
+ query_string = query_template.replace("_CLI_QUERY_", cli_query_string).replace(
927
+ "_LOCAL_QUERY_", local_query_string
928
+ )
929
+ query = gql(query_string)
930
+ res = self.gql(query)
931
+ return res.get("viewer") or {}, res.get("serverInfo") or {}
932
+
933
+ @normalize_exceptions
934
+ def list_projects(self, entity: Optional[str] = None) -> List[Dict[str, str]]:
935
+ """List projects in W&B scoped by entity.
936
+
937
+ Arguments:
938
+ entity (str, optional): The entity to scope this project to.
939
+
940
+ Returns:
941
+ [{"id","name","description"}]
942
+ """
943
+ query = gql(
944
+ """
945
+ query EntityProjects($entity: String) {
946
+ models(first: 10, entityName: $entity) {
947
+ edges {
948
+ node {
949
+ id
950
+ name
951
+ description
952
+ }
953
+ }
954
+ }
955
+ }
956
+ """
957
+ )
958
+ project_list: List[Dict[str, str]] = self._flatten_edges(
959
+ self.gql(
960
+ query, variable_values={"entity": entity or self.settings("entity")}
961
+ )["models"]
962
+ )
963
+ return project_list
964
+
965
+ @normalize_exceptions
966
+ def project(self, project: str, entity: Optional[str] = None) -> "_Response":
967
+ """Retrieve project.
968
+
969
+ Arguments:
970
+ project (str): The project to get details for
971
+ entity (str, optional): The entity to scope this project to.
972
+
973
+ Returns:
974
+ [{"id","name","repo","dockerImage","description"}]
975
+ """
976
+ query = gql(
977
+ """
978
+ query ProjectDetails($entity: String, $project: String) {
979
+ model(name: $project, entityName: $entity) {
980
+ id
981
+ name
982
+ repo
983
+ dockerImage
984
+ description
985
+ }
986
+ }
987
+ """
988
+ )
989
+ response: _Response = self.gql(
990
+ query, variable_values={"entity": entity, "project": project}
991
+ )["model"]
992
+ return response
993
+
994
+ @normalize_exceptions
995
+ def sweep(
996
+ self,
997
+ sweep: str,
998
+ specs: str,
999
+ project: Optional[str] = None,
1000
+ entity: Optional[str] = None,
1001
+ ) -> Dict[str, Any]:
1002
+ """Retrieve sweep.
1003
+
1004
+ Arguments:
1005
+ sweep (str): The sweep to get details for
1006
+ specs (str): history specs
1007
+ project (str, optional): The project to scope this sweep to.
1008
+ entity (str, optional): The entity to scope this sweep to.
1009
+
1010
+ Returns:
1011
+ [{"id","name","repo","dockerImage","description"}]
1012
+ """
1013
+ query = gql(
1014
+ """
1015
+ query SweepWithRuns($entity: String, $project: String, $sweep: String!, $specs: [JSONString!]!) {
1016
+ project(name: $project, entityName: $entity) {
1017
+ sweep(sweepName: $sweep) {
1018
+ id
1019
+ name
1020
+ method
1021
+ state
1022
+ description
1023
+ config
1024
+ createdAt
1025
+ heartbeatAt
1026
+ updatedAt
1027
+ earlyStopJobRunning
1028
+ bestLoss
1029
+ controller
1030
+ scheduler
1031
+ runs {
1032
+ edges {
1033
+ node {
1034
+ name
1035
+ state
1036
+ config
1037
+ exitcode
1038
+ heartbeatAt
1039
+ shouldStop
1040
+ failed
1041
+ stopped
1042
+ running
1043
+ summaryMetrics
1044
+ sampledHistory(specs: $specs)
1045
+ }
1046
+ }
1047
+ }
1048
+ }
1049
+ }
1050
+ }
1051
+ """
1052
+ )
1053
+ entity = entity or self.settings("entity")
1054
+ project = project or self.settings("project")
1055
+ response = self.gql(
1056
+ query,
1057
+ variable_values={
1058
+ "entity": entity,
1059
+ "project": project,
1060
+ "sweep": sweep,
1061
+ "specs": specs,
1062
+ },
1063
+ )
1064
+ if response["project"] is None or response["project"]["sweep"] is None:
1065
+ raise ValueError(f"Sweep {entity}/{project}/{sweep} not found")
1066
+ data: Dict[str, Any] = response["project"]["sweep"]
1067
+ if data:
1068
+ data["runs"] = self._flatten_edges(data["runs"])
1069
+ return data
1070
+
1071
+ @normalize_exceptions
1072
+ def list_runs(
1073
+ self, project: str, entity: Optional[str] = None
1074
+ ) -> List[Dict[str, str]]:
1075
+ """List runs in W&B scoped by project.
1076
+
1077
+ Arguments:
1078
+ project (str): The project to scope the runs to
1079
+ entity (str, optional): The entity to scope this project to. Defaults to public models
1080
+
1081
+ Returns:
1082
+ [{"id","name","description"}]
1083
+ """
1084
+ query = gql(
1085
+ """
1086
+ query ProjectRuns($model: String!, $entity: String) {
1087
+ model(name: $model, entityName: $entity) {
1088
+ buckets(first: 10) {
1089
+ edges {
1090
+ node {
1091
+ id
1092
+ name
1093
+ displayName
1094
+ description
1095
+ }
1096
+ }
1097
+ }
1098
+ }
1099
+ }
1100
+ """
1101
+ )
1102
+ return self._flatten_edges(
1103
+ self.gql(
1104
+ query,
1105
+ variable_values={
1106
+ "entity": entity or self.settings("entity"),
1107
+ "model": project or self.settings("project"),
1108
+ },
1109
+ )["model"]["buckets"]
1110
+ )
1111
+
1112
+ @normalize_exceptions
1113
+ def run_config(
1114
+ self, project: str, run: Optional[str] = None, entity: Optional[str] = None
1115
+ ) -> Tuple[str, Dict[str, Any], Optional[str], Dict[str, Any]]:
1116
+ """Get the relevant configs for a run.
1117
+
1118
+ Arguments:
1119
+ project (str): The project to download, (can include bucket)
1120
+ run (str, optional): The run to download
1121
+ entity (str, optional): The entity to scope this project to.
1122
+ """
1123
+ check_httpclient_logger_handler()
1124
+
1125
+ query = gql(
1126
+ """
1127
+ query RunConfigs(
1128
+ $name: String!,
1129
+ $entity: String,
1130
+ $run: String!,
1131
+ $pattern: String!,
1132
+ $includeConfig: Boolean!,
1133
+ ) {
1134
+ model(name: $name, entityName: $entity) {
1135
+ bucket(name: $run) {
1136
+ config @include(if: $includeConfig)
1137
+ commit @include(if: $includeConfig)
1138
+ files(pattern: $pattern) {
1139
+ pageInfo {
1140
+ hasNextPage
1141
+ endCursor
1142
+ }
1143
+ edges {
1144
+ node {
1145
+ name
1146
+ directUrl
1147
+ }
1148
+ }
1149
+ }
1150
+ }
1151
+ }
1152
+ }
1153
+ """
1154
+ )
1155
+
1156
+ variable_values = {
1157
+ "name": project,
1158
+ "run": run,
1159
+ "entity": entity,
1160
+ "includeConfig": True,
1161
+ }
1162
+
1163
+ commit: str = ""
1164
+ config: Dict[str, Any] = {}
1165
+ patch: Optional[str] = None
1166
+ metadata: Dict[str, Any] = {}
1167
+
1168
+ # If we use the `names` parameter on the `files` node, then the server
1169
+ # will helpfully give us and 'open' file handle to the files that don't
1170
+ # exist. This is so that we can upload data to it. However, in this
1171
+ # case, we just want to download that file and not upload to it, so
1172
+ # let's instead query for the files that do exist using `pattern`
1173
+ # (with no wildcards).
1174
+ #
1175
+ # Unfortunately we're unable to construct a single pattern that matches
1176
+ # our 2 files, we would need something like regex for that.
1177
+ for filename in [DIFF_FNAME, METADATA_FNAME]:
1178
+ variable_values["pattern"] = filename
1179
+ response = self.gql(query, variable_values=variable_values)
1180
+ if response["model"] is None:
1181
+ raise CommError(f"Run {entity}/{project}/{run} not found")
1182
+ run_obj: Dict = response["model"]["bucket"]
1183
+ # we only need to fetch this config once
1184
+ if variable_values["includeConfig"]:
1185
+ commit = run_obj["commit"]
1186
+ config = json.loads(run_obj["config"] or "{}")
1187
+ variable_values["includeConfig"] = False
1188
+ if run_obj["files"] is not None:
1189
+ for file_edge in run_obj["files"]["edges"]:
1190
+ name = file_edge["node"]["name"]
1191
+ url = file_edge["node"]["directUrl"]
1192
+ res = requests.get(url)
1193
+ res.raise_for_status()
1194
+ if name == METADATA_FNAME:
1195
+ metadata = res.json()
1196
+ elif name == DIFF_FNAME:
1197
+ patch = res.text
1198
+
1199
+ return commit, config, patch, metadata
1200
+
1201
+ @normalize_exceptions
1202
+ def run_resume_status(
1203
+ self, entity: str, project_name: str, name: str
1204
+ ) -> Optional[Dict[str, Any]]:
1205
+ """Check if a run exists and get resume information.
1206
+
1207
+ Arguments:
1208
+ entity (str): The entity to scope this project to.
1209
+ project_name (str): The project to download, (can include bucket)
1210
+ name (str): The run to download
1211
+ """
1212
+ query = gql(
1213
+ """
1214
+ query RunResumeStatus($project: String, $entity: String, $name: String!) {
1215
+ model(name: $project, entityName: $entity) {
1216
+ id
1217
+ name
1218
+ entity {
1219
+ id
1220
+ name
1221
+ }
1222
+
1223
+ bucket(name: $name, missingOk: true) {
1224
+ id
1225
+ name
1226
+ summaryMetrics
1227
+ displayName
1228
+ logLineCount
1229
+ historyLineCount
1230
+ eventsLineCount
1231
+ historyTail
1232
+ eventsTail
1233
+ config
1234
+ tags
1235
+ }
1236
+ }
1237
+ }
1238
+ """
1239
+ )
1240
+
1241
+ response = self.gql(
1242
+ query,
1243
+ variable_values={
1244
+ "entity": entity,
1245
+ "project": project_name,
1246
+ "name": name,
1247
+ },
1248
+ )
1249
+
1250
+ if "model" not in response or "bucket" not in (response["model"] or {}):
1251
+ return None
1252
+
1253
+ project = response["model"]
1254
+ self.set_setting("project", project_name)
1255
+ if "entity" in project:
1256
+ self.set_setting("entity", project["entity"]["name"])
1257
+
1258
+ result: Dict[str, Any] = project["bucket"]
1259
+
1260
+ return result
1261
+
1262
+ @normalize_exceptions
1263
+ def check_stop_requested(
1264
+ self, project_name: str, entity_name: str, run_id: str
1265
+ ) -> bool:
1266
+ query = gql(
1267
+ """
1268
+ query RunStoppedStatus($projectName: String, $entityName: String, $runId: String!) {
1269
+ project(name:$projectName, entityName:$entityName) {
1270
+ run(name:$runId) {
1271
+ stopped
1272
+ }
1273
+ }
1274
+ }
1275
+ """
1276
+ )
1277
+
1278
+ response = self.gql(
1279
+ query,
1280
+ variable_values={
1281
+ "projectName": project_name,
1282
+ "entityName": entity_name,
1283
+ "runId": run_id,
1284
+ },
1285
+ )
1286
+
1287
+ project = response.get("project", None)
1288
+ if not project:
1289
+ return False
1290
+ run = project.get("run", None)
1291
+ if not run:
1292
+ return False
1293
+
1294
+ status: bool = run["stopped"]
1295
+ return status
1296
+
1297
+ def format_project(self, project: str) -> str:
1298
+ return re.sub(r"\W+", "-", project.lower()).strip("-_")
1299
+
1300
+ @normalize_exceptions
1301
+ def upsert_project(
1302
+ self,
1303
+ project: str,
1304
+ id: Optional[str] = None,
1305
+ description: Optional[str] = None,
1306
+ entity: Optional[str] = None,
1307
+ ) -> Dict[str, Any]:
1308
+ """Create a new project.
1309
+
1310
+ Arguments:
1311
+ project (str): The project to create
1312
+ description (str, optional): A description of this project
1313
+ entity (str, optional): The entity to scope this project to.
1314
+ """
1315
+ mutation = gql(
1316
+ """
1317
+ mutation UpsertModel($name: String!, $id: String, $entity: String!, $description: String, $repo: String) {
1318
+ upsertModel(input: { id: $id, name: $name, entityName: $entity, description: $description, repo: $repo }) {
1319
+ model {
1320
+ name
1321
+ description
1322
+ }
1323
+ }
1324
+ }
1325
+ """
1326
+ )
1327
+ response = self.gql(
1328
+ mutation,
1329
+ variable_values={
1330
+ "name": self.format_project(project),
1331
+ "entity": entity or self.settings("entity"),
1332
+ "description": description,
1333
+ "id": id,
1334
+ },
1335
+ )
1336
+ # TODO(jhr): Commenting out 'repo' field for cling, add back
1337
+ # 'description': description, 'repo': self.git.remote_url, 'id': id})
1338
+ result: Dict[str, Any] = response["upsertModel"]["model"]
1339
+ return result
1340
+
1341
+ @normalize_exceptions
1342
+ def entity_is_team(self, entity: str) -> bool:
1343
+ query = gql(
1344
+ """
1345
+ query EntityIsTeam($entity: String!) {
1346
+ entity(name: $entity) {
1347
+ id
1348
+ isTeam
1349
+ }
1350
+ }
1351
+ """
1352
+ )
1353
+ variable_values = {
1354
+ "entity": entity,
1355
+ }
1356
+
1357
+ res = self.gql(query, variable_values)
1358
+ if res.get("entity") is None:
1359
+ raise Exception(
1360
+ f"Error fetching entity {entity} "
1361
+ "check that you have access to this entity"
1362
+ )
1363
+
1364
+ is_team: bool = res["entity"]["isTeam"]
1365
+ return is_team
1366
+
1367
+ @normalize_exceptions
1368
+ def get_project_run_queues(self, entity: str, project: str) -> List[Dict[str, str]]:
1369
+ query = gql(
1370
+ """
1371
+ query ProjectRunQueues($entity: String!, $projectName: String!){
1372
+ project(entityName: $entity, name: $projectName) {
1373
+ runQueues {
1374
+ id
1375
+ name
1376
+ createdBy
1377
+ access
1378
+ }
1379
+ }
1380
+ }
1381
+ """
1382
+ )
1383
+ variable_values = {
1384
+ "projectName": project,
1385
+ "entity": entity,
1386
+ }
1387
+
1388
+ res = self.gql(query, variable_values)
1389
+ if res.get("project") is None:
1390
+ # circular dependency: (LAUNCH_DEFAULT_PROJECT = model-registry)
1391
+ if project == "model-registry":
1392
+ msg = (
1393
+ f"Error fetching run queues for {entity} "
1394
+ "check that you have access to this entity and project"
1395
+ )
1396
+ else:
1397
+ msg = (
1398
+ f"Error fetching run queues for {entity}/{project} "
1399
+ "check that you have access to this entity and project"
1400
+ )
1401
+
1402
+ raise Exception(msg)
1403
+
1404
+ project_run_queues: List[Dict[str, str]] = res["project"]["runQueues"]
1405
+ return project_run_queues
1406
+
1407
+ @normalize_exceptions
1408
+ def create_default_resource_config(
1409
+ self,
1410
+ entity: str,
1411
+ resource: str,
1412
+ config: str,
1413
+ template_variables: Optional[Dict[str, Union[float, int, str]]],
1414
+ ) -> Optional[Dict[str, Any]]:
1415
+ if not self.create_default_resource_config_introspection():
1416
+ raise Exception()
1417
+ supports_template_vars, _ = self.push_to_run_queue_introspection()
1418
+
1419
+ mutation_params = """
1420
+ $entityName: String!,
1421
+ $resource: String!,
1422
+ $config: JSONString!
1423
+ """
1424
+ mutation_inputs = """
1425
+ entityName: $entityName,
1426
+ resource: $resource,
1427
+ config: $config
1428
+ """
1429
+
1430
+ if supports_template_vars:
1431
+ mutation_params += ", $templateVariables: JSONString"
1432
+ mutation_inputs += ", templateVariables: $templateVariables"
1433
+ else:
1434
+ if template_variables is not None:
1435
+ raise UnsupportedError(
1436
+ "server does not support template variables, please update server instance to >=0.46"
1437
+ )
1438
+
1439
+ variable_values = {
1440
+ "entityName": entity,
1441
+ "resource": resource,
1442
+ "config": config,
1443
+ }
1444
+ if supports_template_vars:
1445
+ if template_variables is not None:
1446
+ variable_values["templateVariables"] = json.dumps(template_variables)
1447
+ else:
1448
+ variable_values["templateVariables"] = "{}"
1449
+
1450
+ query = gql(
1451
+ f"""
1452
+ mutation createDefaultResourceConfig(
1453
+ {mutation_params}
1454
+ ) {{
1455
+ createDefaultResourceConfig(
1456
+ input: {{
1457
+ {mutation_inputs}
1458
+ }}
1459
+ ) {{
1460
+ defaultResourceConfigID
1461
+ success
1462
+ }}
1463
+ }}
1464
+ """
1465
+ )
1466
+
1467
+ result: Optional[Dict[str, Any]] = self.gql(query, variable_values)[
1468
+ "createDefaultResourceConfig"
1469
+ ]
1470
+ return result
1471
+
1472
+ @normalize_exceptions
1473
+ def create_run_queue(
1474
+ self,
1475
+ entity: str,
1476
+ project: str,
1477
+ queue_name: str,
1478
+ access: str,
1479
+ prioritization_mode: Optional[str] = None,
1480
+ config_id: Optional[str] = None,
1481
+ ) -> Optional[Dict[str, Any]]:
1482
+ (
1483
+ create_run_queue,
1484
+ supports_drc,
1485
+ supports_prioritization,
1486
+ ) = self.create_run_queue_introspection()
1487
+ if not create_run_queue:
1488
+ raise UnsupportedError(
1489
+ "run queue creation is not supported by this version of "
1490
+ "wandb server. Consider updating to the latest version."
1491
+ )
1492
+ if not supports_drc and config_id is not None:
1493
+ raise UnsupportedError(
1494
+ "default resource configurations are not supported by this version "
1495
+ "of wandb server. Consider updating to the latest version."
1496
+ )
1497
+ if not supports_prioritization and prioritization_mode is not None:
1498
+ raise UnsupportedError(
1499
+ "launch prioritization is not supported by this version of "
1500
+ "wandb server. Consider updating to the latest version."
1501
+ )
1502
+
1503
+ if supports_prioritization:
1504
+ query = gql(
1505
+ """
1506
+ mutation createRunQueue(
1507
+ $entity: String!,
1508
+ $project: String!,
1509
+ $queueName: String!,
1510
+ $access: RunQueueAccessType!,
1511
+ $prioritizationMode: RunQueuePrioritizationMode,
1512
+ $defaultResourceConfigID: ID,
1513
+ ) {
1514
+ createRunQueue(
1515
+ input: {
1516
+ entityName: $entity,
1517
+ projectName: $project,
1518
+ queueName: $queueName,
1519
+ access: $access,
1520
+ prioritizationMode: $prioritizationMode
1521
+ defaultResourceConfigID: $defaultResourceConfigID
1522
+ }
1523
+ ) {
1524
+ success
1525
+ queueID
1526
+ }
1527
+ }
1528
+ """
1529
+ )
1530
+ variable_values = {
1531
+ "entity": entity,
1532
+ "project": project,
1533
+ "queueName": queue_name,
1534
+ "access": access,
1535
+ "prioritizationMode": prioritization_mode,
1536
+ "defaultResourceConfigID": config_id,
1537
+ }
1538
+ else:
1539
+ query = gql(
1540
+ """
1541
+ mutation createRunQueue(
1542
+ $entity: String!,
1543
+ $project: String!,
1544
+ $queueName: String!,
1545
+ $access: RunQueueAccessType!,
1546
+ $defaultResourceConfigID: ID,
1547
+ ) {
1548
+ createRunQueue(
1549
+ input: {
1550
+ entityName: $entity,
1551
+ projectName: $project,
1552
+ queueName: $queueName,
1553
+ access: $access,
1554
+ defaultResourceConfigID: $defaultResourceConfigID
1555
+ }
1556
+ ) {
1557
+ success
1558
+ queueID
1559
+ }
1560
+ }
1561
+ """
1562
+ )
1563
+ variable_values = {
1564
+ "entity": entity,
1565
+ "project": project,
1566
+ "queueName": queue_name,
1567
+ "access": access,
1568
+ "defaultResourceConfigID": config_id,
1569
+ }
1570
+
1571
+ result: Optional[Dict[str, Any]] = self.gql(query, variable_values)[
1572
+ "createRunQueue"
1573
+ ]
1574
+ return result
1575
+
1576
+ @normalize_exceptions
1577
+ def push_to_run_queue_by_name(
1578
+ self,
1579
+ entity: str,
1580
+ project: str,
1581
+ queue_name: str,
1582
+ run_spec: str,
1583
+ template_variables: Optional[Dict[str, Union[int, float, str]]],
1584
+ priority: Optional[int] = None,
1585
+ ) -> Optional[Dict[str, Any]]:
1586
+ self.push_to_run_queue_introspection()
1587
+ """Queryless mutation, should be used before legacy fallback method."""
1588
+
1589
+ mutation_params = """
1590
+ $entityName: String!,
1591
+ $projectName: String!,
1592
+ $queueName: String!,
1593
+ $runSpec: JSONString!
1594
+ """
1595
+
1596
+ mutation_input = """
1597
+ entityName: $entityName,
1598
+ projectName: $projectName,
1599
+ queueName: $queueName,
1600
+ runSpec: $runSpec
1601
+ """
1602
+
1603
+ variables: Dict[str, Any] = {
1604
+ "entityName": entity,
1605
+ "projectName": project,
1606
+ "queueName": queue_name,
1607
+ "runSpec": run_spec,
1608
+ }
1609
+ if self.server_push_to_run_queue_supports_priority:
1610
+ if priority is not None:
1611
+ variables["priority"] = priority
1612
+ mutation_params += ", $priority: Int"
1613
+ mutation_input += ", priority: $priority"
1614
+ else:
1615
+ if priority is not None:
1616
+ raise UnsupportedError(
1617
+ "server does not support priority, please update server instance to >=0.46"
1618
+ )
1619
+
1620
+ if self.server_supports_template_variables:
1621
+ if template_variables is not None:
1622
+ variables.update(
1623
+ {"templateVariableValues": json.dumps(template_variables)}
1624
+ )
1625
+ mutation_params += ", $templateVariableValues: JSONString"
1626
+ mutation_input += ", templateVariableValues: $templateVariableValues"
1627
+ else:
1628
+ if template_variables is not None:
1629
+ raise UnsupportedError(
1630
+ "server does not support template variables, please update server instance to >=0.46"
1631
+ )
1632
+
1633
+ mutation = gql(
1634
+ f"""
1635
+ mutation pushToRunQueueByName(
1636
+ {mutation_params}
1637
+ ) {{
1638
+ pushToRunQueueByName(
1639
+ input: {{
1640
+ {mutation_input}
1641
+ }}
1642
+ ) {{
1643
+ runQueueItemId
1644
+ runSpec
1645
+ }}
1646
+ }}
1647
+ """
1648
+ )
1649
+
1650
+ try:
1651
+ result: Optional[Dict[str, Any]] = self.gql(
1652
+ mutation, variables, check_retry_fn=util.no_retry_4xx
1653
+ ).get("pushToRunQueueByName")
1654
+ if not result:
1655
+ return None
1656
+
1657
+ if result.get("runSpec"):
1658
+ run_spec = json.loads(str(result["runSpec"]))
1659
+ result["runSpec"] = run_spec
1660
+
1661
+ return result
1662
+ except Exception as e:
1663
+ if (
1664
+ 'Cannot query field "runSpec" on type "PushToRunQueueByNamePayload"'
1665
+ not in str(e)
1666
+ ):
1667
+ return None
1668
+
1669
+ mutation_no_runspec = gql(
1670
+ """
1671
+ mutation pushToRunQueueByName(
1672
+ $entityName: String!,
1673
+ $projectName: String!,
1674
+ $queueName: String!,
1675
+ $runSpec: JSONString!,
1676
+ ) {
1677
+ pushToRunQueueByName(
1678
+ input: {
1679
+ entityName: $entityName,
1680
+ projectName: $projectName,
1681
+ queueName: $queueName,
1682
+ runSpec: $runSpec
1683
+ }
1684
+ ) {
1685
+ runQueueItemId
1686
+ }
1687
+ }
1688
+ """
1689
+ )
1690
+
1691
+ try:
1692
+ result = self.gql(
1693
+ mutation_no_runspec, variables, check_retry_fn=util.no_retry_4xx
1694
+ ).get("pushToRunQueueByName")
1695
+ except Exception:
1696
+ result = None
1697
+
1698
+ return result
1699
+
1700
+ @normalize_exceptions
1701
+ def push_to_run_queue(
1702
+ self,
1703
+ queue_name: str,
1704
+ launch_spec: Dict[str, str],
1705
+ template_variables: Optional[dict],
1706
+ project_queue: str,
1707
+ priority: Optional[int] = None,
1708
+ ) -> Optional[Dict[str, Any]]:
1709
+ self.push_to_run_queue_introspection()
1710
+ entity = launch_spec.get("queue_entity") or launch_spec["entity"]
1711
+ run_spec = json.dumps(launch_spec)
1712
+
1713
+ push_result = self.push_to_run_queue_by_name(
1714
+ entity, project_queue, queue_name, run_spec, template_variables, priority
1715
+ )
1716
+
1717
+ if push_result:
1718
+ return push_result
1719
+
1720
+ if priority is not None:
1721
+ # Cannot proceed with legacy method if priority is set
1722
+ return None
1723
+
1724
+ """ Legacy Method """
1725
+ queues_found = self.get_project_run_queues(entity, project_queue)
1726
+ matching_queues = [
1727
+ q
1728
+ for q in queues_found
1729
+ if q["name"] == queue_name
1730
+ # ensure user has access to queue
1731
+ and (
1732
+ # TODO: User created queues in the UI have USER access
1733
+ q["access"] in ["PROJECT", "USER"]
1734
+ or q["createdBy"] == self.default_entity
1735
+ )
1736
+ ]
1737
+ if not matching_queues:
1738
+ # in the case of a missing default queue. create it
1739
+ if queue_name == "default":
1740
+ wandb.termlog(
1741
+ f"No default queue existing for entity: {entity} in project: {project_queue}, creating one."
1742
+ )
1743
+ res = self.create_run_queue(
1744
+ launch_spec["entity"],
1745
+ project_queue,
1746
+ queue_name,
1747
+ access="PROJECT",
1748
+ )
1749
+
1750
+ if res is None or res.get("queueID") is None:
1751
+ wandb.termerror(
1752
+ f"Unable to create default queue for entity: {entity} on project: {project_queue}. Run could not be added to a queue"
1753
+ )
1754
+ return None
1755
+ queue_id = res["queueID"]
1756
+
1757
+ else:
1758
+ if project_queue == "model-registry":
1759
+ _msg = f"Unable to push to run queue {queue_name}. Queue not found."
1760
+ else:
1761
+ _msg = f"Unable to push to run queue {project_queue}/{queue_name}. Queue not found."
1762
+ wandb.termwarn(_msg)
1763
+ return None
1764
+ elif len(matching_queues) > 1:
1765
+ wandb.termerror(
1766
+ f"Unable to push to run queue {queue_name}. More than one queue found with this name."
1767
+ )
1768
+ return None
1769
+ else:
1770
+ queue_id = matching_queues[0]["id"]
1771
+ spec_json = json.dumps(launch_spec)
1772
+ variables = {"queueID": queue_id, "runSpec": spec_json}
1773
+
1774
+ mutation_params = """
1775
+ $queueID: ID!,
1776
+ $runSpec: JSONString!
1777
+ """
1778
+ mutation_input = """
1779
+ queueID: $queueID,
1780
+ runSpec: $runSpec
1781
+ """
1782
+ if self.server_supports_template_variables:
1783
+ if template_variables is not None:
1784
+ mutation_params += ", $templateVariableValues: JSONString"
1785
+ mutation_input += ", templateVariableValues: $templateVariableValues"
1786
+ variables.update(
1787
+ {"templateVariableValues": json.dumps(template_variables)}
1788
+ )
1789
+ else:
1790
+ if template_variables is not None:
1791
+ raise UnsupportedError(
1792
+ "server does not support template variables, please update server instance to >=0.46"
1793
+ )
1794
+
1795
+ mutation = gql(
1796
+ f"""
1797
+ mutation pushToRunQueue(
1798
+ {mutation_params}
1799
+ ) {{
1800
+ pushToRunQueue(
1801
+ input: {{{mutation_input}}}
1802
+ ) {{
1803
+ runQueueItemId
1804
+ }}
1805
+ }}
1806
+ """
1807
+ )
1808
+
1809
+ response = self.gql(mutation, variable_values=variables)
1810
+ if not response.get("pushToRunQueue"):
1811
+ raise CommError(f"Error pushing run queue item to queue {queue_name}.")
1812
+
1813
+ result: Optional[Dict[str, Any]] = response["pushToRunQueue"]
1814
+ return result
1815
+
1816
+ @normalize_exceptions
1817
+ def pop_from_run_queue(
1818
+ self,
1819
+ queue_name: str,
1820
+ entity: Optional[str] = None,
1821
+ project: Optional[str] = None,
1822
+ agent_id: Optional[str] = None,
1823
+ ) -> Optional[Dict[str, Any]]:
1824
+ mutation = gql(
1825
+ """
1826
+ mutation popFromRunQueue($entity: String!, $project: String!, $queueName: String!, $launchAgentId: ID) {
1827
+ popFromRunQueue(input: {
1828
+ entityName: $entity,
1829
+ projectName: $project,
1830
+ queueName: $queueName,
1831
+ launchAgentId: $launchAgentId
1832
+ }) {
1833
+ runQueueItemId
1834
+ runSpec
1835
+ }
1836
+ }
1837
+ """
1838
+ )
1839
+ response = self.gql(
1840
+ mutation,
1841
+ variable_values={
1842
+ "entity": entity,
1843
+ "project": project,
1844
+ "queueName": queue_name,
1845
+ "launchAgentId": agent_id,
1846
+ },
1847
+ )
1848
+ result: Optional[Dict[str, Any]] = response["popFromRunQueue"]
1849
+ return result
1850
+
1851
+ @normalize_exceptions
1852
+ def ack_run_queue_item(self, item_id: str, run_id: Optional[str] = None) -> bool:
1853
+ mutation = gql(
1854
+ """
1855
+ mutation ackRunQueueItem($itemId: ID!, $runId: String!) {
1856
+ ackRunQueueItem(input: { runQueueItemId: $itemId, runName: $runId }) {
1857
+ success
1858
+ }
1859
+ }
1860
+ """
1861
+ )
1862
+ response = self.gql(
1863
+ mutation, variable_values={"itemId": item_id, "runId": str(run_id)}
1864
+ )
1865
+ if not response["ackRunQueueItem"]["success"]:
1866
+ raise CommError(
1867
+ "Error acking run queue item. Item may have already been acknowledged by another process"
1868
+ )
1869
+ result: bool = response["ackRunQueueItem"]["success"]
1870
+ return result
1871
+
1872
+ @normalize_exceptions
1873
+ def create_launch_agent_fields_introspection(self) -> List:
1874
+ if self.create_launch_agent_input_info:
1875
+ return self.create_launch_agent_input_info
1876
+ query_string = """
1877
+ query ProbeServerCreateLaunchAgentInput {
1878
+ CreateLaunchAgentInputInfoType: __type(name:"CreateLaunchAgentInput") {
1879
+ inputFields{
1880
+ name
1881
+ }
1882
+ }
1883
+ }
1884
+ """
1885
+
1886
+ query = gql(query_string)
1887
+ res = self.gql(query)
1888
+
1889
+ self.create_launch_agent_input_info = [
1890
+ field.get("name", "")
1891
+ for field in res.get("CreateLaunchAgentInputInfoType", {}).get(
1892
+ "inputFields", [{}]
1893
+ )
1894
+ ]
1895
+ return self.create_launch_agent_input_info
1896
+
1897
+ @normalize_exceptions
1898
+ def create_launch_agent(
1899
+ self,
1900
+ entity: str,
1901
+ project: str,
1902
+ queues: List[str],
1903
+ agent_config: Dict[str, Any],
1904
+ version: str,
1905
+ gorilla_agent_support: bool,
1906
+ ) -> dict:
1907
+ project_queues = self.get_project_run_queues(entity, project)
1908
+ if not project_queues:
1909
+ # create default queue if it doesn't already exist
1910
+ default = self.create_run_queue(
1911
+ entity, project, "default", access="PROJECT"
1912
+ )
1913
+ if default is None or default.get("queueID") is None:
1914
+ raise CommError(
1915
+ "Unable to create default queue for {}/{}. No queues for agent to poll".format(
1916
+ entity, project
1917
+ )
1918
+ )
1919
+ project_queues = [{"id": default["queueID"], "name": "default"}]
1920
+ polling_queue_ids = [
1921
+ q["id"] for q in project_queues if q["name"] in queues
1922
+ ] # filter to poll specified queues
1923
+ if len(polling_queue_ids) != len(queues):
1924
+ raise CommError(
1925
+ f"Could not start launch agent: Not all of requested queues ({', '.join(queues)}) found. "
1926
+ f"Available queues for this project: {','.join([q['name'] for q in project_queues])}"
1927
+ )
1928
+
1929
+ if not gorilla_agent_support:
1930
+ # if gorilla doesn't support launch agents, return a client-generated id
1931
+ return {
1932
+ "success": True,
1933
+ "launchAgentId": None,
1934
+ }
1935
+
1936
+ hostname = socket.gethostname()
1937
+
1938
+ variable_values = {
1939
+ "entity": entity,
1940
+ "project": project,
1941
+ "queues": polling_queue_ids,
1942
+ "hostname": hostname,
1943
+ }
1944
+
1945
+ mutation_params = """
1946
+ $entity: String!,
1947
+ $project: String!,
1948
+ $queues: [ID!]!,
1949
+ $hostname: String!
1950
+ """
1951
+
1952
+ mutation_input = """
1953
+ entityName: $entity,
1954
+ projectName: $project,
1955
+ runQueues: $queues,
1956
+ hostname: $hostname
1957
+ """
1958
+
1959
+ if "agentConfig" in self.create_launch_agent_fields_introspection():
1960
+ variable_values["agentConfig"] = json.dumps(agent_config)
1961
+ mutation_params += ", $agentConfig: JSONString"
1962
+ mutation_input += ", agentConfig: $agentConfig"
1963
+ if "version" in self.create_launch_agent_fields_introspection():
1964
+ variable_values["version"] = version
1965
+ mutation_params += ", $version: String"
1966
+ mutation_input += ", version: $version"
1967
+
1968
+ mutation = gql(
1969
+ f"""
1970
+ mutation createLaunchAgent(
1971
+ {mutation_params}
1972
+ ) {{
1973
+ createLaunchAgent(
1974
+ input: {{
1975
+ {mutation_input}
1976
+ }}
1977
+ ) {{
1978
+ launchAgentId
1979
+ }}
1980
+ }}
1981
+ """
1982
+ )
1983
+ result: dict = self.gql(mutation, variable_values)["createLaunchAgent"]
1984
+ return result
1985
+
1986
+ @normalize_exceptions
1987
+ def update_launch_agent_status(
1988
+ self,
1989
+ agent_id: str,
1990
+ status: str,
1991
+ gorilla_agent_support: bool,
1992
+ ) -> dict:
1993
+ if not gorilla_agent_support:
1994
+ # if gorilla doesn't support launch agents, this is a no-op
1995
+ return {
1996
+ "success": True,
1997
+ }
1998
+
1999
+ mutation = gql(
2000
+ """
2001
+ mutation updateLaunchAgent($agentId: ID!, $agentStatus: String){
2002
+ updateLaunchAgent(
2003
+ input: {
2004
+ launchAgentId: $agentId
2005
+ agentStatus: $agentStatus
2006
+ }
2007
+ ) {
2008
+ success
2009
+ }
2010
+ }
2011
+ """
2012
+ )
2013
+ variable_values = {
2014
+ "agentId": agent_id,
2015
+ "agentStatus": status,
2016
+ }
2017
+ result: dict = self.gql(mutation, variable_values)["updateLaunchAgent"]
2018
+ return result
2019
+
2020
+ @normalize_exceptions
2021
+ def get_launch_agent(self, agent_id: str, gorilla_agent_support: bool) -> dict:
2022
+ if not gorilla_agent_support:
2023
+ return {
2024
+ "id": None,
2025
+ "name": "",
2026
+ "stopPolling": False,
2027
+ }
2028
+ query = gql(
2029
+ """
2030
+ query LaunchAgent($agentId: ID!) {
2031
+ launchAgent(id: $agentId) {
2032
+ id
2033
+ name
2034
+ runQueues
2035
+ hostname
2036
+ agentStatus
2037
+ stopPolling
2038
+ heartbeatAt
2039
+ }
2040
+ }
2041
+ """
2042
+ )
2043
+ variable_values = {
2044
+ "agentId": agent_id,
2045
+ }
2046
+ result: dict = self.gql(query, variable_values)["launchAgent"]
2047
+ return result
2048
+
2049
+ @normalize_exceptions
2050
+ def upsert_run(
2051
+ self,
2052
+ id: Optional[str] = None,
2053
+ name: Optional[str] = None,
2054
+ project: Optional[str] = None,
2055
+ host: Optional[str] = None,
2056
+ group: Optional[str] = None,
2057
+ tags: Optional[List[str]] = None,
2058
+ config: Optional[dict] = None,
2059
+ description: Optional[str] = None,
2060
+ entity: Optional[str] = None,
2061
+ state: Optional[str] = None,
2062
+ display_name: Optional[str] = None,
2063
+ notes: Optional[str] = None,
2064
+ repo: Optional[str] = None,
2065
+ job_type: Optional[str] = None,
2066
+ program_path: Optional[str] = None,
2067
+ commit: Optional[str] = None,
2068
+ sweep_name: Optional[str] = None,
2069
+ summary_metrics: Optional[str] = None,
2070
+ num_retries: Optional[int] = None,
2071
+ ) -> Tuple[dict, bool, Optional[List]]:
2072
+ """Update a run.
2073
+
2074
+ Arguments:
2075
+ id (str, optional): The existing run to update
2076
+ name (str, optional): The name of the run to create
2077
+ group (str, optional): Name of the group this run is a part of
2078
+ project (str, optional): The name of the project
2079
+ host (str, optional): The name of the host
2080
+ tags (list, optional): A list of tags to apply to the run
2081
+ config (dict, optional): The latest config params
2082
+ description (str, optional): A description of this project
2083
+ entity (str, optional): The entity to scope this project to.
2084
+ display_name (str, optional): The display name of this project
2085
+ notes (str, optional): Notes about this run
2086
+ repo (str, optional): Url of the program's repository.
2087
+ state (str, optional): State of the program.
2088
+ job_type (str, optional): Type of job, e.g 'train'.
2089
+ program_path (str, optional): Path to the program.
2090
+ commit (str, optional): The Git SHA to associate the run with
2091
+ sweep_name (str, optional): The name of the sweep this run is a part of
2092
+ summary_metrics (str, optional): The JSON summary metrics
2093
+ num_retries (int, optional): Number of retries
2094
+ """
2095
+ query_string = """
2096
+ mutation UpsertBucket(
2097
+ $id: String,
2098
+ $name: String,
2099
+ $project: String,
2100
+ $entity: String,
2101
+ $groupName: String,
2102
+ $description: String,
2103
+ $displayName: String,
2104
+ $notes: String,
2105
+ $commit: String,
2106
+ $config: JSONString,
2107
+ $host: String,
2108
+ $debug: Boolean,
2109
+ $program: String,
2110
+ $repo: String,
2111
+ $jobType: String,
2112
+ $state: String,
2113
+ $sweep: String,
2114
+ $tags: [String!],
2115
+ $summaryMetrics: JSONString,
2116
+ ) {
2117
+ upsertBucket(input: {
2118
+ id: $id,
2119
+ name: $name,
2120
+ groupName: $groupName,
2121
+ modelName: $project,
2122
+ entityName: $entity,
2123
+ description: $description,
2124
+ displayName: $displayName,
2125
+ notes: $notes,
2126
+ config: $config,
2127
+ commit: $commit,
2128
+ host: $host,
2129
+ debug: $debug,
2130
+ jobProgram: $program,
2131
+ jobRepo: $repo,
2132
+ jobType: $jobType,
2133
+ state: $state,
2134
+ sweep: $sweep,
2135
+ tags: $tags,
2136
+ summaryMetrics: $summaryMetrics,
2137
+ }) {
2138
+ bucket {
2139
+ id
2140
+ name
2141
+ displayName
2142
+ description
2143
+ config
2144
+ sweepName
2145
+ project {
2146
+ id
2147
+ name
2148
+ entity {
2149
+ id
2150
+ name
2151
+ }
2152
+ }
2153
+ historyLineCount
2154
+ }
2155
+ inserted
2156
+ _Server_Settings_
2157
+ }
2158
+ }
2159
+ """
2160
+ self.server_settings_introspection()
2161
+
2162
+ server_settings_string = (
2163
+ """
2164
+ serverSettings {
2165
+ serverMessages{
2166
+ utfText
2167
+ plainText
2168
+ htmlText
2169
+ messageType
2170
+ messageLevel
2171
+ }
2172
+ }
2173
+ """
2174
+ if self._server_settings_type
2175
+ else ""
2176
+ )
2177
+
2178
+ query_string = query_string.replace("_Server_Settings_", server_settings_string)
2179
+ mutation = gql(query_string)
2180
+ config_str = json.dumps(config) if config else None
2181
+ if not description or description.isspace():
2182
+ description = None
2183
+
2184
+ kwargs = {}
2185
+ if num_retries is not None:
2186
+ kwargs["num_retries"] = num_retries
2187
+
2188
+ variable_values = {
2189
+ "id": id,
2190
+ "entity": entity or self.settings("entity"),
2191
+ "name": name,
2192
+ "project": project or util.auto_project_name(program_path),
2193
+ "groupName": group,
2194
+ "tags": tags,
2195
+ "description": description,
2196
+ "config": config_str,
2197
+ "commit": commit,
2198
+ "displayName": display_name,
2199
+ "notes": notes,
2200
+ "host": None if self.settings().get("anonymous") == "true" else host,
2201
+ "debug": env.is_debug(env=self._environ),
2202
+ "repo": repo,
2203
+ "program": program_path,
2204
+ "jobType": job_type,
2205
+ "state": state,
2206
+ "sweep": sweep_name,
2207
+ "summaryMetrics": summary_metrics,
2208
+ }
2209
+
2210
+ # retry conflict errors for 2 minutes, default to no_auth_retry
2211
+ check_retry_fn = util.make_check_retry_fn(
2212
+ check_fn=util.check_retry_conflict_or_gone,
2213
+ check_timedelta=datetime.timedelta(minutes=2),
2214
+ fallback_retry_fn=util.no_retry_auth,
2215
+ )
2216
+
2217
+ response = self.gql(
2218
+ mutation,
2219
+ variable_values=variable_values,
2220
+ check_retry_fn=check_retry_fn,
2221
+ **kwargs,
2222
+ )
2223
+
2224
+ run_obj: Dict[str, Dict[str, Dict[str, str]]] = response["upsertBucket"][
2225
+ "bucket"
2226
+ ]
2227
+ project_obj: Dict[str, Dict[str, str]] = run_obj.get("project", {})
2228
+ if project_obj:
2229
+ self.set_setting("project", project_obj["name"])
2230
+ entity_obj = project_obj.get("entity", {})
2231
+ if entity_obj:
2232
+ self.set_setting("entity", entity_obj["name"])
2233
+
2234
+ server_messages = None
2235
+ if self._server_settings_type:
2236
+ server_messages = (
2237
+ response["upsertBucket"]
2238
+ .get("serverSettings", {})
2239
+ .get("serverMessages", [])
2240
+ )
2241
+
2242
+ return (
2243
+ response["upsertBucket"]["bucket"],
2244
+ response["upsertBucket"]["inserted"],
2245
+ server_messages,
2246
+ )
2247
+
2248
+ @normalize_exceptions
2249
+ def get_run_info(
2250
+ self,
2251
+ entity: str,
2252
+ project: str,
2253
+ name: str,
2254
+ ) -> dict:
2255
+ query = gql(
2256
+ """
2257
+ query RunInfo($project: String!, $entity: String!, $name: String!) {
2258
+ project(name: $project, entityName: $entity) {
2259
+ run(name: $name) {
2260
+ runInfo {
2261
+ program
2262
+ args
2263
+ os
2264
+ python
2265
+ colab
2266
+ executable
2267
+ codeSaved
2268
+ cpuCount
2269
+ gpuCount
2270
+ gpu
2271
+ git {
2272
+ remote
2273
+ commit
2274
+ }
2275
+ }
2276
+ }
2277
+ }
2278
+ }
2279
+ """
2280
+ )
2281
+ variable_values = {"project": project, "entity": entity, "name": name}
2282
+ res = self.gql(query, variable_values)
2283
+ if res.get("project") is None:
2284
+ raise CommError(
2285
+ "Error fetching run info for {}/{}/{}. Check that this project exists and you have access to this entity and project".format(
2286
+ entity, project, name
2287
+ )
2288
+ )
2289
+ elif res["project"].get("run") is None:
2290
+ raise CommError(
2291
+ "Error fetching run info for {}/{}/{}. Check that this run id exists".format(
2292
+ entity, project, name
2293
+ )
2294
+ )
2295
+ run_info: dict = res["project"]["run"]["runInfo"]
2296
+ return run_info
2297
+
2298
+ @normalize_exceptions
2299
+ def get_run_state(self, entity: str, project: str, name: str) -> str:
2300
+ query = gql(
2301
+ """
2302
+ query RunState(
2303
+ $project: String!,
2304
+ $entity: String!,
2305
+ $name: String!) {
2306
+ project(name: $project, entityName: $entity) {
2307
+ run(name: $name) {
2308
+ state
2309
+ }
2310
+ }
2311
+ }
2312
+ """
2313
+ )
2314
+ variable_values = {
2315
+ "project": project,
2316
+ "entity": entity,
2317
+ "name": name,
2318
+ }
2319
+ res = self.gql(query, variable_values)
2320
+ if res.get("project") is None or res["project"].get("run") is None:
2321
+ raise CommError(f"Error fetching run state for {entity}/{project}/{name}.")
2322
+ run_state: str = res["project"]["run"]["state"]
2323
+ return run_state
2324
+
2325
+ @normalize_exceptions
2326
+ def create_run_files_introspection(self) -> bool:
2327
+ _, _, mutations = self.server_info_introspection()
2328
+ return "createRunFiles" in mutations
2329
+
2330
+ @normalize_exceptions
2331
+ def upload_urls(
2332
+ self,
2333
+ project: str,
2334
+ files: Union[List[str], Dict[str, IO]],
2335
+ run: Optional[str] = None,
2336
+ entity: Optional[str] = None,
2337
+ description: Optional[str] = None,
2338
+ ) -> Tuple[str, List[str], Dict[str, Dict[str, Any]]]:
2339
+ """Generate temporary resumable upload urls.
2340
+
2341
+ Arguments:
2342
+ project (str): The project to download
2343
+ files (list or dict): The filenames to upload
2344
+ run (str, optional): The run to upload to
2345
+ entity (str, optional): The entity to scope this project to.
2346
+ description (str, optional): description
2347
+
2348
+ Returns:
2349
+ (run_id, upload_headers, file_info)
2350
+ run_id: id of run we uploaded files to
2351
+ upload_headers: A list of headers to use when uploading files.
2352
+ file_info: A dict of filenames and urls.
2353
+ {
2354
+ "run_id": "run_id",
2355
+ "upload_headers": [""],
2356
+ "file_info": [
2357
+ { "weights.h5": { "uploadUrl": "https://weights.url" } },
2358
+ { "model.json": { "uploadUrl": "https://model.json" } }
2359
+ ]
2360
+ }
2361
+ """
2362
+ run_name = run or self.current_run_id
2363
+ assert run_name, "run must be specified"
2364
+ entity = entity or self.settings("entity")
2365
+ assert entity, "entity must be specified"
2366
+
2367
+ has_create_run_files_mutation = self.create_run_files_introspection()
2368
+ if not has_create_run_files_mutation:
2369
+ return self.legacy_upload_urls(project, files, run, entity, description)
2370
+
2371
+ query = gql(
2372
+ """
2373
+ mutation CreateRunFiles($entity: String!, $project: String!, $run: String!, $files: [String!]!) {
2374
+ createRunFiles(input: {entityName: $entity, projectName: $project, runName: $run, files: $files}) {
2375
+ runID
2376
+ uploadHeaders
2377
+ files {
2378
+ name
2379
+ uploadUrl
2380
+ }
2381
+ }
2382
+ }
2383
+ """
2384
+ )
2385
+
2386
+ query_result = self.gql(
2387
+ query,
2388
+ variable_values={
2389
+ "project": project,
2390
+ "run": run_name,
2391
+ "entity": entity,
2392
+ "files": [file for file in files],
2393
+ },
2394
+ )
2395
+
2396
+ result = query_result["createRunFiles"]
2397
+ run_id = result["runID"]
2398
+ if not run_id:
2399
+ raise CommError(
2400
+ f"Error uploading files to {entity}/{project}/{run_name}. Check that this project exists and you have access to this entity and project"
2401
+ )
2402
+ file_name_urls = {file["name"]: file for file in result["files"]}
2403
+ return run_id, result["uploadHeaders"], file_name_urls
2404
+
2405
+ def legacy_upload_urls(
2406
+ self,
2407
+ project: str,
2408
+ files: Union[List[str], Dict[str, IO]],
2409
+ run: Optional[str] = None,
2410
+ entity: Optional[str] = None,
2411
+ description: Optional[str] = None,
2412
+ ) -> Tuple[str, List[str], Dict[str, Dict[str, Any]]]:
2413
+ """Generate temporary resumable upload urls.
2414
+
2415
+ A new mutation createRunFiles was introduced after 0.15.4.
2416
+ This function is used to support older versions.
2417
+ """
2418
+ query = gql(
2419
+ """
2420
+ query RunUploadUrls($name: String!, $files: [String]!, $entity: String, $run: String!, $description: String) {
2421
+ model(name: $name, entityName: $entity) {
2422
+ bucket(name: $run, desc: $description) {
2423
+ id
2424
+ files(names: $files) {
2425
+ uploadHeaders
2426
+ edges {
2427
+ node {
2428
+ name
2429
+ url(upload: true)
2430
+ updatedAt
2431
+ }
2432
+ }
2433
+ }
2434
+ }
2435
+ }
2436
+ }
2437
+ """
2438
+ )
2439
+ run_id = run or self.current_run_id
2440
+ assert run_id, "run must be specified"
2441
+ entity = entity or self.settings("entity")
2442
+ query_result = self.gql(
2443
+ query,
2444
+ variable_values={
2445
+ "name": project,
2446
+ "run": run_id,
2447
+ "entity": entity,
2448
+ "files": [file for file in files],
2449
+ "description": description,
2450
+ },
2451
+ )
2452
+
2453
+ run_obj = query_result["model"]["bucket"]
2454
+ if run_obj:
2455
+ for file_node in run_obj["files"]["edges"]:
2456
+ file = file_node["node"]
2457
+ # we previously used "url" field but now use "uploadUrl"
2458
+ # replace the "url" field with "uploadUrl for downstream compatibility
2459
+ if "url" in file and "uploadUrl" not in file:
2460
+ file["uploadUrl"] = file.pop("url")
2461
+
2462
+ result = {
2463
+ file["name"]: file for file in self._flatten_edges(run_obj["files"])
2464
+ }
2465
+ return run_obj["id"], run_obj["files"]["uploadHeaders"], result
2466
+ else:
2467
+ raise CommError(f"Run does not exist {entity}/{project}/{run_id}.")
2468
+
2469
+ @normalize_exceptions
2470
+ def download_urls(
2471
+ self,
2472
+ project: str,
2473
+ run: Optional[str] = None,
2474
+ entity: Optional[str] = None,
2475
+ ) -> Dict[str, Dict[str, str]]:
2476
+ """Generate download urls.
2477
+
2478
+ Arguments:
2479
+ project (str): The project to download
2480
+ run (str): The run to upload to
2481
+ entity (str, optional): The entity to scope this project to. Defaults to wandb models
2482
+
2483
+ Returns:
2484
+ A dict of extensions and urls
2485
+
2486
+ {
2487
+ 'weights.h5': { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' },
2488
+ 'model.json': { "url": "https://model.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }
2489
+ }
2490
+ """
2491
+ query = gql(
2492
+ """
2493
+ query RunDownloadUrls($name: String!, $entity: String, $run: String!) {
2494
+ model(name: $name, entityName: $entity) {
2495
+ bucket(name: $run) {
2496
+ files {
2497
+ edges {
2498
+ node {
2499
+ name
2500
+ url
2501
+ md5
2502
+ updatedAt
2503
+ }
2504
+ }
2505
+ }
2506
+ }
2507
+ }
2508
+ }
2509
+ """
2510
+ )
2511
+ run = run or self.current_run_id
2512
+ assert run, "run must be specified"
2513
+ entity = entity or self.settings("entity")
2514
+ query_result = self.gql(
2515
+ query,
2516
+ variable_values={
2517
+ "name": project,
2518
+ "run": run,
2519
+ "entity": entity,
2520
+ },
2521
+ )
2522
+ if query_result["model"] is None:
2523
+ raise CommError(f"Run does not exist {entity}/{project}/{run}.")
2524
+ files = self._flatten_edges(query_result["model"]["bucket"]["files"])
2525
+ return {file["name"]: file for file in files if file}
2526
+
2527
+ @normalize_exceptions
2528
+ def download_url(
2529
+ self,
2530
+ project: str,
2531
+ file_name: str,
2532
+ run: Optional[str] = None,
2533
+ entity: Optional[str] = None,
2534
+ ) -> Optional[Dict[str, str]]:
2535
+ """Generate download urls.
2536
+
2537
+ Arguments:
2538
+ project (str): The project to download
2539
+ file_name (str): The name of the file to download
2540
+ run (str): The run to upload to
2541
+ entity (str, optional): The entity to scope this project to. Defaults to wandb models
2542
+
2543
+ Returns:
2544
+ A dict of extensions and urls
2545
+
2546
+ { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }
2547
+
2548
+ """
2549
+ query = gql(
2550
+ """
2551
+ query RunDownloadUrl($name: String!, $fileName: String!, $entity: String, $run: String!) {
2552
+ model(name: $name, entityName: $entity) {
2553
+ bucket(name: $run) {
2554
+ files(names: [$fileName]) {
2555
+ edges {
2556
+ node {
2557
+ name
2558
+ url
2559
+ md5
2560
+ updatedAt
2561
+ }
2562
+ }
2563
+ }
2564
+ }
2565
+ }
2566
+ }
2567
+ """
2568
+ )
2569
+ run = run or self.current_run_id
2570
+ assert run, "run must be specified"
2571
+ query_result = self.gql(
2572
+ query,
2573
+ variable_values={
2574
+ "name": project,
2575
+ "run": run,
2576
+ "fileName": file_name,
2577
+ "entity": entity or self.settings("entity"),
2578
+ },
2579
+ )
2580
+ if query_result["model"]:
2581
+ files = self._flatten_edges(query_result["model"]["bucket"]["files"])
2582
+ return files[0] if len(files) > 0 and files[0].get("updatedAt") else None
2583
+ else:
2584
+ return None
2585
+
2586
+ @normalize_exceptions
2587
+ def download_file(self, url: str) -> Tuple[int, requests.Response]:
2588
+ """Initiate a streaming download.
2589
+
2590
+ Arguments:
2591
+ url (str): The url to download
2592
+
2593
+ Returns:
2594
+ A tuple of the content length and the streaming response
2595
+ """
2596
+ check_httpclient_logger_handler()
2597
+ auth = None
2598
+ if _thread_local_api_settings.cookies is None:
2599
+ auth = ("user", self.api_key or "")
2600
+ response = requests.get(
2601
+ url,
2602
+ auth=auth,
2603
+ cookies=_thread_local_api_settings.cookies or {},
2604
+ headers=_thread_local_api_settings.headers or {},
2605
+ stream=True,
2606
+ )
2607
+ response.raise_for_status()
2608
+ return int(response.headers.get("content-length", 0)), response
2609
+
2610
+ @normalize_exceptions
2611
+ def download_write_file(
2612
+ self,
2613
+ metadata: Dict[str, str],
2614
+ out_dir: Optional[str] = None,
2615
+ ) -> Tuple[str, Optional[requests.Response]]:
2616
+ """Download a file from a run and write it to wandb/.
2617
+
2618
+ Arguments:
2619
+ metadata (obj): The metadata object for the file to download. Comes from Api.download_urls().
2620
+ out_dir (str, optional): The directory to write the file to. Defaults to wandb/
2621
+
2622
+ Returns:
2623
+ A tuple of the file's local path and the streaming response. The streaming response is None if the file
2624
+ already existed and was up-to-date.
2625
+ """
2626
+ filename = metadata["name"]
2627
+ path = os.path.join(out_dir or self.settings("wandb_dir"), filename)
2628
+ if self.file_current(filename, B64MD5(metadata["md5"])):
2629
+ return path, None
2630
+
2631
+ size, response = self.download_file(metadata["url"])
2632
+
2633
+ with util.fsync_open(path, "wb") as file:
2634
+ for data in response.iter_content(chunk_size=1024):
2635
+ file.write(data)
2636
+
2637
+ return path, response
2638
+
2639
+ def upload_file_azure(
2640
+ self, url: str, file: Any, extra_headers: Dict[str, str]
2641
+ ) -> None:
2642
+ """Upload a file to azure."""
2643
+ from azure.core.exceptions import AzureError # type: ignore
2644
+
2645
+ # Configure the client without retries so our existing logic can handle them
2646
+ client = self._azure_blob_module.BlobClient.from_blob_url(
2647
+ url, retry_policy=self._azure_blob_module.LinearRetry(retry_total=0)
2648
+ )
2649
+ try:
2650
+ if extra_headers.get("Content-MD5") is not None:
2651
+ md5: Optional[bytes] = base64.b64decode(extra_headers["Content-MD5"])
2652
+ else:
2653
+ md5 = None
2654
+ content_settings = self._azure_blob_module.ContentSettings(
2655
+ content_md5=md5,
2656
+ content_type=extra_headers.get("Content-Type"),
2657
+ )
2658
+ client.upload_blob(
2659
+ file,
2660
+ max_concurrency=4,
2661
+ length=len(file),
2662
+ overwrite=True,
2663
+ content_settings=content_settings,
2664
+ )
2665
+ except AzureError as e:
2666
+ if hasattr(e, "response"):
2667
+ response = requests.models.Response()
2668
+ response.status_code = e.response.status_code
2669
+ response.headers = e.response.headers
2670
+ raise requests.exceptions.RequestException(e.message, response=response)
2671
+ else:
2672
+ raise requests.exceptions.ConnectionError(e.message)
2673
+
2674
+ def upload_multipart_file_chunk(
2675
+ self,
2676
+ url: str,
2677
+ upload_chunk: bytes,
2678
+ extra_headers: Optional[Dict[str, str]] = None,
2679
+ ) -> Optional[requests.Response]:
2680
+ """Upload a file chunk to S3 with failure resumption.
2681
+
2682
+ Arguments:
2683
+ url: The url to download
2684
+ upload_chunk: The path to the file you want to upload
2685
+ extra_headers: A dictionary of extra headers to send with the request
2686
+
2687
+ Returns:
2688
+ The `requests` library response object
2689
+ """
2690
+ check_httpclient_logger_handler()
2691
+ try:
2692
+ if env.is_debug(env=self._environ):
2693
+ logger.debug("upload_file: %s", url)
2694
+ response = self._upload_file_session.put(
2695
+ url, data=upload_chunk, headers=extra_headers
2696
+ )
2697
+ if env.is_debug(env=self._environ):
2698
+ logger.debug("upload_file: %s complete", url)
2699
+ response.raise_for_status()
2700
+ except requests.exceptions.RequestException as e:
2701
+ logger.error(f"upload_file exception {url}: {e}")
2702
+ request_headers = e.request.headers if e.request is not None else ""
2703
+ logger.error(f"upload_file request headers: {request_headers}")
2704
+ response_content = e.response.content if e.response is not None else ""
2705
+ logger.error(f"upload_file response body: {response_content}")
2706
+ status_code = e.response.status_code if e.response is not None else 0
2707
+ # S3 reports retryable request timeouts out-of-band
2708
+ is_aws_retryable = status_code == 400 and "RequestTimeout" in str(
2709
+ response_content
2710
+ )
2711
+ # Retry errors from cloud storage or local network issues
2712
+ if (
2713
+ status_code in (308, 408, 409, 429, 500, 502, 503, 504)
2714
+ or isinstance(
2715
+ e,
2716
+ (requests.exceptions.Timeout, requests.exceptions.ConnectionError),
2717
+ )
2718
+ or is_aws_retryable
2719
+ ):
2720
+ _e = retry.TransientError(exc=e)
2721
+ raise _e.with_traceback(sys.exc_info()[2])
2722
+ else:
2723
+ wandb._sentry.reraise(e)
2724
+ return response
2725
+
2726
+ def upload_file(
2727
+ self,
2728
+ url: str,
2729
+ file: IO[bytes],
2730
+ callback: Optional["ProgressFn"] = None,
2731
+ extra_headers: Optional[Dict[str, str]] = None,
2732
+ ) -> Optional[requests.Response]:
2733
+ """Upload a file to W&B with failure resumption.
2734
+
2735
+ Arguments:
2736
+ url: The url to download
2737
+ file: The path to the file you want to upload
2738
+ callback: A callback which is passed the number of
2739
+ bytes uploaded since the last time it was called, used to report progress
2740
+ extra_headers: A dictionary of extra headers to send with the request
2741
+
2742
+ Returns:
2743
+ The `requests` library response object
2744
+ """
2745
+ check_httpclient_logger_handler()
2746
+ extra_headers = extra_headers.copy() if extra_headers else {}
2747
+ response: Optional[requests.Response] = None
2748
+ progress = Progress(file, callback=callback)
2749
+ try:
2750
+ if "x-ms-blob-type" in extra_headers and self._azure_blob_module:
2751
+ self.upload_file_azure(url, progress, extra_headers)
2752
+ else:
2753
+ if "x-ms-blob-type" in extra_headers:
2754
+ wandb.termwarn(
2755
+ "Azure uploads over 256MB require the azure SDK, install with pip install wandb[azure]",
2756
+ repeat=False,
2757
+ )
2758
+ if env.is_debug(env=self._environ):
2759
+ logger.debug("upload_file: %s", url)
2760
+ response = self._upload_file_session.put(
2761
+ url, data=progress, headers=extra_headers
2762
+ )
2763
+ if env.is_debug(env=self._environ):
2764
+ logger.debug("upload_file: %s complete", url)
2765
+ response.raise_for_status()
2766
+ except requests.exceptions.RequestException as e:
2767
+ logger.error(f"upload_file exception {url}: {e}")
2768
+ request_headers = e.request.headers if e.request is not None else ""
2769
+ logger.error(f"upload_file request headers: {request_headers}")
2770
+ response_content = e.response.content if e.response is not None else ""
2771
+ logger.error(f"upload_file response body: {response_content}")
2772
+ status_code = e.response.status_code if e.response is not None else 0
2773
+ # S3 reports retryable request timeouts out-of-band
2774
+ is_aws_retryable = (
2775
+ "x-amz-meta-md5" in extra_headers
2776
+ and status_code == 400
2777
+ and "RequestTimeout" in str(response_content)
2778
+ )
2779
+ # We need to rewind the file for the next retry (the file passed in is `seek`'ed to 0)
2780
+ progress.rewind()
2781
+ # Retry errors from cloud storage or local network issues
2782
+ if (
2783
+ status_code in (308, 408, 409, 429, 500, 502, 503, 504)
2784
+ or isinstance(
2785
+ e,
2786
+ (requests.exceptions.Timeout, requests.exceptions.ConnectionError),
2787
+ )
2788
+ or is_aws_retryable
2789
+ ):
2790
+ _e = retry.TransientError(exc=e)
2791
+ raise _e.with_traceback(sys.exc_info()[2])
2792
+ else:
2793
+ wandb._sentry.reraise(e)
2794
+
2795
+ return response
2796
+
2797
+ async def upload_file_async(
2798
+ self,
2799
+ url: str,
2800
+ file: IO[bytes],
2801
+ callback: Optional["ProgressFn"] = None,
2802
+ extra_headers: Optional[Dict[str, str]] = None,
2803
+ ) -> None:
2804
+ """An async not-quite-equivalent version of `upload_file`.
2805
+
2806
+ Differences from `upload_file`:
2807
+ - This method doesn't implement Azure uploads. (The Azure SDK supports
2808
+ async, but it's nontrivial to use it here.) If the upload looks like
2809
+ it's destined for Azure, this method will delegate to the sync impl.
2810
+ - Consequently, this method doesn't return the response object.
2811
+ (Because it might fall back to the sync impl, it would sometimes
2812
+ return a `requests.Response` and sometimes an `httpx.Response`.)
2813
+ - This method doesn't wrap retryable errors in `TransientError`.
2814
+ It leaves that determination to the caller.
2815
+ """
2816
+ check_httpclient_logger_handler()
2817
+ must_delegate = False
2818
+
2819
+ if httpx is None:
2820
+ wandb.termwarn( # type: ignore[unreachable]
2821
+ "async file-uploads require `pip install wandb[async]`; falling back to sync implementation",
2822
+ repeat=False,
2823
+ )
2824
+ must_delegate = True
2825
+
2826
+ if extra_headers is not None and "x-ms-blob-type" in extra_headers:
2827
+ wandb.termwarn(
2828
+ "async file-uploads don't support Azure; falling back to sync implementation",
2829
+ repeat=False,
2830
+ )
2831
+ must_delegate = True
2832
+
2833
+ if must_delegate:
2834
+ await asyncio.get_event_loop().run_in_executor(
2835
+ None,
2836
+ lambda: self.upload_file_retry(
2837
+ url=url,
2838
+ file=file,
2839
+ callback=callback,
2840
+ extra_headers=extra_headers,
2841
+ ),
2842
+ )
2843
+ return
2844
+
2845
+ if self._async_httpx_client is None:
2846
+ self._async_httpx_client = httpx.AsyncClient()
2847
+
2848
+ progress = AsyncProgress(Progress(file, callback=callback))
2849
+
2850
+ try:
2851
+ response = await self._async_httpx_client.put(
2852
+ url=url,
2853
+ content=progress,
2854
+ headers={
2855
+ "Content-Length": str(len(progress)),
2856
+ **(extra_headers if extra_headers is not None else {}),
2857
+ },
2858
+ )
2859
+ response.raise_for_status()
2860
+ except Exception as e:
2861
+ progress.rewind()
2862
+ logger.error(f"upload_file_async exception {url}: {e}")
2863
+ if isinstance(e, httpx.RequestError):
2864
+ logger.error(f"upload_file_async request headers: {e.request.headers}")
2865
+ if isinstance(e, httpx.HTTPStatusError):
2866
+ logger.error(f"upload_file_async response body: {e.response.content!r}")
2867
+ raise
2868
+
2869
+ async def upload_file_retry_async(
2870
+ self,
2871
+ url: str,
2872
+ file: IO[bytes],
2873
+ callback: Optional["ProgressFn"] = None,
2874
+ extra_headers: Optional[Dict[str, str]] = None,
2875
+ num_retries: int = 100,
2876
+ ) -> None:
2877
+ backoff = retry.FilteredBackoff(
2878
+ filter=check_httpx_exc_retriable,
2879
+ wrapped=retry.ExponentialBackoff(
2880
+ initial_sleep=datetime.timedelta(seconds=1),
2881
+ max_sleep=datetime.timedelta(seconds=60),
2882
+ max_retries=num_retries,
2883
+ timeout_at=datetime.datetime.now() + datetime.timedelta(days=7),
2884
+ ),
2885
+ )
2886
+
2887
+ await retry.retry_async(
2888
+ backoff=backoff,
2889
+ fn=self.upload_file_async,
2890
+ url=url,
2891
+ file=file,
2892
+ callback=callback,
2893
+ extra_headers=extra_headers,
2894
+ )
2895
+
2896
+ @normalize_exceptions
2897
+ def register_agent(
2898
+ self,
2899
+ host: str,
2900
+ sweep_id: Optional[str] = None,
2901
+ project_name: Optional[str] = None,
2902
+ entity: Optional[str] = None,
2903
+ ) -> dict:
2904
+ """Register a new agent.
2905
+
2906
+ Arguments:
2907
+ host (str): hostname
2908
+ sweep_id (str): sweep id
2909
+ project_name: (str): model that contains sweep
2910
+ entity: (str): entity that contains sweep
2911
+ """
2912
+ mutation = gql(
2913
+ """
2914
+ mutation CreateAgent(
2915
+ $host: String!
2916
+ $projectName: String,
2917
+ $entityName: String,
2918
+ $sweep: String!
2919
+ ) {
2920
+ createAgent(input: {
2921
+ host: $host,
2922
+ projectName: $projectName,
2923
+ entityName: $entityName,
2924
+ sweep: $sweep,
2925
+ }) {
2926
+ agent {
2927
+ id
2928
+ }
2929
+ }
2930
+ }
2931
+ """
2932
+ )
2933
+ if entity is None:
2934
+ entity = self.settings("entity")
2935
+ if project_name is None:
2936
+ project_name = self.settings("project")
2937
+
2938
+ response = self.gql(
2939
+ mutation,
2940
+ variable_values={
2941
+ "host": host,
2942
+ "entityName": entity,
2943
+ "projectName": project_name,
2944
+ "sweep": sweep_id,
2945
+ },
2946
+ check_retry_fn=util.no_retry_4xx,
2947
+ )
2948
+ result: dict = response["createAgent"]["agent"]
2949
+ return result
2950
+
2951
+ def agent_heartbeat(
2952
+ self, agent_id: str, metrics: dict, run_states: dict
2953
+ ) -> List[Dict[str, Any]]:
2954
+ """Notify server about agent state, receive commands.
2955
+
2956
+ Arguments:
2957
+ agent_id (str): agent_id
2958
+ metrics (dict): system metrics
2959
+ run_states (dict): run_id: state mapping
2960
+ Returns:
2961
+ List of commands to execute.
2962
+ """
2963
+ mutation = gql(
2964
+ """
2965
+ mutation Heartbeat(
2966
+ $id: ID!,
2967
+ $metrics: JSONString,
2968
+ $runState: JSONString
2969
+ ) {
2970
+ agentHeartbeat(input: {
2971
+ id: $id,
2972
+ metrics: $metrics,
2973
+ runState: $runState
2974
+ }) {
2975
+ agent {
2976
+ id
2977
+ }
2978
+ commands
2979
+ }
2980
+ }
2981
+ """
2982
+ )
2983
+
2984
+ if agent_id is None:
2985
+ raise ValueError("Cannot call heartbeat with an unregistered agent.")
2986
+
2987
+ try:
2988
+ response = self.gql(
2989
+ mutation,
2990
+ variable_values={
2991
+ "id": agent_id,
2992
+ "metrics": json.dumps(metrics),
2993
+ "runState": json.dumps(run_states),
2994
+ },
2995
+ timeout=60,
2996
+ )
2997
+ except Exception as e:
2998
+ # GQL raises exceptions with stringified python dictionaries :/
2999
+ message = ast.literal_eval(e.args[0])["message"]
3000
+ logger.error("Error communicating with W&B: %s", message)
3001
+ return []
3002
+ else:
3003
+ result: List[Dict[str, Any]] = json.loads(
3004
+ response["agentHeartbeat"]["commands"]
3005
+ )
3006
+ return result
3007
+
3008
+ @staticmethod
3009
+ def _validate_config_and_fill_distribution(config: dict) -> dict:
3010
+ # verify that parameters are well specified.
3011
+ # TODO(dag): deprecate this in favor of jsonschema validation once
3012
+ # apiVersion 2 is released and local controller is integrated with
3013
+ # wandb/client.
3014
+
3015
+ # avoid modifying the original config dict in
3016
+ # case it is reused outside the calling func
3017
+ config = deepcopy(config)
3018
+
3019
+ # explicitly cast to dict in case config was passed as a sweepconfig
3020
+ # sweepconfig does not serialize cleanly to yaml and breaks graphql,
3021
+ # but it is a subclass of dict, so this conversion is clean
3022
+ config = dict(config)
3023
+
3024
+ if "parameters" not in config:
3025
+ # still shows an anaconda warning, but doesn't error
3026
+ return config
3027
+
3028
+ for parameter_name in config["parameters"]:
3029
+ parameter = config["parameters"][parameter_name]
3030
+ if "min" in parameter and "max" in parameter:
3031
+ if "distribution" not in parameter:
3032
+ if isinstance(parameter["min"], int) and isinstance(
3033
+ parameter["max"], int
3034
+ ):
3035
+ parameter["distribution"] = "int_uniform"
3036
+ elif isinstance(parameter["min"], float) and isinstance(
3037
+ parameter["max"], float
3038
+ ):
3039
+ parameter["distribution"] = "uniform"
3040
+ else:
3041
+ raise ValueError(
3042
+ "Parameter %s is ambiguous, please specify bounds as both floats (for a float_"
3043
+ "uniform distribution) or ints (for an int_uniform distribution)."
3044
+ % parameter_name
3045
+ )
3046
+ return config
3047
+
3048
+ @normalize_exceptions
3049
+ def upsert_sweep(
3050
+ self,
3051
+ config: dict,
3052
+ controller: Optional[str] = None,
3053
+ launch_scheduler: Optional[str] = None,
3054
+ scheduler: Optional[str] = None,
3055
+ obj_id: Optional[str] = None,
3056
+ project: Optional[str] = None,
3057
+ entity: Optional[str] = None,
3058
+ state: Optional[str] = None,
3059
+ ) -> Tuple[str, List[str]]:
3060
+ """Upsert a sweep object.
3061
+
3062
+ Arguments:
3063
+ config (dict): sweep config (will be converted to yaml)
3064
+ controller (str): controller to use
3065
+ launch_scheduler (str): launch scheduler to use
3066
+ scheduler (str): scheduler to use
3067
+ obj_id (str): object id
3068
+ project (str): project to use
3069
+ entity (str): entity to use
3070
+ state (str): state
3071
+ """
3072
+ project_query = """
3073
+ project {
3074
+ id
3075
+ name
3076
+ entity {
3077
+ id
3078
+ name
3079
+ }
3080
+ }
3081
+ """
3082
+ mutation_str = """
3083
+ mutation UpsertSweep(
3084
+ $id: ID,
3085
+ $config: String,
3086
+ $description: String,
3087
+ $entityName: String,
3088
+ $projectName: String,
3089
+ $controller: JSONString,
3090
+ $scheduler: JSONString,
3091
+ $state: String
3092
+ ) {
3093
+ upsertSweep(input: {
3094
+ id: $id,
3095
+ config: $config,
3096
+ description: $description,
3097
+ entityName: $entityName,
3098
+ projectName: $projectName,
3099
+ controller: $controller,
3100
+ scheduler: $scheduler,
3101
+ state: $state
3102
+ }) {
3103
+ sweep {
3104
+ name
3105
+ _PROJECT_QUERY_
3106
+ }
3107
+ configValidationWarnings
3108
+ }
3109
+ }
3110
+ """
3111
+ # TODO(jhr): we need protocol versioning to know schema is not supported
3112
+ # for now we will just try both new and old query
3113
+
3114
+ # launchScheduler was introduced in core v0.14.0
3115
+ mutation_4 = gql(
3116
+ mutation_str.replace(
3117
+ "$controller: JSONString,",
3118
+ "$controller: JSONString,$launchScheduler: JSONString,",
3119
+ )
3120
+ .replace(
3121
+ "controller: $controller,",
3122
+ "controller: $controller,launchScheduler: $launchScheduler,",
3123
+ )
3124
+ .replace("_PROJECT_QUERY_", project_query)
3125
+ )
3126
+
3127
+ # mutation 3 maps to backend that can support CLI version of at least 0.10.31
3128
+ mutation_3 = gql(mutation_str.replace("_PROJECT_QUERY_", project_query))
3129
+ mutation_2 = gql(
3130
+ mutation_str.replace("_PROJECT_QUERY_", project_query).replace(
3131
+ "configValidationWarnings", ""
3132
+ )
3133
+ )
3134
+ mutation_1 = gql(
3135
+ mutation_str.replace("_PROJECT_QUERY_", "").replace(
3136
+ "configValidationWarnings", ""
3137
+ )
3138
+ )
3139
+
3140
+ # TODO(dag): replace this with a query for protocol versioning
3141
+ mutations = [mutation_4, mutation_3, mutation_2, mutation_1]
3142
+
3143
+ config = self._validate_config_and_fill_distribution(config)
3144
+
3145
+ # Silly, but attr-dicts like EasyDicts don't serialize correctly to yaml.
3146
+ # This sanitizes them with a round trip pass through json to get a regular dict.
3147
+ config_str = yaml.dump(json.loads(json.dumps(config)))
3148
+
3149
+ err: Optional[Exception] = None
3150
+ for mutation in mutations:
3151
+ try:
3152
+ variables = {
3153
+ "id": obj_id,
3154
+ "config": config_str,
3155
+ "description": config.get("description"),
3156
+ "entityName": entity or self.settings("entity"),
3157
+ "projectName": project or self.settings("project"),
3158
+ "controller": controller,
3159
+ "launchScheduler": launch_scheduler,
3160
+ "scheduler": scheduler,
3161
+ }
3162
+ if state:
3163
+ variables["state"] = state
3164
+
3165
+ response = self.gql(
3166
+ mutation,
3167
+ variable_values=variables,
3168
+ check_retry_fn=util.no_retry_4xx,
3169
+ )
3170
+ except UsageError as e:
3171
+ raise e
3172
+ except Exception as e:
3173
+ # graphql schema exception is generic
3174
+ err = e
3175
+ continue
3176
+ err = None
3177
+ break
3178
+ if err:
3179
+ raise err
3180
+
3181
+ sweep: Dict[str, Dict[str, Dict]] = response["upsertSweep"]["sweep"]
3182
+ project_obj: Dict[str, Dict] = sweep.get("project", {})
3183
+ if project_obj:
3184
+ self.set_setting("project", project_obj["name"])
3185
+ entity_obj: dict = project_obj.get("entity", {})
3186
+ if entity_obj:
3187
+ self.set_setting("entity", entity_obj["name"])
3188
+
3189
+ warnings = response["upsertSweep"].get("configValidationWarnings", [])
3190
+ return response["upsertSweep"]["sweep"]["name"], warnings
3191
+
3192
+ @normalize_exceptions
3193
+ def create_anonymous_api_key(self) -> str:
3194
+ """Create a new API key belonging to a new anonymous user."""
3195
+ mutation = gql(
3196
+ """
3197
+ mutation CreateAnonymousApiKey {
3198
+ createAnonymousEntity(input: {}) {
3199
+ apiKey {
3200
+ name
3201
+ }
3202
+ }
3203
+ }
3204
+ """
3205
+ )
3206
+
3207
+ response = self.gql(mutation, variable_values={})
3208
+ key: str = str(response["createAnonymousEntity"]["apiKey"]["name"])
3209
+ return key
3210
+
3211
+ @staticmethod
3212
+ def file_current(fname: str, md5: B64MD5) -> bool:
3213
+ """Checksum a file and compare the md5 with the known md5."""
3214
+ return os.path.isfile(fname) and md5_file_b64(fname) == md5
3215
+
3216
+ @normalize_exceptions
3217
+ def pull(
3218
+ self, project: str, run: Optional[str] = None, entity: Optional[str] = None
3219
+ ) -> "List[requests.Response]":
3220
+ """Download files from W&B.
3221
+
3222
+ Arguments:
3223
+ project (str): The project to download
3224
+ run (str, optional): The run to upload to
3225
+ entity (str, optional): The entity to scope this project to. Defaults to wandb models
3226
+
3227
+ Returns:
3228
+ The `requests` library response object
3229
+ """
3230
+ project, run = self.parse_slug(project, run=run)
3231
+ urls = self.download_urls(project, run, entity)
3232
+ responses = []
3233
+ for filename in urls:
3234
+ _, response = self.download_write_file(urls[filename])
3235
+ if response:
3236
+ responses.append(response)
3237
+
3238
+ return responses
3239
+
3240
+ def get_project(self) -> str:
3241
+ project: str = self.default_settings.get("project") or self.settings("project")
3242
+ return project
3243
+
3244
+ @normalize_exceptions
3245
+ def push(
3246
+ self,
3247
+ files: Union[List[str], Dict[str, IO]],
3248
+ run: Optional[str] = None,
3249
+ entity: Optional[str] = None,
3250
+ project: Optional[str] = None,
3251
+ description: Optional[str] = None,
3252
+ force: bool = True,
3253
+ progress: Union[TextIO, bool] = False,
3254
+ ) -> "List[Optional[requests.Response]]":
3255
+ """Uploads multiple files to W&B.
3256
+
3257
+ Arguments:
3258
+ files (list or dict): The filenames to upload, when dict the values are open files
3259
+ run (str, optional): The run to upload to
3260
+ entity (str, optional): The entity to scope this project to. Defaults to wandb models
3261
+ project (str, optional): The name of the project to upload to. Defaults to the one in settings.
3262
+ description (str, optional): The description of the changes
3263
+ force (bool, optional): Whether to prevent push if git has uncommitted changes
3264
+ progress (callable, or stream): If callable, will be called with (chunk_bytes,
3265
+ total_bytes) as argument else if True, renders a progress bar to stream.
3266
+
3267
+ Returns:
3268
+ A list of `requests.Response` objects
3269
+ """
3270
+ if project is None:
3271
+ project = self.get_project()
3272
+ if project is None:
3273
+ raise CommError("No project configured.")
3274
+ if run is None:
3275
+ run = self.current_run_id
3276
+
3277
+ # TODO(adrian): we use a retriable version of self.upload_file() so
3278
+ # will never retry self.upload_urls() here. Instead, maybe we should
3279
+ # make push itself retriable.
3280
+ _, upload_headers, result = self.upload_urls(
3281
+ project,
3282
+ files,
3283
+ run,
3284
+ entity,
3285
+ )
3286
+ extra_headers = {}
3287
+ for upload_header in upload_headers:
3288
+ key, val = upload_header.split(":", 1)
3289
+ extra_headers[key] = val
3290
+ responses = []
3291
+ for file_name, file_info in result.items():
3292
+ file_url = file_info["uploadUrl"]
3293
+
3294
+ # If the upload URL is relative, fill it in with the base URL,
3295
+ # since it's a proxied file store like the on-prem VM.
3296
+ if file_url.startswith("/"):
3297
+ file_url = f"{self.api_url}{file_url}"
3298
+
3299
+ try:
3300
+ # To handle Windows paths
3301
+ # TODO: this doesn't handle absolute paths...
3302
+ normal_name = os.path.join(*file_name.split("/"))
3303
+ open_file = (
3304
+ files[file_name]
3305
+ if isinstance(files, dict)
3306
+ else open(normal_name, "rb")
3307
+ )
3308
+ except OSError:
3309
+ print(f"{file_name} does not exist")
3310
+ continue
3311
+ if progress is False:
3312
+ responses.append(
3313
+ self.upload_file_retry(
3314
+ file_info["uploadUrl"], open_file, extra_headers=extra_headers
3315
+ )
3316
+ )
3317
+ else:
3318
+ if callable(progress):
3319
+ responses.append( # type: ignore
3320
+ self.upload_file_retry(
3321
+ file_url, open_file, progress, extra_headers=extra_headers
3322
+ )
3323
+ )
3324
+ else:
3325
+ length = os.fstat(open_file.fileno()).st_size
3326
+ with click.progressbar(
3327
+ file=progress, # type: ignore
3328
+ length=length,
3329
+ label=f"Uploading file: {file_name}",
3330
+ fill_char=click.style("&", fg="green"),
3331
+ ) as bar:
3332
+ responses.append(
3333
+ self.upload_file_retry(
3334
+ file_url,
3335
+ open_file,
3336
+ lambda bites, _: bar.update(bites),
3337
+ extra_headers=extra_headers,
3338
+ )
3339
+ )
3340
+ open_file.close()
3341
+ return responses
3342
+
3343
+ def link_artifact(
3344
+ self,
3345
+ client_id: str,
3346
+ server_id: str,
3347
+ portfolio_name: str,
3348
+ entity: str,
3349
+ project: str,
3350
+ aliases: Sequence[str],
3351
+ ) -> Dict[str, Any]:
3352
+ template = """
3353
+ mutation LinkArtifact(
3354
+ $artifactPortfolioName: String!,
3355
+ $entityName: String!,
3356
+ $projectName: String!,
3357
+ $aliases: [ArtifactAliasInput!],
3358
+ ID_TYPE
3359
+ ) {
3360
+ linkArtifact(input: {
3361
+ artifactPortfolioName: $artifactPortfolioName,
3362
+ entityName: $entityName,
3363
+ projectName: $projectName,
3364
+ aliases: $aliases,
3365
+ ID_VALUE
3366
+ }) {
3367
+ versionIndex
3368
+ }
3369
+ }
3370
+ """
3371
+
3372
+ def replace(a: str, b: str) -> None:
3373
+ nonlocal template
3374
+ template = template.replace(a, b)
3375
+
3376
+ if server_id:
3377
+ replace("ID_TYPE", "$artifactID: ID")
3378
+ replace("ID_VALUE", "artifactID: $artifactID")
3379
+ elif client_id:
3380
+ replace("ID_TYPE", "$clientID: ID")
3381
+ replace("ID_VALUE", "clientID: $clientID")
3382
+
3383
+ variable_values = {
3384
+ "clientID": client_id,
3385
+ "artifactID": server_id,
3386
+ "artifactPortfolioName": portfolio_name,
3387
+ "entityName": entity,
3388
+ "projectName": project,
3389
+ "aliases": [
3390
+ {"alias": alias, "artifactCollectionName": portfolio_name}
3391
+ for alias in aliases
3392
+ ],
3393
+ }
3394
+
3395
+ mutation = gql(template)
3396
+ response = self.gql(mutation, variable_values=variable_values)
3397
+ link_artifact: Dict[str, Any] = response["linkArtifact"]
3398
+ return link_artifact
3399
+
3400
+ def use_artifact(
3401
+ self,
3402
+ artifact_id: str,
3403
+ entity_name: Optional[str] = None,
3404
+ project_name: Optional[str] = None,
3405
+ run_name: Optional[str] = None,
3406
+ use_as: Optional[str] = None,
3407
+ ) -> Optional[Dict[str, Any]]:
3408
+ query_template = """
3409
+ mutation UseArtifact(
3410
+ $entityName: String!,
3411
+ $projectName: String!,
3412
+ $runName: String!,
3413
+ $artifactID: ID!,
3414
+ _USED_AS_TYPE_
3415
+ ) {
3416
+ useArtifact(input: {
3417
+ entityName: $entityName,
3418
+ projectName: $projectName,
3419
+ runName: $runName,
3420
+ artifactID: $artifactID,
3421
+ _USED_AS_VALUE_
3422
+ }) {
3423
+ artifact {
3424
+ id
3425
+ digest
3426
+ description
3427
+ state
3428
+ createdAt
3429
+ metadata
3430
+ }
3431
+ }
3432
+ }
3433
+ """
3434
+
3435
+ artifact_types = self.server_use_artifact_input_introspection()
3436
+ if "usedAs" in artifact_types:
3437
+ query_template = query_template.replace(
3438
+ "_USED_AS_TYPE_", "$usedAs: String"
3439
+ ).replace("_USED_AS_VALUE_", "usedAs: $usedAs")
3440
+ else:
3441
+ query_template = query_template.replace("_USED_AS_TYPE_", "").replace(
3442
+ "_USED_AS_VALUE_", ""
3443
+ )
3444
+
3445
+ query = gql(query_template)
3446
+
3447
+ entity_name = entity_name or self.settings("entity")
3448
+ project_name = project_name or self.settings("project")
3449
+ run_name = run_name or self.current_run_id
3450
+
3451
+ response = self.gql(
3452
+ query,
3453
+ variable_values={
3454
+ "entityName": entity_name,
3455
+ "projectName": project_name,
3456
+ "runName": run_name,
3457
+ "artifactID": artifact_id,
3458
+ "usedAs": use_as,
3459
+ },
3460
+ )
3461
+
3462
+ if response["useArtifact"]["artifact"]:
3463
+ artifact: Dict[str, Any] = response["useArtifact"]["artifact"]
3464
+ return artifact
3465
+ return None
3466
+
3467
+ def create_artifact_type(
3468
+ self,
3469
+ artifact_type_name: str,
3470
+ entity_name: Optional[str] = None,
3471
+ project_name: Optional[str] = None,
3472
+ description: Optional[str] = None,
3473
+ ) -> Optional[str]:
3474
+ mutation = gql(
3475
+ """
3476
+ mutation CreateArtifactType(
3477
+ $entityName: String!,
3478
+ $projectName: String!,
3479
+ $artifactTypeName: String!,
3480
+ $description: String
3481
+ ) {
3482
+ createArtifactType(input: {
3483
+ entityName: $entityName,
3484
+ projectName: $projectName,
3485
+ name: $artifactTypeName,
3486
+ description: $description
3487
+ }) {
3488
+ artifactType {
3489
+ id
3490
+ }
3491
+ }
3492
+ }
3493
+ """
3494
+ )
3495
+ entity_name = entity_name or self.settings("entity")
3496
+ project_name = project_name or self.settings("project")
3497
+ response = self.gql(
3498
+ mutation,
3499
+ variable_values={
3500
+ "entityName": entity_name,
3501
+ "projectName": project_name,
3502
+ "artifactTypeName": artifact_type_name,
3503
+ "description": description,
3504
+ },
3505
+ )
3506
+ _id: Optional[str] = response["createArtifactType"]["artifactType"]["id"]
3507
+ return _id
3508
+
3509
+ def server_artifact_introspection(self) -> List:
3510
+ query_string = """
3511
+ query ProbeServerArtifact {
3512
+ ArtifactInfoType: __type(name:"Artifact") {
3513
+ fields {
3514
+ name
3515
+ }
3516
+ }
3517
+ }
3518
+ """
3519
+
3520
+ if self.server_artifact_fields_info is None:
3521
+ query = gql(query_string)
3522
+ res = self.gql(query)
3523
+ input_fields = res.get("ArtifactInfoType", {}).get("fields", [{}])
3524
+ self.server_artifact_fields_info = [
3525
+ field["name"] for field in input_fields if "name" in field
3526
+ ]
3527
+
3528
+ return self.server_artifact_fields_info
3529
+
3530
+ def server_create_artifact_introspection(self) -> List:
3531
+ query_string = """
3532
+ query ProbeServerCreateArtifactInput {
3533
+ CreateArtifactInputInfoType: __type(name:"CreateArtifactInput") {
3534
+ inputFields{
3535
+ name
3536
+ }
3537
+ }
3538
+ }
3539
+ """
3540
+
3541
+ if self.server_create_artifact_input_info is None:
3542
+ query = gql(query_string)
3543
+ res = self.gql(query)
3544
+ input_fields = res.get("CreateArtifactInputInfoType", {}).get(
3545
+ "inputFields", [{}]
3546
+ )
3547
+ self.server_create_artifact_input_info = [
3548
+ field["name"] for field in input_fields if "name" in field
3549
+ ]
3550
+
3551
+ return self.server_create_artifact_input_info
3552
+
3553
+ def _get_create_artifact_mutation(
3554
+ self,
3555
+ fields: List,
3556
+ history_step: Optional[int],
3557
+ distributed_id: Optional[str],
3558
+ ) -> str:
3559
+ types = ""
3560
+ values = ""
3561
+
3562
+ if "historyStep" in fields and history_step not in [0, None]:
3563
+ types += "$historyStep: Int64!,"
3564
+ values += "historyStep: $historyStep,"
3565
+
3566
+ if distributed_id:
3567
+ types += "$distributedID: String,"
3568
+ values += "distributedID: $distributedID,"
3569
+
3570
+ if "clientID" in fields:
3571
+ types += "$clientID: ID,"
3572
+ values += "clientID: $clientID,"
3573
+
3574
+ if "sequenceClientID" in fields:
3575
+ types += "$sequenceClientID: ID,"
3576
+ values += "sequenceClientID: $sequenceClientID,"
3577
+
3578
+ if "enableDigestDeduplication" in fields:
3579
+ values += "enableDigestDeduplication: true,"
3580
+
3581
+ if "ttlDurationSeconds" in fields:
3582
+ types += "$ttlDurationSeconds: Int64,"
3583
+ values += "ttlDurationSeconds: $ttlDurationSeconds,"
3584
+
3585
+ query_template = """
3586
+ mutation CreateArtifact(
3587
+ $artifactTypeName: String!,
3588
+ $artifactCollectionNames: [String!],
3589
+ $entityName: String!,
3590
+ $projectName: String!,
3591
+ $runName: String,
3592
+ $description: String,
3593
+ $digest: String!,
3594
+ $aliases: [ArtifactAliasInput!],
3595
+ $metadata: JSONString,
3596
+ _CREATE_ARTIFACT_ADDITIONAL_TYPE_
3597
+ ) {
3598
+ createArtifact(input: {
3599
+ artifactTypeName: $artifactTypeName,
3600
+ artifactCollectionNames: $artifactCollectionNames,
3601
+ entityName: $entityName,
3602
+ projectName: $projectName,
3603
+ runName: $runName,
3604
+ description: $description,
3605
+ digest: $digest,
3606
+ digestAlgorithm: MANIFEST_MD5,
3607
+ aliases: $aliases,
3608
+ metadata: $metadata,
3609
+ _CREATE_ARTIFACT_ADDITIONAL_VALUE_
3610
+ }) {
3611
+ artifact {
3612
+ id
3613
+ state
3614
+ artifactSequence {
3615
+ id
3616
+ latestArtifact {
3617
+ id
3618
+ versionIndex
3619
+ }
3620
+ }
3621
+ }
3622
+ }
3623
+ }
3624
+ """
3625
+
3626
+ return query_template.replace(
3627
+ "_CREATE_ARTIFACT_ADDITIONAL_TYPE_", types
3628
+ ).replace("_CREATE_ARTIFACT_ADDITIONAL_VALUE_", values)
3629
+
3630
+ def create_artifact(
3631
+ self,
3632
+ artifact_type_name: str,
3633
+ artifact_collection_name: str,
3634
+ digest: str,
3635
+ client_id: Optional[str] = None,
3636
+ sequence_client_id: Optional[str] = None,
3637
+ entity_name: Optional[str] = None,
3638
+ project_name: Optional[str] = None,
3639
+ run_name: Optional[str] = None,
3640
+ description: Optional[str] = None,
3641
+ metadata: Optional[Dict] = None,
3642
+ ttl_duration_seconds: Optional[int] = None,
3643
+ aliases: Optional[List[Dict[str, str]]] = None,
3644
+ distributed_id: Optional[str] = None,
3645
+ is_user_created: Optional[bool] = False,
3646
+ history_step: Optional[int] = None,
3647
+ ) -> Tuple[Dict, Dict]:
3648
+ fields = self.server_create_artifact_introspection()
3649
+ artifact_fields = self.server_artifact_introspection()
3650
+ if "ttlIsInherited" not in artifact_fields and ttl_duration_seconds:
3651
+ wandb.termwarn(
3652
+ "Server not compatible with setting Artifact TTLs, please upgrade the server to use Artifact TTL"
3653
+ )
3654
+ # ttlDurationSeconds is only usable if ttlIsInherited is also present
3655
+ ttl_duration_seconds = None
3656
+ query_template = self._get_create_artifact_mutation(
3657
+ fields, history_step, distributed_id
3658
+ )
3659
+
3660
+ entity_name = entity_name or self.settings("entity")
3661
+ project_name = project_name or self.settings("project")
3662
+ if not is_user_created:
3663
+ run_name = run_name or self.current_run_id
3664
+ if aliases is None:
3665
+ aliases = []
3666
+
3667
+ mutation = gql(query_template)
3668
+ response = self.gql(
3669
+ mutation,
3670
+ variable_values={
3671
+ "entityName": entity_name,
3672
+ "projectName": project_name,
3673
+ "runName": run_name,
3674
+ "artifactTypeName": artifact_type_name,
3675
+ "artifactCollectionNames": [artifact_collection_name],
3676
+ "clientID": client_id,
3677
+ "sequenceClientID": sequence_client_id,
3678
+ "digest": digest,
3679
+ "description": description,
3680
+ "aliases": [alias for alias in aliases],
3681
+ "metadata": json.dumps(util.make_safe_for_json(metadata))
3682
+ if metadata
3683
+ else None,
3684
+ "ttlDurationSeconds": ttl_duration_seconds,
3685
+ "distributedID": distributed_id,
3686
+ "historyStep": history_step,
3687
+ },
3688
+ )
3689
+ av = response["createArtifact"]["artifact"]
3690
+ latest = response["createArtifact"]["artifact"]["artifactSequence"].get(
3691
+ "latestArtifact"
3692
+ )
3693
+ return av, latest
3694
+
3695
+ def commit_artifact(self, artifact_id: str) -> "_Response":
3696
+ mutation = gql(
3697
+ """
3698
+ mutation CommitArtifact(
3699
+ $artifactID: ID!,
3700
+ ) {
3701
+ commitArtifact(input: {
3702
+ artifactID: $artifactID,
3703
+ }) {
3704
+ artifact {
3705
+ id
3706
+ digest
3707
+ }
3708
+ }
3709
+ }
3710
+ """
3711
+ )
3712
+
3713
+ response: _Response = self.gql(
3714
+ mutation,
3715
+ variable_values={"artifactID": artifact_id},
3716
+ timeout=60,
3717
+ )
3718
+ return response
3719
+
3720
+ def complete_multipart_upload_artifact(
3721
+ self,
3722
+ artifact_id: str,
3723
+ storage_path: str,
3724
+ completed_parts: List[Dict[str, Any]],
3725
+ upload_id: Optional[str],
3726
+ complete_multipart_action: str = "Complete",
3727
+ ) -> Optional[str]:
3728
+ mutation = gql(
3729
+ """
3730
+ mutation CompleteMultipartUploadArtifact(
3731
+ $completeMultipartAction: CompleteMultipartAction!,
3732
+ $completedParts: [UploadPartsInput!]!,
3733
+ $artifactID: ID!
3734
+ $storagePath: String!
3735
+ $uploadID: String!
3736
+ ) {
3737
+ completeMultipartUploadArtifact(
3738
+ input: {
3739
+ completeMultipartAction: $completeMultipartAction,
3740
+ completedParts: $completedParts,
3741
+ artifactID: $artifactID,
3742
+ storagePath: $storagePath
3743
+ uploadID: $uploadID
3744
+ }
3745
+ ) {
3746
+ digest
3747
+ }
3748
+ }
3749
+ """
3750
+ )
3751
+ response = self.gql(
3752
+ mutation,
3753
+ variable_values={
3754
+ "completeMultipartAction": complete_multipart_action,
3755
+ "artifactID": artifact_id,
3756
+ "storagePath": storage_path,
3757
+ "completedParts": completed_parts,
3758
+ "uploadID": upload_id,
3759
+ },
3760
+ )
3761
+ digest: Optional[str] = response["completeMultipartUploadArtifact"]["digest"]
3762
+ return digest
3763
+
3764
+ def create_artifact_manifest(
3765
+ self,
3766
+ name: str,
3767
+ digest: str,
3768
+ artifact_id: Optional[str],
3769
+ base_artifact_id: Optional[str] = None,
3770
+ entity: Optional[str] = None,
3771
+ project: Optional[str] = None,
3772
+ run: Optional[str] = None,
3773
+ include_upload: bool = True,
3774
+ type: str = "FULL",
3775
+ ) -> Tuple[str, Dict[str, Any]]:
3776
+ mutation = gql(
3777
+ """
3778
+ mutation CreateArtifactManifest(
3779
+ $name: String!,
3780
+ $digest: String!,
3781
+ $artifactID: ID!,
3782
+ $baseArtifactID: ID,
3783
+ $entityName: String!,
3784
+ $projectName: String!,
3785
+ $runName: String!,
3786
+ $includeUpload: Boolean!,
3787
+ {}
3788
+ ) {{
3789
+ createArtifactManifest(input: {{
3790
+ name: $name,
3791
+ digest: $digest,
3792
+ artifactID: $artifactID,
3793
+ baseArtifactID: $baseArtifactID,
3794
+ entityName: $entityName,
3795
+ projectName: $projectName,
3796
+ runName: $runName,
3797
+ {}
3798
+ }}) {{
3799
+ artifactManifest {{
3800
+ id
3801
+ file {{
3802
+ id
3803
+ name
3804
+ displayName
3805
+ uploadUrl @include(if: $includeUpload)
3806
+ uploadHeaders @include(if: $includeUpload)
3807
+ }}
3808
+ }}
3809
+ }}
3810
+ }}
3811
+ """.format(
3812
+ "$type: ArtifactManifestType = FULL" if type != "FULL" else "",
3813
+ "type: $type" if type != "FULL" else "",
3814
+ )
3815
+ )
3816
+
3817
+ entity_name = entity or self.settings("entity")
3818
+ project_name = project or self.settings("project")
3819
+ run_name = run or self.current_run_id
3820
+
3821
+ response = self.gql(
3822
+ mutation,
3823
+ variable_values={
3824
+ "name": name,
3825
+ "digest": digest,
3826
+ "artifactID": artifact_id,
3827
+ "baseArtifactID": base_artifact_id,
3828
+ "entityName": entity_name,
3829
+ "projectName": project_name,
3830
+ "runName": run_name,
3831
+ "includeUpload": include_upload,
3832
+ "type": type,
3833
+ },
3834
+ )
3835
+ return (
3836
+ response["createArtifactManifest"]["artifactManifest"]["id"],
3837
+ response["createArtifactManifest"]["artifactManifest"]["file"],
3838
+ )
3839
+
3840
+ def update_artifact_manifest(
3841
+ self,
3842
+ artifact_manifest_id: str,
3843
+ base_artifact_id: Optional[str] = None,
3844
+ digest: Optional[str] = None,
3845
+ include_upload: Optional[bool] = True,
3846
+ ) -> Tuple[str, Dict[str, Any]]:
3847
+ mutation = gql(
3848
+ """
3849
+ mutation UpdateArtifactManifest(
3850
+ $artifactManifestID: ID!,
3851
+ $digest: String,
3852
+ $baseArtifactID: ID,
3853
+ $includeUpload: Boolean!,
3854
+ ) {
3855
+ updateArtifactManifest(input: {
3856
+ artifactManifestID: $artifactManifestID,
3857
+ digest: $digest,
3858
+ baseArtifactID: $baseArtifactID,
3859
+ }) {
3860
+ artifactManifest {
3861
+ id
3862
+ file {
3863
+ id
3864
+ name
3865
+ displayName
3866
+ uploadUrl @include(if: $includeUpload)
3867
+ uploadHeaders @include(if: $includeUpload)
3868
+ }
3869
+ }
3870
+ }
3871
+ }
3872
+ """
3873
+ )
3874
+
3875
+ response = self.gql(
3876
+ mutation,
3877
+ variable_values={
3878
+ "artifactManifestID": artifact_manifest_id,
3879
+ "digest": digest,
3880
+ "baseArtifactID": base_artifact_id,
3881
+ "includeUpload": include_upload,
3882
+ },
3883
+ )
3884
+
3885
+ return (
3886
+ response["updateArtifactManifest"]["artifactManifest"]["id"],
3887
+ response["updateArtifactManifest"]["artifactManifest"]["file"],
3888
+ )
3889
+
3890
+ def _resolve_client_id(
3891
+ self,
3892
+ client_id: str,
3893
+ ) -> Optional[str]:
3894
+ if client_id in self._client_id_mapping:
3895
+ return self._client_id_mapping[client_id]
3896
+
3897
+ query = gql(
3898
+ """
3899
+ query ClientIDMapping($clientID: ID!) {
3900
+ clientIDMapping(clientID: $clientID) {
3901
+ serverID
3902
+ }
3903
+ }
3904
+ """
3905
+ )
3906
+ response = self.gql(
3907
+ query,
3908
+ variable_values={
3909
+ "clientID": client_id,
3910
+ },
3911
+ )
3912
+ server_id = None
3913
+ if response is not None:
3914
+ client_id_mapping = response.get("clientIDMapping")
3915
+ if client_id_mapping is not None:
3916
+ server_id = client_id_mapping.get("serverID")
3917
+ if server_id is not None:
3918
+ self._client_id_mapping[client_id] = server_id
3919
+ return server_id
3920
+
3921
+ def server_create_artifact_file_spec_input_introspection(self) -> List:
3922
+ query_string = """
3923
+ query ProbeServerCreateArtifactFileSpecInput {
3924
+ CreateArtifactFileSpecInputInfoType: __type(name:"CreateArtifactFileSpecInput") {
3925
+ inputFields{
3926
+ name
3927
+ }
3928
+ }
3929
+ }
3930
+ """
3931
+
3932
+ query = gql(query_string)
3933
+ res = self.gql(query)
3934
+ create_artifact_file_spec_input_info = [
3935
+ field.get("name", "")
3936
+ for field in res.get("CreateArtifactFileSpecInputInfoType", {}).get(
3937
+ "inputFields", [{}]
3938
+ )
3939
+ ]
3940
+ return create_artifact_file_spec_input_info
3941
+
3942
+ @normalize_exceptions
3943
+ def create_artifact_files(
3944
+ self, artifact_files: Iterable["CreateArtifactFileSpecInput"]
3945
+ ) -> Mapping[str, "CreateArtifactFilesResponseFile"]:
3946
+ query_template = """
3947
+ mutation CreateArtifactFiles(
3948
+ $storageLayout: ArtifactStorageLayout!
3949
+ $artifactFiles: [CreateArtifactFileSpecInput!]!
3950
+ ) {
3951
+ createArtifactFiles(input: {
3952
+ artifactFiles: $artifactFiles,
3953
+ storageLayout: $storageLayout,
3954
+ }) {
3955
+ files {
3956
+ edges {
3957
+ node {
3958
+ id
3959
+ name
3960
+ displayName
3961
+ uploadUrl
3962
+ uploadHeaders
3963
+ _MULTIPART_UPLOAD_FIELDS_
3964
+ artifact {
3965
+ id
3966
+ }
3967
+ }
3968
+ }
3969
+ }
3970
+ }
3971
+ }
3972
+ """
3973
+ multipart_upload_url_query = """
3974
+ storagePath
3975
+ uploadMultipartUrls {
3976
+ uploadID
3977
+ uploadUrlParts {
3978
+ partNumber
3979
+ uploadUrl
3980
+ }
3981
+ }
3982
+ """
3983
+
3984
+ # TODO: we should use constants here from interface/artifacts.py
3985
+ # but probably don't want the dependency. We're going to remove
3986
+ # this setting in a future release, so I'm just hard-coding the strings.
3987
+ storage_layout = "V2"
3988
+ if env.get_use_v1_artifacts():
3989
+ storage_layout = "V1"
3990
+
3991
+ create_artifact_file_spec_input_fields = (
3992
+ self.server_create_artifact_file_spec_input_introspection()
3993
+ )
3994
+ if "uploadPartsInput" in create_artifact_file_spec_input_fields:
3995
+ query_template = query_template.replace(
3996
+ "_MULTIPART_UPLOAD_FIELDS_", multipart_upload_url_query
3997
+ )
3998
+ else:
3999
+ query_template = query_template.replace("_MULTIPART_UPLOAD_FIELDS_", "")
4000
+
4001
+ mutation = gql(query_template)
4002
+ response = self.gql(
4003
+ mutation,
4004
+ variable_values={
4005
+ "storageLayout": storage_layout,
4006
+ "artifactFiles": [af for af in artifact_files],
4007
+ },
4008
+ )
4009
+
4010
+ result = {}
4011
+ for edge in response["createArtifactFiles"]["files"]["edges"]:
4012
+ node = edge["node"]
4013
+ result[node["displayName"]] = node
4014
+ return result
4015
+
4016
+ @normalize_exceptions
4017
+ def notify_scriptable_run_alert(
4018
+ self,
4019
+ title: str,
4020
+ text: str,
4021
+ level: Optional[str] = None,
4022
+ wait_duration: Optional["Number"] = None,
4023
+ ) -> bool:
4024
+ mutation = gql(
4025
+ """
4026
+ mutation NotifyScriptableRunAlert(
4027
+ $entityName: String!,
4028
+ $projectName: String!,
4029
+ $runName: String!,
4030
+ $title: String!,
4031
+ $text: String!,
4032
+ $severity: AlertSeverity = INFO,
4033
+ $waitDuration: Duration
4034
+ ) {
4035
+ notifyScriptableRunAlert(input: {
4036
+ entityName: $entityName,
4037
+ projectName: $projectName,
4038
+ runName: $runName,
4039
+ title: $title,
4040
+ text: $text,
4041
+ severity: $severity,
4042
+ waitDuration: $waitDuration
4043
+ }) {
4044
+ success
4045
+ }
4046
+ }
4047
+ """
4048
+ )
4049
+
4050
+ response = self.gql(
4051
+ mutation,
4052
+ variable_values={
4053
+ "entityName": self.settings("entity"),
4054
+ "projectName": self.settings("project"),
4055
+ "runName": self.current_run_id,
4056
+ "title": title,
4057
+ "text": text,
4058
+ "severity": level,
4059
+ "waitDuration": wait_duration,
4060
+ },
4061
+ )
4062
+ success: bool = response["notifyScriptableRunAlert"]["success"]
4063
+ return success
4064
+
4065
+ def get_sweep_state(
4066
+ self, sweep: str, entity: Optional[str] = None, project: Optional[str] = None
4067
+ ) -> "SweepState":
4068
+ state: SweepState = self.sweep(
4069
+ sweep=sweep, entity=entity, project=project, specs="{}"
4070
+ )["state"]
4071
+ return state
4072
+
4073
+ def set_sweep_state(
4074
+ self,
4075
+ sweep: str,
4076
+ state: "SweepState",
4077
+ entity: Optional[str] = None,
4078
+ project: Optional[str] = None,
4079
+ ) -> None:
4080
+ assert state in ("RUNNING", "PAUSED", "CANCELED", "FINISHED")
4081
+ s = self.sweep(sweep=sweep, entity=entity, project=project, specs="{}")
4082
+ curr_state = s["state"].upper()
4083
+ if state == "PAUSED" and curr_state not in ("PAUSED", "RUNNING"):
4084
+ raise Exception("Cannot pause %s sweep." % curr_state.lower())
4085
+ elif state != "RUNNING" and curr_state not in ("RUNNING", "PAUSED", "PENDING"):
4086
+ raise Exception("Sweep already %s." % curr_state.lower())
4087
+ sweep_id = s["id"]
4088
+ mutation = gql(
4089
+ """
4090
+ mutation UpsertSweep(
4091
+ $id: ID,
4092
+ $state: String,
4093
+ $entityName: String,
4094
+ $projectName: String
4095
+ ) {
4096
+ upsertSweep(input: {
4097
+ id: $id,
4098
+ state: $state,
4099
+ entityName: $entityName,
4100
+ projectName: $projectName
4101
+ }){
4102
+ sweep {
4103
+ name
4104
+ }
4105
+ }
4106
+ }
4107
+ """
4108
+ )
4109
+ self.gql(
4110
+ mutation,
4111
+ variable_values={
4112
+ "id": sweep_id,
4113
+ "state": state,
4114
+ "entityName": entity or self.settings("entity"),
4115
+ "projectName": project or self.settings("project"),
4116
+ },
4117
+ )
4118
+
4119
+ def stop_sweep(
4120
+ self,
4121
+ sweep: str,
4122
+ entity: Optional[str] = None,
4123
+ project: Optional[str] = None,
4124
+ ) -> None:
4125
+ """Finish the sweep to stop running new runs and let currently running runs finish."""
4126
+ self.set_sweep_state(
4127
+ sweep=sweep, state="FINISHED", entity=entity, project=project
4128
+ )
4129
+
4130
+ def cancel_sweep(
4131
+ self,
4132
+ sweep: str,
4133
+ entity: Optional[str] = None,
4134
+ project: Optional[str] = None,
4135
+ ) -> None:
4136
+ """Cancel the sweep to kill all running runs and stop running new runs."""
4137
+ self.set_sweep_state(
4138
+ sweep=sweep, state="CANCELED", entity=entity, project=project
4139
+ )
4140
+
4141
+ def pause_sweep(
4142
+ self,
4143
+ sweep: str,
4144
+ entity: Optional[str] = None,
4145
+ project: Optional[str] = None,
4146
+ ) -> None:
4147
+ """Pause the sweep to temporarily stop running new runs."""
4148
+ self.set_sweep_state(
4149
+ sweep=sweep, state="PAUSED", entity=entity, project=project
4150
+ )
4151
+
4152
+ def resume_sweep(
4153
+ self,
4154
+ sweep: str,
4155
+ entity: Optional[str] = None,
4156
+ project: Optional[str] = None,
4157
+ ) -> None:
4158
+ """Resume the sweep to continue running new runs."""
4159
+ self.set_sweep_state(
4160
+ sweep=sweep, state="RUNNING", entity=entity, project=project
4161
+ )
4162
+
4163
+ def _status_request(self, url: str, length: int) -> requests.Response:
4164
+ """Ask google how much we've uploaded."""
4165
+ check_httpclient_logger_handler()
4166
+ return requests.put(
4167
+ url=url,
4168
+ headers={"Content-Length": "0", "Content-Range": "bytes */%i" % length},
4169
+ )
4170
+
4171
+ def _flatten_edges(self, response: "_Response") -> List[Dict]:
4172
+ """Return an array from the nested graphql relay structure."""
4173
+ return [node["node"] for node in response["edges"]]
4174
+
4175
+ @normalize_exceptions
4176
+ def stop_run(
4177
+ self,
4178
+ run_id: str,
4179
+ ) -> bool:
4180
+ mutation = gql(
4181
+ """
4182
+ mutation stopRun($id: ID!) {
4183
+ stopRun(input: {
4184
+ id: $id
4185
+ }) {
4186
+ clientMutationId
4187
+ success
4188
+ }
4189
+ }
4190
+ """
4191
+ )
4192
+
4193
+ response = self.gql(
4194
+ mutation,
4195
+ variable_values={
4196
+ "id": run_id,
4197
+ },
4198
+ )
4199
+
4200
+ success: bool = response["stopRun"].get("success")
4201
+
4202
+ return success